Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pkg/inference/scheduling/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package scheduling

import (
"strings"
"time"

"github.com/docker/model-runner/pkg/inference"
)
Expand Down Expand Up @@ -42,3 +43,15 @@ type OpenAIInferenceRequest struct {
// Model is the requested model name.
Model string `json:"model"`
}

// BackendStatus represents information about a running backend
type BackendStatus struct {
// BackendName is the name of the backend
BackendName string `json:"backend_name"`
// ModelName is the name of the model loaded in the backend
ModelName string `json:"model_name"`
// Mode is the mode the backend is operating in
Mode string `json:"mode"`
// LastUsed represents when this (backend, model, mode) tuple was last used
LastUsed time.Time `json:"last_used,omitempty"`
}
42 changes: 42 additions & 0 deletions pkg/inference/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
"time"

"github.com/docker/model-distribution/distribution"
"github.com/docker/model-runner/pkg/inference"
Expand Down Expand Up @@ -81,6 +82,7 @@ func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
m[route] = s.handleOpenAIInference
}
m["GET "+inference.InferencePrefix+"/status"] = s.GetBackendStatus
m["GET "+inference.InferencePrefix+"/ps"] = s.GetRunningBackends
return m
}

Expand Down Expand Up @@ -224,6 +226,46 @@ func (s *Scheduler) ResetInstaller(httpClient *http.Client) {
s.installer = newInstaller(s.log, s.backends, httpClient)
}

// GetRunningBackends returns information about all running backends
func (s *Scheduler) GetRunningBackends(w http.ResponseWriter, r *http.Request) {
runningBackends := s.getLoaderStatus()

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(runningBackends); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}

// getLoaderStatus returns information about all running backends managed by the loader
func (s *Scheduler) getLoaderStatus() []BackendStatus {
if !s.loader.lock(context.Background()) {
return []BackendStatus{}
}
defer s.loader.unlock()

result := make([]BackendStatus, 0, len(s.loader.runners))

for key, slot := range s.loader.runners {
if s.loader.slots[slot] != nil {
status := BackendStatus{
BackendName: key.backend,
ModelName: key.model,
Mode: key.mode.String(),
LastUsed: time.Time{},
}

if s.loader.references[slot] == 0 {
status.LastUsed = s.loader.timestamps[slot]
}

result = append(result, status)
}
}

return result
}

// ServeHTTP implements net/http.Handler.ServeHTTP.
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
Expand Down