diff --git a/pkg/distribution/internal/store/store_test.go b/pkg/distribution/internal/store/store_test.go index 15479017d..a9286f357 100644 --- a/pkg/distribution/internal/store/store_test.go +++ b/pkg/distribution/internal/store/store_test.go @@ -796,6 +796,184 @@ func newTestModelWithMultimodalProjector(t *testing.T) types.ModelArtifact { } // TestWriteLightweight tests the WriteLightweight method +func TestResetStore(t *testing.T) { + tests := []struct { + name string + setupModels int + }{ + { + name: "reset with multiple models in store", + setupModels: 3, + }, + { + name: "reset empty store", + setupModels: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a temporary directory for the test store + tempDir, err := os.MkdirTemp("", "reset-store-test") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create store + storePath := filepath.Join(tempDir, "reset-model-store") + s, err := store.New(store.Options{ + RootPath: storePath, + }) + if err != nil { + t.Fatalf("Failed to create store: %v", err) + } + + // Track blob and manifest paths for verification + var blobPaths []string + var manifestPaths []string + + // Setup models based on test case + if tt.setupModels > 0 { + for i := 0; i < tt.setupModels; i++ { + // Create a unique model file for each iteration + modelContent := []byte(fmt.Sprintf("unique model content %d", i)) + modelPath := filepath.Join(tempDir, fmt.Sprintf("model-%d.gguf", i)) + if err := os.WriteFile(modelPath, modelContent, 0644); err != nil { + t.Fatalf("Failed to create model file: %v", err) + } + + mdl, err := gguf.NewModel(modelPath) + if err != nil { + t.Fatalf("Failed to create model: %v", err) + } + + tag := fmt.Sprintf("test-model-%d:latest", i) + if err := s.Write(mdl, []string{tag}, nil); err != nil { + t.Fatalf("Failed to write model %d: %v", i, err) + } + + // Collect blob paths + layers, err := mdl.Layers() + if err != nil { + t.Fatalf("Failed to get layers: %v", err) + } + for _, layer := range layers { + digest, err := layer.Digest() + if err != nil { + t.Fatalf("Failed to get layer digest: %v", err) + } + blobPath := filepath.Join(storePath, "blobs", digest.Algorithm, digest.Hex) + blobPaths = append(blobPaths, blobPath) + } + + // Collect config blob path + configName, err := mdl.ConfigName() + if err != nil { + t.Fatalf("Failed to get config name: %v", err) + } + configPath := filepath.Join(storePath, "blobs", configName.Algorithm, configName.Hex) + blobPaths = append(blobPaths, configPath) + + // Collect manifest path + digest, err := mdl.Digest() + if err != nil { + t.Fatalf("Failed to get digest: %v", err) + } + manifestPath := filepath.Join(storePath, "manifests", digest.Algorithm, digest.Hex) + manifestPaths = append(manifestPaths, manifestPath) + } + + // Verify models exist before reset + models, err := s.List() + if err != nil { + t.Fatalf("Failed to list models before reset: %v", err) + } + if len(models) != tt.setupModels { + t.Fatalf("Expected %d models before reset, got %d", tt.setupModels, len(models)) + } + + // Verify blobs exist before reset + for _, blobPath := range blobPaths { + if _, err := os.Stat(blobPath); os.IsNotExist(err) { + t.Errorf("Blob file should exist before reset: %s", blobPath) + } + } + + // Verify manifests exist before reset + for _, manifestPath := range manifestPaths { + if _, err := os.Stat(manifestPath); os.IsNotExist(err) { + t.Errorf("Manifest file should exist before reset: %s", manifestPath) + } + } + + } + + // Call Reset + if err := s.Reset(); err != nil { + t.Fatalf("Reset failed: %v", err) + } + + // Verify store is empty after reset + models, err := s.List() + if err != nil { + t.Fatalf("Failed to list models after reset: %v", err) + } + if len(models) != 0 { + t.Errorf("Expected empty store after reset, got %d models", len(models)) + } + + // Verify all blobs are deleted + for _, blobPath := range blobPaths { + if _, err := os.Stat(blobPath); !os.IsNotExist(err) { + t.Errorf("Blob file should be deleted after reset: %s", blobPath) + } + } + + // Verify all manifests are deleted + for _, manifestPath := range manifestPaths { + if _, err := os.Stat(manifestPath); !os.IsNotExist(err) { + t.Errorf("Manifest file should be deleted after reset: %s", manifestPath) + } + } + + // Verify store root directory still exists + if _, err := os.Stat(storePath); os.IsNotExist(err) { + t.Error("Store directory should still exist after reset") + } + + // Note: blobs and manifests directories are created on-demand, + // so they won't exist after reset until models are written + + // Verify store is functional after reset by writing a new model + newModel := newTestModel(t) + if err := s.Write(newModel, []string{"post-reset:latest"}, nil); err != nil { + t.Fatalf("Failed to write model after reset: %v", err) + } + + // Verify the new model can be read + readModel, err := s.Read("post-reset:latest") + if err != nil { + t.Fatalf("Failed to read model after reset: %v", err) + } + + readDigest, err := readModel.Digest() + if err != nil { + t.Fatalf("Failed to get digest: %v", err) + } + + newDigest, err := newModel.Digest() + if err != nil { + t.Fatalf("Failed to get new digest: %v", err) + } + + if readDigest.String() != newDigest.String() { + t.Error("Model written after reset doesn't match") + } + }) + } +} + func TestWriteLightweight(t *testing.T) { // Create a temporary directory for the test store tempDir, err := os.MkdirTemp("", "lightweight-write-test") diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index ed52760a3..c43b0ea09 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -177,6 +177,7 @@ func (m *Manager) routeHandlers() map[string]http.HandlerFunc { "GET " + inference.ModelsPrefix + "/{name...}": m.handleGetModel, "DELETE " + inference.ModelsPrefix + "/{name...}": m.handleDeleteModel, "POST " + inference.ModelsPrefix + "/{nameAndAction...}": m.handleModelAction, + "DELETE " + inference.ModelsPrefix + "/prune": m.handlePrune, "GET " + inference.InferencePrefix + "/{backend}/v1/models": m.handleOpenAIGetModels, "GET " + inference.InferencePrefix + "/{backend}/v1/models/{name...}": m.handleOpenAIGetModel, "GET " + inference.InferencePrefix + "/v1/models": m.handleOpenAIGetModels, @@ -611,6 +612,20 @@ func (m *Manager) handlePushModel(w http.ResponseWriter, r *http.Request, model } } +// handlePrune handles DELETE /models/prune requests. +func (m *Manager) handlePrune(w http.ResponseWriter, _ *http.Request) { + if m.distributionClient == nil { + http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable) + return + } + + if err := m.distributionClient.ResetStore(); err != nil { + m.log.Warnf("Failed to prune models: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + // GetDiskUsage returns the disk usage of the model store. func (m *Manager) GetDiskUsage() (int64, error, int) { if m.distributionClient == nil {