Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ func main() {
modelManager,
log.WithFields(logrus.Fields{"component": "llama.cpp"}),
llamaServerPath,
func() string { wd, _ := os.Getwd(); return wd }(),
func() string {
wd, _ := os.Getwd()
d := filepath.Join(wd, "updated-inference")
_ = os.MkdirAll(d, 0o755)
return d
}(),
)
if err != nil {
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
Expand Down
24 changes: 24 additions & 0 deletions pkg/diskusage/diskusage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package diskusage

import (
"io/fs"
"path/filepath"
)

func Size(path string) (float64, error) {
var size int64
err := filepath.WalkDir(path, func(_ string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.Type().IsRegular() {
info, err := d.Info()
if err != nil {
return err
}
size += info.Size()
}
return nil
})
return float64(size), err
}
2 changes: 2 additions & 0 deletions pkg/inference/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,6 @@ type Backend interface {
Run(ctx context.Context, socket, model string, mode BackendMode) error
// Status returns a description of the backend's state.
Status() string
// GetDiskUsage returns the disk usage of the backend.
GetDiskUsage() (float64, error)
}
9 changes: 9 additions & 0 deletions pkg/inference/backends/llamacpp/llamacpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"runtime"
"strconv"

"github.com/docker/model-runner/pkg/diskusage"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/logging"
Expand Down Expand Up @@ -199,3 +200,11 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
func (l *llamaCpp) Status() string {
return l.status
}

func (l *llamaCpp) GetDiskUsage() (float64, error) {
size, err := diskusage.Size(l.updatedServerStoragePath)
if err != nil {
return 0, fmt.Errorf("error while getting store size: %v", err)
}
return size, nil
}
4 changes: 4 additions & 0 deletions pkg/inference/backends/mlx/mlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,7 @@ func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.Back
func (m *mlx) Status() string {
return "not running"
}

func (m *mlx) GetDiskUsage() (float64, error) {
return 0, nil
}
4 changes: 4 additions & 0 deletions pkg/inference/backends/vllm/vllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.Bac
func (v *vLLM) Status() string {
return "not running"
}

func (v *vLLM) GetDiskUsage() (float64, error) {
return 0, nil
}
16 changes: 16 additions & 0 deletions pkg/inference/models/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/docker/model-distribution/distribution"
"github.com/docker/model-distribution/registry"
"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/diskusage"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/logging"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -399,6 +400,21 @@ func (m *Manager) handlePushModel(w http.ResponseWriter, r *http.Request, model
}
}

// GetDiskUsage returns the disk usage of the model store.
func (m *Manager) GetDiskUsage() (float64, error, int) {
if m.distributionClient == nil {
return 0, errors.New("model distribution service unavailable"), http.StatusServiceUnavailable
}

storePath := m.distributionClient.GetStorePath()
size, err := diskusage.Size(storePath)
if err != nil {
return 0, fmt.Errorf("error while getting store size: %v", err), http.StatusInternalServerError
}

return size, nil, http.StatusOK
}

// ServeHTTP implement net/http.Handler.ServeHTTP.
func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.router.ServeHTTP(w, r)
Expand Down
19 changes: 19 additions & 0 deletions pkg/inference/scheduling/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package scheduling

import (
"strings"
"time"

"github.com/docker/model-runner/pkg/inference"
)
Expand Down Expand Up @@ -42,3 +43,21 @@ type OpenAIInferenceRequest struct {
// Model is the requested model name.
Model string `json:"model"`
}

// BackendStatus represents information about a running backend
type BackendStatus struct {
// BackendName is the name of the backend
BackendName string `json:"backend_name"`
// ModelName is the name of the model loaded in the backend
ModelName string `json:"model_name"`
// Mode is the mode the backend is operating in
Mode string `json:"mode"`
// LastUsed represents when this (backend, model, mode) tuple was last used
LastUsed time.Time `json:"last_used,omitempty"`
}

// DiskUsage represents the disk usage of the models and default backend.
type DiskUsage struct {
ModelsDiskUsage float64 `json:"models_disk_usage"`
DefaultBackendDiskUsage float64 `json:"default_backend_disk_usage"`
}
65 changes: 65 additions & 0 deletions pkg/inference/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
"time"

"github.com/docker/model-distribution/distribution"
"github.com/docker/model-runner/pkg/inference"
Expand Down Expand Up @@ -81,6 +82,8 @@ func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
m[route] = s.handleOpenAIInference
}
m["GET "+inference.InferencePrefix+"/status"] = s.GetBackendStatus
m["GET "+inference.InferencePrefix+"/ps"] = s.GetRunningBackends
m["GET "+inference.InferencePrefix+"/df"] = s.GetDiskUsage
return m
}

Expand Down Expand Up @@ -224,6 +227,68 @@ func (s *Scheduler) ResetInstaller(httpClient *http.Client) {
s.installer = newInstaller(s.log, s.backends, httpClient)
}

// GetRunningBackends returns information about all running backends
func (s *Scheduler) GetRunningBackends(w http.ResponseWriter, r *http.Request) {
runningBackends := s.getLoaderStatus()

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(runningBackends); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}

// getLoaderStatus returns information about all running backends managed by the loader
func (s *Scheduler) getLoaderStatus() []BackendStatus {
if !s.loader.lock(context.Background()) {
return []BackendStatus{}
}
defer s.loader.unlock()

result := make([]BackendStatus, 0, len(s.loader.runners))

for key, slot := range s.loader.runners {
if s.loader.slots[slot] != nil {
status := BackendStatus{
BackendName: key.backend,
ModelName: key.model,
Mode: key.mode.String(),
LastUsed: time.Time{},
}

if s.loader.references[slot] == 0 {
status.LastUsed = s.loader.timestamps[slot]
}

result = append(result, status)
}
}

return result
}

func (s *Scheduler) GetDiskUsage(w http.ResponseWriter, _ *http.Request) {
modelsDiskUsage, err, httpCode := s.modelManager.GetDiskUsage()
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get models disk usage: %v", err), httpCode)
return
}

// TODO: Get disk usage for each backend once the backends are implemented.
defaultBackendDiskUsage, err := s.defaultBackend.GetDiskUsage()
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get disk usage for %s: %v", s.defaultBackend.Name(), err), http.StatusInternalServerError)
return
}

diskUsage := DiskUsage{modelsDiskUsage, defaultBackendDiskUsage}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(diskUsage); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}

// ServeHTTP implements net/http.Handler.ServeHTTP.
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
Expand Down