Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions pkg/config/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package config

import (
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

// TestSplitDigest covers the digest parser used to build cache paths.
func TestSplitDigest(t *testing.T) {
cases := []struct {
in string
wantAlgo string
wantHex string
}{
{"sha256:abc123", "sha256", "abc123"},
{"sha512:deadbeef", "sha512", "deadbeef"},
{"abc123", "sha256", "abc123"}, // no colon → default algo
{"sha256:", "sha256", ""}, // empty hex still splits
}
for _, c := range cases {
algo, hex := splitDigest(c.in)
require.Equal(t, c.wantAlgo, algo, "algo for %q", c.in)
require.Equal(t, c.wantHex, hex, "hex for %q", c.in)
}
}

// TestGetCacheDirs exercises every cache path helper at once and checks the
// two parallel trees (content/ and refs/) share the same algo/hex suffix for
// any given digest.
func TestGetCacheDirs(t *testing.T) {
cfg := &RawConfig{RootDir: "/var/lib/model-csi"}

require.Equal(t, "/var/lib/model-csi/cache", cfg.GetCacheDir())
require.Equal(t,
filepath.Join("/var/lib/model-csi", "cache", "content"),
cfg.GetCacheContentRootDir(),
)
require.Equal(t,
filepath.Join("/var/lib/model-csi", "cache", "refs"),
cfg.GetCacheRefsRootDir(),
)

// Canonical sha256 digest.
digest := "sha256:69a0c4d9505eb64e2454444baac2f5273c12450942f5a117e83557d161fb1206"
require.Equal(t,
filepath.Join("/var/lib/model-csi", "cache", "content", "sha256", "69a0c4d9505eb64e2454444baac2f5273c12450942f5a117e83557d161fb1206"),
cfg.GetCacheContentDir(digest),
)
require.Equal(t,
filepath.Join("/var/lib/model-csi", "cache", "refs", "sha256", "69a0c4d9505eb64e2454444baac2f5273c12450942f5a117e83557d161fb1206"),
cfg.GetCacheRefsDir(digest),
)

// Non-sha256 algo is respected end to end.
require.Equal(t,
filepath.Join("/var/lib/model-csi", "cache", "content", "sha512", "deadbeef"),
cfg.GetCacheContentDir("sha512:deadbeef"),
)
require.Equal(t,
filepath.Join("/var/lib/model-csi", "cache", "refs", "sha512", "deadbeef"),
cfg.GetCacheRefsDir("sha512:deadbeef"),
)

// Digest without algo prefix falls back to sha256.
require.Equal(t,
filepath.Join("/var/lib/model-csi", "cache", "content", "sha256", "abc123"),
cfg.GetCacheContentDir("abc123"),
)

// content/ and refs/ must share the same algo/hex suffix for a digest.
contentRel, err := filepath.Rel(cfg.GetCacheContentRootDir(), cfg.GetCacheContentDir(digest))
require.NoError(t, err)
refsRel, err := filepath.Rel(cfg.GetCacheRefsRootDir(), cfg.GetCacheRefsDir(digest))
require.NoError(t, err)
require.Equal(t, contentRel, refsRel)
}
38 changes: 38 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,44 @@ func (cfg *RawConfig) GetCSISockDirForDynamic(volumeName string) string {
return filepath.Join(cfg.GetVolumeDirForDynamic(volumeName), "csi")
}

// /var/lib/dragonfly/model-csi/cache
func (cfg *RawConfig) GetCacheDir() string {
return filepath.Join(cfg.RootDir, "cache")
}

// /var/lib/dragonfly/model-csi/cache/content
func (cfg *RawConfig) GetCacheContentRootDir() string {
return filepath.Join(cfg.GetCacheDir(), "content")
}

// /var/lib/dragonfly/model-csi/cache/refs
func (cfg *RawConfig) GetCacheRefsRootDir() string {
return filepath.Join(cfg.GetCacheDir(), "refs")
}

// splitDigest splits a digest like "sha256:abc..." into ("sha256", "abc...").
// For inputs without a ":" prefix, the algorithm defaults to "sha256".
func splitDigest(digest string) (algo string, hex string) {
for i := 0; i < len(digest); i++ {
if digest[i] == ':' {
return digest[:i], digest[i+1:]
}
}
return "sha256", digest
}

// /var/lib/dragonfly/model-csi/cache/content/$algo/$hex
func (cfg *RawConfig) GetCacheContentDir(digest string) string {
algo, hex := splitDigest(digest)
return filepath.Join(cfg.GetCacheContentRootDir(), algo, hex)
}

// /var/lib/dragonfly/model-csi/cache/refs/$algo/$hex
func (cfg *RawConfig) GetCacheRefsDir(digest string) string {
algo, hex := splitDigest(digest)
return filepath.Join(cfg.GetCacheRefsRootDir(), algo, hex)
}

// /var/lib/dragonfly/model-csi/volumes/$volumeName/csi/csi.sock
func (cfg *RawConfig) GetCSISockPathForDynamic(volumeName string) string {
return filepath.Join(cfg.GetCSISockDirForDynamic(volumeName), "csi.sock")
Expand Down
8 changes: 8 additions & 0 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"context"
"crypto/sha256"
"fmt"
"os"
"os/exec"
Expand Down Expand Up @@ -568,6 +569,13 @@ func TestServer(t *testing.T) {
}
}

// Avoid contacting a real registry to resolve manifest digests for the
// node-level shared cache: derive a deterministic fake digest from the
// reference string itself so tests can run fully offline.
service.ResolveCacheDigest = func(_ context.Context, reference string) (string, error) {
return "sha256:" + fmt.Sprintf("%x", sha256.Sum256([]byte(reference))), nil
}

ctx := context.TODO()
server, err := NewServer(cfg)
require.NoError(t, err)
Expand Down
95 changes: 95 additions & 0 deletions pkg/service/node_static_inline.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,80 @@ import (
)

func (s *Service) nodePublishVolumeStaticInlineVolume(ctx context.Context, volumeName, targetPath, reference string, excludeModelWeights bool, excludeFilePatterns []string) (*csi.NodePublishVolumeResponse, error) {
// Partial-file variants (exclude_model_weights / exclude_file_patterns) are
// intentionally not shared through the node-level cache in this path to
// keep the sharing semantics simple: only a full pull becomes a shared
// cache entry. Fall back to the legacy per-volume pull for those cases.
if excludeModelWeights || len(excludeFilePatterns) > 0 {
return s.nodePublishVolumeStaticInlineVolumeLegacy(ctx, volumeName, targetPath, reference, excludeModelWeights, excludeFilePatterns)
}

statusPath := filepath.Join(s.cfg.Get().GetVolumeDir(volumeName), "status.json")

// Persist an initial status so that progress / reference is observable
// even while the shared pull is still in flight.
if _, err := s.sm.Set(statusPath, modelStatus.Status{
VolumeName: volumeName,
Reference: reference,
Inline: true,
State: modelStatus.StatePullRunning,
}); err != nil {
return nil, status.Error(codes.Internal, errors.Wrap(err, "set initial volume status").Error())
}

startedAt := time.Now()
digest, err := s.worker.EnsureCachedModel(ctx, reference, volumeName)
if err != nil {
// Best-effort: record failure state; the volume directory will be
// cleaned up by the unpublish path.
if _, setErr := s.sm.Set(statusPath, modelStatus.Status{
VolumeName: volumeName,
Reference: reference,
Inline: true,
State: modelStatus.StatePullFailed,
}); setErr != nil {
logger.WithContext(ctx).WithError(setErr).Warn("failed to persist pull-failed status")
}
return nil, status.Error(codes.Internal, errors.Wrap(err, "ensure cached model").Error())
}
duration := time.Since(startedAt)
logger.WithContext(ctx).Infof("ensured cached model: %s digest=%s duration=%s", reference, digest, duration)

sourceDir := s.cfg.Get().GetCacheContentDir(digest)
if err := mounter.Mount(
ctx,
mounter.NewBuilder().
Bind().
From(sourceDir).
MountPoint(targetPath),
); err != nil {
// Roll back the ref we just registered so we don't leak an entry that
// would keep the cache alive forever.
if relErr := s.worker.ReleaseCachedModel(ctx, digest, volumeName); relErr != nil {
logger.WithContext(ctx).WithError(relErr).Warnf("release cached model after mount failure: %s", digest)
}
return nil, status.Error(codes.Internal, errors.Wrapf(err, "bind mount %s to target %s", sourceDir, targetPath).Error())
}

if _, err := s.sm.Set(statusPath, modelStatus.Status{
VolumeName: volumeName,
Reference: reference,
Inline: true,
CacheDigest: digest,
State: modelStatus.StateMounted,
}); err != nil {
return nil, status.Error(codes.Internal, errors.Wrap(err, "set volume status").Error())
}

return &csi.NodePublishVolumeResponse{}, nil
}

// nodePublishVolumeStaticInlineVolumeLegacy is the per-volume pull path used
// only for partial-file inline volumes (exclude_model_weights /
// exclude_file_patterns). It preserves the original behavior to avoid
// accidentally sharing a partial pull with other pods that may expect a full
// model tree.
func (s *Service) nodePublishVolumeStaticInlineVolumeLegacy(ctx context.Context, volumeName, targetPath, reference string, excludeModelWeights bool, excludeFilePatterns []string) (*csi.NodePublishVolumeResponse, error) {
modelDir := s.cfg.Get().GetModelDir(volumeName)

startedAt := time.Now()
Expand Down Expand Up @@ -58,7 +132,28 @@ func (s *Service) nodeUnPublishVolumeStaticInlineVolume(ctx context.Context, vol
}
}

// Release the shared-cache reference, if any. We first try the digest
// recorded in status.json; if that is missing (e.g. older volumes or a
// crash before status was persisted), fall back to scanning the refs tree.
sourceVolumeDir := s.cfg.Get().GetVolumeDir(volumeName)
statusPath := filepath.Join(sourceVolumeDir, "status.json")
digest := ""
if volumeStatus, err := s.sm.Get(statusPath); err == nil && volumeStatus != nil {
digest = volumeStatus.CacheDigest
}
if digest == "" {
if found, err := s.worker.FindCacheDigestByRef(volumeName); err != nil {
logger.WithContext(ctx).WithError(err).Warnf("scan cache refs for volume: %s", volumeName)
} else {
digest = found
}
}
if digest != "" {
if err := s.worker.ReleaseCachedModel(ctx, digest, volumeName); err != nil {
logger.WithContext(ctx).WithError(err).Warnf("release cached model: %s", digest)
}
}

if err := os.RemoveAll(sourceVolumeDir); err != nil {
return nil, status.Error(codes.Internal, errors.Wrapf(err, "remove static inline volume dir").Error())
}
Expand Down
Loading
Loading