Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions pkg/distribution/internal/store/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,184 @@ func newTestModelWithMultimodalProjector(t *testing.T) types.ModelArtifact {
}

// TestWriteLightweight tests the WriteLightweight method
func TestResetStore(t *testing.T) {
tests := []struct {
name string
setupModels int
}{
{
name: "reset with multiple models in store",
setupModels: 3,
},
{
name: "reset empty store",
setupModels: 0,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a temporary directory for the test store
tempDir, err := os.MkdirTemp("", "reset-store-test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)

// Create store
storePath := filepath.Join(tempDir, "reset-model-store")
s, err := store.New(store.Options{
RootPath: storePath,
})
if err != nil {
t.Fatalf("Failed to create store: %v", err)
}

// Track blob and manifest paths for verification
var blobPaths []string
var manifestPaths []string

// Setup models based on test case
if tt.setupModels > 0 {
for i := 0; i < tt.setupModels; i++ {
// Create a unique model file for each iteration
modelContent := []byte(fmt.Sprintf("unique model content %d", i))
modelPath := filepath.Join(tempDir, fmt.Sprintf("model-%d.gguf", i))
if err := os.WriteFile(modelPath, modelContent, 0644); err != nil {
t.Fatalf("Failed to create model file: %v", err)
}

mdl, err := gguf.NewModel(modelPath)
if err != nil {
t.Fatalf("Failed to create model: %v", err)
}

tag := fmt.Sprintf("test-model-%d:latest", i)
if err := s.Write(mdl, []string{tag}, nil); err != nil {
t.Fatalf("Failed to write model %d: %v", i, err)
}

// Collect blob paths
layers, err := mdl.Layers()
if err != nil {
t.Fatalf("Failed to get layers: %v", err)
}
for _, layer := range layers {
digest, err := layer.Digest()
if err != nil {
t.Fatalf("Failed to get layer digest: %v", err)
}
blobPath := filepath.Join(storePath, "blobs", digest.Algorithm, digest.Hex)
blobPaths = append(blobPaths, blobPath)
}

// Collect config blob path
configName, err := mdl.ConfigName()
if err != nil {
t.Fatalf("Failed to get config name: %v", err)
}
configPath := filepath.Join(storePath, "blobs", configName.Algorithm, configName.Hex)
blobPaths = append(blobPaths, configPath)

// Collect manifest path
digest, err := mdl.Digest()
if err != nil {
t.Fatalf("Failed to get digest: %v", err)
}
manifestPath := filepath.Join(storePath, "manifests", digest.Algorithm, digest.Hex)
manifestPaths = append(manifestPaths, manifestPath)
}

// Verify models exist before reset
models, err := s.List()
if err != nil {
t.Fatalf("Failed to list models before reset: %v", err)
}
if len(models) != tt.setupModels {
t.Fatalf("Expected %d models before reset, got %d", tt.setupModels, len(models))
}

// Verify blobs exist before reset
for _, blobPath := range blobPaths {
if _, err := os.Stat(blobPath); os.IsNotExist(err) {
t.Errorf("Blob file should exist before reset: %s", blobPath)
}
}

// Verify manifests exist before reset
for _, manifestPath := range manifestPaths {
if _, err := os.Stat(manifestPath); os.IsNotExist(err) {
t.Errorf("Manifest file should exist before reset: %s", manifestPath)
}
}

}

// Call Reset
if err := s.Reset(); err != nil {
t.Fatalf("Reset failed: %v", err)
}

// Verify store is empty after reset
models, err := s.List()
if err != nil {
t.Fatalf("Failed to list models after reset: %v", err)
}
if len(models) != 0 {
t.Errorf("Expected empty store after reset, got %d models", len(models))
}

// Verify all blobs are deleted
for _, blobPath := range blobPaths {
if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
t.Errorf("Blob file should be deleted after reset: %s", blobPath)
}
}

// Verify all manifests are deleted
for _, manifestPath := range manifestPaths {
if _, err := os.Stat(manifestPath); !os.IsNotExist(err) {
t.Errorf("Manifest file should be deleted after reset: %s", manifestPath)
}
}

// Verify store root directory still exists
if _, err := os.Stat(storePath); os.IsNotExist(err) {
t.Error("Store directory should still exist after reset")
}

// Note: blobs and manifests directories are created on-demand,
// so they won't exist after reset until models are written

// Verify store is functional after reset by writing a new model
newModel := newTestModel(t)
if err := s.Write(newModel, []string{"post-reset:latest"}, nil); err != nil {
t.Fatalf("Failed to write model after reset: %v", err)
}

// Verify the new model can be read
readModel, err := s.Read("post-reset:latest")
if err != nil {
t.Fatalf("Failed to read model after reset: %v", err)
}

readDigest, err := readModel.Digest()
if err != nil {
t.Fatalf("Failed to get digest: %v", err)
}

newDigest, err := newModel.Digest()
if err != nil {
t.Fatalf("Failed to get new digest: %v", err)
}

if readDigest.String() != newDigest.String() {
t.Error("Model written after reset doesn't match")
}
})
}
}

func TestWriteLightweight(t *testing.T) {
// Create a temporary directory for the test store
tempDir, err := os.MkdirTemp("", "lightweight-write-test")
Expand Down
15 changes: 15 additions & 0 deletions pkg/inference/models/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ func (m *Manager) routeHandlers() map[string]http.HandlerFunc {
"GET " + inference.ModelsPrefix + "/{name...}": m.handleGetModel,
"DELETE " + inference.ModelsPrefix + "/{name...}": m.handleDeleteModel,
"POST " + inference.ModelsPrefix + "/{nameAndAction...}": m.handleModelAction,
"DELETE " + inference.ModelsPrefix + "/prune": m.handlePrune,
"GET " + inference.InferencePrefix + "/{backend}/v1/models": m.handleOpenAIGetModels,
"GET " + inference.InferencePrefix + "/{backend}/v1/models/{name...}": m.handleOpenAIGetModel,
"GET " + inference.InferencePrefix + "/v1/models": m.handleOpenAIGetModels,
Expand Down Expand Up @@ -611,6 +612,20 @@ func (m *Manager) handlePushModel(w http.ResponseWriter, r *http.Request, model
}
}

// handlePrune handles DELETE <inference-prefix>/models/prune requests.
func (m *Manager) handlePrune(w http.ResponseWriter, _ *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}

if err := m.distributionClient.ResetStore(); err != nil {
m.log.Warnf("Failed to prune models: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}

// GetDiskUsage returns the disk usage of the model store.
func (m *Manager) GetDiskUsage() (int64, error, int) {
if m.distributionClient == nil {
Expand Down