diff --git a/cmd/cli/commands/utils_test.go b/cmd/cli/commands/utils_test.go index 583c97d8d..9d3b22744 100644 --- a/cmd/cli/commands/utils_test.go +++ b/cmd/cli/commands/utils_test.go @@ -42,12 +42,12 @@ func TestNormalizeModelName(t *testing.T) { { name: "huggingface model", input: "hf.co/bartowski/model", - expected: "hf.co/bartowski/model:latest", + expected: "huggingface.co/bartowski/model:latest", }, { name: "huggingface model with tag", input: "hf.co/bartowski/model:Q4_K_S", - expected: "hf.co/bartowski/model:q4_k_s", + expected: "huggingface.co/bartowski/model:q4_k_s", }, { name: "registry with model", diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 58d0bae9d..4b197f9ad 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/url" + "os" "strconv" "strings" "time" @@ -107,8 +108,18 @@ func (c *Client) Status() Status { func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, printer standalone.StatusPrinter) (string, bool, error) { model = normalizeHuggingFaceModelName(model) + // Check if this is a Hugging Face model and if HF_TOKEN is set + var hfToken string + if strings.HasPrefix(strings.ToLower(model), "hf.co/") { + hfToken = os.Getenv("HF_TOKEN") + } + return c.withRetries("download", 3, printer, func(attempt int) (string, bool, error, bool) { - jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck}) + jsonData, err := json.Marshal(dmrm.ModelCreateRequest{ + From: model, + IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck, + BearerToken: hfToken, + }) if err != nil { // Marshaling errors are not retryable return "", false, fmt.Errorf("error marshaling request: %w", err), false diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index 69508436e..872f45ca4 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -17,6 +17,7 @@ import ( "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/tarball" "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/go-containerregistry/pkg/authn" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/remote" "github.com/docker/model-runner/pkg/inference/platform" ) @@ -138,11 +139,19 @@ func NewClient(opts ...Option) (*Client, error) { } // PullModel pulls a model from a registry and returns the local file path -func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer) error { +func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer, bearerToken ...string) error { c.log.Infoln("Starting model pull:", utils.SanitizeForLog(reference)) + // Use the client's registry, or create a temporary one if bearer token is provided + registryClient := c.registry + if len(bearerToken) > 0 && bearerToken[0] != "" { + // Create a temporary registry client with bearer token authentication + auth := &authn.Bearer{Token: bearerToken[0]} + registryClient = registry.FromClient(c.registry, registry.WithAuth(auth)) + } + // First, fetch the remote model to get the manifest - remoteModel, err := c.registry.Model(ctx, reference) + remoteModel, err := registryClient.Model(ctx, reference) if err != nil { return fmt.Errorf("reading model from registry: %w", err) } @@ -214,7 +223,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter } digestReference := repository + "@" + remoteDigest.String() c.log.Infof("Re-fetching model with digest reference: %s", utils.SanitizeForLog(digestReference)) - remoteModel, err = c.registry.Model(ctx, digestReference) + remoteModel, err = registryClient.Model(ctx, digestReference) if err != nil { return fmt.Errorf("reading model from registry with resume context: %w", err) } diff --git a/pkg/distribution/registry/client.go b/pkg/distribution/registry/client.go index 58210282d..aab9e5aaf 100644 --- a/pkg/distribution/registry/client.go +++ b/pkg/distribution/registry/client.go @@ -85,6 +85,15 @@ func WithAuthConfig(username, password string) ClientOption { } } +// WithAuth sets a custom authenticator. +func WithAuth(auth authn.Authenticator) ClientOption { + return func(c *Client) { + if auth != nil { + c.auth = auth + } + } +} + func NewClient(opts ...ClientOption) *Client { client := &Client{ transport: remote.DefaultTransport, @@ -97,6 +106,21 @@ func NewClient(opts ...ClientOption) *Client { return client } +// FromClient creates a new Client by copying an existing client's configuration +// and applying optional modifications via ClientOption functions. +func FromClient(base *Client, opts ...ClientOption) *Client { + client := &Client{ + transport: base.transport, + userAgent: base.userAgent, + keychain: base.keychain, + auth: base.auth, + } + for _, opt := range opts { + opt(client) + } + return client +} + func (c *Client) Model(ctx context.Context, reference string) (types.ModelArtifact, error) { // Parse the reference ref, err := name.ParseReference(reference, GetDefaultRegistryOptions()...) diff --git a/pkg/inference/models/api.go b/pkg/inference/models/api.go index 50ccc5e7e..86efbea8d 100644 --- a/pkg/inference/models/api.go +++ b/pkg/inference/models/api.go @@ -17,6 +17,8 @@ type ModelCreateRequest struct { // IgnoreRuntimeMemoryCheck indicates whether the server should check if it has sufficient // memory to run the given model (assuming default configuration). IgnoreRuntimeMemoryCheck bool `json:"ignore-runtime-memory-check,omitempty"` + // BearerToken is an optional bearer token for authentication. + BearerToken string `json:"bearer-token,omitempty"` } // ModelPackageRequest represents a model package request, which creates a new model diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index cf231b609..f547863aa 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -17,11 +17,11 @@ import ( "github.com/docker/model-runner/pkg/distribution/distribution" "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/types" + v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/middleware" - v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/sirupsen/logrus" ) @@ -141,7 +141,8 @@ func NormalizeModelName(model string) string { // Normalize HuggingFace model names (lowercase) if strings.HasPrefix(model, "hf.co/") { - model = strings.ToLower(model) + // Replace hf.co with huggingface.co to avoid losing the Authorization header on redirect. + model = "huggingface.co" + strings.ToLower(strings.TrimPrefix(model, "hf.co")) } // Check if model contains a registry (domain with dot before first slash) @@ -221,7 +222,7 @@ func (m *Manager) handleCreateModel(w http.ResponseWriter, r *http.Request) { return } } - if err := m.PullModel(request.From, r, w); err != nil { + if err := m.PullModel(request.From, request.BearerToken, r, w); err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { m.log.Infof("Request canceled/timed out while pulling model %q", request.From) return @@ -881,15 +882,15 @@ func (m *Manager) GetModels() ([]*Model, error) { return nil, fmt.Errorf("error while listing models: %w", err) } - apiModels := make([]*Model, 0, len(models)) - for _, model := range models { - apiModel, err := ToModel(model) - if err != nil { - m.log.Warnf("error while converting model, skipping: %v", err) - continue - } - apiModels = append(apiModels, apiModel) - } + apiModels := make([]*Model, 0, len(models)) + for _, model := range models { + apiModel, err := ToModel(model) + if err != nil { + m.log.Warnf("error while converting model, skipping: %v", err) + continue + } + apiModels = append(apiModels, apiModel) + } return apiModels, nil } @@ -941,7 +942,7 @@ func (m *Manager) GetBundle(ref string) (types.ModelBundle, error) { // PullModel pulls a model to local storage. Any error it returns is suitable // for writing back to the client. -func (m *Manager) PullModel(model string, r *http.Request, w http.ResponseWriter) error { +func (m *Manager) PullModel(model string, bearerToken string, r *http.Request, w http.ResponseWriter) error { // Restrict model pull concurrency. select { case <-m.pullTokens: @@ -983,7 +984,16 @@ func (m *Manager) PullModel(model string, r *http.Request, w http.ResponseWriter // Pull the model using the Docker model distribution client m.log.Infoln("Pulling model:", model) - err := m.distributionClient.PullModel(r.Context(), model, progressWriter) + + // Use bearer token if provided + var err error + if bearerToken != "" { + m.log.Infoln("Using provided bearer token for authentication") + err = m.distributionClient.PullModel(r.Context(), model, progressWriter, bearerToken) + } else { + err = m.distributionClient.PullModel(r.Context(), model, progressWriter) + } + if err != nil { return fmt.Errorf("error while pulling model: %w", err) } diff --git a/pkg/inference/models/manager_test.go b/pkg/inference/models/manager_test.go index 5a517fe15..5e36b445d 100644 --- a/pkg/inference/models/manager_test.go +++ b/pkg/inference/models/manager_test.go @@ -135,7 +135,7 @@ func TestPullModel(t *testing.T) { } w := httptest.NewRecorder() - err = m.PullModel(tag, r, w) + err = m.PullModel(tag, "", r, w) if err != nil { t.Fatalf("Failed to pull model: %v", err) } @@ -246,7 +246,7 @@ func TestHandleGetModel(t *testing.T) { if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") { r := httptest.NewRequest("POST", "/models/create", strings.NewReader(`{"from": "`+tt.modelName+`"}`)) w := httptest.NewRecorder() - err = m.PullModel(tt.modelName, r, w) + err = m.PullModel(tt.modelName, "", r, w) if err != nil { t.Fatalf("Failed to pull model: %v", err) } diff --git a/pkg/ollama/handler.go b/pkg/ollama/handler.go index 697505a00..dcabbd952 100644 --- a/pkg/ollama/handler.go +++ b/pkg/ollama/handler.go @@ -579,7 +579,7 @@ func (h *Handler) handlePull(w http.ResponseWriter, r *http.Request) { r.Header.Set("Accept", "application/json") // Call the model manager's PullModel method - if err := h.modelManager.PullModel(modelName, r, w); err != nil { + if err := h.modelManager.PullModel(modelName, "", r, w); err != nil { h.log.Errorf("Failed to pull model: %v", err) // Only write error if headers haven't been sent yet if !isHeadersSent(w) {