diff --git a/commands/root.go b/commands/root.go index 474e1979..522b0e13 100644 --- a/commands/root.go +++ b/commands/root.go @@ -110,6 +110,7 @@ func NewRootCmd(cli *command.DockerCli) *cobra.Command { newUninstallRunner(), newPSCmd(), newDFCmd(), + newUnloadCmd(), ) return rootCmd } diff --git a/commands/unload.go b/commands/unload.go new file mode 100644 index 00000000..20c6f280 --- /dev/null +++ b/commands/unload.go @@ -0,0 +1,66 @@ +package commands + +import ( + "fmt" + + "github.com/docker/model-cli/commands/completion" + "github.com/docker/model-cli/desktop" + "github.com/spf13/cobra" +) + +func newUnloadCmd() *cobra.Command { + var all bool + var backend string + + cmdArgs := "(MODEL [--backend BACKEND] | --all)" + c := &cobra.Command{ + Use: "unload " + cmdArgs, + Short: "Unload running models", + RunE: func(cmd *cobra.Command, args []string) error { + model := args[0] + unloadResp, err := desktopClient.Unload(desktop.UnloadRequest{All: all, Backend: backend, Model: model}) + if err != nil { + err = handleClientError(err, "Failed to unload models") + return handleNotRunningError(err) + } + unloaded := unloadResp.UnloadedRunners + if unloaded == 0 { + if all { + cmd.Println("No models are running.") + } else { + cmd.Println("No such model(s) running.") + } + } else { + cmd.Printf("Unloaded %d model(s).\n", unloaded) + } + return nil + }, + ValidArgsFunction: completion.NoComplete, + } + c.Args = func(cmd *cobra.Command, args []string) error { + if all { + if len(args) > 0 { + return fmt.Errorf( + "'docker model unload' does not take MODEL when --all is specified.\n\n" + + "Usage: docker model unload " + cmdArgs + "\n\n" + + "See 'docker model unload --help' for more information.", + ) + } + return nil + } + if len(args) < 1 { + return fmt.Errorf( + "'docker model unload' requires MODEL unless --all is specified.\n\n" + + "Usage: docker model unload " + cmdArgs + "\n\n" + + "See 'docker model unload --help' for more information.", + ) + } + if len(args) > 1 { + return fmt.Errorf("too many arguments, expected " + cmdArgs) + } + return nil + } + c.Flags().BoolVar(&all, "all", false, "Unload all running models") + c.Flags().StringVar(&backend, "backend", "", "Optional backend to target") + return c +} diff --git a/desktop/desktop.go b/desktop/desktop.go index 23f3fb74..d2a0dbb0 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -499,6 +499,49 @@ func (c *Client) DF() (DiskUsage, error) { return df, nil } +// UnloadRequest to be imported from docker/model-runner when https://github.com/docker/model-runner/pull/46 is merged. +type UnloadRequest struct { + All bool `json:"all"` + Backend string `json:"backend"` + Model string `json:"model"` +} + +// UnloadResponse to be imported from docker/model-runner when https://github.com/docker/model-runner/pull/46 is merged. +type UnloadResponse struct { + UnloadedRunners int `json:"unloaded_runners"` +} + +func (c *Client) Unload(req UnloadRequest) (UnloadResponse, error) { + unloadPath := inference.InferencePrefix + "/unload" + jsonData, err := json.Marshal(req) + if err != nil { + return UnloadResponse{}, fmt.Errorf("error marshaling request: %w", err) + } + + resp, err := c.doRequest(http.MethodPost, unloadPath, bytes.NewReader(jsonData)) + if err != nil { + return UnloadResponse{}, c.handleQueryError(err, unloadPath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return UnloadResponse{}, fmt.Errorf("unloading failed with status %s: %s", resp.Status, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return UnloadResponse{}, fmt.Errorf("failed to read response body: %w", err) + } + + var unloadResp UnloadResponse + if err := json.Unmarshal(body, &unloadResp); err != nil { + return UnloadResponse{}, fmt.Errorf("failed to unmarshal response body: %w", err) + } + + return unloadResp, nil +} + // doRequest is a helper function that performs HTTP requests and handles 503 responses func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response, error) { req, err := http.NewRequest(method, c.modelRunner.URL(path), body)