diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 2f18aaee..1d1a6e9b 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -8,6 +8,7 @@ import ( "io" "os" "os/signal" + "strconv" "strings" "syscall" @@ -15,6 +16,8 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" "github.com/docker/model-runner/cmd/cli/readline" + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/scheduling" "github.com/fatih/color" "github.com/muesli/termenv" "github.com/spf13/cobra" @@ -90,11 +93,12 @@ func readMultilineInput(cmd *cobra.Command, scanner *bufio.Scanner) (string, err func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.Client, model string) error { usage := func() { fmt.Fprintln(os.Stderr, "Available Commands:") - fmt.Fprintln(os.Stderr, " /set system Set or update the system message") fmt.Fprintln(os.Stderr, " /bye Exit") + fmt.Fprintln(os.Stderr, " /set Set a session variable") fmt.Fprintln(os.Stderr, " /?, /help Help for a command") fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") fmt.Fprintln(os.Stderr, " /? files Help for file inclusion with @ symbol") + fmt.Fprintln(os.Stderr, " /? set Help for /set command") fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, `Use """ to begin a multi-line message.`) fmt.Fprintln(os.Stderr, "") @@ -134,6 +138,13 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop. fmt.Fprintln(os.Stderr, "") } + usageSet := func() { + fmt.Fprintln(os.Stderr, "Available /set commands:") + fmt.Fprintln(os.Stderr, " /set system Set system message for the conversation") + fmt.Fprintln(os.Stderr, " /set parameter num_ctx Set context window size (in tokens)") + fmt.Fprintln(os.Stderr, "") + } + scanner, err := readline.New(readline.Prompt{ Prompt: "> ", AltPrompt: ". ", @@ -212,6 +223,8 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop. usageShortcuts() case "file", "files": usageFiles() + case "set": + usageSet() default: usage() } @@ -219,18 +232,56 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop. usage() } continue - case strings.HasPrefix(line, "/set system ") || line == "/set system": - // Extract the system prompt text after "/set system " - systemPrompt = strings.TrimPrefix(line, "/set system ") - systemPrompt = strings.TrimSpace(systemPrompt) - if systemPrompt == "" { - fmt.Fprintln(os.Stderr, "Cleared system message.") - } else { - fmt.Fprintln(os.Stderr, "Set system message.") - } - continue case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"): return nil + case strings.HasPrefix(line, "/set"): + args := strings.Fields(line) + if len(args) < 2 { + usageSet() + continue + } + switch args[1] { + case "system": + // Extract the system prompt text after "/set system " + systemPrompt = strings.TrimPrefix(line, "/set system ") + systemPrompt = strings.TrimSpace(systemPrompt) + if systemPrompt == "" { + fmt.Fprintln(os.Stderr, "Cleared system message.") + } else { + fmt.Fprintln(os.Stderr, "Set system message.") + } + case "parameter": + if len(args) < 4 { + fmt.Fprintln(os.Stderr, "Usage: /set parameter ") + fmt.Fprintln(os.Stderr, "Available parameters: num_ctx") + continue + } + paramName, paramValue := args[2], args[3] + switch paramName { + case "num_ctx": + if val, err := strconv.ParseInt(paramValue, 10, 32); err == nil && val > 0 { + ctx := int32(val) + if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{ + Model: model, + BackendConfiguration: inference.BackendConfiguration{ + ContextSize: &ctx, + }, + }); err != nil { + fmt.Fprintf(os.Stderr, "Failed to set num_ctx: %v\n", err) + } else { + fmt.Fprintf(os.Stderr, "Set num_ctx to %d\n", val) + } + } else { + fmt.Fprintf(os.Stderr, "Invalid value for num_ctx: %s (must be a positive integer)\n", paramValue) + } + default: + fmt.Fprintf(os.Stderr, "Unknown parameter: %s\n", paramName) + fmt.Fprintln(os.Stderr, "Available parameters: num_ctx") + } + default: + usageSet() + } + continue case strings.HasPrefix(line, "/"): fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0]) continue diff --git a/main.go b/main.go index fea431ca..54f9fb1a 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "os" "os/signal" "path/filepath" + "strconv" "strings" "syscall" "time" @@ -77,6 +78,18 @@ func main() { sglangServerPath := os.Getenv("SGLANG_SERVER_PATH") mlxServerPath := os.Getenv("MLX_SERVER_PATH") + // Parse default context length from environment + var defaultContextLength *int32 + if ctxStr := os.Getenv("DMR_CONTEXT_LENGTH"); ctxStr != "" { + if parsed, err := strconv.ParseInt(ctxStr, 10, 32); err == nil && parsed > 0 { + ctx := int32(parsed) + defaultContextLength = &ctx + log.Infof("DMR_CONTEXT_LENGTH: %d", ctx) + } else { + log.Warnf("Invalid DMR_CONTEXT_LENGTH: %s (must be a positive integer)", ctxStr) + } + } + // Create a proxy-aware HTTP transport // Use a safe type assertion with fallback, and explicitly set Proxy to http.ProxyFromEnvironment var baseTransport *http.Transport @@ -175,6 +188,7 @@ func main() { "", false, ), + defaultContextLength, ) // Create the HTTP handler for the scheduler diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 631b8218..5d77ba32 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -40,6 +40,8 @@ type Scheduler struct { tracker *metrics.Tracker // openAIRecorder is used to record OpenAI API inference requests and responses. openAIRecorder *metrics.OpenAIRecorder + // defaultContextLength is the default context length from environment variable. + defaultContextLength *int32 } // NewScheduler creates a new inference scheduler. @@ -50,19 +52,21 @@ func NewScheduler( modelManager *models.Manager, httpClient *http.Client, tracker *metrics.Tracker, + defaultContextLength *int32, ) *Scheduler { openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager) // Create the scheduler. s := &Scheduler{ - log: log, - backends: backends, - defaultBackend: defaultBackend, - modelManager: modelManager, - installer: newInstaller(log, backends, httpClient), - loader: newLoader(log, backends, modelManager, openAIRecorder), - tracker: tracker, - openAIRecorder: openAIRecorder, + log: log, + backends: backends, + defaultBackend: defaultBackend, + modelManager: modelManager, + installer: newInstaller(log, backends, httpClient), + loader: newLoader(log, backends, modelManager, openAIRecorder), + tracker: tracker, + openAIRecorder: openAIRecorder, + defaultContextLength: defaultContextLength, } // Scheduler successfully initialized. @@ -253,7 +257,12 @@ func (s *Scheduler) ConfigureRunner(ctx context.Context, backend inference.Backe // Build runner configuration with shared settings var runnerConfig inference.BackendConfiguration - runnerConfig.ContextSize = req.ContextSize + // Use request context size if provided, otherwise fall back to default from env var + if req.ContextSize != nil { + runnerConfig.ContextSize = req.ContextSize + } else if s.defaultContextLength != nil { + runnerConfig.ContextSize = s.defaultContextLength + } runnerConfig.Speculative = req.Speculative runnerConfig.RuntimeFlags = runtimeFlags diff --git a/pkg/inference/scheduling/scheduler_test.go b/pkg/inference/scheduling/scheduler_test.go index 7e8e3fd3..0c1d4a71 100644 --- a/pkg/inference/scheduling/scheduler_test.go +++ b/pkg/inference/scheduling/scheduler_test.go @@ -33,7 +33,7 @@ func TestCors(t *testing.T) { discard := logrus.New() discard.SetOutput(io.Discard) log := logrus.NewEntry(discard) - s := NewScheduler(log, nil, nil, nil, nil, nil) + s := NewScheduler(log, nil, nil, nil, nil, nil, nil) httpHandler := NewHTTPHandler(s, nil, []string{"*"}) req := httptest.NewRequest(http.MethodOptions, "http://model-runner.docker.internal"+tt.path, http.NoBody) req.Header.Set("Origin", "docker.com")