diff --git a/pkg/config/cache_test.go b/pkg/config/cache_test.go new file mode 100644 index 0000000..475204d --- /dev/null +++ b/pkg/config/cache_test.go @@ -0,0 +1,78 @@ +package config + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestSplitDigest covers the digest parser used to build cache paths. +func TestSplitDigest(t *testing.T) { + cases := []struct { + in string + wantAlgo string + wantHex string + }{ + {"sha256:abc123", "sha256", "abc123"}, + {"sha512:deadbeef", "sha512", "deadbeef"}, + {"abc123", "sha256", "abc123"}, // no colon → default algo + {"sha256:", "sha256", ""}, // empty hex still splits + } + for _, c := range cases { + algo, hex := splitDigest(c.in) + require.Equal(t, c.wantAlgo, algo, "algo for %q", c.in) + require.Equal(t, c.wantHex, hex, "hex for %q", c.in) + } +} + +// TestGetCacheDirs exercises every cache path helper at once and checks the +// two parallel trees (content/ and refs/) share the same algo/hex suffix for +// any given digest. +func TestGetCacheDirs(t *testing.T) { + cfg := &RawConfig{RootDir: "/var/lib/model-csi"} + + require.Equal(t, "/var/lib/model-csi/cache", cfg.GetCacheDir()) + require.Equal(t, + filepath.Join("/var/lib/model-csi", "cache", "content"), + cfg.GetCacheContentRootDir(), + ) + require.Equal(t, + filepath.Join("/var/lib/model-csi", "cache", "refs"), + cfg.GetCacheRefsRootDir(), + ) + + // Canonical sha256 digest. + digest := "sha256:69a0c4d9505eb64e2454444baac2f5273c12450942f5a117e83557d161fb1206" + require.Equal(t, + filepath.Join("/var/lib/model-csi", "cache", "content", "sha256", "69a0c4d9505eb64e2454444baac2f5273c12450942f5a117e83557d161fb1206"), + cfg.GetCacheContentDir(digest), + ) + require.Equal(t, + filepath.Join("/var/lib/model-csi", "cache", "refs", "sha256", "69a0c4d9505eb64e2454444baac2f5273c12450942f5a117e83557d161fb1206"), + cfg.GetCacheRefsDir(digest), + ) + + // Non-sha256 algo is respected end to end. + require.Equal(t, + filepath.Join("/var/lib/model-csi", "cache", "content", "sha512", "deadbeef"), + cfg.GetCacheContentDir("sha512:deadbeef"), + ) + require.Equal(t, + filepath.Join("/var/lib/model-csi", "cache", "refs", "sha512", "deadbeef"), + cfg.GetCacheRefsDir("sha512:deadbeef"), + ) + + // Digest without algo prefix falls back to sha256. + require.Equal(t, + filepath.Join("/var/lib/model-csi", "cache", "content", "sha256", "abc123"), + cfg.GetCacheContentDir("abc123"), + ) + + // content/ and refs/ must share the same algo/hex suffix for a digest. + contentRel, err := filepath.Rel(cfg.GetCacheContentRootDir(), cfg.GetCacheContentDir(digest)) + require.NoError(t, err) + refsRel, err := filepath.Rel(cfg.GetCacheRefsRootDir(), cfg.GetCacheRefsDir(digest)) + require.NoError(t, err) + require.Equal(t, contentRel, refsRel) +} diff --git a/pkg/config/config.go b/pkg/config/config.go index fad6758..d413617 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -143,6 +143,44 @@ func (cfg *RawConfig) GetCSISockDirForDynamic(volumeName string) string { return filepath.Join(cfg.GetVolumeDirForDynamic(volumeName), "csi") } +// /var/lib/dragonfly/model-csi/cache +func (cfg *RawConfig) GetCacheDir() string { + return filepath.Join(cfg.RootDir, "cache") +} + +// /var/lib/dragonfly/model-csi/cache/content +func (cfg *RawConfig) GetCacheContentRootDir() string { + return filepath.Join(cfg.GetCacheDir(), "content") +} + +// /var/lib/dragonfly/model-csi/cache/refs +func (cfg *RawConfig) GetCacheRefsRootDir() string { + return filepath.Join(cfg.GetCacheDir(), "refs") +} + +// splitDigest splits a digest like "sha256:abc..." into ("sha256", "abc..."). +// For inputs without a ":" prefix, the algorithm defaults to "sha256". +func splitDigest(digest string) (algo string, hex string) { + for i := 0; i < len(digest); i++ { + if digest[i] == ':' { + return digest[:i], digest[i+1:] + } + } + return "sha256", digest +} + +// /var/lib/dragonfly/model-csi/cache/content/$algo/$hex +func (cfg *RawConfig) GetCacheContentDir(digest string) string { + algo, hex := splitDigest(digest) + return filepath.Join(cfg.GetCacheContentRootDir(), algo, hex) +} + +// /var/lib/dragonfly/model-csi/cache/refs/$algo/$hex +func (cfg *RawConfig) GetCacheRefsDir(digest string) string { + algo, hex := splitDigest(digest) + return filepath.Join(cfg.GetCacheRefsRootDir(), algo, hex) +} + // /var/lib/dragonfly/model-csi/volumes/$volumeName/csi/csi.sock func (cfg *RawConfig) GetCSISockPathForDynamic(volumeName string) string { return filepath.Join(cfg.GetCSISockDirForDynamic(volumeName), "csi.sock") diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index e6eb103..efc0ed8 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "crypto/sha256" "fmt" "os" "os/exec" @@ -568,6 +569,13 @@ func TestServer(t *testing.T) { } } + // Avoid contacting a real registry to resolve manifest digests for the + // node-level shared cache: derive a deterministic fake digest from the + // reference string itself so tests can run fully offline. + service.ResolveCacheDigest = func(_ context.Context, reference string) (string, error) { + return "sha256:" + fmt.Sprintf("%x", sha256.Sum256([]byte(reference))), nil + } + ctx := context.TODO() server, err := NewServer(cfg) require.NoError(t, err) diff --git a/pkg/service/node_static_inline.go b/pkg/service/node_static_inline.go index cc9add1..8601e1e 100644 --- a/pkg/service/node_static_inline.go +++ b/pkg/service/node_static_inline.go @@ -16,6 +16,80 @@ import ( ) func (s *Service) nodePublishVolumeStaticInlineVolume(ctx context.Context, volumeName, targetPath, reference string, excludeModelWeights bool, excludeFilePatterns []string) (*csi.NodePublishVolumeResponse, error) { + // Partial-file variants (exclude_model_weights / exclude_file_patterns) are + // intentionally not shared through the node-level cache in this path to + // keep the sharing semantics simple: only a full pull becomes a shared + // cache entry. Fall back to the legacy per-volume pull for those cases. + if excludeModelWeights || len(excludeFilePatterns) > 0 { + return s.nodePublishVolumeStaticInlineVolumeLegacy(ctx, volumeName, targetPath, reference, excludeModelWeights, excludeFilePatterns) + } + + statusPath := filepath.Join(s.cfg.Get().GetVolumeDir(volumeName), "status.json") + + // Persist an initial status so that progress / reference is observable + // even while the shared pull is still in flight. + if _, err := s.sm.Set(statusPath, modelStatus.Status{ + VolumeName: volumeName, + Reference: reference, + Inline: true, + State: modelStatus.StatePullRunning, + }); err != nil { + return nil, status.Error(codes.Internal, errors.Wrap(err, "set initial volume status").Error()) + } + + startedAt := time.Now() + digest, err := s.worker.EnsureCachedModel(ctx, reference, volumeName) + if err != nil { + // Best-effort: record failure state; the volume directory will be + // cleaned up by the unpublish path. + if _, setErr := s.sm.Set(statusPath, modelStatus.Status{ + VolumeName: volumeName, + Reference: reference, + Inline: true, + State: modelStatus.StatePullFailed, + }); setErr != nil { + logger.WithContext(ctx).WithError(setErr).Warn("failed to persist pull-failed status") + } + return nil, status.Error(codes.Internal, errors.Wrap(err, "ensure cached model").Error()) + } + duration := time.Since(startedAt) + logger.WithContext(ctx).Infof("ensured cached model: %s digest=%s duration=%s", reference, digest, duration) + + sourceDir := s.cfg.Get().GetCacheContentDir(digest) + if err := mounter.Mount( + ctx, + mounter.NewBuilder(). + Bind(). + From(sourceDir). + MountPoint(targetPath), + ); err != nil { + // Roll back the ref we just registered so we don't leak an entry that + // would keep the cache alive forever. + if relErr := s.worker.ReleaseCachedModel(ctx, digest, volumeName); relErr != nil { + logger.WithContext(ctx).WithError(relErr).Warnf("release cached model after mount failure: %s", digest) + } + return nil, status.Error(codes.Internal, errors.Wrapf(err, "bind mount %s to target %s", sourceDir, targetPath).Error()) + } + + if _, err := s.sm.Set(statusPath, modelStatus.Status{ + VolumeName: volumeName, + Reference: reference, + Inline: true, + CacheDigest: digest, + State: modelStatus.StateMounted, + }); err != nil { + return nil, status.Error(codes.Internal, errors.Wrap(err, "set volume status").Error()) + } + + return &csi.NodePublishVolumeResponse{}, nil +} + +// nodePublishVolumeStaticInlineVolumeLegacy is the per-volume pull path used +// only for partial-file inline volumes (exclude_model_weights / +// exclude_file_patterns). It preserves the original behavior to avoid +// accidentally sharing a partial pull with other pods that may expect a full +// model tree. +func (s *Service) nodePublishVolumeStaticInlineVolumeLegacy(ctx context.Context, volumeName, targetPath, reference string, excludeModelWeights bool, excludeFilePatterns []string) (*csi.NodePublishVolumeResponse, error) { modelDir := s.cfg.Get().GetModelDir(volumeName) startedAt := time.Now() @@ -58,7 +132,28 @@ func (s *Service) nodeUnPublishVolumeStaticInlineVolume(ctx context.Context, vol } } + // Release the shared-cache reference, if any. We first try the digest + // recorded in status.json; if that is missing (e.g. older volumes or a + // crash before status was persisted), fall back to scanning the refs tree. sourceVolumeDir := s.cfg.Get().GetVolumeDir(volumeName) + statusPath := filepath.Join(sourceVolumeDir, "status.json") + digest := "" + if volumeStatus, err := s.sm.Get(statusPath); err == nil && volumeStatus != nil { + digest = volumeStatus.CacheDigest + } + if digest == "" { + if found, err := s.worker.FindCacheDigestByRef(volumeName); err != nil { + logger.WithContext(ctx).WithError(err).Warnf("scan cache refs for volume: %s", volumeName) + } else { + digest = found + } + } + if digest != "" { + if err := s.worker.ReleaseCachedModel(ctx, digest, volumeName); err != nil { + logger.WithContext(ctx).WithError(err).Warnf("release cached model: %s", digest) + } + } + if err := os.RemoveAll(sourceVolumeDir); err != nil { return nil, status.Error(codes.Internal, errors.Wrapf(err, "remove static inline volume dir").Error()) } diff --git a/pkg/service/node_static_inline_test.go b/pkg/service/node_static_inline_test.go new file mode 100644 index 0000000..23655e5 --- /dev/null +++ b/pkg/service/node_static_inline_test.go @@ -0,0 +1,520 @@ +package service + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/agiledragon/gomonkey/v2" + modctlbackend "github.com/modelpack/modctl/pkg/backend" + "github.com/modelpack/model-csi-driver/pkg/config" + "github.com/modelpack/model-csi-driver/pkg/config/auth" + "github.com/modelpack/model-csi-driver/pkg/mounter" + "github.com/modelpack/model-csi-driver/pkg/status" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +// applyInlineCachePatches stubs the shared-cache publish path so it runs +// fully offline: ResolveCacheDigest returns the canned digest and the worker's +// puller writes a sentinel file into the cache content directory. When +// pullErr is non-nil the puller returns it instead. onPull (if non-nil) is +// invoked on every pull call so tests can count physical pulls. +// +// The returned *gomonkey.Patches is preserved purely for backwards +// compatibility with existing call sites that defer Reset() on it; the +// digest-resolver swap is undone via t.Cleanup so callers don't need to. +func applyInlineCachePatches(t *testing.T, svc *Service, digest string, pullErr error, onPull func()) *gomonkey.Patches { + t.Helper() + + origResolve := ResolveCacheDigest + ResolveCacheDigest = func(_ context.Context, _ string) (string, error) { + if digest == "" { + return "", errors.New("empty manifest digest") + } + return digest, nil + } + t.Cleanup(func() { ResolveCacheDigest = origResolve }) + + svc.worker.newPuller = func(context.Context, *config.PullConfig, *status.Hook, *DiskQuotaChecker) Puller { + return &fakePuller{pullFunc: func(_ context.Context, _, targetDir string) error { + if onPull != nil { + onPull() + } + if pullErr != nil { + return pullErr + } + if err := os.MkdirAll(targetDir, 0755); err != nil { + return err + } + return os.WriteFile(filepath.Join(targetDir, "model.bin"), []byte("x"), 0644) + }} + } + + return gomonkey.NewPatches() +} + +// TestNodePublishVolumeStaticInlineVolume_SharedCache covers the publish +// happy path: first mount triggers a pull, a second mount of the same +// reference reuses the shared cache without re-pulling, both volumes record +// CacheDigest in status.json, and the bind mount targets the shared content +// directory. +func TestNodePublishVolumeStaticInlineVolume_SharedCache(t *testing.T) { + svc, _ := newNodeService(t) + ctx := context.Background() + + digest := "sha256:sharedcache" + pulls := 0 + patches := applyInlineCachePatches(t, svc, digest, nil, func() { pulls++ }) + defer patches.Reset() + + var mountCmds []string + patchMount := gomonkey.ApplyFunc(mounter.Mount, func(_ context.Context, b mounter.Builder) error { + cmd, err := b.Build() + if err != nil { + return err + } + mountCmds = append(mountCmds, cmd.String()) + return nil + }) + defer patchMount.Reset() + + // First pod mounts. + _, err := svc.nodePublishVolumeStaticInlineVolume(ctx, "pvc-a", t.TempDir(), "r/m:v1", false, nil) + require.NoError(t, err) + + // Second pod mounts the same reference. + _, err = svc.nodePublishVolumeStaticInlineVolume(ctx, "pvc-b", t.TempDir(), "r/m:v1", false, nil) + require.NoError(t, err) + + require.Equal(t, 1, pulls, "second mount must reuse the shared cache") + + // Bind mount source must point at the shared cache content directory. + expectedSrc := svc.cfg.Get().GetCacheContentDir(digest) + require.Len(t, mountCmds, 2) + for _, cmd := range mountCmds { + require.Contains(t, cmd, "--bind") + require.Contains(t, cmd, expectedSrc) + } + + // Both refs should exist under the refs tree. + refsDir := svc.cfg.Get().GetCacheRefsDir(digest) + _, err = os.Stat(filepath.Join(refsDir, "pvc-a")) + require.NoError(t, err) + _, err = os.Stat(filepath.Join(refsDir, "pvc-b")) + require.NoError(t, err) + + // status.json must record the digest for both volumes. + for _, vol := range []string{"pvc-a", "pvc-b"} { + statusPath := filepath.Join(svc.cfg.Get().GetVolumeDir(vol), "status.json") + s, err := svc.sm.Get(statusPath) + require.NoError(t, err) + require.Equal(t, digest, s.CacheDigest) + require.Equal(t, status.StateMounted, s.State) + require.True(t, s.Inline) + } +} + +// TestNodePublishVolumeStaticInlineVolume_Failures covers the two error +// branches of the publish path: +// - pull failure: status.json is persisted with PULL_FAILED and no digest. +// - bind mount failure after successful pull: the ref we just registered +// is rolled back so the cache is GC'd instead of leaking forever. +func TestNodePublishVolumeStaticInlineVolume_Failures(t *testing.T) { + t.Run("pull failure persists status", func(t *testing.T) { + svc, _ := newNodeService(t) + patches := applyInlineCachePatches(t, svc, "sha256:pullfail", errors.New("net down"), nil) + defer patches.Reset() + + volumeName := "pvc-pullfail" + _, err := svc.nodePublishVolumeStaticInlineVolume( + context.Background(), volumeName, t.TempDir(), "r/m:v1", false, nil, + ) + require.Error(t, err) + + statusPath := filepath.Join(svc.cfg.Get().GetVolumeDir(volumeName), "status.json") + s, err := svc.sm.Get(statusPath) + require.NoError(t, err) + require.Equal(t, status.StatePullFailed, s.State) + require.Equal(t, "", s.CacheDigest) + }) + + t.Run("bind mount failure rolls back ref", func(t *testing.T) { + svc, _ := newNodeService(t) + digest := "sha256:mountfail" + patches := applyInlineCachePatches(t, svc, digest, nil, nil) + defer patches.Reset() + + patchMount := gomonkey.ApplyFunc(mounter.Mount, func(_ context.Context, _ mounter.Builder) error { + return errors.New("bind failed") + }) + defer patchMount.Reset() + + volumeName := "pvc-mountfail" + _, err := svc.nodePublishVolumeStaticInlineVolume( + context.Background(), volumeName, t.TempDir(), "r/m:v1", false, nil, + ) + require.Error(t, err) + + // Ref must be rolled back, and since it was the only ref, GC removes + // the entire cache entry. + _, statErr := os.Stat(svc.cfg.Get().GetCacheContentDir(digest)) + require.True(t, os.IsNotExist(statErr)) + _, statErr = os.Stat(svc.cfg.Get().GetCacheRefsDir(digest)) + require.True(t, os.IsNotExist(statErr)) + }) +} + +// TestNodeUnPublishVolumeStaticInlineVolume covers the three unpublish +// branches: +// - release by digest recorded in status.json (happy path), +// - fallback reverse scan when status.json lacks CacheDigest, +// - clean volume dir when there is neither status nor cache (degenerate). +func TestNodeUnPublishVolumeStaticInlineVolume(t *testing.T) { + patchUMount := gomonkey.ApplyFunc(mounter.UMount, func(_ context.Context, _ string, _ bool) error { + return nil + }) + defer patchUMount.Reset() + + t.Run("release by digest from status", func(t *testing.T) { + svc, _ := newNodeService(t) + digest := "sha256:unpub-digest" + patches := applyInlineCachePatches(t, svc, digest, nil, nil) + defer patches.Reset() + patchMount := gomonkey.ApplyFunc(mounter.Mount, func(_ context.Context, _ mounter.Builder) error { + return nil + }) + defer patchMount.Reset() + + volumeName := "pvc-unpub-digest" + _, err := svc.nodePublishVolumeStaticInlineVolume( + context.Background(), volumeName, t.TempDir(), "r/m:v1", false, nil, + ) + require.NoError(t, err) + + _, err = svc.nodeUnPublishVolumeStaticInlineVolume(context.Background(), volumeName, t.TempDir(), true) + require.NoError(t, err) + + // Last ref → cache GC'd, volume dir removed. + _, statErr := os.Stat(svc.cfg.Get().GetCacheContentDir(digest)) + require.True(t, os.IsNotExist(statErr)) + _, statErr = os.Stat(svc.cfg.Get().GetVolumeDir(volumeName)) + require.True(t, os.IsNotExist(statErr)) + }) + + t.Run("fallback scan when status has no digest", func(t *testing.T) { + svc, _ := newNodeService(t) + digest := "sha256:unpub-scan" + volumeName := "pvc-unpub-scan" + + // Simulate an older on-disk layout: cache + refs populated by hand, + // status.json has no CacheDigest field. + contentDir := svc.cfg.Get().GetCacheContentDir(digest) + refsDir := svc.cfg.Get().GetCacheRefsDir(digest) + require.NoError(t, os.MkdirAll(contentDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(contentDir, "model.bin"), []byte("x"), 0644)) + require.NoError(t, os.MkdirAll(refsDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(refsDir, ".ready"), nil, 0644)) + require.NoError(t, os.WriteFile(filepath.Join(refsDir, volumeName), nil, 0644)) + + volumeDir := svc.cfg.Get().GetVolumeDir(volumeName) + require.NoError(t, os.MkdirAll(volumeDir, 0755)) + _, err := svc.sm.Set(filepath.Join(volumeDir, "status.json"), status.Status{ + VolumeName: volumeName, + Reference: "r/m:v1", + Inline: true, + State: status.StateMounted, + }) + require.NoError(t, err) + + _, err = svc.nodeUnPublishVolumeStaticInlineVolume(context.Background(), volumeName, t.TempDir(), true) + require.NoError(t, err) + + // Reverse scan must have located the digest and GC'd everything. + _, statErr := os.Stat(contentDir) + require.True(t, os.IsNotExist(statErr)) + _, statErr = os.Stat(refsDir) + require.True(t, os.IsNotExist(statErr)) + _, statErr = os.Stat(volumeDir) + require.True(t, os.IsNotExist(statErr)) + }) + + t.Run("no cache no status still cleans volume", func(t *testing.T) { + svc, _ := newNodeService(t) + volumeName := "pvc-nothing" + volumeDir := svc.cfg.Get().GetVolumeDir(volumeName) + require.NoError(t, os.MkdirAll(volumeDir, 0755)) + + _, err := svc.nodeUnPublishVolumeStaticInlineVolume(context.Background(), volumeName, t.TempDir(), false) + require.NoError(t, err) + + _, statErr := os.Stat(volumeDir) + require.True(t, os.IsNotExist(statErr)) + }) + + t.Run("umount failure propagates", func(t *testing.T) { + svc, _ := newNodeService(t) + patchUMount := gomonkey.ApplyFunc(mounter.UMount, func(_ context.Context, _ string, _ bool) error { + return errors.New("umount boom") + }) + defer patchUMount.Reset() + + _, err := svc.nodeUnPublishVolumeStaticInlineVolume(context.Background(), "pvc-x", t.TempDir(), true) + require.Error(t, err) + require.Contains(t, err.Error(), "unmount target path") + }) +} + +// legacyPuller is a Puller that captures the exclusion arguments passed to +// Pull so the legacy-path test can assert them. +type legacyPuller struct { + targetDir string + excludeModelWeights bool + excludeFilePatterns []string +} + +func (p *legacyPuller) Pull(_ context.Context, _, targetDir string, excludeModelWeights bool, excludeFilePatterns []string) error { + p.targetDir = targetDir + p.excludeModelWeights = excludeModelWeights + p.excludeFilePatterns = excludeFilePatterns + if err := os.MkdirAll(targetDir, 0755); err != nil { + return err + } + return os.WriteFile(filepath.Join(targetDir, "partial.bin"), []byte("x"), 0644) +} + +// TestNodePublishVolumeStaticInlineVolumeLegacy covers the fallback per-volume +// pull path used when excludeModelWeights / excludeFilePatterns is set. Those +// partial-file variants intentionally do not go through the shared cache, so +// we validate the puller is invoked with the exclusion flags and the bind +// mount targets the per-volume model dir. +func TestNodePublishVolumeStaticInlineVolumeLegacy(t *testing.T) { + svc, _ := newNodeService(t) + + patches := gomonkey.NewPatches() + defer patches.Reset() + patches.ApplyFunc(auth.GetKeyChainByRef, func(string) (*auth.PassKeyChain, error) { + return &auth.PassKeyChain{ServerScheme: "https"}, nil + }) + patches.ApplyFunc(modctlbackend.New, func(string) (modctlbackend.Backend, error) { + return nil, nil + }) + + captured := &legacyPuller{} + svc.worker.newPuller = func(context.Context, *config.PullConfig, *status.Hook, *DiskQuotaChecker) Puller { + return captured + } + + var mountCmd string + patchMount := gomonkey.ApplyFunc(mounter.Mount, func(_ context.Context, b mounter.Builder) error { + cmd, err := b.Build() + if err != nil { + return err + } + mountCmd = cmd.String() + return nil + }) + defer patchMount.Reset() + + volumeName := "pvc-legacy" + _, err := svc.nodePublishVolumeStaticInlineVolume( + context.Background(), volumeName, t.TempDir(), + "r/m:v1", true, []string{"*.ignore"}, + ) + require.NoError(t, err) + + // Pull target must be the per-volume model dir, not a shared cache dir. + expectedModelDir := svc.cfg.Get().GetModelDir(volumeName) + require.Equal(t, expectedModelDir, captured.targetDir) + require.True(t, captured.excludeModelWeights) + require.Equal(t, []string{"*.ignore"}, captured.excludeFilePatterns) + require.Contains(t, mountCmd, expectedModelDir) + + // status.json must be marked inline and mounted, and carry no CacheDigest. + statusPath := filepath.Join(svc.cfg.Get().GetVolumeDir(volumeName), "status.json") + s, err := svc.sm.Get(statusPath) + require.NoError(t, err) + require.True(t, s.Inline) + require.Equal(t, status.StateMounted, s.State) + require.Equal(t, "", s.CacheDigest) +} + +// TestNodePublishVolumeStaticInlineVolume_StatusSetErrors covers two error +// branches that depend on sm.Set failing at different points: +// - initial PullRunning Set fails before any pull is attempted, +// - final Mounted Set fails after a successful pull + mount, in which case +// the function still returns an error. +func TestNodePublishVolumeStaticInlineVolume_StatusSetErrors(t *testing.T) { + t.Run("initial set fails", func(t *testing.T) { + svc, _ := newNodeService(t) + + // Make the very first sm.Set call fail by planting a regular file at + // the volume dir path so MkdirAll inside the status manager errors. + volumeName := "pvc-initset-fail" + volumesDir := svc.cfg.Get().GetVolumesDir() + require.NoError(t, os.MkdirAll(volumesDir, 0755)) + require.NoError(t, os.WriteFile(svc.cfg.Get().GetVolumeDir(volumeName), []byte("blocker"), 0644)) + + _, err := svc.nodePublishVolumeStaticInlineVolume( + context.Background(), volumeName, t.TempDir(), "r/m:v1", false, nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "set initial volume status") + }) + + t.Run("final set fails after mount", func(t *testing.T) { + svc, _ := newNodeService(t) + digest := "sha256:finalset-fail" + patches := applyInlineCachePatches(t, svc, digest, nil, nil) + defer patches.Reset() + + patchMount := gomonkey.ApplyFunc(mounter.Mount, func(_ context.Context, _ mounter.Builder) error { + return nil + }) + defer patchMount.Reset() + + volumeName := "pvc-finalset-fail" + // Force the second Set (Mounted) to fail by patching sm.Set. + var setCalls int + patchSet := gomonkey.ApplyMethod(svc.sm, "Set", + func(sm *status.StatusManager, statusPath string, newStatus status.Status) (*status.Status, error) { + setCalls++ + if setCalls >= 2 { + return nil, errors.New("set boom") + } + return &newStatus, nil + }) + defer patchSet.Reset() + + _, err := svc.nodePublishVolumeStaticInlineVolume( + context.Background(), volumeName, t.TempDir(), "r/m:v1", false, nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "set volume status") + }) +} + +// TestNodePublishVolumeStaticInlineVolumeLegacy_Errors covers the error +// branches of the legacy (partial-file) publish path: pull failure, mount +// failure after a successful pull, and the status Get/Set failures at the +// tail of the function. +func TestNodePublishVolumeStaticInlineVolumeLegacy_Errors(t *testing.T) { + t.Run("pull failure", func(t *testing.T) { + svc, _ := newNodeService(t) + + patches := gomonkey.NewPatches() + defer patches.Reset() + patches.ApplyFunc(auth.GetKeyChainByRef, func(string) (*auth.PassKeyChain, error) { + return &auth.PassKeyChain{ServerScheme: "https"}, nil + }) + patches.ApplyFunc(modctlbackend.New, func(string) (modctlbackend.Backend, error) { + return nil, nil + }) + svc.worker.newPuller = func(context.Context, *config.PullConfig, *status.Hook, *DiskQuotaChecker) Puller { + return &fakePuller{pullFunc: func(context.Context, string, string) error { + return errors.New("pull boom") + }} + } + + _, err := svc.nodePublishVolumeStaticInlineVolume( + context.Background(), "pvc-legacy-pullfail", t.TempDir(), + "r/m:v1", true, nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "pull model") + }) + + t.Run("mount failure", func(t *testing.T) { + svc, _ := newNodeService(t) + + patches := gomonkey.NewPatches() + defer patches.Reset() + patches.ApplyFunc(auth.GetKeyChainByRef, func(string) (*auth.PassKeyChain, error) { + return &auth.PassKeyChain{ServerScheme: "https"}, nil + }) + patches.ApplyFunc(modctlbackend.New, func(string) (modctlbackend.Backend, error) { + return nil, nil + }) + svc.worker.newPuller = func(context.Context, *config.PullConfig, *status.Hook, *DiskQuotaChecker) Puller { + return &legacyPuller{} + } + + patchMount := gomonkey.ApplyFunc(mounter.Mount, func(_ context.Context, _ mounter.Builder) error { + return errors.New("bind boom") + }) + defer patchMount.Reset() + + _, err := svc.nodePublishVolumeStaticInlineVolume( + context.Background(), "pvc-legacy-mountfail", t.TempDir(), + "r/m:v1", true, nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "bind mount") + }) + + t.Run("status get failure", func(t *testing.T) { + svc, _ := newNodeService(t) + + patches := gomonkey.NewPatches() + defer patches.Reset() + patches.ApplyFunc(auth.GetKeyChainByRef, func(string) (*auth.PassKeyChain, error) { + return &auth.PassKeyChain{ServerScheme: "https"}, nil + }) + patches.ApplyFunc(modctlbackend.New, func(string) (modctlbackend.Backend, error) { + return nil, nil + }) + svc.worker.newPuller = func(context.Context, *config.PullConfig, *status.Hook, *DiskQuotaChecker) Puller { + return &legacyPuller{} + } + patchMount := gomonkey.ApplyFunc(mounter.Mount, func(_ context.Context, _ mounter.Builder) error { + return nil + }) + defer patchMount.Reset() + + // Force sm.Get to fail right after the mount. + patchGet := gomonkey.ApplyMethod(svc.sm, "Get", + func(*status.StatusManager, string) (*status.Status, error) { + return nil, errors.New("get boom") + }) + defer patchGet.Reset() + + _, err := svc.nodePublishVolumeStaticInlineVolume( + context.Background(), "pvc-legacy-getfail", t.TempDir(), + "r/m:v1", true, nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "get volume status") + }) + +} + +// TestNodeUnPublishVolumeStaticInlineVolume_FindRefError verifies that when +// the reverse-scan path errors out (e.g. refs root is not a directory), the +// unpublish still proceeds, logs a warning, and successfully cleans up the +// volume directory. +func TestNodeUnPublishVolumeStaticInlineVolume_FindRefError(t *testing.T) { + svc, _ := newNodeService(t) + + patchUMount := gomonkey.ApplyFunc(mounter.UMount, func(_ context.Context, _ string, _ bool) error { + return nil + }) + defer patchUMount.Reset() + + volumeName := "pvc-find-err" + volumeDir := svc.cfg.Get().GetVolumeDir(volumeName) + require.NoError(t, os.MkdirAll(volumeDir, 0755)) + + // Plant a regular file where the refs root is expected so that + // FindCacheDigestByRef returns a non-IsNotExist error and the warn + // branch is exercised. + refsRoot := svc.cfg.Get().GetCacheRefsRootDir() + require.NoError(t, os.MkdirAll(filepath.Dir(refsRoot), 0755)) + require.NoError(t, os.WriteFile(refsRoot, []byte("not a dir"), 0644)) + + _, err := svc.nodeUnPublishVolumeStaticInlineVolume(context.Background(), volumeName, t.TempDir(), true) + require.NoError(t, err) + + _, statErr := os.Stat(volumeDir) + require.True(t, os.IsNotExist(statErr)) +} diff --git a/pkg/service/worker.go b/pkg/service/worker.go index 8628736..32d798e 100644 --- a/pkg/service/worker.go +++ b/pkg/service/worker.go @@ -9,7 +9,9 @@ import ( "time" "github.com/containerd/containerd/pkg/kmutex" + "github.com/modelpack/modctl/pkg/backend" "github.com/modelpack/model-csi-driver/pkg/config" + "github.com/modelpack/model-csi-driver/pkg/config/auth" "github.com/modelpack/model-csi-driver/pkg/logger" "github.com/modelpack/model-csi-driver/pkg/metrics" "github.com/modelpack/model-csi-driver/pkg/status" @@ -18,6 +20,39 @@ import ( "golang.org/x/sync/singleflight" ) +// cacheReadyMarker is the sentinel filename placed under the refs directory +// to indicate that the corresponding content directory is fully materialized. +// It is kept under refs/ (not content/) so it doesn't pollute the bind-mounted +// model content seen by pods. +const cacheReadyMarker = ".ready" + +// ResolveCacheDigest resolves a model reference to its manifest digest, used +// as the cache key for the node-level shared cache. It is exposed as a +// package-level variable so integration tests that inject a fake Puller can +// also inject a fake digest resolver and avoid contacting a real registry. +var ResolveCacheDigest = func(ctx context.Context, reference string) (string, error) { + keyChain, err := auth.GetKeyChainByRef(reference) + if err != nil { + return "", errors.Wrapf(err, "get auth for model: %s", reference) + } + plainHTTP := keyChain.ServerScheme == "http" + + b, err := backend.New("") + if err != nil { + return "", errors.Wrap(err, "create modctl backend") + } + + modelArtifact := NewModelArtifact(b, reference, plainHTTP) + artifact, err := modelArtifact.Inspect(ctx, reference) + if err != nil { + return "", errors.Wrapf(err, "inspect model: %s", reference) + } + if artifact.Digest == "" { + return "", errors.Errorf("empty manifest digest for model: %s", reference) + } + return artifact.Digest, nil +} + var ErrConflict = errors.New("conflict") type ContextMap struct { @@ -78,6 +113,11 @@ func (worker *Worker) deleteModel(ctx context.Context, isStaticVolume bool, volu logger.WithContext(ctx).Infof("canceled pulling request: %s", contextKey) } _, err, _ := worker.inflight.Do(inflightKey, func() (interface{}, error) { + // Intentionally use context.Background() here: deleteModel is + // frequently invoked as the cleanup tail of a failed/canceled + // PullModel, whose ctx is already past its deadline. Using that + // ctx would make the lock acquisition fail immediately with + // "context deadline exceeded" and leave the volume dir on disk. if err := worker.kmutex.Lock(context.Background(), contextKey); err != nil { return nil, errors.Wrapf(err, "lock context key: %s", contextKey) } @@ -160,6 +200,12 @@ func (worker *Worker) pullModel(ctx context.Context, statusPath, volumeName, mou inflightKey := fmt.Sprintf("pull-%s/%s", volumeName, mountID) contextKey := fmt.Sprintf("%s/%s", volumeName, mountID) _, err, shared := worker.inflight.Do(inflightKey, func() (interface{}, error) { + // Intentionally use context.Background() here: the inflight + // caller's ctx may already be past its deadline (e.g. CSI client + // timeout) by the time another shared waiter actually gets to + // acquire this lock. Use a fresh ctx so the lock can always be + // obtained; cancellation is propagated separately via the + // per-key contextMap below. if err := worker.kmutex.Lock(context.Background(), contextKey); err != nil { return nil, errors.Wrapf(err, "lock context key: %s", contextKey) } @@ -232,6 +278,187 @@ func (worker *Worker) pullModel(ctx context.Context, statusPath, volumeName, mou return nil } +// EnsureCachedModel makes sure the model for the given reference is present in +// the shared node-level cache, then registers a reference under +// cache/refs///. The returned string is the cache digest +// (e.g. "sha256:abc..."), which callers must remember (typically via +// status.json) so that they can call ReleaseCachedModel later. +// +// Concurrent callers for the same reference dedupe on the manifest digest: +// only one pull actually runs, the others wait and then just register their +// own ref file. Each caller holds an exclusive kmutex on the digest while +// checking readiness / pulling / writing its ref file, so ref registration is +// serialized against garbage collection in ReleaseCachedModel. +func (worker *Worker) EnsureCachedModel(ctx context.Context, reference, refName string) (string, error) { + start := time.Now() + + digest, err := ResolveCacheDigest(ctx, reference) + if err != nil { + return "", err + } + + contentDir := worker.cfg.Get().GetCacheContentDir(digest) + refsDir := worker.cfg.Get().GetCacheRefsDir(digest) + + // Deduplicate concurrent pulls for the same digest. The inner function + // only performs the pull when not yet ready; ref registration is handled + // outside (and under its own lock) so each caller writes its own ref file + // even when sharing the pull result. + _, pullErr, shared := worker.inflight.Do("cache-pull-"+digest, func() (interface{}, error) { + // Intentionally use context.Background(): the caller's ctx may + // already be past its deadline (e.g. CSI client timeout) by the + // time a shared waiter actually gets to acquire this lock. + if err := worker.kmutex.Lock(context.Background(), digest); err != nil { + return nil, errors.Wrapf(err, "lock cache digest: %s", digest) + } + defer worker.kmutex.Unlock(digest) + + if _, err := os.Stat(filepath.Join(refsDir, cacheReadyMarker)); err == nil { + return nil, nil + } + + // Clean up any half-finished cache directory from previous failures + // before starting a fresh pull. + if err := os.RemoveAll(contentDir); err != nil { + return nil, errors.Wrapf(err, "cleanup cache content dir: %s", contentDir) + } + if err := os.MkdirAll(refsDir, 0755); err != nil { + return nil, errors.Wrapf(err, "create cache refs dir: %s", refsDir) + } + + hook := status.NewHook(ctx) + p := worker.newPuller(ctx, &worker.cfg.Get().PullConfig, hook, nil) + if err := p.Pull(ctx, reference, contentDir, false, nil); err != nil { + _ = os.RemoveAll(contentDir) + return nil, errors.Wrapf(err, "pull model %s into cache", reference) + } + + readyPath := filepath.Join(refsDir, cacheReadyMarker) + if err := os.WriteFile(readyPath, []byte{}, 0644); err != nil { + _ = os.RemoveAll(contentDir) + return nil, errors.Wrapf(err, "write cache ready marker: %s", readyPath) + } + return nil, nil + }) + metrics.NodeOpObserve("cache_ensure_pull", start, pullErr) + if pullErr != nil { + return "", errors.Wrapf(pullErr, "ensure cached model: %s (shared=%v)", reference, shared) + } + logger.WithContext(ctx).Infof("ensured cached model: %s digest=%s shared=%v", reference, digest, shared) + + // Register this caller's reference. Held under the same per-digest lock + // that guards release/GC, guaranteeing that the content directory cannot + // disappear between readiness check and bind mount by the caller. + // Use context.Background() for the same reason as above. + if err := worker.kmutex.Lock(context.Background(), digest); err != nil { + return "", errors.Wrapf(err, "lock cache digest for ref: %s", digest) + } + defer worker.kmutex.Unlock(digest) + + if _, err := os.Stat(filepath.Join(refsDir, cacheReadyMarker)); err != nil { + return "", errors.Wrapf(err, "cache content vanished before ref registration: %s", digest) + } + refFile := filepath.Join(refsDir, refName) + if err := os.WriteFile(refFile, []byte{}, 0644); err != nil { + return "", errors.Wrapf(err, "write cache ref file: %s", refFile) + } + + return digest, nil +} + +// ReleaseCachedModel removes a caller's ref file under cache/refs//. +// If no other callers hold a ref, the entire cache content + refs entry is +// garbage collected. Unknown / already-gone digests are treated as success. +func (worker *Worker) ReleaseCachedModel(ctx context.Context, digest, refName string) error { + start := time.Now() + if digest == "" { + return nil + } + + contentDir := worker.cfg.Get().GetCacheContentDir(digest) + refsDir := worker.cfg.Get().GetCacheRefsDir(digest) + + // Use context.Background(): release is typically called from an + // unpublish path whose ctx may have already been canceled by kubelet, + // but we still need to drop the ref and GC the cache reliably. + if err := worker.kmutex.Lock(context.Background(), digest); err != nil { + return errors.Wrapf(err, "lock cache digest for release: %s", digest) + } + defer worker.kmutex.Unlock(digest) + + refFile := filepath.Join(refsDir, refName) + if err := os.Remove(refFile); err != nil && !os.IsNotExist(err) { + return errors.Wrapf(err, "remove cache ref file: %s", refFile) + } + + entries, err := os.ReadDir(refsDir) + if err != nil { + if os.IsNotExist(err) { + metrics.NodeOpObserve("cache_release", start, nil) + return nil + } + return errors.Wrapf(err, "read cache refs dir: %s", refsDir) + } + hasUser := false + for _, entry := range entries { + if entry.Name() == cacheReadyMarker { + continue + } + hasUser = true + break + } + if hasUser { + metrics.NodeOpObserve("cache_release", start, nil) + return nil + } + + // No more users: GC the cache content and the refs entry. + if err := os.RemoveAll(contentDir); err != nil { + return errors.Wrapf(err, "remove cache content dir: %s", contentDir) + } + if err := os.RemoveAll(refsDir); err != nil { + return errors.Wrapf(err, "remove cache refs dir: %s", refsDir) + } + logger.WithContext(ctx).Infof("garbage collected cache entry: digest=%s", digest) + metrics.NodeOpObserve("cache_release", start, nil) + return nil +} + +// FindCacheDigestByRef scans cache/refs/// for a ref file named +// refName and returns the corresponding digest string (":"). Used +// by unpublish paths that don't have the digest at hand and want to avoid +// re-contacting the registry just to release a reference. +func (worker *Worker) FindCacheDigestByRef(refName string) (string, error) { + refsRoot := worker.cfg.Get().GetCacheRefsRootDir() + algoEntries, err := os.ReadDir(refsRoot) + if err != nil { + if os.IsNotExist(err) { + return "", nil + } + return "", errors.Wrapf(err, "read cache refs root: %s", refsRoot) + } + for _, algoEntry := range algoEntries { + if !algoEntry.IsDir() { + continue + } + algoDir := filepath.Join(refsRoot, algoEntry.Name()) + hexEntries, err := os.ReadDir(algoDir) + if err != nil { + continue + } + for _, hexEntry := range hexEntries { + if !hexEntry.IsDir() { + continue + } + candidate := filepath.Join(algoDir, hexEntry.Name(), refName) + if _, err := os.Stat(candidate); err == nil { + return algoEntry.Name() + ":" + hexEntry.Name(), nil + } + } + } + return "", nil +} + func (worker *Worker) isModelExisted(ctx context.Context, reference string) bool { volumesDir := worker.cfg.Get().GetVolumesDir() volumeDirs, err := os.ReadDir(volumesDir) diff --git a/pkg/service/worker_cache_test.go b/pkg/service/worker_cache_test.go new file mode 100644 index 0000000..d586a34 --- /dev/null +++ b/pkg/service/worker_cache_test.go @@ -0,0 +1,395 @@ +package service + +import ( + "context" + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/modelpack/model-csi-driver/pkg/config" + "github.com/modelpack/model-csi-driver/pkg/status" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +// fakePuller is a test double for Puller that records Pull calls and lets +// tests simulate delays / failures without hitting a real registry. +type fakePuller struct { + pullFunc func(ctx context.Context, reference, targetDir string) error +} + +func (f *fakePuller) Pull(ctx context.Context, reference, targetDir string, excludeModelWeights bool, excludeFilePatterns []string) error { + return f.pullFunc(ctx, reference, targetDir) +} + +// newCacheWorker builds a Worker rooted at a temp dir and stubs the package +// level ResolveCacheDigest so tests run fully offline. The returned cleanup +// function MUST be deferred by the caller; it restores ResolveCacheDigest to +// its original value so tests don't leak state to each other. +func newCacheWorker(t *testing.T, digest string) (worker *Worker, cfg *config.Config, cleanup func()) { + t.Helper() + + cfg = config.NewWithRaw(&config.RawConfig{ + ServiceName: "test.csi.example.com", + RootDir: t.TempDir(), + }) + sm, err := status.NewStatusManager() + require.NoError(t, err) + worker, err = NewWorker(cfg, sm) + require.NoError(t, err) + + origResolve := ResolveCacheDigest + ResolveCacheDigest = func(_ context.Context, _ string) (string, error) { + if digest == "" { + return "", errors.New("empty manifest digest") + } + return digest, nil + } + + // Default puller: writes a sentinel file so callers can assert that the + // cache content was materialized. Tests can override this. + worker.newPuller = func(context.Context, *config.PullConfig, *status.Hook, *DiskQuotaChecker) Puller { + return &fakePuller{pullFunc: func(_ context.Context, _, targetDir string) error { + if err := os.MkdirAll(targetDir, 0755); err != nil { + return err + } + return os.WriteFile(filepath.Join(targetDir, "model.bin"), []byte("x"), 0644) + }} + } + + cleanup = func() { ResolveCacheDigest = origResolve } + return worker, cfg, cleanup +} + +// TestEnsureCachedModel covers the main branches of EnsureCachedModel: digest +// resolver failure, first pull populating the cache, and a second caller +// reusing the existing entry without re-pulling. +func TestEnsureCachedModel(t *testing.T) { + t.Run("resolver error propagates", func(t *testing.T) { + worker, _, cleanup := newCacheWorker(t, "") + defer cleanup() + worker.newPuller = func(context.Context, *config.PullConfig, *status.Hook, *DiskQuotaChecker) Puller { + return &fakePuller{pullFunc: func(context.Context, string, string) error { + t.Fatal("pull must not run when resolver fails") + return nil + }} + } + + _, err := worker.EnsureCachedModel(context.Background(), "r/m:v1", "pvc-a") + require.Error(t, err) + require.Contains(t, err.Error(), "empty manifest digest") + }) + + t.Run("first pull and reuse", func(t *testing.T) { + digest := "sha256:reuse" + worker, cfg, cleanup := newCacheWorker(t, digest) + defer cleanup() + + var pulls atomic.Int32 + worker.newPuller = func(context.Context, *config.PullConfig, *status.Hook, *DiskQuotaChecker) Puller { + return &fakePuller{pullFunc: func(_ context.Context, _, targetDir string) error { + pulls.Add(1) + require.NoError(t, os.MkdirAll(targetDir, 0755)) + return os.WriteFile(filepath.Join(targetDir, "model.bin"), []byte("x"), 0644) + }} + } + + // First caller triggers the pull and registers pvc-a. + got, err := worker.EnsureCachedModel(context.Background(), "r/m:v1", "pvc-a") + require.NoError(t, err) + require.Equal(t, digest, got) + + // Second caller must reuse the cache without re-pulling. + _, err = worker.EnsureCachedModel(context.Background(), "r/m:v1", "pvc-b") + require.NoError(t, err) + require.Equal(t, int32(1), pulls.Load()) + + // Cache content + ready marker + both refs exist. + _, err = os.Stat(filepath.Join(cfg.Get().GetCacheContentDir(digest), "model.bin")) + require.NoError(t, err) + refsDir := cfg.Get().GetCacheRefsDir(digest) + _, err = os.Stat(filepath.Join(refsDir, ".ready")) + require.NoError(t, err) + _, err = os.Stat(filepath.Join(refsDir, "pvc-a")) + require.NoError(t, err) + _, err = os.Stat(filepath.Join(refsDir, "pvc-b")) + require.NoError(t, err) + }) +} + +// TestEnsureCachedModel_ConcurrentDedup verifies that N concurrent callers +// for the same reference collapse into a single physical pull while each +// still registering its own ref file. +func TestEnsureCachedModel_ConcurrentDedup(t *testing.T) { + digest := "sha256:concurrent" + worker, cfg, cleanup := newCacheWorker(t, digest) + defer cleanup() + + var pulls atomic.Int32 + worker.newPuller = func(context.Context, *config.PullConfig, *status.Hook, *DiskQuotaChecker) Puller { + return &fakePuller{pullFunc: func(_ context.Context, _, targetDir string) error { + time.Sleep(100 * time.Millisecond) // encourage overlap + pulls.Add(1) + require.NoError(t, os.MkdirAll(targetDir, 0755)) + return os.WriteFile(filepath.Join(targetDir, "model.bin"), []byte("x"), 0644) + }} + } + + const N = 10 + var wg sync.WaitGroup + errs := make([]error, N) + start := make(chan struct{}) + for i := 0; i < N; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + refName := "pvc-" + string(rune('a'+idx)) + _, errs[idx] = worker.EnsureCachedModel(context.Background(), "r/m:v1", refName) + }(i) + } + close(start) + wg.Wait() + + for _, err := range errs { + require.NoError(t, err) + } + require.Equal(t, int32(1), pulls.Load(), "10 callers must collapse into 1 pull") + + refsDir := cfg.Get().GetCacheRefsDir(digest) + entries, err := os.ReadDir(refsDir) + require.NoError(t, err) + names := map[string]bool{} + for _, entry := range entries { + names[entry.Name()] = true + } + require.True(t, names[".ready"]) + for i := 0; i < N; i++ { + require.True(t, names["pvc-"+string(rune('a'+i))]) + } +} + +// TestEnsureCachedModel_PullFailureCleansUp verifies that a failed pull +// cleans up the half-populated content dir and does NOT write the ready +// marker, so a subsequent retry can succeed. +func TestEnsureCachedModel_PullFailureCleansUp(t *testing.T) { + digest := "sha256:boom" + worker, cfg, cleanup := newCacheWorker(t, digest) + defer cleanup() + + // First attempt: simulate a half-populated dir plus a failure. + worker.newPuller = func(context.Context, *config.PullConfig, *status.Hook, *DiskQuotaChecker) Puller { + return &fakePuller{pullFunc: func(_ context.Context, _, targetDir string) error { + require.NoError(t, os.MkdirAll(targetDir, 0755)) + _ = os.WriteFile(filepath.Join(targetDir, "half.bin"), []byte("x"), 0644) + return errors.New("pull exploded") + }} + } + _, err := worker.EnsureCachedModel(context.Background(), "r/m:v1", "pvc-a") + require.Error(t, err) + + _, statErr := os.Stat(cfg.Get().GetCacheContentDir(digest)) + require.True(t, os.IsNotExist(statErr), "content dir must be cleaned up after failure") + _, statErr = os.Stat(filepath.Join(cfg.Get().GetCacheRefsDir(digest), ".ready")) + require.True(t, os.IsNotExist(statErr), "ready marker must not exist after failure") + + // Retry with a working puller must succeed. + worker.newPuller = func(context.Context, *config.PullConfig, *status.Hook, *DiskQuotaChecker) Puller { + return &fakePuller{pullFunc: func(_ context.Context, _, targetDir string) error { + require.NoError(t, os.MkdirAll(targetDir, 0755)) + return os.WriteFile(filepath.Join(targetDir, "model.bin"), []byte("x"), 0644) + }} + } + got, err := worker.EnsureCachedModel(context.Background(), "r/m:v1", "pvc-a") + require.NoError(t, err) + require.Equal(t, digest, got) +} + +// TestReleaseCachedModel exercises the full release state machine: empty +// digest / unknown digest no-op, keep-if-other-refs, GC-when-last-ref, and +// idempotent double-release. +func TestReleaseCachedModel(t *testing.T) { + digest := "sha256:rel" + worker, cfg, cleanup := newCacheWorker(t, digest) + defer cleanup() + + // No-op paths first: empty / unknown digest must not error and must not + // touch the filesystem. + require.NoError(t, worker.ReleaseCachedModel(context.Background(), "", "pvc-a")) + require.NoError(t, worker.ReleaseCachedModel(context.Background(), "sha256:ghost", "pvc-a")) + + // Populate the cache with two refs. + _, err := worker.EnsureCachedModel(context.Background(), "r/m:v1", "pvc-a") + require.NoError(t, err) + _, err = worker.EnsureCachedModel(context.Background(), "r/m:v1", "pvc-b") + require.NoError(t, err) + + // Release pvc-a: cache must stay alive because pvc-b is still there. + require.NoError(t, worker.ReleaseCachedModel(context.Background(), digest, "pvc-a")) + _, statErr := os.Stat(cfg.Get().GetCacheContentDir(digest)) + require.NoError(t, statErr, "content must persist while pvc-b holds a ref") + _, statErr = os.Stat(filepath.Join(cfg.Get().GetCacheRefsDir(digest), "pvc-a")) + require.True(t, os.IsNotExist(statErr)) + _, statErr = os.Stat(filepath.Join(cfg.Get().GetCacheRefsDir(digest), "pvc-b")) + require.NoError(t, statErr) + + // Release pvc-b: last ref → GC. + require.NoError(t, worker.ReleaseCachedModel(context.Background(), digest, "pvc-b")) + _, statErr = os.Stat(cfg.Get().GetCacheContentDir(digest)) + require.True(t, os.IsNotExist(statErr), "content must be GC'd when the last ref is gone") + _, statErr = os.Stat(cfg.Get().GetCacheRefsDir(digest)) + require.True(t, os.IsNotExist(statErr), "refs dir must be GC'd when the last ref is gone") + + // Releasing again after GC is a no-op, not an error. + require.NoError(t, worker.ReleaseCachedModel(context.Background(), digest, "pvc-b")) +} + +// TestFindCacheDigestByRef exercises the reverse lookup: missing refs root, +// no match, a valid match, and verifies the .ready sentinel never shadows a +// real ref entry. +func TestFindCacheDigestByRef(t *testing.T) { + cfg := config.NewWithRaw(&config.RawConfig{ServiceName: "test", RootDir: t.TempDir()}) + sm, err := status.NewStatusManager() + require.NoError(t, err) + worker, err := NewWorker(cfg, sm) + require.NoError(t, err) + + // Missing refs root → no result, no error. + got, err := worker.FindCacheDigestByRef("pvc-a") + require.NoError(t, err) + require.Equal(t, "", got) + + // Digest1 only has the ready marker; digest2 has our ref. + digest1 := "sha256:only-ready" + digest2 := "sha256:has-ref" + require.NoError(t, os.MkdirAll(cfg.Get().GetCacheRefsDir(digest1), 0755)) + require.NoError(t, os.WriteFile(filepath.Join(cfg.Get().GetCacheRefsDir(digest1), ".ready"), nil, 0644)) + require.NoError(t, os.MkdirAll(cfg.Get().GetCacheRefsDir(digest2), 0755)) + require.NoError(t, os.WriteFile(filepath.Join(cfg.Get().GetCacheRefsDir(digest2), ".ready"), nil, 0644)) + require.NoError(t, os.WriteFile(filepath.Join(cfg.Get().GetCacheRefsDir(digest2), "pvc-target"), nil, 0644)) + + // Unknown ref → no result. + got, err = worker.FindCacheDigestByRef("pvc-missing") + require.NoError(t, err) + require.Equal(t, "", got) + + // Existing ref → digest returned. + got, err = worker.FindCacheDigestByRef("pvc-target") + require.NoError(t, err) + require.Equal(t, digest2, got) + + // Stray non-directory entries under refs root / algo dir must be skipped + // gracefully instead of aborting the scan. + refsRoot := cfg.Get().GetCacheRefsRootDir() + require.NoError(t, os.WriteFile(filepath.Join(refsRoot, "stray-file"), nil, 0644)) + algoDir := filepath.Join(refsRoot, "sha256") + require.NoError(t, os.WriteFile(filepath.Join(algoDir, "stray-under-algo"), nil, 0644)) + + got, err = worker.FindCacheDigestByRef("pvc-target") + require.NoError(t, err) + require.Equal(t, digest2, got) + + // A miss-lookup after the strays are in place forces the scan to walk + // past the non-directory entries under both refsRoot and algoDir, + // exercising the "skip non-directory" branches end to end. + got, err = worker.FindCacheDigestByRef("pvc-absent") + require.NoError(t, err) + require.Equal(t, "", got) +} + +// TestFindCacheDigestByRef_ReadRootError covers the error branch where the +// refs root is not a directory (e.g. someone created a file at that path), +// which makes os.ReadDir return a non-IsNotExist error. +func TestFindCacheDigestByRef_ReadRootError(t *testing.T) { + cfg := config.NewWithRaw(&config.RawConfig{ServiceName: "test", RootDir: t.TempDir()}) + sm, err := status.NewStatusManager() + require.NoError(t, err) + worker, err := NewWorker(cfg, sm) + require.NoError(t, err) + + // Plant a regular file where refs root is expected so that ReadDir fails + // with a "not a directory" error instead of IsNotExist. + refsRoot := cfg.Get().GetCacheRefsRootDir() + require.NoError(t, os.MkdirAll(filepath.Dir(refsRoot), 0755)) + require.NoError(t, os.WriteFile(refsRoot, []byte("not a dir"), 0644)) + + got, err := worker.FindCacheDigestByRef("pvc-any") + require.Error(t, err) + require.Contains(t, err.Error(), "read cache refs root") + require.Equal(t, "", got) +} + +// TestReleaseCachedModel_RemoveRefError covers the branch where os.Remove on +// the ref file returns a non-IsNotExist error. We trigger this by making the +// "ref file" actually be a non-empty directory: os.Remove on a non-empty +// directory fails with ENOTEMPTY (not ENOENT), exercising the error path. +func TestReleaseCachedModel_RemoveRefError(t *testing.T) { + digest := "sha256:bad-remove" + worker, cfg, cleanup := newCacheWorker(t, digest) + defer cleanup() + + refsDir := cfg.Get().GetCacheRefsDir(digest) + require.NoError(t, os.MkdirAll(refsDir, 0755)) + + // Plant a non-empty directory at the ref file path so os.Remove fails + // with ENOTEMPTY instead of being tolerated as ENOENT. + refAsDir := filepath.Join(refsDir, "pvc-a") + require.NoError(t, os.MkdirAll(refAsDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(refAsDir, "blocker"), nil, 0644)) + + err := worker.ReleaseCachedModel(context.Background(), digest, "pvc-a") + require.Error(t, err) + require.Contains(t, err.Error(), "remove cache ref file") +} + +// TestEnsureCachedModel_WriteRefFileError covers the branch where writing +// this caller's ref file fails. We trigger this by pre-creating a directory +// at the ref file path, so os.WriteFile fails with EISDIR. +func TestEnsureCachedModel_WriteRefFileError(t *testing.T) { + digest := "sha256:bad-ref-write" + worker, cfg, cleanup := newCacheWorker(t, digest) + defer cleanup() + + // Pre-populate the cache so the pull side returns immediately and we go + // straight to the ref-registration step. + contentDir := cfg.Get().GetCacheContentDir(digest) + refsDir := cfg.Get().GetCacheRefsDir(digest) + require.NoError(t, os.MkdirAll(contentDir, 0755)) + require.NoError(t, os.MkdirAll(refsDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(refsDir, ".ready"), nil, 0644)) + + // Plant a directory where the ref file is expected so WriteFile fails. + refName := "pvc-bad-ref" + require.NoError(t, os.MkdirAll(filepath.Join(refsDir, refName), 0755)) + + _, err := worker.EnsureCachedModel(context.Background(), "r/m:v1", refName) + require.Error(t, err) + require.Contains(t, err.Error(), "write cache ref file") +} + +// TestEnsureCachedModel_CreateRefsDirError covers the branch where the refs +// dir cannot be created because a regular file already occupies that path. +func TestEnsureCachedModel_CreateRefsDirError(t *testing.T) { + digest := "sha256:bad-mkdir" + worker, cfg, cleanup := newCacheWorker(t, digest) + defer cleanup() + + // Plant a regular file where the refs dir is expected so MkdirAll fails. + refsDir := cfg.Get().GetCacheRefsDir(digest) + require.NoError(t, os.MkdirAll(filepath.Dir(refsDir), 0755)) + require.NoError(t, os.WriteFile(refsDir, []byte("not a dir"), 0644)) + + worker.newPuller = func(context.Context, *config.PullConfig, *status.Hook, *DiskQuotaChecker) Puller { + return &fakePuller{pullFunc: func(context.Context, string, string) error { + t.Fatal("pull must not run when refs dir cannot be created") + return nil + }} + } + + _, err := worker.EnsureCachedModel(context.Background(), "r/m:v1", "pvc-a") + require.Error(t, err) + require.Contains(t, err.Error(), "create cache refs dir") +} diff --git a/pkg/status/status.go b/pkg/status/status.go index d8e10fd..870c6aa 100644 --- a/pkg/status/status.go +++ b/pkg/status/status.go @@ -57,12 +57,13 @@ func (p *Progress) String() (string, error) { } type Status struct { - VolumeName string `json:"volume_name,omitempty"` - MountID string `json:"mount_id,omitempty"` - Reference string `json:"reference,omitempty"` - State State `json:"state,omitempty"` - Inline bool `json:"inline,omitempty"` - Progress Progress `json:"progress,omitempty"` + VolumeName string `json:"volume_name,omitempty"` + MountID string `json:"mount_id,omitempty"` + Reference string `json:"reference,omitempty"` + State State `json:"state,omitempty"` + Inline bool `json:"inline,omitempty"` + CacheDigest string `json:"cache_digest,omitempty"` + Progress Progress `json:"progress,omitempty"` } func NewStatusManager() (*StatusManager, error) {