From 630e8c750ff9534aa7bd9e72d7113c8970c70dd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 14 Jan 2026 13:41:57 +0100 Subject: [PATCH 01/20] feat(diffusers): implement diffusers backend for image generation --- Dockerfile | 50 +++ diffusers_backend.go | 26 ++ diffusers_backend_stub.go | 17 + main.go | 7 + pkg/inference/backend.go | 7 + pkg/inference/backends/diffusers/diffusers.go | 220 +++++++++++++ .../backends/diffusers/diffusers_config.go | 48 +++ pkg/inference/platform/platform.go | 6 + pkg/inference/scheduling/api.go | 3 + pkg/inference/scheduling/http_handler.go | 3 + python/diffusers_server/__init__.py | 2 + python/diffusers_server/server.py | 310 ++++++++++++++++++ 12 files changed, 699 insertions(+) create mode 100644 diffusers_backend.go create mode 100644 diffusers_backend_stub.go create mode 100644 pkg/inference/backends/diffusers/diffusers.go create mode 100644 pkg/inference/backends/diffusers/diffusers_config.go create mode 100644 python/diffusers_server/__init__.py create mode 100644 python/diffusers_server/server.py diff --git a/Dockerfile b/Dockerfile index b4a8160c2..ffbd97d62 100644 --- a/Dockerfile +++ b/Dockerfile @@ -144,6 +144,52 @@ RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ && ~/.local/bin/uv pip install --python /opt/sglang-env/bin/python "sglang==${SGLANG_VERSION}" RUN /opt/sglang-env/bin/python -c "import sglang; print(sglang.__version__)" > /opt/sglang-env/version + +# --- Diffusers variant --- +FROM llamacpp AS diffusers + +ARG DIFFUSERS_VERSION=0.36.0 +ARG TORCH_VERSION=2.9.1 + +USER root + +RUN apt update && apt install -y \ + python3 python3-venv python3-dev \ + curl ca-certificates build-essential \ + && rm -rf /var/lib/apt/lists/* + +RUN mkdir -p /opt/diffusers-env && chown -R modelrunner:modelrunner /opt/diffusers-env + +USER modelrunner + +# Install uv and diffusers as modelrunner user +RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ + && ~/.local/bin/uv venv --python /usr/bin/python3 /opt/diffusers-env \ + && ~/.local/bin/uv pip install --python /opt/diffusers-env/bin/python \ + "diffusers==${DIFFUSERS_VERSION}" \ + "torch==${TORCH_VERSION}" \ + "transformers" \ + "accelerate" \ + "safetensors" \ + "fastapi" \ + "uvicorn[standard]" \ + "pillow" + +# Determine the Python site-packages directory dynamically and copy the Python server code +RUN PYTHON_SITE_PACKAGES=$(/opt/diffusers-env/bin/python -c "import site; print(site.getsitepackages()[0])") && \ + mkdir -p "$PYTHON_SITE_PACKAGES/diffusers_server" + +# Copy Python server code (needs to be done as root then chown, or use a multi-stage approach) +USER root +COPY python/diffusers_server /tmp/diffusers_server/ +RUN PYTHON_SITE_PACKAGES=$(/opt/diffusers-env/bin/python -c "import site; print(site.getsitepackages()[0])") && \ + cp -r /tmp/diffusers_server/* "$PYTHON_SITE_PACKAGES/diffusers_server/" && \ + chown -R modelrunner:modelrunner "$PYTHON_SITE_PACKAGES/diffusers_server/" && \ + rm -rf /tmp/diffusers_server +USER modelrunner + +RUN /opt/diffusers-env/bin/python -c "import diffusers; print(diffusers.__version__)" > /opt/diffusers-env/version + FROM llamacpp AS final-llamacpp # Copy the built binary from builder COPY --from=builder /app/model-runner /app/model-runner @@ -155,3 +201,7 @@ COPY --from=builder /app/model-runner /app/model-runner FROM sglang AS final-sglang # Copy the built binary from builder-sglang (without vLLM) COPY --from=builder-sglang /app/model-runner /app/model-runner + +FROM diffusers AS final-diffusers +# Copy the built binary from builder (with diffusers support) +COPY --from=builder /app/model-runner /app/model-runner diff --git a/diffusers_backend.go b/diffusers_backend.go new file mode 100644 index 000000000..eb5161ff4 --- /dev/null +++ b/diffusers_backend.go @@ -0,0 +1,26 @@ +//go:build !nodiffusers + +package main + +import ( + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/backends/diffusers" + "github.com/docker/model-runner/pkg/inference/models" + "github.com/sirupsen/logrus" +) + +func initDiffusersBackend(log *logrus.Logger, modelManager *models.Manager, customPythonPath string) (inference.Backend, error) { + return diffusers.New( + log, + modelManager, + log.WithFields(logrus.Fields{"component": diffusers.Name}), + nil, + customPythonPath, + ) +} + +func registerDiffusersBackend(backends map[string]inference.Backend, backend inference.Backend) { + if backend != nil { + backends[diffusers.Name] = backend + } +} diff --git a/diffusers_backend_stub.go b/diffusers_backend_stub.go new file mode 100644 index 000000000..62b25fa0d --- /dev/null +++ b/diffusers_backend_stub.go @@ -0,0 +1,17 @@ +//go:build nodiffusers + +package main + +import ( + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/models" + "github.com/sirupsen/logrus" +) + +func initDiffusersBackend(log *logrus.Logger, modelManager *models.Manager, customPythonPath string) (inference.Backend, error) { + return nil, nil // Diffusers backend is disabled +} + +func registerDiffusersBackend(backends map[string]inference.Backend, backend inference.Backend) { + // Diffusers backend is disabled, do nothing +} diff --git a/main.go b/main.go index fea431ca5..31e91cd91 100644 --- a/main.go +++ b/main.go @@ -76,6 +76,7 @@ func main() { vllmServerPath := os.Getenv("VLLM_SERVER_PATH") sglangServerPath := os.Getenv("SGLANG_SERVER_PATH") mlxServerPath := os.Getenv("MLX_SERVER_PATH") + diffusersServerPath := os.Getenv("DIFFUSERS_SERVER_PATH") // Create a proxy-aware HTTP transport // Use a safe type assertion with fallback, and explicitly set Proxy to http.ProxyFromEnvironment @@ -156,12 +157,18 @@ func main() { log.Fatalf("unable to initialize %s backend: %v", sglang.Name, err) } + diffusersBackend, err := initDiffusersBackend(log, modelManager, diffusersServerPath) + if err != nil { + log.Fatalf("unable to initialize diffusers backend: %v", err) + } + backends := map[string]inference.Backend{ llamacpp.Name: llamaCppBackend, mlx.Name: mlxBackend, sglang.Name: sglangBackend, } registerVLLMBackend(backends, vllmBackend) + registerDiffusersBackend(backends, diffusersBackend) scheduler := scheduling.NewScheduler( log, diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index c4c9dfb88..e83ff5f9d 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -18,6 +18,9 @@ const ( // mode. BackendModeEmbedding BackendModeReranking + // BackendModeImageGeneration indicates that the backend should run in + // image generation mode. + BackendModeImageGeneration ) type ErrGGUFParse struct { @@ -37,6 +40,8 @@ func (m BackendMode) String() string { return "embedding" case BackendModeReranking: return "reranking" + case BackendModeImageGeneration: + return "image-generation" default: return "unknown" } @@ -72,6 +77,8 @@ func ParseBackendMode(mode string) (BackendMode, bool) { return BackendModeEmbedding, true case "reranking": return BackendModeReranking, true + case "image-generation": + return BackendModeImageGeneration, true default: return BackendModeCompletion, false } diff --git a/pkg/inference/backends/diffusers/diffusers.go b/pkg/inference/backends/diffusers/diffusers.go new file mode 100644 index 000000000..c6387e41f --- /dev/null +++ b/pkg/inference/backends/diffusers/diffusers.go @@ -0,0 +1,220 @@ +package diffusers + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/docker/model-runner/pkg/diskusage" + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/backends" + "github.com/docker/model-runner/pkg/inference/models" + "github.com/docker/model-runner/pkg/inference/platform" + "github.com/docker/model-runner/pkg/logging" +) + +const ( + // Name is the backend name. + Name = "diffusers" + diffusersDir = "/opt/diffusers-env" +) + +var ( + ErrNotImplemented = errors.New("not implemented") + ErrDiffusersNotFound = errors.New("diffusers package not installed") + ErrPythonNotFound = errors.New("python3 not found in PATH") +) + +// diffusers is the diffusers-based backend implementation for image generation. +type diffusers struct { + // log is the associated logger. + log logging.Logger + // modelManager is the shared model manager. + modelManager *models.Manager + // serverLog is the logger to use for the diffusers server process. + serverLog logging.Logger + // config is the configuration for the diffusers backend. + config *Config + // status is the state in which the diffusers backend is in. + status string + // pythonPath is the path to the python3 binary. + pythonPath string + // customPythonPath is an optional custom path to the python3 binary. + customPythonPath string +} + +// New creates a new diffusers-based backend for image generation. +// customPythonPath is an optional path to a custom python3 binary; if empty, the default path is used. +func New(log logging.Logger, modelManager *models.Manager, serverLog logging.Logger, conf *Config, customPythonPath string) (inference.Backend, error) { + // If no config is provided, use the default configuration + if conf == nil { + conf = NewDefaultConfig() + } + + return &diffusers{ + log: log, + modelManager: modelManager, + serverLog: serverLog, + config: conf, + status: "not installed", + customPythonPath: customPythonPath, + }, nil +} + +// Name implements inference.Backend.Name. +func (d *diffusers) Name() string { + return Name +} + +// UsesExternalModelManagement implements inference.Backend.UsesExternalModelManagement. +// Diffusers uses the shared model manager but also supports loading models directly from HuggingFace. +func (d *diffusers) UsesExternalModelManagement() bool { + return true // For now, we'll use external model management (HuggingFace downloads) +} + +// UsesTCP implements inference.Backend.UsesTCP. +// Diffusers uses TCP for communication, like SGLang. +func (d *diffusers) UsesTCP() bool { + return true +} + +// Install implements inference.Backend.Install. +func (d *diffusers) Install(_ context.Context, _ *http.Client) error { + if !platform.SupportsDiffusers() { + return ErrNotImplemented + } + + var pythonPath string + + // Use custom python path if specified + if d.customPythonPath != "" { + pythonPath = d.customPythonPath + } else { + venvPython := filepath.Join(diffusersDir, "bin", "python3") + pythonPath = venvPython + + if _, err := os.Stat(venvPython); err != nil { + // Fall back to system Python + systemPython, err := exec.LookPath("python3") + if err != nil { + d.status = ErrPythonNotFound.Error() + return ErrPythonNotFound + } + pythonPath = systemPython + } + } + + d.pythonPath = pythonPath + + // Check if diffusers is installed + if err := d.pythonCmd("-c", "import diffusers").Run(); err != nil { + d.status = "diffusers package not installed" + d.log.Warnf("diffusers package not found. Install with: uv pip install diffusers torch") + return ErrDiffusersNotFound + } + + // Get version + output, err := d.pythonCmd("-c", "import diffusers; print(diffusers.__version__)").Output() + if err != nil { + d.log.Warnf("could not get diffusers version: %v", err) + d.status = "running diffusers version: unknown" + } else { + d.status = fmt.Sprintf("running diffusers version: %s", strings.TrimSpace(string(output))) + } + + return nil +} + +// Run implements inference.Backend.Run. +func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef string, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error { + if !platform.SupportsDiffusers() { + d.log.Warn("diffusers backend is not yet supported on this platform") + return ErrNotImplemented + } + + // For diffusers, we support image generation mode + if mode != inference.BackendModeImageGeneration { + return fmt.Errorf("diffusers backend only supports image-generation mode, got %s", mode) + } + + args, err := d.config.GetArgs(model, socket, mode, backendConfig) + if err != nil { + return fmt.Errorf("failed to get diffusers arguments: %w", err) + } + + // Add served model name + if model != "" { + // Replace colons with underscores to sanitize the model name + sanitizedModel := strings.ReplaceAll(model, ":", "_") + args = append(args, "--served-model-name", sanitizedModel) + } + + if d.pythonPath == "" { + return fmt.Errorf("diffusers: python runtime not configured; did you forget to call Install?") + } + + sandboxPath := "" + if _, err := os.Stat(diffusersDir); err == nil { + sandboxPath = diffusersDir + } + + return backends.RunBackend(ctx, backends.RunnerConfig{ + BackendName: "Diffusers", + Socket: socket, + BinaryPath: d.pythonPath, + SandboxPath: sandboxPath, + SandboxConfig: "", + Args: args, + Logger: d.log, + ServerLogWriter: d.serverLog.Writer(), + }) +} + +// Status implements inference.Backend.Status. +func (d *diffusers) Status() string { + return d.status +} + +// GetDiskUsage implements inference.Backend.GetDiskUsage. +func (d *diffusers) GetDiskUsage() (int64, error) { + // Check if Docker installation exists + if _, err := os.Stat(diffusersDir); err == nil { + size, err := diskusage.Size(diffusersDir) + if err != nil { + return 0, fmt.Errorf("error while getting diffusers dir size: %w", err) + } + return size, nil + } + // Python installation doesn't have a dedicated installation directory + // It's installed via pip in the system Python environment + return 0, nil +} + +// GetRequiredMemoryForModel returns the estimated memory requirements for a model. +func (d *diffusers) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (inference.RequiredMemory, error) { + if !platform.SupportsDiffusers() { + return inference.RequiredMemory{}, ErrNotImplemented + } + + // Stable Diffusion models typically require significant VRAM + // SD 1.5: ~4GB VRAM, SD 2.1: ~5GB VRAM, SDXL: ~8GB VRAM + return inference.RequiredMemory{ + RAM: 4 * 1024 * 1024 * 1024, // 4GB RAM + VRAM: 6 * 1024 * 1024 * 1024, // 6GB VRAM (average estimate) + }, nil +} + +// pythonCmd creates an exec.Cmd that runs python with the given arguments. +// It uses the configured pythonPath if available, otherwise falls back to "python3". +func (d *diffusers) pythonCmd(args ...string) *exec.Cmd { + pythonBinary := "python3" + if d.pythonPath != "" { + pythonBinary = d.pythonPath + } + return exec.Command(pythonBinary, args...) +} diff --git a/pkg/inference/backends/diffusers/diffusers_config.go b/pkg/inference/backends/diffusers/diffusers_config.go new file mode 100644 index 000000000..0fc60b16f --- /dev/null +++ b/pkg/inference/backends/diffusers/diffusers_config.go @@ -0,0 +1,48 @@ +package diffusers + +import ( + "fmt" + "net" + + "github.com/docker/model-runner/pkg/inference" +) + +// Config is the configuration for the diffusers backend. +type Config struct { + // Args are the base arguments that are always included. + Args []string +} + +// NewDefaultConfig creates a new Config with default values. +func NewDefaultConfig() *Config { + return &Config{} +} + +// GetArgs implements BackendConfig.GetArgs for the diffusers backend. +func (c *Config) GetArgs(model string, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) { + // Start with the arguments from Config + args := append([]string{}, c.Args...) + + // Diffusers uses Python module: python -m diffusers_server.server + args = append(args, "-m", "diffusers_server.server") + + // Add model path - for diffusers this can be a HuggingFace model ID or local path + args = append(args, "--model-path", model) + + // Parse host:port from socket + host, port, err := net.SplitHostPort(socket) + if err != nil { + return nil, fmt.Errorf("failed to parse host:port from %q: %w", socket, err) + } + args = append(args, "--host", host, "--port", port) + + // Add mode-specific arguments + switch mode { + case inference.BackendModeImageGeneration: + // Default mode for diffusers - image generation + default: + return nil, fmt.Errorf("unsupported backend mode %q for diffusers", mode) + } + + return args, nil +} diff --git a/pkg/inference/platform/platform.go b/pkg/inference/platform/platform.go index 49bffb75e..3b0b6db26 100644 --- a/pkg/inference/platform/platform.go +++ b/pkg/inference/platform/platform.go @@ -17,3 +17,9 @@ func SupportsMLX() bool { func SupportsSGLang() bool { return runtime.GOOS == "linux" } + +// SupportsDiffusers returns true if diffusers is supported on the current platform. +// Diffusers is supported on Linux (for Docker/CUDA) and macOS (for MPS/Apple Silicon). +func SupportsDiffusers() bool { + return runtime.GOOS == "linux" || runtime.GOOS == "darwin" +} diff --git a/pkg/inference/scheduling/api.go b/pkg/inference/scheduling/api.go index f9460dc21..7cd444a4b 100644 --- a/pkg/inference/scheduling/api.go +++ b/pkg/inference/scheduling/api.go @@ -41,6 +41,9 @@ func backendModeForRequest(path string) (inference.BackendMode, bool) { } else if strings.HasSuffix(path, "/v1/messages") || strings.HasSuffix(path, "/v1/messages/count_tokens") { // Anthropic Messages API - treated as completion mode return inference.BackendModeCompletion, true + } else if strings.HasSuffix(path, "/v1/images/generations") { + // OpenAI Images API - image generation mode + return inference.BackendModeImageGeneration, true } return inference.BackendMode(0), false } diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index e36b6b4ad..4f1378793 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -66,6 +66,9 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc { "POST " + inference.InferencePrefix + "/rerank", "POST " + inference.InferencePrefix + "/{backend}/score", "POST " + inference.InferencePrefix + "/score", + // Image generation routes + "POST " + inference.InferencePrefix + "/{backend}/v1/images/generations", + "POST " + inference.InferencePrefix + "/v1/images/generations", } // Anthropic Messages API routes diff --git a/python/diffusers_server/__init__.py b/python/diffusers_server/__init__.py new file mode 100644 index 000000000..50af3f434 --- /dev/null +++ b/python/diffusers_server/__init__.py @@ -0,0 +1,2 @@ +# Diffusers Server for Docker Model Runner +# Provides OpenAI Images API compatible endpoint for Stable Diffusion models diff --git a/python/diffusers_server/server.py b/python/diffusers_server/server.py new file mode 100644 index 000000000..b2da9b2ae --- /dev/null +++ b/python/diffusers_server/server.py @@ -0,0 +1,310 @@ +""" +Diffusers Server for Docker Model Runner + +A FastAPI-based server that provides OpenAI Images API compatible endpoints +for Stable Diffusion and other diffusion models using the Hugging Face diffusers library. +""" + +import argparse +import base64 +import io +import logging +import time +from typing import Optional, List, Literal + +import torch +from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoPipelineForText2Image +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field +import uvicorn + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +app = FastAPI(title="Diffusers Server", description="OpenAI Images API compatible server for diffusion models") + +# Global pipeline instance +pipeline: Optional[DiffusionPipeline] = None +current_model: Optional[str] = None +served_model_name: Optional[str] = None + + +class ImageGenerationRequest(BaseModel): + """Request model for image generation (OpenAI Images API compatible)""" + model: str = Field(..., description="The model to use for image generation") + prompt: str = Field(..., description="A text description of the desired image(s)") + n: int = Field(default=1, ge=1, le=10, description="The number of images to generate") + size: str = Field(default="512x512", description="The size of the generated images") + response_format: Literal["url", "b64_json"] = Field(default="b64_json", description="The format of the generated images") + quality: Optional[str] = Field(default="standard", description="The quality of the image") + style: Optional[str] = Field(default=None, description="The style of the generated images") + negative_prompt: Optional[str] = Field(default=None, description="Text to avoid in generation") + num_inference_steps: int = Field(default=50, ge=1, le=150, description="Number of denoising steps") + guidance_scale: float = Field(default=7.5, ge=1.0, le=20.0, description="Guidance scale for generation") + seed: Optional[int] = Field(default=None, description="Random seed for reproducibility") + + +class ImageData(BaseModel): + """Single image in the response""" + b64_json: Optional[str] = None + url: Optional[str] = None + revised_prompt: Optional[str] = None + + +class ImageGenerationResponse(BaseModel): + """Response model for image generation (OpenAI Images API compatible)""" + created: int + data: List[ImageData] + + +def parse_size(size: str) -> tuple[int, int]: + """Parse size string like '512x512' into (width, height) tuple""" + try: + parts = size.lower().split('x') + if len(parts) != 2: + raise ValueError(f"Invalid size format: {size}") + width = int(parts[0]) + height = int(parts[1]) + return width, height + except (ValueError, IndexError) as e: + raise ValueError(f"Invalid size format '{size}'. Expected format like '512x512': {e}") + + +def load_model(model_path: str) -> DiffusionPipeline: + """Load a diffusion model from the given path or HuggingFace model ID""" + global pipeline, current_model + + if pipeline is not None and current_model == model_path: + logger.info(f"Model {model_path} already loaded") + return pipeline + + logger.info(f"Loading model: {model_path}") + + # Determine device + if torch.cuda.is_available(): + device = "cuda" + dtype = torch.float16 + logger.info("Using CUDA device with float16") + elif torch.backends.mps.is_available(): + device = "mps" + dtype = torch.float16 + logger.info("Using MPS device (Apple Silicon) with float16") + else: + device = "cpu" + dtype = torch.float32 + logger.info("Using CPU device with float32") + + try: + # Try to load using AutoPipelineForText2Image which handles most model types + pipeline = AutoPipelineForText2Image.from_pretrained( + model_path, + torch_dtype=dtype, + safety_checker=None, # Disable safety checker for performance + requires_safety_checker=False, + ) + except Exception as e: + logger.warning(f"AutoPipelineForText2Image failed: {e}, trying StableDiffusionPipeline") + try: + pipeline = StableDiffusionPipeline.from_pretrained( + model_path, + torch_dtype=dtype, + safety_checker=None, + requires_safety_checker=False, + ) + except Exception as e2: + logger.warning(f"StableDiffusionPipeline failed: {e2}, trying generic DiffusionPipeline") + pipeline = DiffusionPipeline.from_pretrained( + model_path, + torch_dtype=dtype, + ) + + pipeline = pipeline.to(device) + + # Enable memory efficient attention if available + if hasattr(pipeline, 'enable_attention_slicing'): + pipeline.enable_attention_slicing() + + current_model = model_path + logger.info(f"Model loaded successfully on {device}") + return pipeline + + +def generate_images( + prompt: str, + n: int = 1, + width: int = 512, + height: int = 512, + negative_prompt: Optional[str] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + seed: Optional[int] = None, +) -> List[bytes]: + """Generate images using the loaded pipeline""" + global pipeline + + if pipeline is None: + raise RuntimeError("No model loaded") + + # Set seed for reproducibility + generator = None + if seed is not None: + if torch.cuda.is_available(): + generator = torch.Generator(device="cuda").manual_seed(seed) + elif torch.backends.mps.is_available(): + generator = torch.Generator(device="mps").manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + logger.info(f"Generating {n} image(s) with prompt: {prompt[:100]}...") + + # Generate images + images = [] + for i in range(n): + # If we have a seed, increment it for each image to get different but reproducible results + current_generator = None + if generator is not None and seed is not None: + if torch.cuda.is_available(): + current_generator = torch.Generator(device="cuda").manual_seed(seed + i) + elif torch.backends.mps.is_available(): + current_generator = torch.Generator(device="mps").manual_seed(seed + i) + else: + current_generator = torch.Generator().manual_seed(seed + i) + + result = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=current_generator, + ) + + image = result.images[0] + + # Convert to PNG bytes + buffer = io.BytesIO() + image.save(buffer, format="PNG") + images.append(buffer.getvalue()) + + logger.info(f"Generated {len(images)} image(s)") + return images + + +@app.get("/health") +async def health(): + """Health check endpoint""" + return {"status": "healthy", "model_loaded": current_model is not None} + + +@app.get("/v1/models") +async def list_models(): + """List available models (OpenAI API compatible)""" + models = [] + if served_model_name: + models.append({ + "id": served_model_name, + "object": "model", + "created": int(time.time()), + "owned_by": "diffusers", + }) + if current_model and current_model != served_model_name: + models.append({ + "id": current_model, + "object": "model", + "created": int(time.time()), + "owned_by": "diffusers", + }) + return {"object": "list", "data": models} + + +@app.post("/v1/images/generations", response_model=ImageGenerationResponse) +async def create_image(request: ImageGenerationRequest): + """Generate images from a prompt (OpenAI Images API compatible)""" + global pipeline + + # Check if the requested model matches + requested_model = request.model + if served_model_name and requested_model != served_model_name and requested_model != current_model: + raise HTTPException( + status_code=421, + detail=f"Model '{requested_model}' not loaded. Current model: {served_model_name or current_model}" + ) + + if pipeline is None: + raise HTTPException(status_code=503, detail="No model loaded. Server is not ready.") + + try: + # Parse size + width, height = parse_size(request.size) + + # Generate images + image_bytes_list = generate_images( + prompt=request.prompt, + n=request.n, + width=width, + height=height, + negative_prompt=request.negative_prompt, + num_inference_steps=request.num_inference_steps, + guidance_scale=request.guidance_scale, + seed=request.seed, + ) + + # Format response + data = [] + for img_bytes in image_bytes_list: + if request.response_format == "b64_json": + b64_str = base64.b64encode(img_bytes).decode("utf-8") + data.append(ImageData(b64_json=b64_str)) + else: + # URL format not supported in this implementation + raise HTTPException( + status_code=400, + detail="URL response format is not supported. Use 'b64_json' instead." + ) + + return ImageGenerationResponse( + created=int(time.time()), + data=data + ) + + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.exception("Error generating image") + raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}") + + +@app.on_event("startup") +async def startup_event(): + """Startup event handler""" + logger.info("Diffusers server starting up...") + if current_model: + logger.info(f"Model path: {current_model}") + + +def main(): + """Main entry point for the diffusers server""" + parser = argparse.ArgumentParser(description="Diffusers Server - OpenAI Images API compatible server") + parser.add_argument("--model-path", type=str, required=True, help="Path to the diffusion model or HuggingFace model ID") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") + parser.add_argument("--port", type=int, default=8000, help="Port to bind to") + parser.add_argument("--served-model-name", type=str, default=None, help="Name to serve the model as") + + args = parser.parse_args() + + global served_model_name + served_model_name = args.served_model_name or args.model_path + + # Load the model at startup + load_model(args.model_path) + + # Start the server + logger.info(f"Starting server on {args.host}:{args.port}") + uvicorn.run(app, host=args.host, port=args.port, log_level="info") + + +if __name__ == "__main__": + main() From 199337214db0093c7e68557bd9a04be99d54141e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 14 Jan 2026 21:08:56 +0100 Subject: [PATCH 02/20] feat(diffusers): add support for DDUF (Diffusers Unified Format) file handling --- cmd/cli/commands/package.go | 40 ++++++++-- pkg/distribution/files/classify.go | 9 +++ pkg/distribution/format/format.go | 2 + pkg/distribution/internal/bundle/bundle.go | 9 +++ pkg/distribution/internal/bundle/unpack.go | 32 ++++++++ pkg/distribution/internal/partial/partial.go | 4 + pkg/distribution/internal/store/model.go | 4 + pkg/distribution/modelpack/types.go | 2 + pkg/distribution/types/config.go | 5 ++ pkg/distribution/types/model.go | 2 + pkg/inference/backends/diffusers/diffusers.go | 77 +++++++++++++++++-- .../backends/llamacpp/llamacpp_config_test.go | 4 + .../backends/sglang/sglang_config_test.go | 4 + .../backends/vllm/vllm_config_test.go | 4 + python/diffusers_server/server.py | 59 +++++++++++++- 15 files changed, 242 insertions(+), 15 deletions(-) diff --git a/cmd/cli/commands/package.go b/cmd/cli/commands/package.go index 1a8d23ad7..e34598b06 100644 --- a/cmd/cli/commands/package.go +++ b/cmd/cli/commands/package.go @@ -38,11 +38,12 @@ func newPackagedCmd() *cobra.Command { var opts packageOptions c := &cobra.Command{ - Use: "package (--gguf | --safetensors-dir | --from ) [--license ...] [--mmproj ] [--context-size ] [--push] MODEL", - Short: "Package a GGUF file, Safetensors directory, or existing model into a Docker model OCI artifact.", - Long: "Package a GGUF file, Safetensors directory, or existing model into a Docker model OCI artifact, with optional licenses and multimodal projector. The package is sent to the model-runner, unless --push is specified.\n" + + Use: "package (--gguf | --safetensors-dir | --dduf | --from ) [--license ...] [--mmproj ] [--context-size ] [--push] MODEL", + Short: "Package a GGUF file, Safetensors directory, DDUF file, or existing model into a Docker model OCI artifact.", + Long: "Package a GGUF file, Safetensors directory, DDUF file, or existing model into a Docker model OCI artifact, with optional licenses and multimodal projector. The package is sent to the model-runner, unless --push is specified.\n" + "When packaging a sharded GGUF model, --gguf should point to the first shard. All shard files should be siblings and should include the index in the file name (e.g. model-00001-of-00015.gguf).\n" + "When packaging a Safetensors model, --safetensors-dir should point to a directory containing .safetensors files and config files (*.json, merges.txt). All files will be auto-discovered and config files will be packaged into a tar archive.\n" + + "When packaging a DDUF file (Diffusers Unified Format), --dduf should point to a .dduf archive file.\n" + "When packaging from an existing model using --from, you can modify properties like context size to create a variant of the original model.\n" + "For multimodal models, use --mmproj to include a multimodal projector file.", Args: func(cmd *cobra.Command, args []string) error { @@ -50,7 +51,7 @@ func newPackagedCmd() *cobra.Command { return err } - // Validate that exactly one of --gguf, --safetensors-dir, or --from is provided (mutually exclusive) + // Validate that exactly one of --gguf, --safetensors-dir, --dduf, or --from is provided (mutually exclusive) sourcesProvided := 0 if opts.ggufPath != "" { sourcesProvided++ @@ -58,19 +59,22 @@ func newPackagedCmd() *cobra.Command { if opts.safetensorsDir != "" { sourcesProvided++ } + if opts.ddufPath != "" { + sourcesProvided++ + } if opts.fromModel != "" { sourcesProvided++ } if sourcesProvided == 0 { return fmt.Errorf( - "One of --gguf, --safetensors-dir, or --from is required.\n\n" + + "One of --gguf, --safetensors-dir, --dduf, or --from is required.\n\n" + "See 'docker model package --help' for more information", ) } if sourcesProvided > 1 { return fmt.Errorf( - "Cannot specify more than one of --gguf, --safetensors-dir, or --from. Please use only one source.\n\n" + + "Cannot specify more than one of --gguf, --safetensors-dir, --dduf, or --from. Please use only one source.\n\n" + "See 'docker model package --help' for more information", ) } @@ -141,6 +145,15 @@ func newPackagedCmd() *cobra.Command { } } + // Validate DDUF path if provided + if opts.ddufPath != "" { + var err error + opts.ddufPath, err = validateAbsolutePath(opts.ddufPath, "DDUF") + if err != nil { + return err + } + } + // Validate dir-tar paths are relative (not absolute) for _, dirPath := range opts.dirTarPaths { if filepath.IsAbs(dirPath) { @@ -167,6 +180,7 @@ func newPackagedCmd() *cobra.Command { c.Flags().StringVar(&opts.ggufPath, "gguf", "", "absolute path to gguf file") c.Flags().StringVar(&opts.safetensorsDir, "safetensors-dir", "", "absolute path to directory containing safetensors files and config") + c.Flags().StringVar(&opts.ddufPath, "dduf", "", "absolute path to DDUF archive file (Diffusers Unified Format)") c.Flags().StringVar(&opts.fromModel, "from", "", "reference to an existing model to repackage") c.Flags().StringVar(&opts.chatTemplatePath, "chat-template", "", "absolute path to chat template file (must be Jinja format)") c.Flags().StringArrayVarP(&opts.licensePaths, "license", "l", nil, "absolute path to a license file") @@ -182,6 +196,7 @@ type packageOptions struct { contextSize uint64 ggufPath string safetensorsDir string + ddufPath string fromModel string licensePaths []string dirTarPaths []string @@ -197,7 +212,7 @@ type builderInitResult struct { cleanupFunc func() // Optional cleanup function for temporary files } -// initializeBuilder creates a package builder from GGUF, Safetensors, or existing model +// initializeBuilder creates a package builder from GGUF, Safetensors, DDUF, or existing model func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitResult, error) { result := &builderInitResult{} @@ -246,7 +261,14 @@ func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitRes return nil, fmt.Errorf("add gguf file: %w", err) } result.builder = pkg - } else { + } else if opts.ddufPath != "" { + cmd.PrintErrf("Adding DDUF file from %q\n", opts.ddufPath) + pkg, err := builder.FromPath(opts.ddufPath) + if err != nil { + return nil, fmt.Errorf("add dduf file: %w", err) + } + result.builder = pkg + } else if opts.safetensorsDir != "" { // Safetensors model from directory cmd.PrintErrf("Scanning directory %q for safetensors model...\n", opts.safetensorsDir) safetensorsPaths, tempConfigArchive, err := packaging.PackageFromDirectory(opts.safetensorsDir) @@ -276,6 +298,8 @@ func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitRes } } result.builder = pkg + } else { + return nil, fmt.Errorf("no model source specified") } return result, nil diff --git a/pkg/distribution/files/classify.go b/pkg/distribution/files/classify.go index 9e4b24b66..d28fa5866 100644 --- a/pkg/distribution/files/classify.go +++ b/pkg/distribution/files/classify.go @@ -17,6 +17,8 @@ const ( FileTypeGGUF // FileTypeSafetensors is a safetensors model weight file FileTypeSafetensors + // FileTypeDDUF is a DDUF (Diffusers Unified Format) file + FileTypeDDUF // FileTypeConfig is a configuration file (json, txt, etc.) FileTypeConfig // FileTypeLicense is a license file @@ -32,6 +34,8 @@ func (ft FileType) String() string { return "gguf" case FileTypeSafetensors: return "safetensors" + case FileTypeDDUF: + return "dduf" case FileTypeConfig: return "config" case FileTypeLicense: @@ -74,6 +78,11 @@ func Classify(path string) FileType { return FileTypeSafetensors } + // Check for DDUF files (Diffusers Unified Format) + if strings.HasSuffix(lower, ".dduf") { + return FileTypeDDUF + } + // Check for chat template files (before generic config check) for _, ext := range ChatTemplateExtensions { if strings.HasSuffix(lower, ext) { diff --git a/pkg/distribution/format/format.go b/pkg/distribution/format/format.go index 6b4b5c40e..dafd17e49 100644 --- a/pkg/distribution/format/format.go +++ b/pkg/distribution/format/format.go @@ -60,6 +60,8 @@ func DetectFromPath(path string) (Format, error) { return Get(types.FormatGGUF) case files.FileTypeSafetensors: return Get(types.FormatSafetensors) + case files.FileTypeDDUF: + return Get(types.FormatDiffusers) case files.FileTypeUnknown, files.FileTypeConfig, files.FileTypeLicense, files.FileTypeChatTemplate: return nil, fmt.Errorf("unable to detect format from path: %s (file type: %s)", utils.SanitizeForLog(path), ft) } diff --git a/pkg/distribution/internal/bundle/bundle.go b/pkg/distribution/internal/bundle/bundle.go index b17b241d0..e9d82a8b6 100644 --- a/pkg/distribution/internal/bundle/bundle.go +++ b/pkg/distribution/internal/bundle/bundle.go @@ -17,6 +17,7 @@ type Bundle struct { mmprojPath string ggufFile string // path to GGUF file (first shard when model is split among files) safetensorsFile string // path to safetensors file (first shard when model is split among files) + ddufFile string // path to DDUF file (Diffusers Unified Format) runtimeConfig types.ModelConfig chatTemplatePath string } @@ -59,6 +60,14 @@ func (b *Bundle) SafetensorsPath() string { return filepath.Join(b.dir, ModelSubdir, b.safetensorsFile) } +// DDUFPath returns the path to the DDUF file (Diffusers Unified Format) or "" if none is present. +func (b *Bundle) DDUFPath() string { + if b.ddufFile == "" { + return "" + } + return filepath.Join(b.dir, ModelSubdir, b.ddufFile) +} + // RuntimeConfig returns config that should be respected by the backend at runtime. // Can return either Docker format (*types.Config) or ModelPack format (*modelpack.Model). func (b *Bundle) RuntimeConfig() types.ModelConfig { diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index b2e67d523..686534bca 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -38,6 +38,10 @@ func Unpack(dir string, model types.Model) (*Bundle, error) { if err := unpackSafetensors(bundle, model); err != nil { return nil, fmt.Errorf("unpack safetensors files: %w", err) } + case types.FormatDiffusers: + if err := unpackDDUF(bundle, model); err != nil { + return nil, fmt.Errorf("unpack DDUF file: %w", err) + } default: return nil, fmt.Errorf("no supported model weights found (neither GGUF nor safetensors)") } @@ -88,9 +92,37 @@ func detectModelFormat(model types.Model) types.Format { return types.FormatSafetensors } + // Check for DDUF files + ddufPaths, err := model.DDUFPaths() + if err == nil && len(ddufPaths) > 0 { + return types.FormatDiffusers + } + return "" } +// unpackDDUF unpacks a DDUF (Diffusers Unified Format) file to the bundle. +func unpackDDUF(bundle *Bundle, mdl types.Model) error { + ddufPaths, err := mdl.DDUFPaths() + if err != nil { + return fmt.Errorf("get DDUF files for model: %w", err) + } + + if len(ddufPaths) == 0 { + return fmt.Errorf("no DDUF files found") + } + + modelDir := filepath.Join(bundle.dir, ModelSubdir) + + // DDUF is a single-file format + ddufFilename := filepath.Base(ddufPaths[0]) + if err := unpackFile(filepath.Join(modelDir, ddufFilename), ddufPaths[0]); err != nil { + return err + } + bundle.ddufFile = ddufFilename + return nil +} + // hasLayerWithMediaType checks if the model contains a layer with the specified media type func hasLayerWithMediaType(model types.Model, targetMediaType oci.MediaType) bool { // Check specific media types using the model's methods diff --git a/pkg/distribution/internal/partial/partial.go b/pkg/distribution/internal/partial/partial.go index fdea593a1..b1750cb6d 100644 --- a/pkg/distribution/internal/partial/partial.go +++ b/pkg/distribution/internal/partial/partial.go @@ -124,6 +124,10 @@ func SafetensorsPaths(i WithLayers) ([]string, error) { return layerPathsByMediaType(i, types.MediaTypeSafetensors) } +func DDUFPaths(i WithLayers) ([]string, error) { + return layerPathsByMediaType(i, types.MediaTypeDDUF) +} + func ConfigArchivePath(i WithLayers) (string, error) { paths, err := layerPathsByMediaType(i, types.MediaTypeVLLMConfigArchive) if err != nil { diff --git a/pkg/distribution/internal/store/model.go b/pkg/distribution/internal/store/model.go index 0335c71dd..4fa53acf2 100644 --- a/pkg/distribution/internal/store/model.go +++ b/pkg/distribution/internal/store/model.go @@ -157,6 +157,10 @@ func (m *Model) SafetensorsPaths() ([]string, error) { return mdpartial.SafetensorsPaths(m) } +func (m *Model) DDUFPaths() ([]string, error) { + return mdpartial.DDUFPaths(m) +} + func (m *Model) ConfigArchivePath() (string, error) { return mdpartial.ConfigArchivePath(m) } diff --git a/pkg/distribution/modelpack/types.go b/pkg/distribution/modelpack/types.go index d6d1b5e8d..af4345afc 100644 --- a/pkg/distribution/modelpack/types.go +++ b/pkg/distribution/modelpack/types.go @@ -158,6 +158,8 @@ func (m *Model) GetFormat() types.Format { return types.FormatGGUF case "safetensors": return types.FormatSafetensors + case "diffusers": + return types.FormatDiffusers default: return types.Format(f) } diff --git a/pkg/distribution/types/config.go b/pkg/distribution/types/config.go index 5b910585a..ec4dab3d5 100644 --- a/pkg/distribution/types/config.go +++ b/pkg/distribution/types/config.go @@ -25,6 +25,9 @@ const ( // MediaTypeDirTar indicates a tar archive containing a directory with its structure preserved. MediaTypeDirTar MediaType = "application/vnd.docker.ai.dir.tar" + // MediaTypeDDUF indicates a file in DDUF format (Diffusers Unified Format). + MediaTypeDDUF MediaType = "application/vnd.docker.ai.dduf" + // MediaTypeLicense indicates a plain text file containing a license MediaTypeLicense MediaType = "application/vnd.docker.ai.license" @@ -36,6 +39,7 @@ const ( FormatGGUF = Format("gguf") FormatSafetensors = Format("safetensors") + FormatDiffusers = Format("diffusers") // OCI Annotation keys for model layers // See https://github.com/opencontainers/image-spec/blob/main/annotations.md @@ -82,6 +86,7 @@ type Config struct { Size string `json:"size,omitempty"` GGUF map[string]string `json:"gguf,omitempty"` Safetensors map[string]string `json:"safetensors,omitempty"` + Diffusers map[string]string `json:"diffusers,omitempty"` ContextSize *int32 `json:"context_size,omitempty"` } diff --git a/pkg/distribution/types/model.go b/pkg/distribution/types/model.go index 8fe2956d5..350f49757 100644 --- a/pkg/distribution/types/model.go +++ b/pkg/distribution/types/model.go @@ -8,6 +8,7 @@ type Model interface { ID() (string, error) GGUFPaths() ([]string, error) SafetensorsPaths() ([]string, error) + DDUFPaths() ([]string, error) ConfigArchivePath() (string, error) MMPROJPath() (string, error) Config() (ModelConfig, error) @@ -27,6 +28,7 @@ type ModelBundle interface { RootDir() string GGUFPath() string SafetensorsPath() string + DDUFPath() string ChatTemplatePath() string MMPROJPath() string RuntimeConfig() ModelConfig diff --git a/pkg/inference/backends/diffusers/diffusers.go b/pkg/inference/backends/diffusers/diffusers.go index c6387e41f..89ea42706 100644 --- a/pkg/inference/backends/diffusers/diffusers.go +++ b/pkg/inference/backends/diffusers/diffusers.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/docker/model-runner/pkg/diskusage" + "github.com/docker/model-runner/pkg/distribution/types" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends" "github.com/docker/model-runner/pkg/inference/models" @@ -28,6 +29,7 @@ var ( ErrNotImplemented = errors.New("not implemented") ErrDiffusersNotFound = errors.New("diffusers package not installed") ErrPythonNotFound = errors.New("python3 not found in PATH") + ErrNoDDUFFile = errors.New("no DDUF file found in model bundle") ) // diffusers is the diffusers-based backend implementation for image generation. @@ -72,9 +74,9 @@ func (d *diffusers) Name() string { } // UsesExternalModelManagement implements inference.Backend.UsesExternalModelManagement. -// Diffusers uses the shared model manager but also supports loading models directly from HuggingFace. +// Diffusers uses the shared model manager with bundled DDUF files. func (d *diffusers) UsesExternalModelManagement() bool { - return true // For now, we'll use external model management (HuggingFace downloads) + return false // Use the bundle system for DDUF files } // UsesTCP implements inference.Backend.UsesTCP. @@ -131,7 +133,7 @@ func (d *diffusers) Install(_ context.Context, _ *http.Client) error { } // Run implements inference.Backend.Run. -func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef string, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error { +func (d *diffusers) Run(ctx context.Context, socket, model string, _ string, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error { if !platform.SupportsDiffusers() { d.log.Warn("diffusers backend is not yet supported on this platform") return ErrNotImplemented @@ -142,7 +144,21 @@ func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef stri return fmt.Errorf("diffusers backend only supports image-generation mode, got %s", mode) } - args, err := d.config.GetArgs(model, socket, mode, backendConfig) + // Get the model bundle to find the DDUF file path + bundle, err := d.modelManager.GetBundle(model) + if err != nil { + return fmt.Errorf("failed to get model bundle for %s: %w", model, err) + } + + // Get the DDUF file path from the bundle + ddufPath := bundle.DDUFPath() + if ddufPath == "" { + return fmt.Errorf("%w: model %s", ErrNoDDUFFile, model) + } + + d.log.Infof("Loading DDUF file from: %s", ddufPath) + + args, err := d.config.GetArgs(ddufPath, socket, mode, backendConfig) if err != nil { return fmt.Errorf("failed to get diffusers arguments: %w", err) } @@ -154,8 +170,59 @@ func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef stri args = append(args, "--served-model-name", sanitizedModel) } + d.log.Infof("Diffusers args: %v", args) + + if d.pythonPath == "" { + return fmt.Errorf("diffusers: python runtime not configured; did you forget to call Install") + } + + sandboxPath := "" + if _, err := os.Stat(diffusersDir); err == nil { + sandboxPath = diffusersDir + } + + return backends.RunBackend(ctx, backends.RunnerConfig{ + BackendName: "Diffusers", + Socket: socket, + BinaryPath: d.pythonPath, + SandboxPath: sandboxPath, + SandboxConfig: "", + Args: args, + Logger: d.log, + ServerLogWriter: d.serverLog.Writer(), + }) +} + +// RunWithBundle implements inference.BackendWithBundle.RunWithBundle. +// This method is called when the backend uses the bundle system. +func (d *diffusers) RunWithBundle(ctx context.Context, socket string, bundle types.ModelBundle, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error { + if !platform.SupportsDiffusers() { + d.log.Warn("diffusers backend is not yet supported on this platform") + return ErrNotImplemented + } + + // For diffusers, we support image generation mode + if mode != inference.BackendModeImageGeneration { + return fmt.Errorf("diffusers backend only supports image-generation mode, got %s", mode) + } + + // Get the DDUF file path from the bundle + ddufPath := bundle.DDUFPath() + if ddufPath == "" { + return ErrNoDDUFFile + } + + d.log.Infof("Loading DDUF file from bundle: %s", ddufPath) + + args, err := d.config.GetArgs(ddufPath, socket, mode, backendConfig) + if err != nil { + return fmt.Errorf("failed to get diffusers arguments: %w", err) + } + + d.log.Infof("Diffusers args: %v", args) + if d.pythonPath == "" { - return fmt.Errorf("diffusers: python runtime not configured; did you forget to call Install?") + return fmt.Errorf("diffusers: python runtime not configured; did you forget to call Install") } sandboxPath := "" diff --git a/pkg/inference/backends/llamacpp/llamacpp_config_test.go b/pkg/inference/backends/llamacpp/llamacpp_config_test.go index 1a53a1c85..ee8223c15 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config_test.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config_test.go @@ -448,6 +448,10 @@ func (f *fakeBundle) SafetensorsPath() string { return "" } +func (f *fakeBundle) DDUFPath() string { + return "" +} + func (f *fakeBundle) RuntimeConfig() types.ModelConfig { if f.config == nil { return nil diff --git a/pkg/inference/backends/sglang/sglang_config_test.go b/pkg/inference/backends/sglang/sglang_config_test.go index e4aed9255..2a96b0bc8 100644 --- a/pkg/inference/backends/sglang/sglang_config_test.go +++ b/pkg/inference/backends/sglang/sglang_config_test.go @@ -35,6 +35,10 @@ func (m *mockModelBundle) RuntimeConfig() types.ModelConfig { return m.runtimeConfig } +func (m *mockModelBundle) DDUFPath() string { + return "" +} + func (m *mockModelBundle) RootDir() string { return "/path/to/bundle" } diff --git a/pkg/inference/backends/vllm/vllm_config_test.go b/pkg/inference/backends/vllm/vllm_config_test.go index c52d65e19..d64985346 100644 --- a/pkg/inference/backends/vllm/vllm_config_test.go +++ b/pkg/inference/backends/vllm/vllm_config_test.go @@ -35,6 +35,10 @@ func (m *mockModelBundle) RuntimeConfig() types.ModelConfig { return m.runtimeConfig } +func (m *mockModelBundle) DDUFPath() string { + return "" +} + func (m *mockModelBundle) RootDir() string { return "/path/to/bundle" } diff --git a/python/diffusers_server/server.py b/python/diffusers_server/server.py index b2da9b2ae..66e4ee16b 100644 --- a/python/diffusers_server/server.py +++ b/python/diffusers_server/server.py @@ -9,6 +9,7 @@ import base64 import io import logging +import os import time from typing import Optional, List, Literal @@ -72,8 +73,52 @@ def parse_size(size: str) -> tuple[int, int]: raise ValueError(f"Invalid size format '{size}'. Expected format like '512x512': {e}") +def is_dduf_file(path: str) -> bool: + """Check if the given path is a DDUF file""" + return path.lower().endswith('.dduf') and os.path.isfile(path) + + +def load_model_from_dduf(dduf_path: str, device: str, dtype: torch.dtype) -> DiffusionPipeline: + """Load a diffusion model from a DDUF (Diffusers Unified Format) file""" + logger.info(f"Loading model from DDUF file: {dduf_path}") + + try: + # Try importing DDUFFile - available in diffusers >= 0.32.0 + from huggingface_hub import DDUFFile + + # Open the DDUF file + dduf_file = DDUFFile(dduf_path) + + # Load the pipeline from the DDUF file + # The DDUF file contains everything needed for the pipeline + pipe = DiffusionPipeline.from_pretrained( + dduf_path, + dduf_file=dduf_file, + torch_dtype=dtype, + ) + + pipe = pipe.to(device) + logger.info(f"Model loaded successfully from DDUF on {device}") + return pipe + + except ImportError: + logger.warning("DDUFFile not available. Trying alternative loading method...") + # Fall back to trying to load directly + # Some versions of diffusers support loading DDUF directly + try: + pipe = DiffusionPipeline.from_pretrained( + dduf_path, + torch_dtype=dtype, + ) + pipe = pipe.to(device) + logger.info(f"Model loaded successfully from DDUF (direct) on {device}") + return pipe + except Exception as e: + raise RuntimeError(f"Failed to load DDUF file: {e}. Please ensure diffusers >= 0.32.0 is installed.") + + def load_model(model_path: str) -> DiffusionPipeline: - """Load a diffusion model from the given path or HuggingFace model ID""" + """Load a diffusion model from the given path, DDUF file, or HuggingFace model ID""" global pipeline, current_model if pipeline is not None and current_model == model_path: @@ -96,6 +141,16 @@ def load_model(model_path: str) -> DiffusionPipeline: dtype = torch.float32 logger.info("Using CPU device with float32") + # Check if this is a DDUF file + if is_dduf_file(model_path): + pipeline = load_model_from_dduf(model_path, device, dtype) + current_model = model_path + return pipeline + + # Check if this is a directory containing a model + if os.path.isdir(model_path): + logger.info(f"Loading model from directory: {model_path}") + try: # Try to load using AutoPipelineForText2Image which handles most model types pipeline = AutoPipelineForText2Image.from_pretrained( @@ -288,7 +343,7 @@ async def startup_event(): def main(): """Main entry point for the diffusers server""" parser = argparse.ArgumentParser(description="Diffusers Server - OpenAI Images API compatible server") - parser.add_argument("--model-path", type=str, required=True, help="Path to the diffusion model or HuggingFace model ID") + parser.add_argument("--model-path", type=str, required=True, help="Path to the diffusion model, DDUF file, or HuggingFace model ID") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") parser.add_argument("--port", type=int, default=8000, help="Port to bind to") parser.add_argument("--served-model-name", type=str, default=None, help="Name to serve the model as") From c886d003fa91e348fe9a22d98cced0de5e3fd967 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 10:08:06 +0100 Subject: [PATCH 03/20] feat(dduf): implement DDUF format support and enhance model loading --- Dockerfile | 2 + pkg/distribution/format/dduf.go | 60 +++++++++++++++++++ pkg/distribution/internal/bundle/unpack.go | 4 ++ pkg/inference/backends/diffusers/diffusers.go | 12 ++-- python/diffusers_server/server.py | 35 +++++------ 5 files changed, 84 insertions(+), 29 deletions(-) create mode 100644 pkg/distribution/format/dduf.go diff --git a/Dockerfile b/Dockerfile index ffbd97d62..f66b2d503 100644 --- a/Dockerfile +++ b/Dockerfile @@ -171,6 +171,8 @@ RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ "transformers" \ "accelerate" \ "safetensors" \ + "huggingface_hub>=0.27.0" \ + "bitsandbytes" \ "fastapi" \ "uvicorn[standard]" \ "pillow" diff --git a/pkg/distribution/format/dduf.go b/pkg/distribution/format/dduf.go new file mode 100644 index 000000000..46a5f6c6f --- /dev/null +++ b/pkg/distribution/format/dduf.go @@ -0,0 +1,60 @@ +package format + +import ( + "path/filepath" + + "github.com/docker/model-runner/pkg/distribution/oci" + "github.com/docker/model-runner/pkg/distribution/types" +) + +// DDUFFormat implements the Format interface for DDUF (Diffusers Unified Format) model files. +// DDUF is a single-file archive format for diffusion models used by HuggingFace Diffusers. +type DDUFFormat struct{} + +// init registers the DDUF format implementation. +func init() { + Register(&DDUFFormat{}) +} + +// Name returns the format identifier for DDUF. +func (d *DDUFFormat) Name() types.Format { + return types.FormatDiffusers +} + +// MediaType returns the OCI media type for DDUF layers. +func (d *DDUFFormat) MediaType() oci.MediaType { + return types.MediaTypeDDUF +} + +// DiscoverShards finds all DDUF shard files for a model. +// DDUF is a single-file format, so this always returns a slice containing only the input path. +func (d *DDUFFormat) DiscoverShards(path string) ([]string, error) { + // DDUF files are single archives, not sharded + return []string{path}, nil +} + +// ExtractConfig parses DDUF file(s) and extracts model configuration metadata. +// DDUF files are zip archives containing model config, so we extract what we can. +func (d *DDUFFormat) ExtractConfig(paths []string) (types.Config, error) { + if len(paths) == 0 { + return types.Config{Format: types.FormatDiffusers}, nil + } + + // Extract the filename for metadata + ddufFile := "" + if len(paths) > 0 { + ddufFile = filepath.Base(paths[0]) + } + + // Return config with diffusers-specific metadata + // In the future, we could extract model_index.json from the DDUF archive + // to get architecture details, etc. + return types.Config{ + Format: types.FormatDiffusers, + Architecture: "diffusers", + Diffusers: map[string]string{ + "layout": "dduf", + "dduf_file": ddufFile, + }, + }, nil +} diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index 686534bca..5f56c17ec 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -116,6 +116,10 @@ func unpackDDUF(bundle *Bundle, mdl types.Model) error { // DDUF is a single-file format ddufFilename := filepath.Base(ddufPaths[0]) + // Ensure the filename has the .dduf extension for proper detection by diffusers server + if !strings.HasSuffix(strings.ToLower(ddufFilename), ".dduf") { + ddufFilename = ddufFilename + ".dduf" + } if err := unpackFile(filepath.Join(modelDir, ddufFilename), ddufPaths[0]); err != nil { return err } diff --git a/pkg/inference/backends/diffusers/diffusers.go b/pkg/inference/backends/diffusers/diffusers.go index 89ea42706..11f840d1a 100644 --- a/pkg/inference/backends/diffusers/diffusers.go +++ b/pkg/inference/backends/diffusers/diffusers.go @@ -133,7 +133,7 @@ func (d *diffusers) Install(_ context.Context, _ *http.Client) error { } // Run implements inference.Backend.Run. -func (d *diffusers) Run(ctx context.Context, socket, model string, _ string, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error { +func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef string, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error { if !platform.SupportsDiffusers() { d.log.Warn("diffusers backend is not yet supported on this platform") return ErrNotImplemented @@ -163,11 +163,9 @@ func (d *diffusers) Run(ctx context.Context, socket, model string, _ string, mod return fmt.Errorf("failed to get diffusers arguments: %w", err) } - // Add served model name - if model != "" { - // Replace colons with underscores to sanitize the model name - sanitizedModel := strings.ReplaceAll(model, ":", "_") - args = append(args, "--served-model-name", sanitizedModel) + // Add served model name using the human-readable model reference + if modelRef != "" { + args = append(args, "--served-model-name", modelRef) } d.log.Infof("Diffusers args: %v", args) @@ -219,8 +217,6 @@ func (d *diffusers) RunWithBundle(ctx context.Context, socket string, bundle typ return fmt.Errorf("failed to get diffusers arguments: %w", err) } - d.log.Infof("Diffusers args: %v", args) - if d.pythonPath == "" { return fmt.Errorf("diffusers: python runtime not configured; did you forget to call Install") } diff --git a/python/diffusers_server/server.py b/python/diffusers_server/server.py index 66e4ee16b..3897390a4 100644 --- a/python/diffusers_server/server.py +++ b/python/diffusers_server/server.py @@ -83,17 +83,21 @@ def load_model_from_dduf(dduf_path: str, device: str, dtype: torch.dtype) -> Dif logger.info(f"Loading model from DDUF file: {dduf_path}") try: - # Try importing DDUFFile - available in diffusers >= 0.32.0 - from huggingface_hub import DDUFFile + # Get the directory and filename from the DDUF path + # DiffusionPipeline.from_pretrained() expects: + # - First arg: directory containing the DDUF file (or repo ID for HF Hub) + # - dduf_file: the filename (string) of the DDUF file within that directory + dduf_dir = os.path.dirname(dduf_path) + dduf_filename = os.path.basename(dduf_path) - # Open the DDUF file - dduf_file = DDUFFile(dduf_path) + logger.info(f"Using directory: {dduf_dir}") + logger.info(f"Using DDUF filename: {dduf_filename}") # Load the pipeline from the DDUF file - # The DDUF file contains everything needed for the pipeline + # The diffusers library will internally read the DDUF file and extract components pipe = DiffusionPipeline.from_pretrained( - dduf_path, - dduf_file=dduf_file, + dduf_dir, + dduf_file=dduf_filename, torch_dtype=dtype, ) @@ -101,20 +105,9 @@ def load_model_from_dduf(dduf_path: str, device: str, dtype: torch.dtype) -> Dif logger.info(f"Model loaded successfully from DDUF on {device}") return pipe - except ImportError: - logger.warning("DDUFFile not available. Trying alternative loading method...") - # Fall back to trying to load directly - # Some versions of diffusers support loading DDUF directly - try: - pipe = DiffusionPipeline.from_pretrained( - dduf_path, - torch_dtype=dtype, - ) - pipe = pipe.to(device) - logger.info(f"Model loaded successfully from DDUF (direct) on {device}") - return pipe - except Exception as e: - raise RuntimeError(f"Failed to load DDUF file: {e}. Please ensure diffusers >= 0.32.0 is installed.") + except Exception as e: + logger.exception("Error loading DDUF file") + raise RuntimeError(f"Failed to load DDUF file: {e}") def load_model(model_path: str) -> DiffusionPipeline: From 3f7394e36c56f567688fefb474a308898fc3987b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 10:50:40 +0100 Subject: [PATCH 04/20] feat(dduf): calculate total size of files and add human-readable size format --- pkg/distribution/format/dduf.go | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/pkg/distribution/format/dduf.go b/pkg/distribution/format/dduf.go index 46a5f6c6f..f87e34d41 100644 --- a/pkg/distribution/format/dduf.go +++ b/pkg/distribution/format/dduf.go @@ -1,8 +1,11 @@ package format import ( + "fmt" + "os" "path/filepath" + "github.com/docker/go-units" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/types" ) @@ -40,21 +43,35 @@ func (d *DDUFFormat) ExtractConfig(paths []string) (types.Config, error) { return types.Config{Format: types.FormatDiffusers}, nil } - // Extract the filename for metadata - ddufFile := "" - if len(paths) > 0 { - ddufFile = filepath.Base(paths[0]) + // Calculate total size across all files + var totalSize int64 + for _, path := range paths { + info, err := os.Stat(path) + if err != nil { + return types.Config{}, fmt.Errorf("failed to stat file %s: %w", path, err) + } + totalSize += info.Size() } + // Extract the filename for metadata + ddufFile := filepath.Base(paths[0]) + // Return config with diffusers-specific metadata // In the future, we could extract model_index.json from the DDUF archive // to get architecture details, etc. return types.Config{ Format: types.FormatDiffusers, Architecture: "diffusers", + Size: formatDDUFSize(totalSize), Diffusers: map[string]string{ "layout": "dduf", "dduf_file": ddufFile, }, }, nil } + +// formatDDUFSize converts bytes to human-readable format matching Docker's style +// Returns format like "256MB" (decimal units, no space, matching `docker images`) +func formatDDUFSize(bytes int64) string { + return units.CustomSize("%.2f%s", float64(bytes), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB"}) +} From 4c226781dd1bd3056dd5efb109ed71feb5e48468 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 11:18:36 +0100 Subject: [PATCH 05/20] feat(platform): restrict Diffusers support to Linux only until macOS distribution is designed --- pkg/inference/platform/platform.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/inference/platform/platform.go b/pkg/inference/platform/platform.go index 3b0b6db26..76b422782 100644 --- a/pkg/inference/platform/platform.go +++ b/pkg/inference/platform/platform.go @@ -21,5 +21,6 @@ func SupportsSGLang() bool { // SupportsDiffusers returns true if diffusers is supported on the current platform. // Diffusers is supported on Linux (for Docker/CUDA) and macOS (for MPS/Apple Silicon). func SupportsDiffusers() bool { - return runtime.GOOS == "linux" || runtime.GOOS == "darwin" + //return runtime.GOOS == "linux" || runtime.GOOS == "darwin" + return runtime.GOOS == "linux" // Support for macOS disabled for now until we design a solution to distribute it via Docker Desktop. } From c89429a4b16d46f05886a54b2a0c61f1b3309ab8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 11:39:26 +0100 Subject: [PATCH 06/20] feat(diffusers): add support for DDUF file type handling in repository and config files --- pkg/distribution/huggingface/repository.go | 2 +- pkg/distribution/packaging/safetensors.go | 2 +- pkg/inference/backends/diffusers/diffusers_config.go | 2 +- pkg/inference/backends/llamacpp/llamacpp_config.go | 2 +- pkg/inference/backends/mlx/mlx_config.go | 2 +- pkg/inference/backends/sglang/sglang_config.go | 3 ++- pkg/inference/backends/vllm/vllm_config.go | 7 ++++--- pkg/inference/platform/platform.go | 2 +- 8 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pkg/distribution/huggingface/repository.go b/pkg/distribution/huggingface/repository.go index c0d69a959..1c7bb3f86 100644 --- a/pkg/distribution/huggingface/repository.go +++ b/pkg/distribution/huggingface/repository.go @@ -55,7 +55,7 @@ func FilterModelFiles(repoFiles []RepoFile) (weights []RepoFile, configs []RepoF weights = append(weights, f) case files.FileTypeConfig, files.FileTypeChatTemplate: configs = append(configs, f) - case files.FileTypeUnknown, files.FileTypeLicense: + case files.FileTypeUnknown, files.FileTypeLicense, files.FileTypeDDUF: // Skip these file types } } diff --git a/pkg/distribution/packaging/safetensors.go b/pkg/distribution/packaging/safetensors.go index 2a5be421d..a4348f202 100644 --- a/pkg/distribution/packaging/safetensors.go +++ b/pkg/distribution/packaging/safetensors.go @@ -40,7 +40,7 @@ func PackageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfig safetensorsPaths = append(safetensorsPaths, fullPath) case files.FileTypeConfig, files.FileTypeChatTemplate: configFiles = append(configFiles, fullPath) - case files.FileTypeUnknown, files.FileTypeGGUF, files.FileTypeLicense: + case files.FileTypeUnknown, files.FileTypeGGUF, files.FileTypeLicense, files.FileTypeDDUF: // Skip these file types } } diff --git a/pkg/inference/backends/diffusers/diffusers_config.go b/pkg/inference/backends/diffusers/diffusers_config.go index 0fc60b16f..010445e65 100644 --- a/pkg/inference/backends/diffusers/diffusers_config.go +++ b/pkg/inference/backends/diffusers/diffusers_config.go @@ -40,7 +40,7 @@ func (c *Config) GetArgs(model string, socket string, mode inference.BackendMode switch mode { case inference.BackendModeImageGeneration: // Default mode for diffusers - image generation - default: + case inference.BackendModeCompletion, inference.BackendModeEmbedding, inference.BackendModeReranking: return nil, fmt.Errorf("unsupported backend mode %q for diffusers", mode) } diff --git a/pkg/inference/backends/llamacpp/llamacpp_config.go b/pkg/inference/backends/llamacpp/llamacpp_config.go index f0ed4106f..87816410c 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config.go @@ -65,7 +65,7 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference args = append(args, "--embeddings") case inference.BackendModeReranking: args = append(args, "--embeddings", "--reranking") - default: + case inference.BackendModeImageGeneration: return nil, fmt.Errorf("unsupported backend mode %q", mode) } diff --git a/pkg/inference/backends/mlx/mlx_config.go b/pkg/inference/backends/mlx/mlx_config.go index bc4f605c6..29f98638f 100644 --- a/pkg/inference/backends/mlx/mlx_config.go +++ b/pkg/inference/backends/mlx/mlx_config.go @@ -49,7 +49,7 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference case inference.BackendModeReranking: // MLX may not support reranking mode return nil, fmt.Errorf("reranking mode not supported by MLX backend") - default: + case inference.BackendModeImageGeneration: return nil, fmt.Errorf("unsupported backend mode %q", mode) } diff --git a/pkg/inference/backends/sglang/sglang_config.go b/pkg/inference/backends/sglang/sglang_config.go index 4d220d96c..814a516f2 100644 --- a/pkg/inference/backends/sglang/sglang_config.go +++ b/pkg/inference/backends/sglang/sglang_config.go @@ -50,7 +50,8 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference case inference.BackendModeEmbedding: args = append(args, "--is-embedding") case inference.BackendModeReranking: - default: + // SGLang does not have a specific flag for reranking + case inference.BackendModeImageGeneration: return nil, fmt.Errorf("unsupported backend mode %q", mode) } diff --git a/pkg/inference/backends/vllm/vllm_config.go b/pkg/inference/backends/vllm/vllm_config.go index b172637f2..b3ad0d2dd 100644 --- a/pkg/inference/backends/vllm/vllm_config.go +++ b/pkg/inference/backends/vllm/vllm_config.go @@ -45,10 +45,11 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference case inference.BackendModeCompletion: // Default mode for vLLM case inference.BackendModeEmbedding: - // vLLM doesn't have a specific embedding flag like llama.cpp - // Embedding models are detected automatically + // vLLM doesn't have a specific embedding flag like llama.cpp + // Embedding models are detected automatically case inference.BackendModeReranking: - default: + // vLLM does not have a specific flag for reranking + case inference.BackendModeImageGeneration: return nil, fmt.Errorf("unsupported backend mode %q", mode) } diff --git a/pkg/inference/platform/platform.go b/pkg/inference/platform/platform.go index 76b422782..8f1e1ed1f 100644 --- a/pkg/inference/platform/platform.go +++ b/pkg/inference/platform/platform.go @@ -21,6 +21,6 @@ func SupportsSGLang() bool { // SupportsDiffusers returns true if diffusers is supported on the current platform. // Diffusers is supported on Linux (for Docker/CUDA) and macOS (for MPS/Apple Silicon). func SupportsDiffusers() bool { - //return runtime.GOOS == "linux" || runtime.GOOS == "darwin" + // return runtime.GOOS == "linux" || runtime.GOOS == "darwin" return runtime.GOOS == "linux" // Support for macOS disabled for now until we design a solution to distribute it via Docker Desktop. } From a749093bf80980910cde4089e239925c6416ee80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 11:50:42 +0100 Subject: [PATCH 07/20] feat(diffusers): sanitize log output for Diffusers arguments --- pkg/inference/backends/diffusers/diffusers.go | 51 +------------------ 1 file changed, 2 insertions(+), 49 deletions(-) diff --git a/pkg/inference/backends/diffusers/diffusers.go b/pkg/inference/backends/diffusers/diffusers.go index 11f840d1a..1eb0a315e 100644 --- a/pkg/inference/backends/diffusers/diffusers.go +++ b/pkg/inference/backends/diffusers/diffusers.go @@ -11,11 +11,11 @@ import ( "strings" "github.com/docker/model-runner/pkg/diskusage" - "github.com/docker/model-runner/pkg/distribution/types" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends" "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/inference/platform" + "github.com/docker/model-runner/pkg/internal/utils" "github.com/docker/model-runner/pkg/logging" ) @@ -168,54 +168,7 @@ func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef stri args = append(args, "--served-model-name", modelRef) } - d.log.Infof("Diffusers args: %v", args) - - if d.pythonPath == "" { - return fmt.Errorf("diffusers: python runtime not configured; did you forget to call Install") - } - - sandboxPath := "" - if _, err := os.Stat(diffusersDir); err == nil { - sandboxPath = diffusersDir - } - - return backends.RunBackend(ctx, backends.RunnerConfig{ - BackendName: "Diffusers", - Socket: socket, - BinaryPath: d.pythonPath, - SandboxPath: sandboxPath, - SandboxConfig: "", - Args: args, - Logger: d.log, - ServerLogWriter: d.serverLog.Writer(), - }) -} - -// RunWithBundle implements inference.BackendWithBundle.RunWithBundle. -// This method is called when the backend uses the bundle system. -func (d *diffusers) RunWithBundle(ctx context.Context, socket string, bundle types.ModelBundle, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error { - if !platform.SupportsDiffusers() { - d.log.Warn("diffusers backend is not yet supported on this platform") - return ErrNotImplemented - } - - // For diffusers, we support image generation mode - if mode != inference.BackendModeImageGeneration { - return fmt.Errorf("diffusers backend only supports image-generation mode, got %s", mode) - } - - // Get the DDUF file path from the bundle - ddufPath := bundle.DDUFPath() - if ddufPath == "" { - return ErrNoDDUFFile - } - - d.log.Infof("Loading DDUF file from bundle: %s", ddufPath) - - args, err := d.config.GetArgs(ddufPath, socket, mode, backendConfig) - if err != nil { - return fmt.Errorf("failed to get diffusers arguments: %w", err) - } + d.log.Infof("Diffusers args: %v", utils.SanitizeForLog(strings.Join(args, " "))) if d.pythonPath == "" { return fmt.Errorf("diffusers: python runtime not configured; did you forget to call Install") From e916b406d5c6564f69d52ff44ff5aad53d0b5df6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 11:50:52 +0100 Subject: [PATCH 08/20] feat(docker): streamline Python server code copying in Dockerfile --- Dockerfile | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index f66b2d503..aedf28b64 100644 --- a/Dockerfile +++ b/Dockerfile @@ -177,14 +177,11 @@ RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ "uvicorn[standard]" \ "pillow" -# Determine the Python site-packages directory dynamically and copy the Python server code -RUN PYTHON_SITE_PACKAGES=$(/opt/diffusers-env/bin/python -c "import site; print(site.getsitepackages()[0])") && \ - mkdir -p "$PYTHON_SITE_PACKAGES/diffusers_server" - -# Copy Python server code (needs to be done as root then chown, or use a multi-stage approach) +# Copy Python server code USER root COPY python/diffusers_server /tmp/diffusers_server/ RUN PYTHON_SITE_PACKAGES=$(/opt/diffusers-env/bin/python -c "import site; print(site.getsitepackages()[0])") && \ + mkdir -p "$PYTHON_SITE_PACKAGES/diffusers_server" && \ cp -r /tmp/diffusers_server/* "$PYTHON_SITE_PACKAGES/diffusers_server/" && \ chown -R modelrunner:modelrunner "$PYTHON_SITE_PACKAGES/diffusers_server/" && \ rm -rf /tmp/diffusers_server From d54ea4a74ce03b0f1dc67a72f14fd2533f7b1755 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 12:19:25 +0100 Subject: [PATCH 09/20] feat(docker): specify exact versions for Python packages in Dockerfile --- Dockerfile | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index aedf28b64..4be350728 100644 --- a/Dockerfile +++ b/Dockerfile @@ -75,7 +75,7 @@ ENV MODEL_RUNNER_PORT=12434 ENV LLAMA_SERVER_PATH=/app/bin ENV HOME=/home/modelrunner ENV MODELS_PATH=/models -ENV LD_LIBRARY_PATH=/app/lib:$LD_LIBRARY_PATH +ENV LD_LIBRARY_PATH=/app/lib # Label the image so that it's hidden on cloud engines. LABEL com.docker.desktop.service="model-runner" @@ -148,8 +148,17 @@ RUN /opt/sglang-env/bin/python -c "import sglang; print(sglang.__version__)" > / # --- Diffusers variant --- FROM llamacpp AS diffusers +# Python package versions for reproducible builds ARG DIFFUSERS_VERSION=0.36.0 ARG TORCH_VERSION=2.9.1 +ARG TRANSFORMERS_VERSION=4.57.5 +ARG ACCELERATE_VERSION=1.3.0 +ARG SAFETENSORS_VERSION=0.5.2 +ARG HUGGINGFACE_HUB_VERSION=0.34.0 +ARG BITSANDBYTES_VERSION=0.45.4 +ARG FASTAPI_VERSION=0.115.12 +ARG UVICORN_VERSION=0.34.1 +ARG PILLOW_VERSION=11.2.1 USER root @@ -168,14 +177,14 @@ RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ && ~/.local/bin/uv pip install --python /opt/diffusers-env/bin/python \ "diffusers==${DIFFUSERS_VERSION}" \ "torch==${TORCH_VERSION}" \ - "transformers" \ - "accelerate" \ - "safetensors" \ - "huggingface_hub>=0.27.0" \ - "bitsandbytes" \ - "fastapi" \ - "uvicorn[standard]" \ - "pillow" + "transformers==${TRANSFORMERS_VERSION}" \ + "accelerate==${ACCELERATE_VERSION}" \ + "safetensors==${SAFETENSORS_VERSION}" \ + "huggingface_hub==${HUGGINGFACE_HUB_VERSION}" \ + "bitsandbytes==${BITSANDBYTES_VERSION}" \ + "fastapi==${FASTAPI_VERSION}" \ + "uvicorn[standard]==${UVICORN_VERSION}" \ + "pillow==${PILLOW_VERSION}" # Copy Python server code USER root From 2daa296f0427c08863e80027b2c6ef81ef34dbb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 13:40:04 +0100 Subject: [PATCH 10/20] feat(model): add DDUF file support to packaging command and documentation --- .../docs/reference/docker_model_package.yaml | 16 ++++-- cmd/cli/docs/reference/model.md | 52 +++++++++---------- cmd/cli/docs/reference/model_package.md | 4 +- 3 files changed, 42 insertions(+), 30 deletions(-) diff --git a/cmd/cli/docs/reference/docker_model_package.yaml b/cmd/cli/docs/reference/docker_model_package.yaml index d59835ce1..79d608997 100644 --- a/cmd/cli/docs/reference/docker_model_package.yaml +++ b/cmd/cli/docs/reference/docker_model_package.yaml @@ -1,13 +1,14 @@ command: docker model package short: | - Package a GGUF file, Safetensors directory, or existing model into a Docker model OCI artifact. + Package a GGUF file, Safetensors directory, DDUF file, or existing model into a Docker model OCI artifact. long: |- - Package a GGUF file, Safetensors directory, or existing model into a Docker model OCI artifact, with optional licenses and multimodal projector. The package is sent to the model-runner, unless --push is specified. + Package a GGUF file, Safetensors directory, DDUF file, or existing model into a Docker model OCI artifact, with optional licenses and multimodal projector. The package is sent to the model-runner, unless --push is specified. When packaging a sharded GGUF model, --gguf should point to the first shard. All shard files should be siblings and should include the index in the file name (e.g. model-00001-of-00015.gguf). When packaging a Safetensors model, --safetensors-dir should point to a directory containing .safetensors files and config files (*.json, merges.txt). All files will be auto-discovered and config files will be packaged into a tar archive. + When packaging a DDUF file (Diffusers Unified Format), --dduf should point to a .dduf archive file. When packaging from an existing model using --from, you can modify properties like context size to create a variant of the original model. For multimodal models, use --mmproj to include a multimodal projector file. -usage: docker model package (--gguf | --safetensors-dir | --from ) [--license ...] [--mmproj ] [--context-size ] [--push] MODEL +usage: docker model package (--gguf | --safetensors-dir | --dduf | --from ) [--license ...] [--mmproj ] [--context-size ] [--push] MODEL pname: docker model plink: docker_model.yaml options: @@ -30,6 +31,15 @@ options: experimentalcli: false kubernetes: false swarm: false + - option: dduf + value_type: string + description: absolute path to DDUF archive file (Diffusers Unified Format) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false - option: dir-tar value_type: stringArray default_value: '[]' diff --git a/cmd/cli/docs/reference/model.md b/cmd/cli/docs/reference/model.md index e139fc45d..d47df7f81 100644 --- a/cmd/cli/docs/reference/model.md +++ b/cmd/cli/docs/reference/model.md @@ -5,32 +5,32 @@ Docker Model Runner ### Subcommands -| Name | Description | -|:------------------------------------------------|:------------------------------------------------------------------------------------------------| -| [`bench`](model_bench.md) | Benchmark a model's performance at different concurrency levels | -| [`df`](model_df.md) | Show Docker Model Runner disk usage | -| [`inspect`](model_inspect.md) | Display detailed information on one model | -| [`install-runner`](model_install-runner.md) | Install Docker Model Runner (Docker Engine only) | -| [`list`](model_list.md) | List the models pulled to your local environment | -| [`logs`](model_logs.md) | Fetch the Docker Model Runner logs | -| [`package`](model_package.md) | Package a GGUF file, Safetensors directory, or existing model into a Docker model OCI artifact. | -| [`ps`](model_ps.md) | List running models | -| [`pull`](model_pull.md) | Pull a model from Docker Hub or HuggingFace to your local environment | -| [`purge`](model_purge.md) | Remove all models | -| [`push`](model_push.md) | Push a model to Docker Hub | -| [`reinstall-runner`](model_reinstall-runner.md) | Reinstall Docker Model Runner (Docker Engine only) | -| [`requests`](model_requests.md) | Fetch requests+responses from Docker Model Runner | -| [`restart-runner`](model_restart-runner.md) | Restart Docker Model Runner (Docker Engine only) | -| [`rm`](model_rm.md) | Remove local models downloaded from Docker Hub | -| [`run`](model_run.md) | Run a model and interact with it using a submitted prompt or chat mode | -| [`search`](model_search.md) | Search for models on Docker Hub and HuggingFace | -| [`start-runner`](model_start-runner.md) | Start Docker Model Runner (Docker Engine only) | -| [`status`](model_status.md) | Check if the Docker Model Runner is running | -| [`stop-runner`](model_stop-runner.md) | Stop Docker Model Runner (Docker Engine only) | -| [`tag`](model_tag.md) | Tag a model | -| [`uninstall-runner`](model_uninstall-runner.md) | Uninstall Docker Model Runner (Docker Engine only) | -| [`unload`](model_unload.md) | Unload running models | -| [`version`](model_version.md) | Show the Docker Model Runner version | +| Name | Description | +|:------------------------------------------------|:-----------------------------------------------------------------------------------------------------------| +| [`bench`](model_bench.md) | Benchmark a model's performance at different concurrency levels | +| [`df`](model_df.md) | Show Docker Model Runner disk usage | +| [`inspect`](model_inspect.md) | Display detailed information on one model | +| [`install-runner`](model_install-runner.md) | Install Docker Model Runner (Docker Engine only) | +| [`list`](model_list.md) | List the models pulled to your local environment | +| [`logs`](model_logs.md) | Fetch the Docker Model Runner logs | +| [`package`](model_package.md) | Package a GGUF file, Safetensors directory, DDUF file, or existing model into a Docker model OCI artifact. | +| [`ps`](model_ps.md) | List running models | +| [`pull`](model_pull.md) | Pull a model from Docker Hub or HuggingFace to your local environment | +| [`purge`](model_purge.md) | Remove all models | +| [`push`](model_push.md) | Push a model to Docker Hub | +| [`reinstall-runner`](model_reinstall-runner.md) | Reinstall Docker Model Runner (Docker Engine only) | +| [`requests`](model_requests.md) | Fetch requests+responses from Docker Model Runner | +| [`restart-runner`](model_restart-runner.md) | Restart Docker Model Runner (Docker Engine only) | +| [`rm`](model_rm.md) | Remove local models downloaded from Docker Hub | +| [`run`](model_run.md) | Run a model and interact with it using a submitted prompt or chat mode | +| [`search`](model_search.md) | Search for models on Docker Hub and HuggingFace | +| [`start-runner`](model_start-runner.md) | Start Docker Model Runner (Docker Engine only) | +| [`status`](model_status.md) | Check if the Docker Model Runner is running | +| [`stop-runner`](model_stop-runner.md) | Stop Docker Model Runner (Docker Engine only) | +| [`tag`](model_tag.md) | Tag a model | +| [`uninstall-runner`](model_uninstall-runner.md) | Uninstall Docker Model Runner (Docker Engine only) | +| [`unload`](model_unload.md) | Unload running models | +| [`version`](model_version.md) | Show the Docker Model Runner version | diff --git a/cmd/cli/docs/reference/model_package.md b/cmd/cli/docs/reference/model_package.md index eaf3da293..062f15815 100644 --- a/cmd/cli/docs/reference/model_package.md +++ b/cmd/cli/docs/reference/model_package.md @@ -1,9 +1,10 @@ # docker model package -Package a GGUF file, Safetensors directory, or existing model into a Docker model OCI artifact, with optional licenses and multimodal projector. The package is sent to the model-runner, unless --push is specified. +Package a GGUF file, Safetensors directory, DDUF file, or existing model into a Docker model OCI artifact, with optional licenses and multimodal projector. The package is sent to the model-runner, unless --push is specified. When packaging a sharded GGUF model, --gguf should point to the first shard. All shard files should be siblings and should include the index in the file name (e.g. model-00001-of-00015.gguf). When packaging a Safetensors model, --safetensors-dir should point to a directory containing .safetensors files and config files (*.json, merges.txt). All files will be auto-discovered and config files will be packaged into a tar archive. +When packaging a DDUF file (Diffusers Unified Format), --dduf should point to a .dduf archive file. When packaging from an existing model using --from, you can modify properties like context size to create a variant of the original model. For multimodal models, use --mmproj to include a multimodal projector file. @@ -13,6 +14,7 @@ For multimodal models, use --mmproj to include a multimodal projector file. |:--------------------|:--------------|:--------|:---------------------------------------------------------------------------------------| | `--chat-template` | `string` | | absolute path to chat template file (must be Jinja format) | | `--context-size` | `uint64` | `0` | context size in tokens | +| `--dduf` | `string` | | absolute path to DDUF archive file (Diffusers Unified Format) | | `--dir-tar` | `stringArray` | | relative path to directory to package as tar (can be specified multiple times) | | `--from` | `string` | | reference to an existing model to repackage | | `--gguf` | `string` | | absolute path to gguf file | From cd090d5e94ee14c083263e0400f7b98984e3ae42 Mon Sep 17 00:00:00 2001 From: Ignasi Date: Thu, 15 Jan 2026 14:02:49 +0100 Subject: [PATCH 11/20] Update pkg/distribution/internal/bundle/unpack.go Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- pkg/distribution/internal/bundle/unpack.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index 5f56c17ec..86e39eaf7 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -43,7 +43,7 @@ func Unpack(dir string, model types.Model) (*Bundle, error) { return nil, fmt.Errorf("unpack DDUF file: %w", err) } default: - return nil, fmt.Errorf("no supported model weights found (neither GGUF nor safetensors)") + return nil, fmt.Errorf("no supported model weights found (expected GGUF, safetensors, or diffusers/DDUF)") } // Unpack optional components based on their presence From 400f1a1b857ffa173e65a0422a80aa0257083df5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 15:26:05 +0100 Subject: [PATCH 12/20] refactor(dduf): replace formatDDUFSize with formatSize and clean up unused code --- pkg/distribution/format/dduf.go | 9 +-------- pkg/distribution/format/format.go | 13 +++++++++++++ pkg/distribution/format/safetensors.go | 13 ------------- 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/pkg/distribution/format/dduf.go b/pkg/distribution/format/dduf.go index f87e34d41..d4fe6e838 100644 --- a/pkg/distribution/format/dduf.go +++ b/pkg/distribution/format/dduf.go @@ -5,7 +5,6 @@ import ( "os" "path/filepath" - "github.com/docker/go-units" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/types" ) @@ -62,16 +61,10 @@ func (d *DDUFFormat) ExtractConfig(paths []string) (types.Config, error) { return types.Config{ Format: types.FormatDiffusers, Architecture: "diffusers", - Size: formatDDUFSize(totalSize), + Size: formatSize(totalSize), Diffusers: map[string]string{ "layout": "dduf", "dduf_file": ddufFile, }, }, nil } - -// formatDDUFSize converts bytes to human-readable format matching Docker's style -// Returns format like "256MB" (decimal units, no space, matching `docker images`) -func formatDDUFSize(bytes int64) string { - return units.CustomSize("%.2f%s", float64(bytes), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB"}) -} diff --git a/pkg/distribution/format/format.go b/pkg/distribution/format/format.go index dafd17e49..b905e57b0 100644 --- a/pkg/distribution/format/format.go +++ b/pkg/distribution/format/format.go @@ -6,6 +6,7 @@ package format import ( "fmt" + "github.com/docker/go-units" "github.com/docker/model-runner/pkg/distribution/files" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/types" @@ -95,3 +96,15 @@ func DetectFromPaths(paths []string) (Format, error) { return format, nil } + +// formatParameters converts parameter count to human-readable format +// Returns format like "361.82M" or "1.5B" (no space before unit, base 1000, where B = Billion) +func formatParameters(params int64) string { + return units.CustomSize("%.2f%s", float64(params), 1000.0, []string{"", "K", "M", "B", "T"}) +} + +// formatSize converts bytes to human-readable format matching Docker's style +// Returns format like "256MB" (decimal units, no space, matching `docker images`) +func formatSize(bytes int64) string { + return units.CustomSize("%.2f%s", float64(bytes), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB"}) +} diff --git a/pkg/distribution/format/safetensors.go b/pkg/distribution/format/safetensors.go index 231625770..fa6563799 100644 --- a/pkg/distribution/format/safetensors.go +++ b/pkg/distribution/format/safetensors.go @@ -11,7 +11,6 @@ import ( "sort" "strconv" - "github.com/docker/go-units" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/types" ) @@ -291,15 +290,3 @@ func (h *safetensorsHeader) extractMetadata() map[string]string { return metadata } - -// formatParameters converts parameter count to human-readable format -// Returns format like "361.82M" or "1.5B" (no space before unit, base 1000, where B = Billion) -func formatParameters(params int64) string { - return units.CustomSize("%.2f%s", float64(params), 1000.0, []string{"", "K", "M", "B", "T"}) -} - -// formatSize converts bytes to human-readable format matching Docker's style -// Returns format like "256MB" (decimal units, no space, matching `docker images`) -func formatSize(bytes int64) string { - return units.CustomSize("%.2f%s", float64(bytes), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB"}) -} From f62e4e5ba3a7cc15c65585a742e48f4e415d8ac3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 15:34:44 +0100 Subject: [PATCH 13/20] feat(docker): add support for building and running Diffusers Docker images --- Makefile | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 4578f8742..5d6ecb41d 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ VLLM_BASE_IMAGE := nvidia/cuda:13.0.2-runtime-ubuntu24.04 DOCKER_IMAGE := docker/model-runner:latest DOCKER_IMAGE_VLLM := docker/model-runner:latest-vllm-cuda DOCKER_IMAGE_SGLANG := docker/model-runner:latest-sglang +DOCKER_IMAGE_DIFFUSERS := docker/model-runner:latest-diffusers DOCKER_TARGET ?= final-llamacpp PORT := 8080 MODELS_PATH := $(shell pwd)/models-store @@ -32,7 +33,7 @@ LICENSE ?= BUILD_DMR ?= 1 # Main targets -.PHONY: build run clean test integration-tests test-docker-ce-installation docker-build docker-build-multiplatform docker-run docker-build-vllm docker-run-vllm docker-build-sglang docker-run-sglang docker-run-impl help validate lint model-distribution-tool +.PHONY: build run clean test integration-tests test-docker-ce-installation docker-build docker-build-multiplatform docker-run docker-build-vllm docker-run-vllm docker-build-sglang docker-run-sglang docker-build-diffusers docker-run-diffusers docker-run-impl help validate lint model-distribution-tool # Default target .DEFAULT_GOAL := build @@ -129,6 +130,16 @@ docker-build-sglang: docker-run-sglang: docker-build-sglang @$(MAKE) -s docker-run-impl DOCKER_IMAGE=$(DOCKER_IMAGE_SGLANG) +# Build Diffusers Docker image +docker-build-diffusers: + @$(MAKE) docker-build \ + DOCKER_TARGET=final-diffusers \ + DOCKER_IMAGE=$(DOCKER_IMAGE_DIFFUSERS) + +# Run Diffusers Docker container with TCP port access and mounted model storage +docker-run-diffusers: docker-build-diffusers + @$(MAKE) -s docker-run-impl DOCKER_IMAGE=$(DOCKER_IMAGE_DIFFUSERS) + # Common implementation for running Docker container docker-run-impl: @echo "" @@ -193,6 +204,8 @@ help: @echo " docker-run-vllm - Run vLLM Docker container" @echo " docker-build-sglang - Build SGLang Docker image" @echo " docker-run-sglang - Run SGLang Docker container" + @echo " docker-build-diffusers - Build Diffusers Docker image" + @echo " docker-run-diffusers - Run Diffusers Docker container" @echo " help - Show this help message" @echo "" @echo "Model distribution tool targets:" From a4f76e99aa0fdd8ea07cd0bd9fc14523b4df0344 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 15:40:48 +0100 Subject: [PATCH 14/20] feat(client): add support for Diffusers format in GetSupportedFormats function --- pkg/distribution/distribution/client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index d153e96dd..e1815d151 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -638,9 +638,9 @@ func (c *Client) GetBundle(ref string) (types.ModelBundle, error) { func GetSupportedFormats() []types.Format { if platform.SupportsVLLM() { - return []types.Format{types.FormatGGUF, types.FormatSafetensors} + return []types.Format{types.FormatGGUF, types.FormatSafetensors, types.FormatDiffusers} } - return []types.Format{types.FormatGGUF} + return []types.Format{types.FormatGGUF, types.FormatDiffusers} } func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string, progressWriter io.Writer) error { From 237b9ddfb5ac974c02bea1166e36d3d272a3f2b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 15:49:00 +0100 Subject: [PATCH 15/20] feat(docker): enhance GPU support for additional Docker image variants --- scripts/docker-run.sh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/scripts/docker-run.sh b/scripts/docker-run.sh index a8b064b9d..4d0faa046 100755 --- a/scripts/docker-run.sh +++ b/scripts/docker-run.sh @@ -1,8 +1,10 @@ #!/bin/bash add_accelerators() { - # Add NVIDIA GPU support for CUDA variants - if [[ "${DOCKER_IMAGE-}" == *"-cuda" ]]; then + # Add NVIDIA GPU support for CUDA variants and GPU-accelerated backends + if [[ "${DOCKER_IMAGE-}" == *"-cuda" ]] || \ + [[ "${DOCKER_IMAGE-}" == *"-diffusers" ]] || \ + [[ "${DOCKER_IMAGE-}" == *"-sglang" ]]; then args+=("--gpus" "all" "--runtime=nvidia") fi @@ -79,4 +81,3 @@ main() { } main "$@" - From 200f84bd760ac74516236962fa2935bc6a95a193 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Thu, 15 Jan 2026 19:23:13 +0100 Subject: [PATCH 16/20] feat: add support for image-generation mode in backend operations --- cmd/cli/commands/compose_test.go | 6 ++++++ cmd/cli/commands/configure_flags.go | 6 ++++-- cmd/cli/docs/reference/docker_model_configure.yaml | 3 ++- pkg/inference/scheduling/loader.go | 1 + 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/cmd/cli/commands/compose_test.go b/cmd/cli/commands/compose_test.go index d4a8b59ee..f0960503f 100644 --- a/cmd/cli/commands/compose_test.go +++ b/cmd/cli/commands/compose_test.go @@ -45,6 +45,12 @@ func TestParseBackendMode(t *testing.T) { expected: inference.BackendModeReranking, expectError: false, }, + { + name: "image-generation mode", + input: "image-generation", + expected: inference.BackendModeImageGeneration, + expectError: false, + }, { name: "invalid mode", input: "invalid", diff --git a/cmd/cli/commands/configure_flags.go b/cmd/cli/commands/configure_flags.go index 769916793..11e17cfdb 100644 --- a/cmd/cli/commands/configure_flags.go +++ b/cmd/cli/commands/configure_flags.go @@ -146,7 +146,7 @@ func (f *ConfigureFlags) RegisterFlags(cmd *cobra.Command) { cmd.Flags().StringVar(&f.HFOverrides, "hf_overrides", "", "HuggingFace model config overrides (JSON) - vLLM only") cmd.Flags().Var(NewFloat64PtrValue(&f.GPUMemoryUtilization), "gpu-memory-utilization", "fraction of GPU memory to use for the model executor (0.0-1.0) - vLLM only") cmd.Flags().Var(NewBoolPtrValue(&f.Think), "think", "enable reasoning mode for thinking models") - cmd.Flags().StringVar(&f.Mode, "mode", "", "backend operation mode (completion, embedding, reranking)") + cmd.Flags().StringVar(&f.Mode, "mode", "", "backend operation mode (completion, embedding, reranking, image-generation)") } // BuildConfigureRequest builds a scheduling.ConfigureRequest from the flags. @@ -243,7 +243,9 @@ func parseBackendMode(mode string) (inference.BackendMode, error) { return inference.BackendModeEmbedding, nil case "reranking": return inference.BackendModeReranking, nil + case "image-generation": + return inference.BackendModeImageGeneration, nil default: - return inference.BackendModeCompletion, fmt.Errorf("invalid mode %q: must be one of completion, embedding, reranking", mode) + return inference.BackendModeCompletion, fmt.Errorf("invalid mode %q: must be one of completion, embedding, reranking, image-generation", mode) } } diff --git a/cmd/cli/docs/reference/docker_model_configure.yaml b/cmd/cli/docs/reference/docker_model_configure.yaml index ce7ac0158..a8bfb16a2 100644 --- a/cmd/cli/docs/reference/docker_model_configure.yaml +++ b/cmd/cli/docs/reference/docker_model_configure.yaml @@ -40,7 +40,8 @@ options: swarm: false - option: mode value_type: string - description: backend operation mode (completion, embedding, reranking) + description: | + backend operation mode (completion, embedding, reranking, image-generation) deprecated: false hidden: false experimental: false diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 4a40ee095..538027d90 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -288,6 +288,7 @@ func (l *loader) Unload(ctx context.Context, unload UnloadRequest) int { l.evictRunner(unload.Backend, modelID, inference.BackendModeCompletion) l.evictRunner(unload.Backend, modelID, inference.BackendModeEmbedding) l.evictRunner(unload.Backend, modelID, inference.BackendModeReranking) + l.evictRunner(unload.Backend, modelID, inference.BackendModeImageGeneration) } return len(l.runners) } From 3562f8ae1240d99b4d48098f97da73b33910541c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Fri, 16 Jan 2026 13:28:49 +0100 Subject: [PATCH 17/20] feat(loader): support fallback for image-generation mode in runner config --- pkg/inference/scheduling/loader.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 538027d90..ddfe582fd 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -426,8 +426,8 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string if runnerConfig.Speculative != nil && runnerConfig.Speculative.DraftModel != "" { draftModelID = l.modelManager.ResolveID(runnerConfig.Speculative.DraftModel) } - } else if mode == inference.BackendModeReranking { - // For reranking mode, fallback to completion config if specific config is not found. + } else if (mode == inference.BackendModeReranking) || (mode == inference.BackendModeImageGeneration) { + // For reranking or image-generation mode, fallback to completion config if specific config is not found. if rc, ok := l.runnerConfigs[makeConfigKey(backendName, modelID, inference.BackendModeCompletion)]; ok { runnerConfig = &rc if runnerConfig.Speculative != nil && runnerConfig.Speculative.DraftModel != "" { From cf31231091c171a790a54fe147c294d0ae129f58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Fri, 16 Jan 2026 13:31:09 +0100 Subject: [PATCH 18/20] feat(diffusers): initialize Diffusers backend in main.go --- diffusers_backend.go | 26 -------------------------- diffusers_backend_stub.go | 17 ----------------- main.go | 18 +++++++++++++----- 3 files changed, 13 insertions(+), 48 deletions(-) delete mode 100644 diffusers_backend.go delete mode 100644 diffusers_backend_stub.go diff --git a/diffusers_backend.go b/diffusers_backend.go deleted file mode 100644 index eb5161ff4..000000000 --- a/diffusers_backend.go +++ /dev/null @@ -1,26 +0,0 @@ -//go:build !nodiffusers - -package main - -import ( - "github.com/docker/model-runner/pkg/inference" - "github.com/docker/model-runner/pkg/inference/backends/diffusers" - "github.com/docker/model-runner/pkg/inference/models" - "github.com/sirupsen/logrus" -) - -func initDiffusersBackend(log *logrus.Logger, modelManager *models.Manager, customPythonPath string) (inference.Backend, error) { - return diffusers.New( - log, - modelManager, - log.WithFields(logrus.Fields{"component": diffusers.Name}), - nil, - customPythonPath, - ) -} - -func registerDiffusersBackend(backends map[string]inference.Backend, backend inference.Backend) { - if backend != nil { - backends[diffusers.Name] = backend - } -} diff --git a/diffusers_backend_stub.go b/diffusers_backend_stub.go deleted file mode 100644 index 62b25fa0d..000000000 --- a/diffusers_backend_stub.go +++ /dev/null @@ -1,17 +0,0 @@ -//go:build nodiffusers - -package main - -import ( - "github.com/docker/model-runner/pkg/inference" - "github.com/docker/model-runner/pkg/inference/models" - "github.com/sirupsen/logrus" -) - -func initDiffusersBackend(log *logrus.Logger, modelManager *models.Manager, customPythonPath string) (inference.Backend, error) { - return nil, nil // Diffusers backend is disabled -} - -func registerDiffusersBackend(backends map[string]inference.Backend, backend inference.Backend) { - // Diffusers backend is disabled, do nothing -} diff --git a/main.go b/main.go index 31e91cd91..17826d415 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "github.com/docker/model-runner/pkg/anthropic" "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/backends/diffusers" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/mlx" "github.com/docker/model-runner/pkg/inference/backends/sglang" @@ -157,18 +158,25 @@ func main() { log.Fatalf("unable to initialize %s backend: %v", sglang.Name, err) } - diffusersBackend, err := initDiffusersBackend(log, modelManager, diffusersServerPath) + diffusersBackend, err := diffusers.New( + log, + modelManager, + log.WithFields(logrus.Fields{"component": diffusers.Name}), + nil, + diffusersServerPath, + ) + if err != nil { log.Fatalf("unable to initialize diffusers backend: %v", err) } backends := map[string]inference.Backend{ - llamacpp.Name: llamaCppBackend, - mlx.Name: mlxBackend, - sglang.Name: sglangBackend, + llamacpp.Name: llamaCppBackend, + mlx.Name: mlxBackend, + sglang.Name: sglangBackend, + diffusers.Name: diffusersBackend, } registerVLLMBackend(backends, vllmBackend) - registerDiffusersBackend(backends, diffusersBackend) scheduler := scheduling.NewScheduler( log, From 587cdadcea2c579de25fa67e0d4a4bbd61d8459b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Fri, 16 Jan 2026 15:03:38 +0100 Subject: [PATCH 19/20] feat(diffusers): add error transformation for Python output and enhance backend error handling --- Dockerfile | 2 +- pkg/inference/backends/diffusers/diffusers.go | 17 ++-- pkg/inference/backends/diffusers/errors.go | 65 +++++++++++++ .../backends/diffusers/errors_test.go | 97 +++++++++++++++++++ pkg/inference/backends/runner.go | 15 ++- python/diffusers_server/server.py | 28 ++++-- 6 files changed, 208 insertions(+), 16 deletions(-) create mode 100644 pkg/inference/backends/diffusers/errors.go create mode 100644 pkg/inference/backends/diffusers/errors_test.go diff --git a/Dockerfile b/Dockerfile index 4be350728..e6738a693 100644 --- a/Dockerfile +++ b/Dockerfile @@ -155,7 +155,7 @@ ARG TRANSFORMERS_VERSION=4.57.5 ARG ACCELERATE_VERSION=1.3.0 ARG SAFETENSORS_VERSION=0.5.2 ARG HUGGINGFACE_HUB_VERSION=0.34.0 -ARG BITSANDBYTES_VERSION=0.45.4 +ARG BITSANDBYTES_VERSION=0.49.1 ARG FASTAPI_VERSION=0.115.12 ARG UVICORN_VERSION=0.34.1 ARG PILLOW_VERSION=11.2.1 diff --git a/pkg/inference/backends/diffusers/diffusers.go b/pkg/inference/backends/diffusers/diffusers.go index 1eb0a315e..a966c6678 100644 --- a/pkg/inference/backends/diffusers/diffusers.go +++ b/pkg/inference/backends/diffusers/diffusers.go @@ -180,14 +180,15 @@ func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef stri } return backends.RunBackend(ctx, backends.RunnerConfig{ - BackendName: "Diffusers", - Socket: socket, - BinaryPath: d.pythonPath, - SandboxPath: sandboxPath, - SandboxConfig: "", - Args: args, - Logger: d.log, - ServerLogWriter: d.serverLog.Writer(), + BackendName: "Diffusers", + Socket: socket, + BinaryPath: d.pythonPath, + SandboxPath: sandboxPath, + SandboxConfig: "", + Args: args, + Logger: d.log, + ServerLogWriter: d.serverLog.Writer(), + ErrorTransformer: ExtractPythonError, }) } diff --git a/pkg/inference/backends/diffusers/errors.go b/pkg/inference/backends/diffusers/errors.go new file mode 100644 index 000000000..e3bb68149 --- /dev/null +++ b/pkg/inference/backends/diffusers/errors.go @@ -0,0 +1,65 @@ +package diffusers + +import ( + "fmt" + "regexp" + "strings" +) + +// pythonErrorPatterns contains regex patterns to extract meaningful error messages +// from Python tracebacks. The patterns are tried in order, and the first match wins. +var pythonErrorPatterns = []*regexp.Regexp{ + // Custom error marker from our Python server (highest priority) + regexp.MustCompile(`(?m)^DIFFUSERS_ERROR:\s*(.+)$`), + // Python RuntimeError, ValueError, etc. + regexp.MustCompile(`(?m)^(RuntimeError|ValueError|TypeError|OSError|ImportError|ModuleNotFoundError):\s*(.+)$`), + // CUDA/GPU related errors + regexp.MustCompile(`(?mi)(CUDA|GPU|out of memory|OOM|No GPU found)[^.]*\.?`), + // Generic Python Exception with message + regexp.MustCompile(`(?m)^(\w+Error):\s*(.+)$`), +} + +// ExtractPythonError attempts to extract a meaningful error message from Python output. +// It looks for common error patterns and returns a cleaner, more user-friendly message. +// If no recognizable pattern is found, it returns the original output. +func ExtractPythonError(output string) string { + // Try each pattern in order + for i, pattern := range pythonErrorPatterns { + matches := pattern.FindStringSubmatch(output) + if len(matches) > 0 { + switch i { + case 0: + // Custom error marker: return just the message + return strings.TrimSpace(matches[1]) + case 1: + // Standard Python errors: "ErrorType: message" + return fmt.Sprintf("%s: %s", matches[1], strings.TrimSpace(matches[2])) + case 2: + // GPU/CUDA related errors + return strings.TrimSpace(matches[0]) + case 3: + // Generic Python errors + return fmt.Sprintf("%s: %s", matches[1], strings.TrimSpace(matches[2])) + } + } + } + + // No pattern matched - return original but try to trim some noise + // Take only the last few meaningful lines + lines := strings.Split(strings.TrimSpace(output), "\n") + if len(lines) > 5 { + // Return the last 5 non-empty lines + var meaningful []string + for i := len(lines) - 1; i >= 0 && len(meaningful) < 5; i-- { + line := strings.TrimSpace(lines[i]) + if line != "" && !strings.HasPrefix(line, " ") { + meaningful = append([]string{line}, meaningful...) + } + } + if len(meaningful) > 0 { + return strings.Join(meaningful, "\n") + } + } + + return output +} diff --git a/pkg/inference/backends/diffusers/errors_test.go b/pkg/inference/backends/diffusers/errors_test.go new file mode 100644 index 000000000..69f84d0ff --- /dev/null +++ b/pkg/inference/backends/diffusers/errors_test.go @@ -0,0 +1,97 @@ +package diffusers + +import ( + "testing" +) + +func TestExtractPythonError(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "custom diffusers error marker", + input: "DIFFUSERS_ERROR: No GPU found. A GPU is needed for quantization.", + expected: "No GPU found. A GPU is needed for quantization.", + }, + { + name: "custom error marker in traceback", + input: `Traceback (most recent call last): + File "server.py", line 350, in main + load_model(args.model_path) +RuntimeError: Failed to load DDUF file: No GPU found +DIFFUSERS_ERROR: No GPU found. A GPU is needed for quantization.`, + expected: "No GPU found. A GPU is needed for quantization.", + }, + { + name: "python runtime error", + input: `RuntimeError: Failed to load DDUF file: No GPU found. A GPU is needed for quantization. +RuntimeError: No GPU found. A GPU is needed for quantization.`, + expected: "RuntimeError: Failed to load DDUF file: No GPU found. A GPU is needed for quantization.", + }, + { + name: "full python traceback", + input: ` raise RuntimeError(f"Failed to load DDUF file: {e}") +RuntimeError: Failed to load DDUF file: No GPU found. A GPU is needed for quantization. +RuntimeError: No GPU found. A GPU is needed for quantization. + +During handling of the above exception, another exception occurred: + +Traceback (most recent call last): + File "", line 198, in _run_module_as_main + File "", line 88, in _run_code + File "/opt/diffusers-env/lib/python3.12/site-packages/diffusers_server/server.py", line 358, in + main() + File "/opt/diffusers-env/lib/python3.12/site-packages/diffusers_server/server.py", line 350, in main + load_model(args.model_path) + File "/opt/diffusers-env/lib/python3.12/site-packages/diffusers_server/server.py", line 139, in load_model + pipeline = load_model_from_dduf(model_path, device, dtype)`, + expected: "RuntimeError: Failed to load DDUF file: No GPU found. A GPU is needed for quantization.", + }, + { + name: "GPU not found error", + input: "Some log output\nNo GPU found. A GPU is needed for quantization.\nMore logs", + expected: "No GPU found.", + }, + { + name: "CUDA out of memory error", + input: "CUDA out of memory. Tried to allocate 2.00 GiB", + expected: "CUDA out of memory.", + }, + { + name: "import error", + input: "ImportError: No module named 'torch'", + expected: "ImportError: No module named 'torch'", + }, + { + name: "module not found error", + input: "ModuleNotFoundError: No module named 'diffusers'", + expected: "ModuleNotFoundError: No module named 'diffusers'", + }, + { + name: "value error", + input: "ValueError: Invalid model path", + expected: "ValueError: Invalid model path", + }, + { + name: "short output without pattern", + input: "some random error", + expected: "some random error", + }, + { + name: "empty output", + input: "", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractPythonError(tt.input) + if result != tt.expected { + t.Errorf("ExtractPythonError() = %q, want %q", result, tt.expected) + } + }) + } +} diff --git a/pkg/inference/backends/runner.go b/pkg/inference/backends/runner.go index 244fa1c7b..186857340 100644 --- a/pkg/inference/backends/runner.go +++ b/pkg/inference/backends/runner.go @@ -16,6 +16,11 @@ import ( "github.com/docker/model-runner/pkg/tailbuffer" ) +// ErrorTransformer is a function that transforms raw error output +// into a more user-friendly message. Backends can provide their own +// implementation to customize error presentation. +type ErrorTransformer func(output string) string + // RunnerConfig holds configuration for a backend runner type RunnerConfig struct { // BackendName is the display name of the backend (e.g., "llama.cpp", "vLLM") @@ -34,6 +39,9 @@ type RunnerConfig struct { Logger Logger // ServerLogWriter provides a writer for server logs ServerLogWriter io.WriteCloser + // ErrorTransformer is an optional function to transform error output + // into a more user-friendly message. If nil, the raw output is used. + ErrorTransformer ErrorTransformer } // Logger interface for backend logging @@ -103,7 +111,12 @@ func RunBackend(ctx context.Context, config RunnerConfig) error { } if errOutput.String() != "" { - backendErr = fmt.Errorf("%s exit status: %w\nwith output: %s", config.BackendName, backendErr, errOutput.String()) + errorMsg := errOutput.String() + // Apply error transformer if provided + if config.ErrorTransformer != nil { + errorMsg = config.ErrorTransformer(errorMsg) + } + backendErr = fmt.Errorf("%s failed: %s", config.BackendName, errorMsg) } else { backendErr = fmt.Errorf("%s exit status: %w", config.BackendName, backendErr) } diff --git a/python/diffusers_server/server.py b/python/diffusers_server/server.py index 3897390a4..1db699b24 100644 --- a/python/diffusers_server/server.py +++ b/python/diffusers_server/server.py @@ -346,12 +346,28 @@ def main(): global served_model_name served_model_name = args.served_model_name or args.model_path - # Load the model at startup - load_model(args.model_path) - - # Start the server - logger.info(f"Starting server on {args.host}:{args.port}") - uvicorn.run(app, host=args.host, port=args.port, log_level="info") + try: + # Load the model at startup + load_model(args.model_path) + + # Start the server + logger.info(f"Starting server on {args.host}:{args.port}") + uvicorn.run(app, host=args.host, port=args.port, log_level="info") + except Exception as e: + # Extract the root cause error message for cleaner output + error_msg = str(e) + # If this is a chained exception, try to get the original cause + root_cause = e + while root_cause.__cause__ is not None: + root_cause = root_cause.__cause__ + if root_cause is not e: + error_msg = str(root_cause) + + # Print a clean, single-line error message that can be easily parsed + # This format is recognized by the Go backend for better error reporting + import sys + print(f"DIFFUSERS_ERROR: {error_msg}", file=sys.stderr) + sys.exit(1) if __name__ == "__main__": From 78d8bcd0aaad396933bfc5da271309e46433c361 Mon Sep 17 00:00:00 2001 From: Dorin Geman Date: Fri, 16 Jan 2026 16:57:23 +0200 Subject: [PATCH 20/20] fix(scripts/docker-run): conditionally add nvidia runtime flags Only add --gpus and --runtime=nvidia when the nvidia runtime is detected, allowing diffusers/sglang images to run on non-NVIDIA hosts. Signed-off-by: Dorin Geman --- scripts/docker-run.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/docker-run.sh b/scripts/docker-run.sh index 4d0faa046..59fbaea69 100755 --- a/scripts/docker-run.sh +++ b/scripts/docker-run.sh @@ -5,7 +5,9 @@ add_accelerators() { if [[ "${DOCKER_IMAGE-}" == *"-cuda" ]] || \ [[ "${DOCKER_IMAGE-}" == *"-diffusers" ]] || \ [[ "${DOCKER_IMAGE-}" == *"-sglang" ]]; then - args+=("--gpus" "all" "--runtime=nvidia") + if docker info -f '{{range $k, $v := .Runtimes}}{{$k}}{{"\n"}}{{end}}' 2>/dev/null | grep -qx "nvidia"; then + args+=("--gpus" "all" "--runtime=nvidia") + fi fi # Add GPU/accelerator devices if present