Skip to content
26 changes: 20 additions & 6 deletions pkg/compose/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
"github.com/compose-spec/compose-go/v2/types"
"github.com/containerd/errdefs"
"github.com/docker/cli/cli-plugins/manager"
"github.com/sirupsen/logrus"
"github.com/docker/docker/api/types/versions"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"

Expand Down Expand Up @@ -159,21 +159,22 @@ func (m *modelAPI) PullModel(ctx context.Context, model types.ModelConfig, quiet
}

func (m *modelAPI) ConfigureModel(ctx context.Context, config types.ModelConfig, events api.EventProcessor) error {
if len(config.RuntimeFlags) != 0 {
logrus.Warnf("Runtime flags are not supported and will be ignored for model %s", config.Model)
config.RuntimeFlags = nil
}
events.On(api.Resource{
ID: config.Name,
Status: api.Working,
Text: api.StatusConfiguring,
})
// configure [--context-size=<n>] MODEL
// configure [--context-size=<n>] MODEL [-- <runtime-flags...>]
args := []string{"configure"}
if config.ContextSize > 0 {
args = append(args, "--context-size", strconv.Itoa(config.ContextSize))
}
args = append(args, config.Model)
// Only append RuntimeFlags if docker model CLI version is >= v1.0.6
if len(config.RuntimeFlags) != 0 && m.supportsRuntimeFlags() {
args = append(args, "--")
args = append(args, config.RuntimeFlags...)
}
cmd := exec.CommandContext(ctx, m.path, args...)
err := m.prepare(ctx, cmd)
if err != nil {
Expand Down Expand Up @@ -278,3 +279,16 @@ func (m *modelAPI) ListModels(ctx context.Context) ([]string, error) {
}
return availableModels, nil
}

// supportsRuntimeFlags checks if the docker model version supports runtime flags
// Runtime flags are supported in version >= v1.0.6
func (m *modelAPI) supportsRuntimeFlags() bool {
// If version is not cached, don't append runtime flags to be safe
if m.version == "" {
return false
}

// Strip 'v' prefix if present (e.g., "v1.0.6" -> "1.0.6")
versionStr := strings.TrimPrefix(m.version, "v")
return !versions.LessThan(versionStr, "1.0.6")
}