Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions cmd/cli/commands/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ import (
"io"
"os"
"os/signal"
"strconv"
"strings"
"syscall"

"github.com/charmbracelet/glamour"
"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/cmd/cli/readline"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/scheduling"
"github.com/fatih/color"
"github.com/muesli/termenv"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -90,11 +93,12 @@ func readMultilineInput(cmd *cobra.Command, scanner *bufio.Scanner) (string, err
func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.Client, model string) error {
usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set system Set or update the system message")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /set Set a session variable")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
fmt.Fprintln(os.Stderr, " /? files Help for file inclusion with @ symbol")
fmt.Fprintln(os.Stderr, " /? set Help for /set command")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, `Use """ to begin a multi-line message.`)
fmt.Fprintln(os.Stderr, "")
Expand Down Expand Up @@ -134,6 +138,13 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.
fmt.Fprintln(os.Stderr, "")
}

usageSet := func() {
fmt.Fprintln(os.Stderr, "Available /set commands:")
fmt.Fprintln(os.Stderr, " /set system <message> Set system message for the conversation")
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <n> Set context window size (in tokens)")
fmt.Fprintln(os.Stderr, "")
}

scanner, err := readline.New(readline.Prompt{
Prompt: "> ",
AltPrompt: ". ",
Expand Down Expand Up @@ -212,25 +223,65 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.
usageShortcuts()
case "file", "files":
usageFiles()
case "set":
usageSet()
default:
usage()
}
} else {
usage()
}
continue
case strings.HasPrefix(line, "/set system ") || line == "/set system":
// Extract the system prompt text after "/set system "
systemPrompt = strings.TrimPrefix(line, "/set system ")
systemPrompt = strings.TrimSpace(systemPrompt)
if systemPrompt == "" {
fmt.Fprintln(os.Stderr, "Cleared system message.")
} else {
fmt.Fprintln(os.Stderr, "Set system message.")
}
continue
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/set"):
args := strings.Fields(line)
if len(args) < 2 {
usageSet()
continue
}
switch args[1] {
case "system":
// Extract the system prompt text after "/set system "
systemPrompt = strings.TrimPrefix(line, "/set system ")
systemPrompt = strings.TrimSpace(systemPrompt)
if systemPrompt == "" {
fmt.Fprintln(os.Stderr, "Cleared system message.")
} else {
fmt.Fprintln(os.Stderr, "Set system message.")
}
case "parameter":
if len(args) < 4 {
fmt.Fprintln(os.Stderr, "Usage: /set parameter <name> <value>")
fmt.Fprintln(os.Stderr, "Available parameters: num_ctx")
continue
}
paramName, paramValue := args[2], args[3]
switch paramName {
case "num_ctx":
if val, err := strconv.ParseInt(paramValue, 10, 32); err == nil && val > 0 {
ctx := int32(val)
if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
Model: model,
BackendConfiguration: inference.BackendConfiguration{
ContextSize: &ctx,
},
}); err != nil {
fmt.Fprintf(os.Stderr, "Failed to set num_ctx: %v\n", err)
} else {
fmt.Fprintf(os.Stderr, "Set num_ctx to %d\n", val)
}
} else {
fmt.Fprintf(os.Stderr, "Invalid value for num_ctx: %s (must be a positive integer)\n", paramValue)
}
default:
fmt.Fprintf(os.Stderr, "Unknown parameter: %s\n", paramName)
fmt.Fprintln(os.Stderr, "Available parameters: num_ctx")
}
default:
usageSet()
}
continue
case strings.HasPrefix(line, "/"):
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
continue
Expand Down
14 changes: 14 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"os/signal"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
Expand Down Expand Up @@ -77,6 +78,18 @@ func main() {
sglangServerPath := os.Getenv("SGLANG_SERVER_PATH")
mlxServerPath := os.Getenv("MLX_SERVER_PATH")

// Parse default context length from environment
var defaultContextLength *int32
if ctxStr := os.Getenv("DMR_CONTEXT_LENGTH"); ctxStr != "" {
if parsed, err := strconv.ParseInt(ctxStr, 10, 32); err == nil && parsed > 0 {
ctx := int32(parsed)
defaultContextLength = &ctx
log.Infof("DMR_CONTEXT_LENGTH: %d", ctx)
} else {
log.Warnf("Invalid DMR_CONTEXT_LENGTH: %s (must be a positive integer)", ctxStr)
}
}

// Create a proxy-aware HTTP transport
// Use a safe type assertion with fallback, and explicitly set Proxy to http.ProxyFromEnvironment
var baseTransport *http.Transport
Expand Down Expand Up @@ -175,6 +188,7 @@ func main() {
"",
false,
),
defaultContextLength,
)

// Create the HTTP handler for the scheduler
Expand Down
27 changes: 18 additions & 9 deletions pkg/inference/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type Scheduler struct {
tracker *metrics.Tracker
// openAIRecorder is used to record OpenAI API inference requests and responses.
openAIRecorder *metrics.OpenAIRecorder
// defaultContextLength is the default context length from environment variable.
defaultContextLength *int32
}

// NewScheduler creates a new inference scheduler.
Expand All @@ -50,19 +52,21 @@ func NewScheduler(
modelManager *models.Manager,
httpClient *http.Client,
tracker *metrics.Tracker,
defaultContextLength *int32,
) *Scheduler {
openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager)

// Create the scheduler.
s := &Scheduler{
log: log,
backends: backends,
defaultBackend: defaultBackend,
modelManager: modelManager,
installer: newInstaller(log, backends, httpClient),
loader: newLoader(log, backends, modelManager, openAIRecorder),
tracker: tracker,
openAIRecorder: openAIRecorder,
log: log,
backends: backends,
defaultBackend: defaultBackend,
modelManager: modelManager,
installer: newInstaller(log, backends, httpClient),
loader: newLoader(log, backends, modelManager, openAIRecorder),
tracker: tracker,
openAIRecorder: openAIRecorder,
defaultContextLength: defaultContextLength,
}

// Scheduler successfully initialized.
Expand Down Expand Up @@ -253,7 +257,12 @@ func (s *Scheduler) ConfigureRunner(ctx context.Context, backend inference.Backe

// Build runner configuration with shared settings
var runnerConfig inference.BackendConfiguration
runnerConfig.ContextSize = req.ContextSize
// Use request context size if provided, otherwise fall back to default from env var
if req.ContextSize != nil {
runnerConfig.ContextSize = req.ContextSize
} else if s.defaultContextLength != nil {
runnerConfig.ContextSize = s.defaultContextLength
}
runnerConfig.Speculative = req.Speculative
runnerConfig.RuntimeFlags = runtimeFlags

Expand Down
2 changes: 1 addition & 1 deletion pkg/inference/scheduling/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestCors(t *testing.T) {
discard := logrus.New()
discard.SetOutput(io.Discard)
log := logrus.NewEntry(discard)
s := NewScheduler(log, nil, nil, nil, nil, nil)
s := NewScheduler(log, nil, nil, nil, nil, nil, nil)
httpHandler := NewHTTPHandler(s, nil, []string{"*"})
req := httptest.NewRequest(http.MethodOptions, "http://model-runner.docker.internal"+tt.path, http.NoBody)
req.Header.Set("Origin", "docker.com")
Expand Down
Loading