From d8e7c45768cee0570e0e02ffc1dc742002894241 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 16 Mar 2026 17:57:33 +0000 Subject: [PATCH 01/64] new serve file --- src/transformers/cli/serve_refactored.py | 149 +++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 src/transformers/cli/serve_refactored.py diff --git a/src/transformers/cli/serve_refactored.py b/src/transformers/cli/serve_refactored.py new file mode 100644 index 000000000000..199b46174b12 --- /dev/null +++ b/src/transformers/cli/serve_refactored.py @@ -0,0 +1,149 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CLI entry point for `transformers serve`. +""" + +from __future__ import annotations + +import asyncio +import threading +from typing import Annotated + +import typer + +from transformers.utils import logging +from transformers.utils.import_utils import ( + is_fastapi_available, + is_openai_available, + is_pydantic_available, + is_uvicorn_available, +) + +from .serving.protocol import set_torch_seed + + +serve_dependencies_available = ( + is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available() +) + +logger = logging.get_logger(__name__) + + +class Serve: + def __init__( + self, + # TODO: maybe rename it to model ? + force_model: Annotated[ + str | None, typer.Argument(help="Model to preload and use for all requests.") + ] = None, + # Model options + device: Annotated[str, typer.Option(help="Device for inference; defaults to 'auto'.")] = "auto", + dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto", + attn_implementation: Annotated[ + str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).") + ] = None, + quantization: Annotated[ + str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.") + ] = None, + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False, + model_timeout: Annotated[ + int, typer.Option(help="Seconds before idle model is unloaded. Ignored when model is set.") + ] = 300, + # Server options + host: Annotated[str, typer.Option(help="Server listen address.")] = "localhost", + port: Annotated[int, typer.Option(help="Server listen port.")] = 8000, + enable_cors: Annotated[bool, typer.Option(help="Enable permissive CORS.")] = False, + log_level: Annotated[str, typer.Option(help="Logging level (e.g. 'info', 'warning').")] = "info", + default_seed: Annotated[int | None, typer.Option(help="Default torch seed.")] = None, + non_blocking: Annotated[ + bool, typer.Option(hidden=True, help="Run server in a background thread. Used by tests.") + ] = False, + ) -> None: + if not serve_dependencies_available: + raise ImportError( + "Missing dependencies for serving. Install with `pip install transformers[serving]`" + ) + + import uvicorn + + from .serving.app import build_app + from .serving.handlers.chat_completion import ChatCompletionHandler + from .serving.model_manager import ModelManager + + # Seed + if default_seed is not None: + set_torch_seed(default_seed) + + # Logging + transformers_logger = logging.get_logger("transformers") + transformers_logger.setLevel(logging.log_levels[log_level.lower()]) + + # Preloaded models should never be auto-unloaded + if force_model: + model_timeout = -1 + + model_manager = ModelManager( + device=device, + dtype=dtype, + trust_remote_code=trust_remote_code, + attn_implementation=attn_implementation, + quantization=quantization, + model_timeout=model_timeout, + force_model=force_model, + ) + + chat_handler = ChatCompletionHandler( + model_manager=model_manager, + force_model=force_model, + ) + + app = build_app(model_manager, chat_handler, enable_cors=enable_cors) + + config = uvicorn.Config(app, host=host, port=port, log_level=log_level) + self.server = uvicorn.Server(config) + + if non_blocking: + self.start_server() + else: + self.server.run() + + def start_server(self): + def _run(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.server.serve()) + + self._thread = threading.Thread(target=_run, name="uvicorn-thread", daemon=False) + self._thread.start() + + def kill_server(self): + if not self._thread or not self._thread.is_alive(): + return + self.server.should_exit = True + self._thread.join(timeout=2) + + +Serve.__doc__ = """ +Run a FastAPI server to serve models on-demand with an OpenAI compatible API. +Models will be loaded and unloaded automatically based on usage and a timeout. + +\b +Endpoints: + POST /v1/chat/completions — Chat completions (streaming + non-streaming). + GET /v1/models — Lists available models. + GET /health — Health check. + +Requires FastAPI and Uvicorn: pip install transformers[serving] +""" From f2388673e9adfc70b0e207a41bc3ebdc448045ac Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 16 Mar 2026 18:33:15 +0000 Subject: [PATCH 02/64] app --- src/transformers/cli/serving/app.py | 95 +++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 src/transformers/cli/serving/app.py diff --git a/src/transformers/cli/serving/app.py b/src/transformers/cli/serving/app.py new file mode 100644 index 000000000000..a519fe8ad54d --- /dev/null +++ b/src/transformers/cli/serving/app.py @@ -0,0 +1,95 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FastAPI app factory. +""" + +from __future__ import annotations + +import uuid +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from ...utils import logging +from .handlers.chat_completion import ChatCompletionHandler +from .model_manager import ModelManager +from .protocol import X_REQUEST_ID + + +logger = logging.get_logger(__name__) + + +def build_app( + model_manager: ModelManager, + chat_handler: ChatCompletionHandler, + enable_cors: bool = False, +) -> FastAPI: + """Build and return a configured FastAPI application. + + Args: + model_manager: Handles model loading, caching, and cleanup. + chat_handler: Handles `/v1/chat/completions` requests. + enable_cors: If `True`, adds permissive CORS middleware (allow all origins). + + Returns: + A FastAPI app ready to be passed to uvicorn. + """ + + @asynccontextmanager + async def lifespan(app: FastAPI): + yield + model_manager.shutdown() + + app = FastAPI(lifespan=lifespan) + + if enable_cors: + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + logger.warning_once("CORS allow origin is set to `*`. Not recommended for production.") + + # ---- Middleware ---- + + @app.middleware("http") + async def request_id_middleware(request: Request, call_next): + """Get or set the request ID in the header.""" + request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) + request.state.request_id = request_id + response = await call_next(request) + response.headers[X_REQUEST_ID] = request_id + return response + + # ---- Routes ---- + + @app.post("/v1/chat/completions") + async def chat_completions(request: Request, body: dict): + return chat_handler.handle_request(body, request.state.request_id) + + @app.get("/v1/models") + @app.options("/v1/models") + def list_models(): + return JSONResponse({"object": "list", "data": model_manager.get_gen_models()}) + + @app.get("/health") + def health(): + return JSONResponse({"status": "ok"}) + + return app From be0291d743e5d8211737e546f6cf6dc561dd513d Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 16 Mar 2026 19:45:47 +0000 Subject: [PATCH 03/64] model_manager done --- src/transformers/cli/serving/model_manager.py | 286 ++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 src/transformers/cli/serving/model_manager.py diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py new file mode 100644 index 000000000000..52fb8180884a --- /dev/null +++ b/src/transformers/cli/serving/model_manager.py @@ -0,0 +1,286 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Model loading, caching, and lifecycle management. +""" + +from __future__ import annotations + +import gc +import json +import threading +from functools import lru_cache +from typing import TYPE_CHECKING + +from huggingface_hub import scan_cache_dir +from tqdm import tqdm + +import transformers +from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizerBase + +from ...utils import logging +from .protocol import Modality, reset_torch_cache + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin + + +logger = logging.get_logger(__name__) + + +class TimedModel: + """Wraps a model + processor and auto-deletes them after a period of inactivity. + + Args: + model: The loaded model. + timeout_seconds: Seconds of inactivity before auto-deletion. Use -1 to disable. + processor: The associated processor or tokenizer. + """ + + def __init__( + self, + model: PreTrainedModel, + timeout_seconds: int, + processor: ProcessorMixin | PreTrainedTokenizerFast | None = None, + ): + self.model = model + self._name_or_path = str(model.name_or_path) + self.processor = processor + self.timeout_seconds = timeout_seconds + self._timer = threading.Timer(self.timeout_seconds, self._timeout_reached) + self._timer.start() + + def reset_timer(self) -> None: + """Reset the inactivity timer (called on each request).""" + self._timer.cancel() + self._timer = threading.Timer(self.timeout_seconds, self._timeout_reached) + self._timer.start() + + def delete_model(self) -> None: + """Delete the model and processor, free GPU memory.""" + if hasattr(self, "model") and self.model is not None: + del self.model + del self.processor + self.model = None + self.processor = None + gc.collect() + reset_torch_cache() + self._timer.cancel() + + def _timeout_reached(self) -> None: + if self.timeout_seconds > 0: + self.delete_model() + logger.info(f"{self._name_or_path} was removed from memory after {self.timeout_seconds}s of inactivity") + + def is_deleted(self) -> bool: + """Check if the model has been deleted (by timeout or manually).""" + return not hasattr(self, "model") or self.model is None + + +class ModelManager: + """Loads, caches, and manages the lifecycle of models. + + Handlers receive a reference to this and call `load_model_and_processor()` + to get a model ready for inference. + + Args: + device: Device to place models on (e.g. "auto", "cuda", "cpu"). + dtype: Torch dtype override. "auto" derives from model weights. + trust_remote_code: Whether to trust remote code when loading models. + attn_implementation: Attention implementation override (e.g. "flash_attention_2"). + quantization: Quantization method ("bnb-4bit" or "bnb-8bit"). + model_timeout: Seconds before an idle model is unloaded. -1 disables. + force_model: If set, preload this model at init time. + processor_id: Override processor/tokenizer model ID. Needed for GGUF models + whose repos don't include tokenizer files. + """ + + def __init__( + self, + device: str = "auto", + dtype: str | None = "auto", + trust_remote_code: bool = False, + attn_implementation: str | None = None, + quantization: str | None = None, + model_timeout: int = 300, + force_model: str | None = None, + # TODO: auto-detect from GGUF base_model metadata + processor_id: str | None = None, + ): + self.device = device + self.dtype = dtype + self.trust_remote_code = trust_remote_code + self.attn_implementation = attn_implementation + self.quantization = quantization + self.model_timeout = model_timeout + self.processor_id = processor_id + + self.loaded_models: dict[str, TimedModel] = {} + + if force_model is not None: + self.load_model_and_processor(self.process_model_name(force_model)) + + @staticmethod + def process_model_name(model_id: str) -> str: + """Canonicalize to `'model_id@revision'` format. Defaults to `@main`.""" + if "@" in model_id: + return model_id + return f"{model_id}@main" + + def get_quantization_config(self) -> BitsAndBytesConfig | None: + """Return a BitsAndBytesConfig based on the `quantization` setting, or None.""" + if self.quantization == "bnb-4bit": + return BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + ) + elif self.quantization == "bnb-8bit": + return BitsAndBytesConfig(load_in_8bit=True) + return None + + def _load_processor(self, model_id_and_revision: str) -> ProcessorMixin | PreTrainedTokenizerFast: + """Load a processor, trying AutoProcessor first then AutoTokenizer. + + If `processor_id` was set (e.g. for GGUF models), uses that instead of `model_id`. + Expects `model_id_and_revision` in `'model_id@revision'` format (from `process_model_name`). + """ + from transformers import AutoProcessor + + if self.processor_id: + model_id, revision = self.processor_id, "main" + else: + model_id, revision = model_id_and_revision.split("@", 1) + try: + return AutoProcessor.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code) + except OSError: + try: + return AutoTokenizer.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code) + except OSError: + raise OSError(f"Failed to load processor for {model_id} with AutoProcessor and AutoTokenizer.") + + def _load_model(self, model_id_and_revision: str) -> PreTrainedModel: + """Load a model. GGUF files are detected by the `.gguf` extension and loaded via llama.cpp.""" + import torch + + from transformers import AutoConfig + + model_id, revision = model_id_and_revision.split("@", 1) + + if model_id.endswith(".gguf"): + from llama_cpp_transformers import LlamaCppTransformersModel + + flash_attn = True if self.attn_implementation == "flash_attention_2" else "auto" + return LlamaCppTransformersModel.from_pretrained( + model_id, revision=revision, n_gpu_layers=-1, flash_attn=flash_attn, + ) + + dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype) + model_kwargs = { + "revision": revision, + "attn_implementation": self.attn_implementation, + "dtype": dtype, + "device_map": self.device, + "trust_remote_code": self.trust_remote_code, + "quantization_config": self.get_quantization_config(), + } + + config = AutoConfig.from_pretrained(model_id, **model_kwargs) + architecture = getattr(transformers, config.architectures[0]) + return architecture.from_pretrained(model_id, **model_kwargs) + + def load_model_and_processor( + self, model_id_and_revision: str + ) -> tuple[PreTrainedModel, ProcessorMixin | PreTrainedTokenizerFast]: + """Load a model (or return it from cache), resetting its inactivity timer.""" + if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted(): + processor = self._load_processor(model_id_and_revision) + model = self._load_model(model_id_and_revision) + self.loaded_models[model_id_and_revision] = TimedModel( + model, timeout_seconds=self.model_timeout, processor=processor + ) + else: + self.loaded_models[model_id_and_revision].reset_timer() + model = self.loaded_models[model_id_and_revision].model + processor = self.loaded_models[model_id_and_revision].processor + return model, processor + + def shutdown(self) -> None: + """Delete all loaded models and free resources.""" + for timed in self.loaded_models.values(): + timed.delete_model() + self.loaded_models.clear() + + @staticmethod + def get_model_modality(model: PreTrainedModel, processor=None) -> Modality: + """Detect whether a model is an LLM or VLM based on its architecture.""" + if processor is not None and isinstance(processor, PreTrainedTokenizerBase): + return Modality.LLM + + from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, + ) + + model_classname = model.__class__.__name__ + if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values(): + return Modality.VLM + elif model_classname in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + return Modality.LLM + else: + raise ValueError(f"Unknown modality for: {model_classname}") + + @staticmethod + @lru_cache + def get_gen_models(cache_dir: str | None = None) -> list[dict]: + """List generative models (LLMs and VLMs) available in the HuggingFace cache.""" + from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, + ) + + generative_models = [] + logger.warning("Scanning the cache directory for LLMs and VLMs.") + + for repo in tqdm(scan_cache_dir(cache_dir).repos): + if repo.repo_type != "model": + continue + + for ref, revision_info in repo.refs.items(): + config_path = next((f.file_path for f in revision_info.files if f.file_name == "config.json"), None) + if not config_path: + continue + + config = json.loads(config_path.open().read()) + if not (isinstance(config, dict) and "architectures" in config): + continue + + architectures = config["architectures"] + llms = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values() + vlms = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() + + if any(arch for arch in architectures if arch in [*llms, *vlms]): + author = repo.repo_id.split("/") if "/" in repo.repo_id else "" + repo_handle = repo.repo_id + (f"@{ref}" if ref != "main" else "") + generative_models.append( + { + "owned_by": author, + "id": repo_handle, + "object": "model", + "created": repo.last_modified, + } + ) + + return generative_models From e84d82ef36acc267a654a19ffbe2753235571aa0 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 17 Mar 2026 10:38:14 +0000 Subject: [PATCH 04/64] update serve --- src/transformers/cli/serve_refactored.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/transformers/cli/serve_refactored.py b/src/transformers/cli/serve_refactored.py index 199b46174b12..15015d0dfaae 100644 --- a/src/transformers/cli/serve_refactored.py +++ b/src/transformers/cli/serve_refactored.py @@ -31,7 +31,7 @@ is_uvicorn_available, ) -from .serving.protocol import set_torch_seed +from .serving.utils import set_torch_seed serve_dependencies_available = ( @@ -44,10 +44,8 @@ class Serve: def __init__( self, - # TODO: maybe rename it to model ? - force_model: Annotated[ - str | None, typer.Argument(help="Model to preload and use for all requests.") - ] = None, + # TODO: maybe rename it to model ? + force_model: Annotated[str | None, typer.Argument(help="Model to preload and use for all requests.")] = None, # Model options device: Annotated[str, typer.Option(help="Device for inference; defaults to 'auto'.")] = "auto", dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto", @@ -58,6 +56,10 @@ def __init__( str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.") ] = None, trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False, + # TODO: auto-detect processor from GGUF base_model metadata so this flag isn't needed + processor: Annotated[ + str | None, typer.Option(help="Processor/tokenizer model ID. Needed for GGUF models.") + ] = None, model_timeout: Annotated[ int, typer.Option(help="Seconds before idle model is unloaded. Ignored when model is set.") ] = 300, @@ -72,9 +74,7 @@ def __init__( ] = False, ) -> None: if not serve_dependencies_available: - raise ImportError( - "Missing dependencies for serving. Install with `pip install transformers[serving]`" - ) + raise ImportError("Missing dependencies for serving. Install with `pip install transformers[serving]`") import uvicorn @@ -102,6 +102,7 @@ def __init__( quantization=quantization, model_timeout=model_timeout, force_model=force_model, + processor_id=processor, ) chat_handler = ChatCompletionHandler( From fb77305240023af2b0002dc446cb56a9c622feb3 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 17 Mar 2026 10:38:21 +0000 Subject: [PATCH 05/64] style --- src/transformers/cli/serving/model_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 52fb8180884a..5407a27b1c49 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -30,7 +30,8 @@ from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizerBase from ...utils import logging -from .protocol import Modality, reset_torch_cache +from .protocol import Modality +from .utils import reset_torch_cache if TYPE_CHECKING: From d869d62a7d03cbf2723068298909735160f09b56 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 17 Mar 2026 12:54:47 +0000 Subject: [PATCH 06/64] poc done --- src/transformers/cli/serving/__init__.py | 17 + .../cli/serving/handlers/__init__.py | 0 src/transformers/cli/serving/handlers/base.py | 55 +++ .../cli/serving/handlers/chat_completion.py | 322 ++++++++++++++++++ src/transformers/cli/serving/model_manager.py | 13 +- src/transformers/cli/serving/protocol.py | 112 ++++++ src/transformers/cli/serving/utils.py | 38 +++ 7 files changed, 554 insertions(+), 3 deletions(-) create mode 100644 src/transformers/cli/serving/__init__.py create mode 100644 src/transformers/cli/serving/handlers/__init__.py create mode 100644 src/transformers/cli/serving/handlers/base.py create mode 100644 src/transformers/cli/serving/handlers/chat_completion.py create mode 100644 src/transformers/cli/serving/protocol.py create mode 100644 src/transformers/cli/serving/utils.py diff --git a/src/transformers/cli/serving/__init__.py b/src/transformers/cli/serving/__init__.py new file mode 100644 index 000000000000..3adfb1f2f665 --- /dev/null +++ b/src/transformers/cli/serving/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .app import build_app +from .model_manager import ModelManager +from .protocol import Modality diff --git a/src/transformers/cli/serving/handlers/__init__.py b/src/transformers/cli/serving/handlers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/transformers/cli/serving/handlers/base.py b/src/transformers/cli/serving/handlers/base.py new file mode 100644 index 000000000000..17db5bb09171 --- /dev/null +++ b/src/transformers/cli/serving/handlers/base.py @@ -0,0 +1,55 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Base handler for endpoint handlers. +""" + +from __future__ import annotations + +from pydantic import BaseModel + +from transformers import GenerationConfig + +from ..model_manager import ModelManager + + +class BaseHandler: + """Base class for endpoint handlers. Stores the model manager.""" + + def __init__(self, model_manager: ModelManager): + self.model_manager = model_manager + + @staticmethod + def _apply_default_generation_config(generation_config: GenerationConfig) -> None: + """Apply sensible serving defaults. Many models ship with too few max_new_tokens.""" + if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 1024: + generation_config.max_new_tokens = 1024 + + @staticmethod + def chunk_to_sse(chunk: BaseModel | str) -> str: + """Format a pydantic model or string as a Server-Sent Event. + + Serializes with `exclude_none=True` — some clients (e.g. Cursor) assume + that when a field exists in the JSON, it has data. + + Args: + chunk: A pydantic BaseModel (ChatCompletionChunk, Response event, etc.) + or a pre-formatted string (error paths). + + Returns: + An SSE-formatted string: `data: {json}\\n\\n` + """ + if isinstance(chunk, str): + return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n" + return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" diff --git a/src/transformers/cli/serving/handlers/chat_completion.py b/src/transformers/cli/serving/handlers/chat_completion.py new file mode 100644 index 000000000000..95e255e22f74 --- /dev/null +++ b/src/transformers/cli/serving/handlers/chat_completion.py @@ -0,0 +1,322 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Handler for the /v1/chat/completions endpoint. + +Supports streaming (SSE via DirectStreamer) and non-streaming (JSON) responses. +""" + +from __future__ import annotations + +import asyncio +import copy +import json +import time +from collections.abc import AsyncGenerator +from threading import Thread +from typing import TYPE_CHECKING + +import torch + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerFast, ProcessorMixin +from fastapi import HTTPException +from fastapi.responses import JSONResponse, StreamingResponse +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta +from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk +from tokenizers.decoders import DecodeStream + +from transformers import GenerationConfig, PreTrainedModel +from transformers.generation.streamers import BaseStreamer + +from ....utils import logging +from ..model_manager import ModelManager +from ..protocol import ( + UNUSED_CHAT_COMPLETION_FIELDS, + TransformersCompletionCreateParamsStreaming, + get_processor_inputs_from_messages, +) +from ..utils import _StreamError, set_torch_seed +from .base import BaseHandler + + +logger = logging.get_logger(__name__) + + +class DirectStreamer(BaseStreamer): + """Streamer that decodes tokens incrementally and pushes text to an asyncio.Queue. + + Uses the Rust `DecodeStream.step()` for O(1) per-token decode, unlike + `TextIteratorStreamer` which re-decodes the full sequence each time. + + Args: + processor: A HuggingFace processor or tokenizer (must have a `._tokenizer` attribute). + loop: The asyncio event loop to push results to. + queue: The asyncio.Queue to push decoded text chunks to. + skip_special_tokens: Whether to skip special tokens during decoding. + """ + + def __init__( + self, + processor: ProcessorMixin | PreTrainedTokenizerFast, + loop: asyncio.AbstractEventLoop, + queue: asyncio.Queue, + skip_special_tokens: bool = True, + ): + self._tokenizer = processor._tokenizer # raw tokenizers.Tokenizer + self._loop = loop + self._queue = queue + self._decode_stream = DecodeStream([], skip_special_tokens) + self._first = True + self.total_tokens = 0 + + def put(self, value: torch.Tensor) -> None: + if self._first: + self._first = False + return # skip prompt tokens + if len(value.shape) > 1: + value = value[0] + for token_id in value.tolist(): + self.total_tokens += 1 + text = self._decode_stream.step(self._tokenizer, token_id) + if text is not None: + self._loop.call_soon_threadsafe(self._queue.put_nowait, text) + + def end(self) -> None: + self._loop.call_soon_threadsafe(self._queue.put_nowait, None) + + +class ChatCompletionHandler(BaseHandler): + """Handler for the `/v1/chat/completions` endpoint. + + Supports both streaming (SSE) and non-streaming (JSON) responses. + + Args: + model_manager: The model manager to load models from. + force_model: If set, override the model field in every request. + """ + + def __init__(self, model_manager: ModelManager, force_model: str | None = None): + super().__init__(model_manager) + self.force_model = force_model + + def _validate_request(self, body: dict) -> None: + """Validate a chat completion request. Raises HTTPException if invalid.""" + logger.debug(f"Validating request: {body}") + + input_keys = set(body.keys()) + unexpected = input_keys - TransformersCompletionCreateParamsStreaming.__mutable_keys__ + if unexpected: + logger.error(f"Unexpected keys in the request: {unexpected}") + raise HTTPException(status_code=422, detail=f"Unexpected keys in the request: {unexpected}") + + # TODO: add back strict Pydantic validation (input_validation flag) + unused = input_keys & UNUSED_CHAT_COMPLETION_FIELDS + if unused: + logger.error(f"Unsupported fields in the request: {unused}") + raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") + + def _build_generation_config(self, body: dict, model_generation_config: GenerationConfig) -> GenerationConfig: + """Map Chat Completions API params to a GenerationConfig. + + If `body` contains a `generation_config` JSON string, it is used as baseline + (overriding the model default). Other body params are applied on top. + """ + if body.get("generation_config") is not None: + generation_config = GenerationConfig(**json.loads(body["generation_config"])) + else: + generation_config = copy.deepcopy(model_generation_config) + self._apply_default_generation_config(generation_config) + + if body.get("max_tokens") is not None: + generation_config.max_new_tokens = int(body["max_tokens"]) + if body.get("frequency_penalty") is not None: + generation_config.repetition_penalty = 1.0 + float(body["frequency_penalty"]) + if body.get("logit_bias") is not None: + generation_config.sequence_bias = {(int(k),): v for k, v in body["logit_bias"].items()} + if body.get("stop") is not None: + generation_config.stop_strings = body["stop"] + if body.get("temperature") is not None: + generation_config.temperature = float(body["temperature"]) + if float(body["temperature"]) == 0.0: + generation_config.do_sample = False + if body.get("top_p") is not None: + generation_config.top_p = float(body["top_p"]) + if body.get("seed") is not None: + set_torch_seed(body["seed"]) + + return generation_config + + def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: + """Validate the request, load the model, and dispatch to streaming or non-streaming.""" + self._validate_request(body) + + if self.force_model is not None: + body["model"] = self.force_model + + messages = body["messages"] + + # HACK: tiny-agents sends requests ending with assistant message — skip + if messages and messages[-1]["role"] == "assistant": + return JSONResponse({}, status_code=200) + + model_id = self.model_manager.process_model_name(body["model"]) + model, processor = self.model_manager.load_model_and_processor(model_id) + + modality = self.model_manager.get_model_modality(model, processor=processor) + processor_inputs = get_processor_inputs_from_messages(messages, modality) + + inputs = processor.apply_chat_template( + processor_inputs, + add_generation_prompt=True, + tools=body.get("tools"), + return_tensors="pt", + return_dict=True, + tokenize=True, + ).to(model.device) + + gen_config = self._build_generation_config(body, model.generation_config) + + if body.get("stream"): + return self._streaming(request_id, model, processor, model_id, inputs, gen_config) + return self._non_streaming(request_id, model, processor, model_id, inputs, gen_config) + + # ----- streaming ----- + + def _streaming( + self, + request_id: str, + model: PreTrainedModel, + processor: ProcessorMixin | PreTrainedTokenizerFast, + model_id: str, + inputs: dict[str, torch.Tensor], + gen_config: GenerationConfig, + ) -> StreamingResponse: + """Run generation in a background thread, stream tokens as SSE via DirectStreamer.""" + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + streamer = DirectStreamer(processor, loop, queue, skip_special_tokens=True) + gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config} + + def _run() -> None: + try: + model.generate(**gen_kwargs) + except Exception as e: + loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e))) + + Thread(target=_run, daemon=True).start() + + async def sse_gen() -> AsyncGenerator[str, None]: + # First chunk: tell the client the assistant is about to speak (OpenAI protocol) + yield self._build_chunk_sse(request_id, role="assistant", model=model_id) + + # Stream tokens as they arrive from the background generate thread + while True: + text = await queue.get() + if text is None: + break # generation done + elif isinstance(text, _StreamError): + yield f'data: {{"error": "{text.msg}"}}\n\n' + return + + yield self._build_chunk_sse(request_id, content=text, model=model_id) + + # Last chunk: tell the client why generation stopped + hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens + if hit_max: + logger.warning( + f"Generation hit max_new_tokens={gen_config.max_new_tokens}. " + "Output may be truncated. Use `max_tokens` to increase the limit." + ) + yield self._build_chunk_sse(request_id, finish_reason="length" if hit_max else "stop", model=model_id) + + return StreamingResponse(sse_gen(), media_type="text/event-stream") + + # ----- non-streaming ----- + + def _non_streaming( + self, + request_id: str, + model: PreTrainedModel, + processor: ProcessorMixin | PreTrainedTokenizerFast, + model_id: str, + inputs: dict[str, torch.Tensor], + gen_config: GenerationConfig, + ) -> JSONResponse: + """Run generation synchronously and return a JSONResponse.""" + sequences = model.generate(**inputs, generation_config=gen_config) + + input_len = inputs["input_ids"].shape[-1] + generated_ids = sequences[0, input_len:] + content = processor.decode(generated_ids, skip_special_tokens=True) + + hit_max = gen_config.max_new_tokens is not None and len(generated_ids) >= gen_config.max_new_tokens + if hit_max: + logger.warning( + f"Generation hit max_new_tokens={gen_config.max_new_tokens}. " + "Output may be truncated. Use `max_tokens` to increase the limit." + ) + return JSONResponse( + self._build_completion(request_id, content, model_id, finish_reason="length" if hit_max else "stop"), + media_type="application/json", + ) + + # ----- response builders ----- + + def _build_completion(self, request_id: str, content: str, model_id: str, finish_reason: str) -> dict: + """Build a non-streaming ChatCompletion response dict.""" + result = ChatCompletion( + id=request_id, + created=int(time.time()), + object="chat.completion", + model=model_id, + choices=[ + Choice( + index=0, + message=ChatCompletionMessage(content=content, role="assistant"), + finish_reason=finish_reason, + ) + ], + ) + return result.model_dump(exclude_none=True) + + def _build_chunk_sse( + self, + request_id: str = "", + content: str | None = None, + model: str | None = None, + role: str | None = None, + finish_reason: str | None = None, + tool_calls: list | None = None, + ) -> str: + """Build a streaming ChatCompletionChunk and format it as an SSE event.""" + chunk = ChatCompletionChunk( + id=request_id, + created=int(time.time()), + model=model, + choices=[ + ChoiceChunk( + delta=ChoiceDelta(content=content, role=role, tool_calls=tool_calls), + index=0, + finish_reason=finish_reason, + ) + ], + system_fingerprint="", + object="chat.completion.chunk", + ) + return self.chunk_to_sse(chunk) diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 5407a27b1c49..ef3ffbf2bd30 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -168,7 +168,9 @@ def _load_processor(self, model_id_and_revision: str) -> ProcessorMixin | PreTra return AutoProcessor.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code) except OSError: try: - return AutoTokenizer.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code) + return AutoTokenizer.from_pretrained( + model_id, revision=revision, trust_remote_code=self.trust_remote_code + ) except OSError: raise OSError(f"Failed to load processor for {model_id} with AutoProcessor and AutoTokenizer.") @@ -185,7 +187,10 @@ def _load_model(self, model_id_and_revision: str) -> PreTrainedModel: flash_attn = True if self.attn_implementation == "flash_attention_2" else "auto" return LlamaCppTransformersModel.from_pretrained( - model_id, revision=revision, n_gpu_layers=-1, flash_attn=flash_attn, + model_id, + revision=revision, + n_gpu_layers=-1, + flash_attn=flash_attn, ) dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype) @@ -225,7 +230,9 @@ def shutdown(self) -> None: self.loaded_models.clear() @staticmethod - def get_model_modality(model: PreTrainedModel, processor=None) -> Modality: + def get_model_modality( + model: PreTrainedModel, processor: ProcessorMixin | PreTrainedTokenizerFast | None = None + ) -> Modality: """Detect whether a model is an LLM or VLM based on its architecture.""" if processor is not None and isinstance(processor, PreTrainedTokenizerBase): return Modality.LLM diff --git a/src/transformers/cli/serving/protocol.py b/src/transformers/cli/serving/protocol.py new file mode 100644 index 000000000000..c63bd9bd5f1b --- /dev/null +++ b/src/transformers/cli/serving/protocol.py @@ -0,0 +1,112 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +API contract: request types, constants, and format conversion between OpenAI and HF. +""" + +from __future__ import annotations + +import base64 +import enum +import re +import tempfile +from io import BytesIO + +from transformers.utils.import_utils import is_openai_available, is_vision_available + + +if is_vision_available(): + from PIL import Image + +if is_openai_available(): + from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming + + class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): + generation_config: str + + +X_REQUEST_ID = "x-request-id" + +# Fields accepted by the OpenAI schema but not yet supported. +# Receiving these raises an error to avoid silent misbehaviour. +# NOTE: "stop" is NOT in this set — we map it to stop_strings. +UNUSED_CHAT_COMPLETION_FIELDS = { + "audio", + "function_call", + "functions", + "logprobs", + "max_completion_tokens", + "metadata", + "modalities", + "n", + "parallel_tool_calls", + "prediction", + "presence_penalty", + "reasoning_effort", + "response_format", + "service_tier", + "store", + "stream_options", + "tool_choice", + "top_logprobs", + "user", + "web_search_options", +} + + +class Modality(enum.Enum): + LLM = "LLM" + VLM = "VLM" + STT = "STT" + TTS = "TTS" + + +# --------------------------------------------------------------------------- +# Message preprocessing: OpenAI messages → processor-compatible format +# --------------------------------------------------------------------------- + + +def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) -> list[dict]: + """Convert OpenAI-format messages to the format expected by HF processors.""" + processor_inputs = [] + + for message in messages: + parsed = {"role": message["role"], "content": []} + + if modality == Modality.LLM: + if isinstance(message["content"], str): + parsed["content"] = message["content"] + elif isinstance(message["content"], list): + texts = [c["text"] for c in message["content"] if c["type"] == "text"] + parsed["content"] = " ".join(texts) + + elif modality == Modality.VLM: + if isinstance(message["content"], str): + parsed["content"].append({"type": "text", "text": message["content"]}) + else: + for content in message["content"]: + if content["type"] == "text": + parsed["content"].append(content) + elif content["type"] == "image_url": + url = content["image_url"]["url"] + if "base64" in url: + image_data = re.sub("^data:image/.+;base64,", "", url) + image = Image.open(BytesIO(base64.b64decode(image_data))) + file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + image.save(file.name) + url = file.name + parsed["content"].append({"type": "image", "url": url}) + + processor_inputs.append(parsed) + return processor_inputs diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py new file mode 100644 index 000000000000..f3aee7e33c8c --- /dev/null +++ b/src/transformers/cli/serving/utils.py @@ -0,0 +1,38 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Shared utilities for the serving layer. +""" + +from __future__ import annotations + + +class _StreamError: + """Sentinel to signal an error from the generate thread.""" + + def __init__(self, msg: str): + self.msg = msg + + +def set_torch_seed(seed: int) -> None: + import torch + + torch.manual_seed(seed) + + +def reset_torch_cache() -> None: + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() From bd734e882a4703aff1e026965a897ed4aaf70932 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 17 Mar 2026 13:37:02 +0000 Subject: [PATCH 07/64] renaming --- src/transformers/cli/serve_refactored.py | 6 +- .../serving/{handlers => }/chat_completion.py | 130 ++++++++++-------- .../cli/serving/handlers/__init__.py | 0 src/transformers/cli/serving/handlers/base.py | 55 -------- src/transformers/cli/serving/model_manager.py | 3 +- src/transformers/cli/serving/protocol.py | 112 --------------- .../cli/serving/{app.py => server.py} | 6 +- src/transformers/cli/serving/utils.py | 110 ++++++++++++++- 8 files changed, 190 insertions(+), 232 deletions(-) rename src/transformers/cli/serving/{handlers => }/chat_completion.py (92%) delete mode 100644 src/transformers/cli/serving/handlers/__init__.py delete mode 100644 src/transformers/cli/serving/handlers/base.py delete mode 100644 src/transformers/cli/serving/protocol.py rename src/transformers/cli/serving/{app.py => server.py} (96%) diff --git a/src/transformers/cli/serve_refactored.py b/src/transformers/cli/serve_refactored.py index 15015d0dfaae..66eb082dfa84 100644 --- a/src/transformers/cli/serve_refactored.py +++ b/src/transformers/cli/serve_refactored.py @@ -78,9 +78,9 @@ def __init__( import uvicorn - from .serving.app import build_app - from .serving.handlers.chat_completion import ChatCompletionHandler + from .serving.chat_completion import ChatCompletionHandler from .serving.model_manager import ModelManager + from .serving.server import build_server # Seed if default_seed is not None: @@ -110,7 +110,7 @@ def __init__( force_model=force_model, ) - app = build_app(model_manager, chat_handler, enable_cors=enable_cors) + app = build_server(model_manager, chat_handler, enable_cors=enable_cors) config = uvicorn.Config(app, host=host, port=port, log_level=log_level) self.server = uvicorn.Server(config) diff --git a/src/transformers/cli/serving/handlers/chat_completion.py b/src/transformers/cli/serving/chat_completion.py similarity index 92% rename from src/transformers/cli/serving/handlers/chat_completion.py rename to src/transformers/cli/serving/chat_completion.py index 95e255e22f74..58a1273b01a8 100644 --- a/src/transformers/cli/serving/handlers/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizerFast, ProcessorMixin + from fastapi import HTTPException from fastapi.responses import JSONResponse, StreamingResponse from openai.types.chat import ChatCompletion, ChatCompletionMessage @@ -43,15 +44,15 @@ from transformers import GenerationConfig, PreTrainedModel from transformers.generation.streamers import BaseStreamer -from ....utils import logging -from ..model_manager import ModelManager -from ..protocol import ( +from ...utils import logging +from .model_manager import ModelManager +from .utils import ( UNUSED_CHAT_COMPLETION_FIELDS, TransformersCompletionCreateParamsStreaming, + _StreamError, get_processor_inputs_from_messages, + set_torch_seed, ) -from ..utils import _StreamError, set_torch_seed -from .base import BaseHandler logger = logging.get_logger(__name__) @@ -100,7 +101,7 @@ def end(self) -> None: self._loop.call_soon_threadsafe(self._queue.put_nowait, None) -class ChatCompletionHandler(BaseHandler): +class ChatCompletionHandler: """Handler for the `/v1/chat/completions` endpoint. Supports both streaming (SSE) and non-streaming (JSON) responses. @@ -111,55 +112,10 @@ class ChatCompletionHandler(BaseHandler): """ def __init__(self, model_manager: ModelManager, force_model: str | None = None): - super().__init__(model_manager) + self.model_manager = model_manager self.force_model = force_model - def _validate_request(self, body: dict) -> None: - """Validate a chat completion request. Raises HTTPException if invalid.""" - logger.debug(f"Validating request: {body}") - - input_keys = set(body.keys()) - unexpected = input_keys - TransformersCompletionCreateParamsStreaming.__mutable_keys__ - if unexpected: - logger.error(f"Unexpected keys in the request: {unexpected}") - raise HTTPException(status_code=422, detail=f"Unexpected keys in the request: {unexpected}") - - # TODO: add back strict Pydantic validation (input_validation flag) - unused = input_keys & UNUSED_CHAT_COMPLETION_FIELDS - if unused: - logger.error(f"Unsupported fields in the request: {unused}") - raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") - - def _build_generation_config(self, body: dict, model_generation_config: GenerationConfig) -> GenerationConfig: - """Map Chat Completions API params to a GenerationConfig. - - If `body` contains a `generation_config` JSON string, it is used as baseline - (overriding the model default). Other body params are applied on top. - """ - if body.get("generation_config") is not None: - generation_config = GenerationConfig(**json.loads(body["generation_config"])) - else: - generation_config = copy.deepcopy(model_generation_config) - self._apply_default_generation_config(generation_config) - - if body.get("max_tokens") is not None: - generation_config.max_new_tokens = int(body["max_tokens"]) - if body.get("frequency_penalty") is not None: - generation_config.repetition_penalty = 1.0 + float(body["frequency_penalty"]) - if body.get("logit_bias") is not None: - generation_config.sequence_bias = {(int(k),): v for k, v in body["logit_bias"].items()} - if body.get("stop") is not None: - generation_config.stop_strings = body["stop"] - if body.get("temperature") is not None: - generation_config.temperature = float(body["temperature"]) - if float(body["temperature"]) == 0.0: - generation_config.do_sample = False - if body.get("top_p") is not None: - generation_config.top_p = float(body["top_p"]) - if body.get("seed") is not None: - set_torch_seed(body["seed"]) - - return generation_config + # ----- entry point ----- def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: """Validate the request, load the model, and dispatch to streaming or non-streaming.""" @@ -211,7 +167,7 @@ def _streaming( queue: asyncio.Queue = asyncio.Queue() streamer = DirectStreamer(processor, loop, queue, skip_special_tokens=True) - gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config} + gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor} def _run() -> None: try: @@ -259,7 +215,7 @@ def _non_streaming( gen_config: GenerationConfig, ) -> JSONResponse: """Run generation synchronously and return a JSONResponse.""" - sequences = model.generate(**inputs, generation_config=gen_config) + sequences = model.generate(**inputs, generation_config=gen_config, tokenizer=processor) input_len = inputs["input_ids"].shape[-1] generated_ids = sequences[0, input_len:] @@ -276,6 +232,61 @@ def _non_streaming( media_type="application/json", ) + # ----- helpers ----- + + def _validate_request(self, body: dict) -> None: + """Validate a chat completion request. Raises HTTPException if invalid.""" + logger.debug(f"Validating request: {body}") + + input_keys = set(body.keys()) + unexpected = input_keys - TransformersCompletionCreateParamsStreaming.__mutable_keys__ + if unexpected: + logger.error(f"Unexpected keys in the request: {unexpected}") + raise HTTPException(status_code=422, detail=f"Unexpected keys in the request: {unexpected}") + + # TODO: add back strict Pydantic validation (input_validation flag) + unused = input_keys & UNUSED_CHAT_COMPLETION_FIELDS + if unused: + logger.error(f"Unsupported fields in the request: {unused}") + raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") + + @staticmethod + def _apply_default_generation_config(generation_config: GenerationConfig) -> None: + """Apply sensible serving defaults. Many models ship with too few max_new_tokens.""" + if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 1024: + generation_config.max_new_tokens = 1024 + + def _build_generation_config(self, body: dict, model_generation_config: GenerationConfig) -> GenerationConfig: + """Map Chat Completions API params to a GenerationConfig. + + If `body` contains a `generation_config` JSON string, it is used as baseline + (overriding the model default). Other body params are applied on top. + """ + if body.get("generation_config") is not None: + generation_config = GenerationConfig(**json.loads(body["generation_config"])) + else: + generation_config = copy.deepcopy(model_generation_config) + self._apply_default_generation_config(generation_config) + + if body.get("max_tokens") is not None: + generation_config.max_new_tokens = int(body["max_tokens"]) + if body.get("frequency_penalty") is not None: + generation_config.repetition_penalty = 1.0 + float(body["frequency_penalty"]) + if body.get("logit_bias") is not None: + generation_config.sequence_bias = {(int(k),): v for k, v in body["logit_bias"].items()} + if body.get("stop") is not None: + generation_config.stop_strings = body["stop"] + if body.get("temperature") is not None: + generation_config.temperature = float(body["temperature"]) + if float(body["temperature"]) == 0.0: + generation_config.do_sample = False + if body.get("top_p") is not None: + generation_config.top_p = float(body["top_p"]) + if body.get("seed") is not None: + set_torch_seed(body["seed"]) + + return generation_config + # ----- response builders ----- def _build_completion(self, request_id: str, content: str, model_id: str, finish_reason: str) -> dict: @@ -319,4 +330,11 @@ def _build_chunk_sse( system_fingerprint="", object="chat.completion.chunk", ) - return self.chunk_to_sse(chunk) + return self._chunk_to_sse(chunk) + + @staticmethod + def _chunk_to_sse(chunk) -> str: + """Format a pydantic model or string as an SSE event.""" + if isinstance(chunk, str): + return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n" + return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" diff --git a/src/transformers/cli/serving/handlers/__init__.py b/src/transformers/cli/serving/handlers/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/src/transformers/cli/serving/handlers/base.py b/src/transformers/cli/serving/handlers/base.py deleted file mode 100644 index 17db5bb09171..000000000000 --- a/src/transformers/cli/serving/handlers/base.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Base handler for endpoint handlers. -""" - -from __future__ import annotations - -from pydantic import BaseModel - -from transformers import GenerationConfig - -from ..model_manager import ModelManager - - -class BaseHandler: - """Base class for endpoint handlers. Stores the model manager.""" - - def __init__(self, model_manager: ModelManager): - self.model_manager = model_manager - - @staticmethod - def _apply_default_generation_config(generation_config: GenerationConfig) -> None: - """Apply sensible serving defaults. Many models ship with too few max_new_tokens.""" - if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 1024: - generation_config.max_new_tokens = 1024 - - @staticmethod - def chunk_to_sse(chunk: BaseModel | str) -> str: - """Format a pydantic model or string as a Server-Sent Event. - - Serializes with `exclude_none=True` — some clients (e.g. Cursor) assume - that when a field exists in the JSON, it has data. - - Args: - chunk: A pydantic BaseModel (ChatCompletionChunk, Response event, etc.) - or a pre-formatted string (error paths). - - Returns: - An SSE-formatted string: `data: {json}\\n\\n` - """ - if isinstance(chunk, str): - return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n" - return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index ef3ffbf2bd30..07cba3845c5d 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -30,8 +30,7 @@ from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizerBase from ...utils import logging -from .protocol import Modality -from .utils import reset_torch_cache +from .utils import Modality, reset_torch_cache if TYPE_CHECKING: diff --git a/src/transformers/cli/serving/protocol.py b/src/transformers/cli/serving/protocol.py deleted file mode 100644 index c63bd9bd5f1b..000000000000 --- a/src/transformers/cli/serving/protocol.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -API contract: request types, constants, and format conversion between OpenAI and HF. -""" - -from __future__ import annotations - -import base64 -import enum -import re -import tempfile -from io import BytesIO - -from transformers.utils.import_utils import is_openai_available, is_vision_available - - -if is_vision_available(): - from PIL import Image - -if is_openai_available(): - from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming - - class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): - generation_config: str - - -X_REQUEST_ID = "x-request-id" - -# Fields accepted by the OpenAI schema but not yet supported. -# Receiving these raises an error to avoid silent misbehaviour. -# NOTE: "stop" is NOT in this set — we map it to stop_strings. -UNUSED_CHAT_COMPLETION_FIELDS = { - "audio", - "function_call", - "functions", - "logprobs", - "max_completion_tokens", - "metadata", - "modalities", - "n", - "parallel_tool_calls", - "prediction", - "presence_penalty", - "reasoning_effort", - "response_format", - "service_tier", - "store", - "stream_options", - "tool_choice", - "top_logprobs", - "user", - "web_search_options", -} - - -class Modality(enum.Enum): - LLM = "LLM" - VLM = "VLM" - STT = "STT" - TTS = "TTS" - - -# --------------------------------------------------------------------------- -# Message preprocessing: OpenAI messages → processor-compatible format -# --------------------------------------------------------------------------- - - -def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) -> list[dict]: - """Convert OpenAI-format messages to the format expected by HF processors.""" - processor_inputs = [] - - for message in messages: - parsed = {"role": message["role"], "content": []} - - if modality == Modality.LLM: - if isinstance(message["content"], str): - parsed["content"] = message["content"] - elif isinstance(message["content"], list): - texts = [c["text"] for c in message["content"] if c["type"] == "text"] - parsed["content"] = " ".join(texts) - - elif modality == Modality.VLM: - if isinstance(message["content"], str): - parsed["content"].append({"type": "text", "text": message["content"]}) - else: - for content in message["content"]: - if content["type"] == "text": - parsed["content"].append(content) - elif content["type"] == "image_url": - url = content["image_url"]["url"] - if "base64" in url: - image_data = re.sub("^data:image/.+;base64,", "", url) - image = Image.open(BytesIO(base64.b64decode(image_data))) - file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) - image.save(file.name) - url = file.name - parsed["content"].append({"type": "image", "url": url}) - - processor_inputs.append(parsed) - return processor_inputs diff --git a/src/transformers/cli/serving/app.py b/src/transformers/cli/serving/server.py similarity index 96% rename from src/transformers/cli/serving/app.py rename to src/transformers/cli/serving/server.py index a519fe8ad54d..e88366cb0f9c 100644 --- a/src/transformers/cli/serving/app.py +++ b/src/transformers/cli/serving/server.py @@ -25,15 +25,15 @@ from fastapi.responses import JSONResponse from ...utils import logging -from .handlers.chat_completion import ChatCompletionHandler +from .chat_completion import ChatCompletionHandler from .model_manager import ModelManager -from .protocol import X_REQUEST_ID +from .utils import X_REQUEST_ID logger = logging.get_logger(__name__) -def build_app( +def build_server( model_manager: ModelManager, chat_handler: ChatCompletionHandler, enable_cors: bool = False, diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index f3aee7e33c8c..ed366766535e 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -12,11 +12,74 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Shared utilities for the serving layer. +Shared types, constants, and utilities for the serving layer. """ from __future__ import annotations +import base64 +import enum +import re +import tempfile +from io import BytesIO + +from transformers.utils.import_utils import is_openai_available, is_vision_available + + +if is_vision_available(): + from PIL import Image + +if is_openai_available(): + from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming + + class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): + generation_config: str + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +X_REQUEST_ID = "x-request-id" + +# Fields accepted by the OpenAI schema but not yet supported. +# Receiving these raises an error to avoid silent misbehaviour. +# NOTE: "stop" is NOT in this set — we map it to stop_strings. +UNUSED_CHAT_COMPLETION_FIELDS = { + "audio", + "function_call", + "functions", + "logprobs", + "max_completion_tokens", + "metadata", + "modalities", + "n", + "parallel_tool_calls", + "prediction", + "presence_penalty", + "reasoning_effort", + "response_format", + "service_tier", + "store", + "stream_options", + "tool_choice", + "top_logprobs", + "user", + "web_search_options", +} + + +# --------------------------------------------------------------------------- +# Types +# --------------------------------------------------------------------------- + + +class Modality(enum.Enum): + LLM = "LLM" + VLM = "VLM" + STT = "STT" + TTS = "TTS" + class _StreamError: """Sentinel to signal an error from the generate thread.""" @@ -25,6 +88,11 @@ def __init__(self, msg: str): self.msg = msg +# --------------------------------------------------------------------------- +# Torch helpers +# --------------------------------------------------------------------------- + + def set_torch_seed(seed: int) -> None: import torch @@ -36,3 +104,43 @@ def reset_torch_cache() -> None: if torch.cuda.is_available(): torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Message preprocessing: OpenAI messages → processor-compatible format +# --------------------------------------------------------------------------- + + +def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) -> list[dict]: + """Convert OpenAI-format messages to the format expected by HF processors.""" + processor_inputs = [] + + for message in messages: + parsed = {"role": message["role"], "content": []} + + if modality == Modality.LLM: + if isinstance(message["content"], str): + parsed["content"] = message["content"] + elif isinstance(message["content"], list): + texts = [c["text"] for c in message["content"] if c["type"] == "text"] + parsed["content"] = " ".join(texts) + + elif modality == Modality.VLM: + if isinstance(message["content"], str): + parsed["content"].append({"type": "text", "text": message["content"]}) + else: + for content in message["content"]: + if content["type"] == "text": + parsed["content"].append(content) + elif content["type"] == "image_url": + url = content["image_url"]["url"] + if "base64" in url: + image_data = re.sub("^data:image/.+;base64,", "", url) + image = Image.open(BytesIO(base64.b64decode(image_data))) + file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + image.save(file.name) + url = file.name + parsed["content"].append({"type": "image", "url": url}) + + processor_inputs.append(parsed) + return processor_inputs From 69d3264406cae21733b33af752e1a5631e51416c Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 17 Mar 2026 15:56:42 +0000 Subject: [PATCH 08/64] fix --- src/transformers/cli/serving/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/cli/serving/__init__.py b/src/transformers/cli/serving/__init__.py index 3adfb1f2f665..118d3a9c2012 100644 --- a/src/transformers/cli/serving/__init__.py +++ b/src/transformers/cli/serving/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .app import build_app from .model_manager import ModelManager -from .protocol import Modality +from .server import build_server +from .utils import Modality From f5afd6c4a4489bcbed69a5bc0c2db30454342174 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 17 Mar 2026 15:56:48 +0000 Subject: [PATCH 09/64] new tests --- tests/cli/test_serve_refactored.py | 562 +++++++++++++++++++++++++++++ 1 file changed, 562 insertions(+) create mode 100644 tests/cli/test_serve_refactored.py diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py new file mode 100644 index 000000000000..972a66ec1a37 --- /dev/null +++ b/tests/cli/test_serve_refactored.py @@ -0,0 +1,562 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the refactored serving layer (Phase 1: chat completions). + +Run: pytest tests/cli/test_serve_refactored.py -x -v +Integration tests (need GPU): RUN_SLOW=1 pytest tests/cli/test_serve_refactored.py -x -v -k "Integration" +""" + +import asyncio +import json +import os +import time +import unittest +from unittest.mock import MagicMock + +from transformers.testing_utils import require_openai, slow +from transformers.utils.import_utils import is_openai_available, is_vision_available + + +if is_openai_available(): + from openai import OpenAI + +run_slow = os.environ.get("RUN_SLOW", "0") == "1" + + +# --------------------------------------------------------------------------- +# 1. CLI tests — verify CLI args reach uvicorn +# --------------------------------------------------------------------------- + + +@require_openai +def test_host_port_blocking(cli): + """CLI args --host and --port are passed to uvicorn.Config, and server.run() is called.""" + from unittest.mock import Mock, patch + + with ( + patch("uvicorn.Config") as ConfigMock, + patch("uvicorn.Server") as ServerMock, + ): + server_instance = Mock() + ServerMock.return_value = server_instance + + out = cli("serve", "--host", "0.0.0.0", "--port", "9000") + _, kwargs = ConfigMock.call_args + + assert out.exit_code == 0 + assert kwargs["host"] == "0.0.0.0" + assert kwargs["port"] == 9000 + ServerMock.assert_called_once_with(ConfigMock.return_value) + server_instance.run.assert_called_once() + + +# --------------------------------------------------------------------------- +# 2. Unit tests — message parsing +# --------------------------------------------------------------------------- + + +class TestProcessorInputsFromMessages(unittest.TestCase): + def test_llm_string_content(self): + from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + + messages = [{"role": "user", "content": "Hello"}] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(result, [{"role": "user", "content": "Hello"}]) + + def test_llm_list_content_text_only(self): + from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + + messages = [{"role": "user", "content": [{"type": "text", "text": "A"}, {"type": "text", "text": "B"}]}] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(result, [{"role": "user", "content": "A B"}]) + + def test_vlm_string_content_wrapped(self): + from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + + messages = [{"role": "user", "content": "Hello"}] + result = get_processor_inputs_from_messages(messages, Modality.VLM) + self.assertEqual(result, [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]) + + def test_vlm_text_and_image_url(self): + from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/img.png"}}, + ], + } + ] + result = get_processor_inputs_from_messages(messages, Modality.VLM) + self.assertEqual(len(result[0]["content"]), 2) + self.assertEqual(result[0]["content"][0]["type"], "text") + self.assertEqual(result[0]["content"][1], {"type": "image", "url": "https://example.com/img.png"}) + + def test_llm_multi_turn_conversation(self): + """Multi-turn conversation with string content should pass through as-is.""" + from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + + messages = [ + {"role": "user", "content": "How are you?"}, + {"role": "assistant", "content": "I'm great!"}, + {"role": "user", "content": "Help me write tests?"}, + ] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(len(result), 3) + self.assertEqual(result[0]["content"], "How are you?") + self.assertEqual(result[1]["role"], "assistant") + self.assertEqual(result[2]["content"], "Help me write tests?") + + def test_llm_list_content_with_type(self): + """LLM messages with typed content list should extract text and join.""" + from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + + messages = [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]} + ] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(result[0]["content"], "Hello world") + + @unittest.skipUnless(is_vision_available(), "Requires PIL") + def test_vlm_base64_image_creates_temp_file(self): + """Base64 image URLs should be decoded and saved to a temp file.""" + import os + + from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + + # Minimal valid 1x1 PNG as base64 + base64_url = ( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4" + "2mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + ) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url", "image_url": {"url": base64_url}}, + ], + } + ] + result = get_processor_inputs_from_messages(messages, Modality.VLM) + image_item = result[0]["content"][1] + self.assertEqual(image_item["type"], "image") + self.assertTrue(os.path.exists(image_item["url"])) # temp file was created + + def test_vlm_multi_turn(self): + """VLM multi-turn: string content should be wrapped in text type.""" + from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + + messages = [ + {"role": "user", "content": "Describe the image"}, + {"role": "assistant", "content": "It shows a cat"}, + {"role": "user", "content": "What color?"}, + ] + result = get_processor_inputs_from_messages(messages, Modality.VLM) + self.assertEqual(len(result), 3) + for msg in result: + self.assertIsInstance(msg["content"], list) + self.assertEqual(msg["content"][0]["type"], "text") + + +class TestGenerativeModelList(unittest.TestCase): + def test_lists_only_generative_models(self): + """Should list LLMs and VLMs but not non-generative models like BERT.""" + import tempfile + + from huggingface_hub import hf_hub_download + + from transformers.cli.serving.model_manager import ModelManager + + with tempfile.TemporaryDirectory() as cache_dir: + # Download config.json for a few models + hf_hub_download("Qwen/Qwen2.5-0.5B-Instruct", "config.json", cache_dir=cache_dir) + hf_hub_download("google-bert/bert-base-cased", "config.json", cache_dir=cache_dir) + + result = ModelManager.get_gen_models(cache_dir) + model_ids = {r["id"] for r in result} + + self.assertIn("Qwen/Qwen2.5-0.5B-Instruct", model_ids) + self.assertNotIn("google-bert/bert-base-cased", model_ids) + + +# --------------------------------------------------------------------------- +# 2. Unit tests — generation config mapping +# --------------------------------------------------------------------------- + + +@require_openai +class TestBuildGenerationConfig(unittest.TestCase): + def _make_handler(self): + from transformers.cli.serving.chat_completion import ChatCompletionHandler + + return ChatCompletionHandler(model_manager=MagicMock()) + + def test_max_tokens(self): + from transformers import GenerationConfig + + result = self._make_handler()._build_generation_config({"max_tokens": 7}, GenerationConfig()) + self.assertEqual(result.max_new_tokens, 7) + + def test_temperature_zero_disables_sampling(self): + from transformers import GenerationConfig + + result = self._make_handler()._build_generation_config({"temperature": 0.0}, GenerationConfig(do_sample=True)) + self.assertFalse(result.do_sample) + + def test_frequency_penalty(self): + from transformers import GenerationConfig + + result = self._make_handler()._build_generation_config({"frequency_penalty": 0.5}, GenerationConfig()) + self.assertAlmostEqual(result.repetition_penalty, 1.5) + + def test_logit_bias_tuple_keys(self): + from transformers import GenerationConfig + + result = self._make_handler()._build_generation_config({"logit_bias": {"42": 1.0}}, GenerationConfig()) + self.assertEqual(result.sequence_bias, {(42,): 1.0}) + + def test_stop_strings(self): + from transformers import GenerationConfig + + result = self._make_handler()._build_generation_config({"stop": [""]}, GenerationConfig()) + self.assertEqual(result.stop_strings, [""]) + + def test_generation_config_json_overrides(self): + from transformers import GenerationConfig + + custom = GenerationConfig(max_new_tokens=5, do_sample=False) + result = self._make_handler()._build_generation_config( + {"generation_config": custom.to_json_string()}, GenerationConfig(max_new_tokens=100) + ) + self.assertEqual(result.max_new_tokens, 5) + self.assertFalse(result.do_sample) + + def test_generation_config_json_no_defaults_applied(self): + """When generation_config JSON is passed, serving defaults should NOT be applied.""" + from transformers import GenerationConfig + + custom = GenerationConfig(max_new_tokens=10) + result = self._make_handler()._build_generation_config( + {"generation_config": custom.to_json_string()}, GenerationConfig() + ) + # Should keep 10, not bump to 1024 + self.assertEqual(result.max_new_tokens, 10) + + def test_default_bumps_short_max_new_tokens(self): + from transformers import GenerationConfig + + result = self._make_handler()._build_generation_config({}, GenerationConfig(max_new_tokens=20)) + self.assertEqual(result.max_new_tokens, 1024) + + def test_user_max_tokens_overrides_default(self): + """User's max_tokens should win over the serving default.""" + from transformers import GenerationConfig + + result = self._make_handler()._build_generation_config({"max_tokens": 50}, GenerationConfig(max_new_tokens=20)) + self.assertEqual(result.max_new_tokens, 50) + + +# --------------------------------------------------------------------------- +# 3. Unit tests — validation +# --------------------------------------------------------------------------- + + +@require_openai +class TestValidation(unittest.TestCase): + def _make_handler(self): + from transformers.cli.serving.chat_completion import ChatCompletionHandler + + return ChatCompletionHandler(model_manager=MagicMock()) + + def test_valid_request_passes(self): + from fastapi import HTTPException + + handler = self._make_handler() + # Should not raise + handler._validate_request({"model": "x", "messages": [{"role": "user", "content": "hi"}], "stream": True}) + + def test_unexpected_keys_rejected(self): + from fastapi import HTTPException + + handler = self._make_handler() + with self.assertRaises(HTTPException) as ctx: + handler._validate_request({"model": "x", "messages": [], "bogus_field": True}) + self.assertEqual(ctx.exception.status_code, 422) + self.assertIn("bogus_field", ctx.exception.detail) + + def test_unsupported_fields_rejected(self): + from fastapi import HTTPException + + handler = self._make_handler() + with self.assertRaises(HTTPException) as ctx: + handler._validate_request({"model": "x", "messages": [], "audio": {}}) + self.assertEqual(ctx.exception.status_code, 422) + self.assertIn("audio", ctx.exception.detail) + + +# --------------------------------------------------------------------------- +# 4. Unit tests — model manager +# --------------------------------------------------------------------------- + + +class TestModelManager(unittest.TestCase): + def test_process_model_name_adds_main(self): + from transformers.cli.serving.model_manager import ModelManager + + self.assertEqual(ModelManager.process_model_name("org/model"), "org/model@main") + + def test_process_model_name_preserves_revision(self): + from transformers.cli.serving.model_manager import ModelManager + + self.assertEqual(ModelManager.process_model_name("org/model@dev"), "org/model@dev") + + def test_quantization_config_4bit(self): + from transformers.cli.serving.model_manager import ModelManager + + mm = ModelManager(quantization="bnb-4bit") + cfg = mm.get_quantization_config() + self.assertTrue(cfg.load_in_4bit) + + def test_quantization_config_8bit(self): + from transformers.cli.serving.model_manager import ModelManager + + mm = ModelManager(quantization="bnb-8bit") + cfg = mm.get_quantization_config() + self.assertTrue(cfg.load_in_8bit) + + def test_quantization_config_none(self): + from transformers.cli.serving.model_manager import ModelManager + + mm = ModelManager() + self.assertIsNone(mm.get_quantization_config()) + + +class TestTimedModel(unittest.TestCase): + def test_delete_model(self): + from transformers.cli.serving.model_manager import TimedModel + + mock_model = MagicMock() + timed = TimedModel(mock_model, timeout_seconds=9999, processor=MagicMock()) + self.assertFalse(timed.is_deleted()) + timed.delete_model() + self.assertTrue(timed.is_deleted()) + + def test_timeout_zero_no_delete(self): + from transformers.cli.serving.model_manager import TimedModel + + mock_model = MagicMock() + timed = TimedModel(mock_model, timeout_seconds=0, processor=MagicMock()) + timed._timeout_reached() + self.assertFalse(timed.is_deleted()) + timed._timer.cancel() + + +# --------------------------------------------------------------------------- +# 5. Unit tests — SSE formatting +# --------------------------------------------------------------------------- + + +@require_openai +class TestChunkSSE(unittest.TestCase): + def _make_handler(self): + from transformers.cli.serving.chat_completion import ChatCompletionHandler + + return ChatCompletionHandler(model_manager=MagicMock()) + + def test_build_chunk_sse_content(self): + handler = self._make_handler() + sse = handler._build_chunk_sse(request_id="req1", content="hi", model="m") + self.assertTrue(sse.startswith("data: ")) + self.assertTrue(sse.endswith("\n\n")) + parsed = json.loads(sse[len("data: "):].strip()) + self.assertEqual(parsed["choices"][0]["delta"]["content"], "hi") + + def test_build_chunk_sse_role(self): + handler = self._make_handler() + sse = handler._build_chunk_sse(request_id="req1", role="assistant", model="m") + parsed = json.loads(sse[len("data: "):].strip()) + self.assertEqual(parsed["choices"][0]["delta"]["role"], "assistant") + self.assertNotIn("content", parsed["choices"][0]["delta"]) + + def test_build_chunk_sse_finish_reason(self): + handler = self._make_handler() + sse = handler._build_chunk_sse(request_id="req1", finish_reason="stop", model="m") + parsed = json.loads(sse[len("data: "):].strip()) + self.assertEqual(parsed["choices"][0]["finish_reason"], "stop") + + def test_chunk_to_sse_string_passthrough(self): + handler = self._make_handler() + result = handler._chunk_to_sse("data: already formatted\n\n") + self.assertEqual(result, "data: already formatted\n\n") + + def test_chunk_to_sse_wraps_plain_string(self): + handler = self._make_handler() + result = handler._chunk_to_sse("hello") + self.assertEqual(result, "data: hello\n\n") + + +# --------------------------------------------------------------------------- +# 6. App-level tests with ASGI test client (no real model) +# --------------------------------------------------------------------------- + + +@require_openai +class TestAppRoutes(unittest.TestCase): + @classmethod + def setUpClass(cls): + from transformers.cli.serving.chat_completion import ChatCompletionHandler + from transformers.cli.serving.model_manager import ModelManager + from transformers.cli.serving.server import build_server + + cls.model_manager = MagicMock(spec=ModelManager) + cls.model_manager.get_gen_models.return_value = [ + {"id": "test/model", "owned_by": "test", "object": "model", "created": 0} + ] + cls.chat_handler = MagicMock(spec=ChatCompletionHandler) + cls.app = build_server(cls.model_manager, cls.chat_handler) + + def _run(self, coro): + return asyncio.get_event_loop().run_until_complete(coro) + + def test_health(self): + from httpx import ASGITransport, AsyncClient + + async def _test(): + async with AsyncClient(transport=ASGITransport(app=self.app), base_url="http://test") as c: + resp = await c.get("/health") + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json(), {"status": "ok"}) + + self._run(_test()) + + def test_models_list(self): + from httpx import ASGITransport, AsyncClient + + async def _test(): + async with AsyncClient(transport=ASGITransport(app=self.app), base_url="http://test") as c: + resp = await c.get("/v1/models") + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["object"], "list") + self.assertEqual(len(data["data"]), 1) + + self._run(_test()) + + def test_request_id_generated(self): + from httpx import ASGITransport, AsyncClient + + async def _test(): + async with AsyncClient(transport=ASGITransport(app=self.app), base_url="http://test") as c: + resp = await c.get("/health") + self.assertIn("x-request-id", resp.headers) + self.assertEqual(len(resp.headers["x-request-id"]), 36) # UUID length + + self._run(_test()) + + def test_request_id_passthrough(self): + from httpx import ASGITransport, AsyncClient + + async def _test(): + async with AsyncClient(transport=ASGITransport(app=self.app), base_url="http://test") as c: + resp = await c.get("/health", headers={"x-request-id": "my-id"}) + self.assertEqual(resp.headers["x-request-id"], "my-id") + + self._run(_test()) + + +# --------------------------------------------------------------------------- +# 7. Integration tests (need GPU + model) +# Only test what requires a real model. Everything else is above with mocks. +# --------------------------------------------------------------------------- + + +@unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") +@require_openai +class TestChatCompletion(unittest.TestCase): + """Integration tests for /v1/chat/completions with a real model.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + PORT = 8877 + + @classmethod + def setUpClass(cls): + from transformers.cli.serve_refactored import Serve + + cls.serve = Serve(port=cls.PORT, non_blocking=True, log_level="warning") + import requests + + for _ in range(30): + try: + if requests.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: + break + except Exception: + pass + time.sleep(2) + + cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + def test_non_streaming(self): + resp = self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}] + ) + self.assertIsNotNone(resp.choices[0].message.content) + self.assertIn(resp.choices[0].finish_reason, ("stop", "length")) + + def test_streaming(self): + text = "" + for chunk in self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}], stream=True + ): + if chunk.choices[0].delta.content: + text += chunk.choices[0].delta.content + self.assertTrue(len(text) > 0) + + def test_stop_strings(self): + resp = self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Count to 10"}], stop=["5"] + ) + self.assertNotIn("6", resp.choices[0].message.content) + + def test_multi_turn(self): + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[ + {"role": "user", "content": "My name is Alice"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "What is my name?"}, + ], + ) + self.assertIn("Alice", resp.choices[0].message.content) + + def test_multiple_models_on_demand(self): + """Load two different models via separate requests — both should work.""" + model_a = "Qwen/Qwen2.5-0.5B-Instruct" + model_b = "HuggingFaceTB/SmolLM2-135M-Instruct" + prompt = [{"role": "user", "content": "Say hello"}] + + resp_a = self.client.chat.completions.create(model=model_a, messages=prompt) + self.assertIn(model_a, resp_a.model) + self.assertIsNotNone(resp_a.choices[0].message.content) + + resp_b = self.client.chat.completions.create(model=model_b, messages=prompt) + self.assertIn(model_b, resp_b.model) + self.assertIsNotNone(resp_b.choices[0].message.content) From fedad8e585f8058f867f436ae52ac6e1fc39048a Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 18 Mar 2026 13:46:07 +0000 Subject: [PATCH 10/64] update metrics and processor --- src/transformers/cli/chat.py | 27 ++++++-- src/transformers/cli/serve_refactored.py | 1 + .../cli/serving/chat_completion.py | 65 ++++++++++++++----- src/transformers/cli/serving/model_manager.py | 30 ++++++--- src/transformers/cli/serving/utils.py | 1 + 5 files changed, 93 insertions(+), 31 deletions(-) diff --git a/src/transformers/cli/chat.py b/src/transformers/cli/chat.py index 4b13a5afab51..630e621f5605 100644 --- a/src/transformers/cli/chat.py +++ b/src/transformers/cli/chat.py @@ -109,11 +109,17 @@ async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) self._console.print(f"[bold blue]<{self.model_id}>:") with Live(console=self._console, refresh_per_second=4) as live: text = "" + completion_tokens = 0 + start_time = time.time() finish_reason: str | None = None async for token in await stream: outputs = token.choices[0].delta.content finish_reason = getattr(token.choices[0], "finish_reason", finish_reason) + usage = getattr(token, "usage", None) + if usage is not None: + completion_tokens = getattr(usage, "completion_tokens", completion_tokens) + if not outputs: continue @@ -149,6 +155,10 @@ async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) # Update the Live console output live.update(markdown, refresh=True) + elapsed = time.time() - start_time + if elapsed > 0 and completion_tokens > 0: + tok_per_sec = completion_tokens / elapsed + self._console.print(f"[dim]{completion_tokens} tokens in {elapsed:.1f}s ({tok_per_sec:.1f} tok/s)[/dim]") self._console.print() return text, finish_reason @@ -235,6 +245,10 @@ def __init__( help="Path to a local generation config file or to a HuggingFace repo containing a `generation_config.json` file. Other generation settings passed as CLI arguments will be applied on top of this generation config." ), ] = None, + processor: Annotated[ + str | None, + typer.Option(help="Processor/tokenizer model ID. Needed for GGUF models whose repos don't include tokenizer files."), + ] = None, ) -> None: """Chat with a model from the command line.""" self.base_url = base_url @@ -253,6 +267,7 @@ def __init__( config.update(**parse_generate_flags(generate_flags)) self.config = config + self.processor = processor self.settings = {"base_url": base_url, "model_id": model_id, "config": self.config.to_dict()} # User settings @@ -456,14 +471,18 @@ async def _inner_run(self): else: chat.append({"role": "user", "content": user_input}) + extra_body = { + "generation_config": config.to_json_string(), + "model": self.model_id, + } + if self.processor: + extra_body["processor"] = self.processor + stream = client.chat_completion( chat, stream=True, model=self.model_id, - extra_body={ - "generation_config": config.to_json_string(), - "model": self.model_id, - }, + extra_body=extra_body, ) model_output, finish_reason = await interface.stream_output(stream) diff --git a/src/transformers/cli/serve_refactored.py b/src/transformers/cli/serve_refactored.py index 66eb082dfa84..41f92748882a 100644 --- a/src/transformers/cli/serve_refactored.py +++ b/src/transformers/cli/serve_refactored.py @@ -108,6 +108,7 @@ def __init__( chat_handler = ChatCompletionHandler( model_manager=model_manager, force_model=force_model, + force_processor=processor, ) app = build_server(model_manager, chat_handler, enable_cors=enable_cors) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 58a1273b01a8..c56fedc01511 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -39,6 +39,7 @@ from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk +from openai.types.completion_usage import CompletionUsage from tokenizers.decoders import DecodeStream from transformers import GenerationConfig, PreTrainedModel @@ -111,9 +112,10 @@ class ChatCompletionHandler: force_model: If set, override the model field in every request. """ - def __init__(self, model_manager: ModelManager, force_model: str | None = None): + def __init__(self, model_manager: ModelManager, force_model: str | None = None, force_processor: str | None = None): self.model_manager = model_manager self.force_model = force_model + self.force_processor = force_processor # ----- entry point ----- @@ -131,7 +133,11 @@ def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSO return JSONResponse({}, status_code=200) model_id = self.model_manager.process_model_name(body["model"]) - model, processor = self.model_manager.load_model_and_processor(model_id) + if self.force_processor is not None: + processor_id = self.force_processor + else: + processor_id = body.get("processor") + model, processor = self.model_manager.load_model_and_processor(model_id, processor_id=processor_id) modality = self.model_manager.get_model_modality(model, processor=processor) processor_inputs = get_processor_inputs_from_messages(messages, modality) @@ -145,7 +151,7 @@ def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSO tokenize=True, ).to(model.device) - gen_config = self._build_generation_config(body, model.generation_config) + gen_config = self._build_generation_config(body, model.generation_config, processor) if body.get("stream"): return self._streaming(request_id, model, processor, model_id, inputs, gen_config) @@ -177,6 +183,8 @@ def _run() -> None: Thread(target=_run, daemon=True).start() + input_len = inputs["input_ids"].shape[-1] + async def sse_gen() -> AsyncGenerator[str, None]: # First chunk: tell the client the assistant is about to speak (OpenAI protocol) yield self._build_chunk_sse(request_id, role="assistant", model=model_id) @@ -194,12 +202,17 @@ async def sse_gen() -> AsyncGenerator[str, None]: # Last chunk: tell the client why generation stopped hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens - if hit_max: - logger.warning( - f"Generation hit max_new_tokens={gen_config.max_new_tokens}. " - "Output may be truncated. Use `max_tokens` to increase the limit." - ) - yield self._build_chunk_sse(request_id, finish_reason="length" if hit_max else "stop", model=model_id) + usage = CompletionUsage( + prompt_tokens=input_len, + completion_tokens=streamer.total_tokens, + total_tokens=input_len + streamer.total_tokens, + ) + yield self._build_chunk_sse( + request_id, + finish_reason="length" if hit_max else "stop", + model=model_id, + usage=usage, + ) return StreamingResponse(sse_gen(), media_type="text/event-stream") @@ -222,13 +235,18 @@ def _non_streaming( content = processor.decode(generated_ids, skip_special_tokens=True) hit_max = gen_config.max_new_tokens is not None and len(generated_ids) >= gen_config.max_new_tokens - if hit_max: - logger.warning( - f"Generation hit max_new_tokens={gen_config.max_new_tokens}. " - "Output may be truncated. Use `max_tokens` to increase the limit." - ) + completion_tokens = len(generated_ids) + usage = CompletionUsage( + prompt_tokens=input_len, + completion_tokens=completion_tokens, + total_tokens=input_len + completion_tokens, + ) return JSONResponse( - self._build_completion(request_id, content, model_id, finish_reason="length" if hit_max else "stop"), + self._build_completion( + request_id, content, model_id, + finish_reason="length" if hit_max else "stop", + usage=usage, + ), media_type="application/json", ) @@ -256,7 +274,7 @@ def _apply_default_generation_config(generation_config: GenerationConfig) -> Non if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 1024: generation_config.max_new_tokens = 1024 - def _build_generation_config(self, body: dict, model_generation_config: GenerationConfig) -> GenerationConfig: + def _build_generation_config(self, body: dict, model_generation_config: GenerationConfig, processor=None) -> GenerationConfig: """Map Chat Completions API params to a GenerationConfig. If `body` contains a `generation_config` JSON string, it is used as baseline @@ -268,6 +286,14 @@ def _build_generation_config(self, body: dict, model_generation_config: Generati generation_config = copy.deepcopy(model_generation_config) self._apply_default_generation_config(generation_config) + # GGUF models may not have eos/pad token IDs set — sync from processor + if processor is not None: + tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor + if generation_config.eos_token_id is None and hasattr(tokenizer, "eos_token_id"): + generation_config.eos_token_id = tokenizer.eos_token_id + if generation_config.pad_token_id is None and hasattr(tokenizer, "pad_token_id"): + generation_config.pad_token_id = tokenizer.pad_token_id + if body.get("max_tokens") is not None: generation_config.max_new_tokens = int(body["max_tokens"]) if body.get("frequency_penalty") is not None: @@ -289,7 +315,9 @@ def _build_generation_config(self, body: dict, model_generation_config: Generati # ----- response builders ----- - def _build_completion(self, request_id: str, content: str, model_id: str, finish_reason: str) -> dict: + def _build_completion( + self, request_id: str, content: str, model_id: str, finish_reason: str, usage: CompletionUsage | None = None, + ) -> dict: """Build a non-streaming ChatCompletion response dict.""" result = ChatCompletion( id=request_id, @@ -303,6 +331,7 @@ def _build_completion(self, request_id: str, content: str, model_id: str, finish finish_reason=finish_reason, ) ], + usage=usage, ) return result.model_dump(exclude_none=True) @@ -314,6 +343,7 @@ def _build_chunk_sse( role: str | None = None, finish_reason: str | None = None, tool_calls: list | None = None, + usage: CompletionUsage | None = None, ) -> str: """Build a streaming ChatCompletionChunk and format it as an SSE event.""" chunk = ChatCompletionChunk( @@ -327,6 +357,7 @@ def _build_chunk_sse( finish_reason=finish_reason, ) ], + usage=usage, system_fingerprint="", object="chat.completion.chunk", ) diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 07cba3845c5d..5aa116c3ce94 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -125,12 +125,11 @@ def __init__( self.attn_implementation = attn_implementation self.quantization = quantization self.model_timeout = model_timeout - self.processor_id = processor_id self.loaded_models: dict[str, TimedModel] = {} if force_model is not None: - self.load_model_and_processor(self.process_model_name(force_model)) + self.load_model_and_processor(self.process_model_name(force_model), processor_id=processor_id) @staticmethod def process_model_name(model_id: str) -> str: @@ -151,16 +150,20 @@ def get_quantization_config(self) -> BitsAndBytesConfig | None: return BitsAndBytesConfig(load_in_8bit=True) return None - def _load_processor(self, model_id_and_revision: str) -> ProcessorMixin | PreTrainedTokenizerFast: + def _load_processor( + self, model_id_and_revision: str, processor_id: str | None = None + ) -> ProcessorMixin | PreTrainedTokenizerFast: """Load a processor, trying AutoProcessor first then AutoTokenizer. - If `processor_id` was set (e.g. for GGUF models), uses that instead of `model_id`. - Expects `model_id_and_revision` in `'model_id@revision'` format (from `process_model_name`). + Args: + model_id_and_revision: Model ID in ``'model_id@revision'`` format. + processor_id: Override processor/tokenizer ID (e.g. for GGUF models). + Falls back to ``model_id``. """ from transformers import AutoProcessor - if self.processor_id: - model_id, revision = self.processor_id, "main" + if processor_id: + model_id, revision = processor_id, "main" else: model_id, revision = model_id_and_revision.split("@", 1) try: @@ -190,6 +193,7 @@ def _load_model(self, model_id_and_revision: str) -> PreTrainedModel: revision=revision, n_gpu_layers=-1, flash_attn=flash_attn, + n_ctx=8192, ) dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype) @@ -207,11 +211,17 @@ def _load_model(self, model_id_and_revision: str) -> PreTrainedModel: return architecture.from_pretrained(model_id, **model_kwargs) def load_model_and_processor( - self, model_id_and_revision: str + self, model_id_and_revision: str, processor_id: str | None = None ) -> tuple[PreTrainedModel, ProcessorMixin | PreTrainedTokenizerFast]: - """Load a model (or return it from cache), resetting its inactivity timer.""" + """Load a model (or return it from cache), resetting its inactivity timer. + + Args: + model_id_and_revision: Model ID in ``'model_id@revision'`` format. + processor_id: Optional per-request processor override (takes precedence + over the instance-level ``self.processor_id``). + """ if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted(): - processor = self._load_processor(model_id_and_revision) + processor = self._load_processor(model_id_and_revision, processor_id=processor_id) model = self._load_model(model_id_and_revision) self.loaded_models[model_id_and_revision] = TimedModel( model, timeout_seconds=self.model_timeout, processor=processor diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index ed366766535e..652014eab9b5 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -34,6 +34,7 @@ class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): generation_config: str + processor: str # --------------------------------------------------------------------------- From 9b904b1c38a8f9209c526c63d2a702256c982fa7 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 18 Mar 2026 16:26:53 +0000 Subject: [PATCH 11/64] hardcode n_batch for now --- src/transformers/cli/serving/model_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 5aa116c3ce94..224d43c91086 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -194,6 +194,7 @@ def _load_model(self, model_id_and_revision: str) -> PreTrainedModel: n_gpu_layers=-1, flash_attn=flash_attn, n_ctx=8192, + n_batch=2048, ) dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype) From 0084b910869705ee2096993381647249cf8cec16 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Mar 2026 17:25:07 +0000 Subject: [PATCH 12/64] add response api + compile --- src/transformers/cli/serve_refactored.py | 24 +- .../cli/serving/chat_completion.py | 165 +------- src/transformers/cli/serving/response.py | 376 ++++++++++++++++++ src/transformers/cli/serving/server.py | 7 + src/transformers/cli/serving/utils.py | 197 +++++++++ 5 files changed, 624 insertions(+), 145 deletions(-) create mode 100644 src/transformers/cli/serving/response.py diff --git a/src/transformers/cli/serve_refactored.py b/src/transformers/cli/serve_refactored.py index 41f92748882a..c1abf8104cdd 100644 --- a/src/transformers/cli/serve_refactored.py +++ b/src/transformers/cli/serve_refactored.py @@ -69,6 +69,12 @@ def __init__( enable_cors: Annotated[bool, typer.Option(help="Enable permissive CORS.")] = False, log_level: Annotated[str, typer.Option(help="Logging level (e.g. 'info', 'warning').")] = "info", default_seed: Annotated[int | None, typer.Option(help="Default torch seed.")] = None, + compile: Annotated[ + bool, + typer.Option( + help="Enable static cache + torch.compile for faster decode (~2.6x). First request triggers compilation (~30s)." + ), + ] = False, non_blocking: Annotated[ bool, typer.Option(hidden=True, help="Run server in a background thread. Used by tests.") ] = False, @@ -80,7 +86,9 @@ def __init__( from .serving.chat_completion import ChatCompletionHandler from .serving.model_manager import ModelManager + from .serving.response import ResponseHandler from .serving.server import build_server + from .serving.utils import InferenceThread # Seed if default_seed is not None: @@ -105,13 +113,27 @@ def __init__( processor_id=processor, ) + # Single persistent thread for all generate() calls — required for + # torch.compile + CUDA graphs which use thread-local storage. + inference_thread = InferenceThread() + chat_handler = ChatCompletionHandler( model_manager=model_manager, force_model=force_model, force_processor=processor, + inference_thread=inference_thread, + compile=compile, + ) + + response_handler = ResponseHandler( + model_manager=model_manager, + force_model=force_model, + force_processor=processor, + inference_thread=inference_thread, + compile=compile, ) - app = build_server(model_manager, chat_handler, enable_cors=enable_cors) + app = build_server(model_manager, chat_handler, response_handler=response_handler, enable_cors=enable_cors) config = uvicorn.Config(app, host=host, port=port, log_level=log_level) self.server = uvicorn.Server(config) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index c56fedc01511..88def4c1d0f0 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -19,12 +19,8 @@ from __future__ import annotations -import asyncio -import copy -import json import time from collections.abc import AsyncGenerator -from threading import Thread from typing import TYPE_CHECKING import torch @@ -40,104 +36,41 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk from openai.types.completion_usage import CompletionUsage -from tokenizers.decoders import DecodeStream from transformers import GenerationConfig, PreTrainedModel -from transformers.generation.streamers import BaseStreamer from ...utils import logging -from .model_manager import ModelManager from .utils import ( UNUSED_CHAT_COMPLETION_FIELDS, + BaseHandler, TransformersCompletionCreateParamsStreaming, _StreamError, get_processor_inputs_from_messages, - set_torch_seed, ) logger = logging.get_logger(__name__) -class DirectStreamer(BaseStreamer): - """Streamer that decodes tokens incrementally and pushes text to an asyncio.Queue. - - Uses the Rust `DecodeStream.step()` for O(1) per-token decode, unlike - `TextIteratorStreamer` which re-decodes the full sequence each time. - - Args: - processor: A HuggingFace processor or tokenizer (must have a `._tokenizer` attribute). - loop: The asyncio event loop to push results to. - queue: The asyncio.Queue to push decoded text chunks to. - skip_special_tokens: Whether to skip special tokens during decoding. - """ - - def __init__( - self, - processor: ProcessorMixin | PreTrainedTokenizerFast, - loop: asyncio.AbstractEventLoop, - queue: asyncio.Queue, - skip_special_tokens: bool = True, - ): - self._tokenizer = processor._tokenizer # raw tokenizers.Tokenizer - self._loop = loop - self._queue = queue - self._decode_stream = DecodeStream([], skip_special_tokens) - self._first = True - self.total_tokens = 0 - - def put(self, value: torch.Tensor) -> None: - if self._first: - self._first = False - return # skip prompt tokens - if len(value.shape) > 1: - value = value[0] - for token_id in value.tolist(): - self.total_tokens += 1 - text = self._decode_stream.step(self._tokenizer, token_id) - if text is not None: - self._loop.call_soon_threadsafe(self._queue.put_nowait, text) - - def end(self) -> None: - self._loop.call_soon_threadsafe(self._queue.put_nowait, None) - - -class ChatCompletionHandler: +class ChatCompletionHandler(BaseHandler): """Handler for the `/v1/chat/completions` endpoint. Supports both streaming (SSE) and non-streaming (JSON) responses. - - Args: - model_manager: The model manager to load models from. - force_model: If set, override the model field in every request. """ - def __init__(self, model_manager: ModelManager, force_model: str | None = None, force_processor: str | None = None): - self.model_manager = model_manager - self.force_model = force_model - self.force_processor = force_processor - # ----- entry point ----- def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: """Validate the request, load the model, and dispatch to streaming or non-streaming.""" self._validate_request(body) - if self.force_model is not None: - body["model"] = self.force_model - messages = body["messages"] # HACK: tiny-agents sends requests ending with assistant message — skip if messages and messages[-1]["role"] == "assistant": return JSONResponse({}, status_code=200) - model_id = self.model_manager.process_model_name(body["model"]) - if self.force_processor is not None: - processor_id = self.force_processor - else: - processor_id = body.get("processor") - model, processor = self.model_manager.load_model_and_processor(model_id, processor_id=processor_id) + model_id, model, processor = self._resolve_model(body) modality = self.model_manager.get_model_modality(model, processor=processor) processor_inputs = get_processor_inputs_from_messages(messages, modality) @@ -168,39 +101,22 @@ def _streaming( inputs: dict[str, torch.Tensor], gen_config: GenerationConfig, ) -> StreamingResponse: - """Run generation in a background thread, stream tokens as SSE via DirectStreamer.""" - loop = asyncio.get_running_loop() - queue: asyncio.Queue = asyncio.Queue() - - streamer = DirectStreamer(processor, loop, queue, skip_special_tokens=True) - gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor} - - def _run() -> None: - try: - model.generate(**gen_kwargs) - except Exception as e: - loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e))) - - Thread(target=_run, daemon=True).start() - + """Stream tokens as SSE via DirectStreamer.""" + queue, streamer = self._start_streaming(model, processor, inputs, gen_config) input_len = inputs["input_ids"].shape[-1] async def sse_gen() -> AsyncGenerator[str, None]: - # First chunk: tell the client the assistant is about to speak (OpenAI protocol) yield self._build_chunk_sse(request_id, role="assistant", model=model_id) - # Stream tokens as they arrive from the background generate thread while True: text = await queue.get() if text is None: - break # generation done + break elif isinstance(text, _StreamError): yield f'data: {{"error": "{text.msg}"}}\n\n' return - yield self._build_chunk_sse(request_id, content=text, model=model_id) - # Last chunk: tell the client why generation stopped hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens usage = CompletionUsage( prompt_tokens=input_len, @@ -227,12 +143,8 @@ def _non_streaming( inputs: dict[str, torch.Tensor], gen_config: GenerationConfig, ) -> JSONResponse: - """Run generation synchronously and return a JSONResponse.""" - sequences = model.generate(**inputs, generation_config=gen_config, tokenizer=processor) - - input_len = inputs["input_ids"].shape[-1] - generated_ids = sequences[0, input_len:] - content = processor.decode(generated_ids, skip_special_tokens=True) + """Run generation and return a JSONResponse.""" + content, input_len, generated_ids = self._generate_non_streaming(model, processor, inputs, gen_config) hit_max = gen_config.max_new_tokens is not None and len(generated_ids) >= gen_config.max_new_tokens completion_tokens = len(generated_ids) @@ -243,7 +155,9 @@ def _non_streaming( ) return JSONResponse( self._build_completion( - request_id, content, model_id, + request_id, + content, + model_id, finish_reason="length" if hit_max else "stop", usage=usage, ), @@ -254,45 +168,18 @@ def _non_streaming( def _validate_request(self, body: dict) -> None: """Validate a chat completion request. Raises HTTPException if invalid.""" - logger.debug(f"Validating request: {body}") - input_keys = set(body.keys()) unexpected = input_keys - TransformersCompletionCreateParamsStreaming.__mutable_keys__ if unexpected: - logger.error(f"Unexpected keys in the request: {unexpected}") raise HTTPException(status_code=422, detail=f"Unexpected keys in the request: {unexpected}") - # TODO: add back strict Pydantic validation (input_validation flag) unused = input_keys & UNUSED_CHAT_COMPLETION_FIELDS if unused: - logger.error(f"Unsupported fields in the request: {unused}") raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") - @staticmethod - def _apply_default_generation_config(generation_config: GenerationConfig) -> None: - """Apply sensible serving defaults. Many models ship with too few max_new_tokens.""" - if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 1024: - generation_config.max_new_tokens = 1024 - - def _build_generation_config(self, body: dict, model_generation_config: GenerationConfig, processor=None) -> GenerationConfig: - """Map Chat Completions API params to a GenerationConfig. - - If `body` contains a `generation_config` JSON string, it is used as baseline - (overriding the model default). Other body params are applied on top. - """ - if body.get("generation_config") is not None: - generation_config = GenerationConfig(**json.loads(body["generation_config"])) - else: - generation_config = copy.deepcopy(model_generation_config) - self._apply_default_generation_config(generation_config) - - # GGUF models may not have eos/pad token IDs set — sync from processor - if processor is not None: - tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor - if generation_config.eos_token_id is None and hasattr(tokenizer, "eos_token_id"): - generation_config.eos_token_id = tokenizer.eos_token_id - if generation_config.pad_token_id is None and hasattr(tokenizer, "pad_token_id"): - generation_config.pad_token_id = tokenizer.pad_token_id + def _build_generation_config(self, body: dict, model_generation_config, processor=None): + """Chat Completions params on top of base config.""" + generation_config = super()._build_generation_config(body, model_generation_config, processor) if body.get("max_tokens") is not None: generation_config.max_new_tokens = int(body["max_tokens"]) @@ -302,21 +189,18 @@ def _build_generation_config(self, body: dict, model_generation_config: Generati generation_config.sequence_bias = {(int(k),): v for k, v in body["logit_bias"].items()} if body.get("stop") is not None: generation_config.stop_strings = body["stop"] - if body.get("temperature") is not None: - generation_config.temperature = float(body["temperature"]) - if float(body["temperature"]) == 0.0: - generation_config.do_sample = False - if body.get("top_p") is not None: - generation_config.top_p = float(body["top_p"]) - if body.get("seed") is not None: - set_torch_seed(body["seed"]) return generation_config # ----- response builders ----- def _build_completion( - self, request_id: str, content: str, model_id: str, finish_reason: str, usage: CompletionUsage | None = None, + self, + request_id: str, + content: str, + model_id: str, + finish_reason: str, + usage: CompletionUsage | None = None, ) -> dict: """Build a non-streaming ChatCompletion response dict.""" result = ChatCompletion( @@ -361,11 +245,4 @@ def _build_chunk_sse( system_fingerprint="", object="chat.completion.chunk", ) - return self._chunk_to_sse(chunk) - - @staticmethod - def _chunk_to_sse(chunk) -> str: - """Format a pydantic model or string as an SSE event.""" - if isinstance(chunk, str): - return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n" - return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + return self.chunk_to_sse(chunk) diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py new file mode 100644 index 000000000000..5bae80d08815 --- /dev/null +++ b/src/transformers/cli/serving/response.py @@ -0,0 +1,376 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Handler for the /v1/responses endpoint (OpenAI Responses API). + +Supports streaming (SSE) and non-streaming (JSON) responses. +""" + +from __future__ import annotations + +import time +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING + +from fastapi import HTTPException +from fastapi.responses import JSONResponse, StreamingResponse +from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseError, + ResponseErrorEvent, + ResponseFailedEvent, + ResponseInProgressEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, + ResponseOutputText, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, +) +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage + +from ...utils import logging +from .utils import BaseHandler, _StreamError + + +if TYPE_CHECKING: + from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin + + +logger = logging.get_logger(__name__) + +UNUSED_RESPONSE_FIELDS = { + "background", + "include", + "max_tool_calls", + "previous_response_id", + "prompt", + "reasoning", + "service_tier", + "store", + "text", + "tool_choice", + "top_logprobs", + "truncation", + "user", +} + + +class ResponseHandler(BaseHandler): + """Handler for the ``/v1/responses`` endpoint.""" + + # ----- entry point ----- + + def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: + """Validate, load model, dispatch to streaming or non-streaming.""" + self._validate_request(body) + + model_id, model, processor = self._resolve_model(body) + + messages = self._input_to_messages(body) + inputs = processor.apply_chat_template( + messages, add_generation_prompt=True, return_tensors="pt", return_dict=True + ).to(model.device) + + gen_config = self._build_generation_config(body, model.generation_config, processor) + + if body.get("stream", True): + return self._streaming(request_id, model, processor, model_id, body, inputs, gen_config) + return self._non_streaming(request_id, model, processor, model_id, body, inputs, gen_config) + + # ----- input conversion ----- + + @staticmethod + def _input_to_messages(body: dict) -> list[dict]: + """Convert the Responses API ``input`` field to a list of chat messages.""" + inp = body["input"] + instructions = body.get("instructions") + + if isinstance(inp, str): + messages = [{"role": "system", "content": instructions}] if instructions else [] + messages.append({"role": "user", "content": inp}) + elif isinstance(inp, list): + if instructions: + if inp[0]["role"] != "system": + messages = [{"role": "system", "content": instructions}, *inp] + else: + messages = list(inp) + messages[0]["content"] = instructions + else: + messages = inp + elif isinstance(inp, dict): + messages = [{"role": "system", "content": instructions}] if instructions else [] + messages.append(inp) + else: + raise HTTPException(status_code=422, detail="'input' must be a string, list, or dict") + + return messages + + # ----- streaming ----- + + def _streaming( + self, + request_id: str, + model: PreTrainedModel, + processor: ProcessorMixin | PreTrainedTokenizerFast, + model_id: str, + body: dict, + inputs: dict, + gen_config: GenerationConfig, + ) -> StreamingResponse: + """Generate a streaming Responses API reply (SSE) using DirectStreamer.""" + queue, streamer = self._start_streaming(model, processor, inputs, gen_config) + input_len = inputs["input_ids"].shape[-1] + + seq = 0 + output_index = 0 + content_index = 0 + created_at = time.time() + resp_id = f"resp_{request_id}" + msg_id = f"msg_{request_id}" + + response_base = { + "id": resp_id, + "created_at": created_at, + "model": model_id, + "object": "response", + "tools": [], + "parallel_tool_calls": body.get("parallel_tool_calls", False), + "tool_choice": "auto", + } + + async def event_stream() -> AsyncGenerator[str, None]: + nonlocal seq + + # 1. Created + yield self.chunk_to_sse( + ResponseCreatedEvent( + type="response.created", + sequence_number=seq, + response=Response(**response_base, status="queued", output=[]), + ) + ) + seq += 1 + + # 2. In progress + yield self.chunk_to_sse( + ResponseInProgressEvent( + type="response.in_progress", + sequence_number=seq, + response=Response(**response_base, status="in_progress", output=[]), + ) + ) + seq += 1 + + # 3. Output item added + yield self.chunk_to_sse( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=seq, + output_index=output_index, + item=ResponseOutputMessage( + id=msg_id, + type="message", + status="in_progress", + role="assistant", + content=[], + ), + ) + ) + seq += 1 + + # 4. Content part added + yield self.chunk_to_sse( + ResponseContentPartAddedEvent( + type="response.content_part.added", + item_id=msg_id, + sequence_number=seq, + output_index=output_index, + content_index=content_index, + part=ResponseOutputText(type="output_text", text="", annotations=[]), + ) + ) + seq += 1 + + # 5. Text deltas from DirectStreamer queue + full_text = "" + while True: + text = await queue.get() + if text is None: + break + if isinstance(text, _StreamError): + logger.error(f"Exception in response generation: {text.msg}") + yield self.chunk_to_sse(ResponseErrorEvent(type="error", sequence_number=seq, message=text.msg)) + seq += 1 + yield self.chunk_to_sse( + ResponseFailedEvent( + type="response.failed", + sequence_number=seq, + response=Response( + **response_base, + status="failed", + output=[], + error=ResponseError(code="server_error", message=text.msg), + ), + ) + ) + return + + full_text += text + yield self.chunk_to_sse( + ResponseTextDeltaEvent( + type="response.output_text.delta", + item_id=msg_id, + sequence_number=seq, + output_index=output_index, + content_index=content_index, + delta=text, + logprobs=[], + ) + ) + seq += 1 + + # 6. Text done + yield self.chunk_to_sse( + ResponseTextDoneEvent( + type="response.output_text.done", + item_id=msg_id, + sequence_number=seq, + output_index=output_index, + content_index=0, + text=full_text, + logprobs=[], + ) + ) + seq += 1 + + # 7. Content part done + output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[]) + yield self.chunk_to_sse( + ResponseContentPartDoneEvent( + type="response.content_part.done", + item_id=msg_id, + sequence_number=seq, + output_index=output_index, + content_index=content_index, + part=output_text_part, + ) + ) + seq += 1 + + # 8. Output item done + output_item = ResponseOutputMessage( + id=msg_id, + type="message", + status="completed", + role="assistant", + content=[output_text_part], + annotations=[], + ) + yield self.chunk_to_sse( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=seq, + output_index=output_index, + item=output_item, + ) + ) + seq += 1 + + # 9. Completed + usage = _make_usage(input_len, streamer.total_tokens) + yield self.chunk_to_sse( + ResponseCompletedEvent( + type="response.completed", + sequence_number=seq, + response=Response(**response_base, status="completed", output=[output_item], usage=usage), + ) + ) + seq += 1 + + return StreamingResponse(event_stream(), media_type="text/event-stream") + + # ----- non-streaming ----- + + def _non_streaming( + self, + request_id: str, + model: PreTrainedModel, + processor: ProcessorMixin | PreTrainedTokenizerFast, + model_id: str, + body: dict, + inputs: dict, + gen_config: GenerationConfig, + ) -> JSONResponse: + """Generate a non-streaming Responses API reply (single JSON).""" + full_text, input_len, generated_ids = self._generate_non_streaming(model, processor, inputs, gen_config) + + created_at = time.time() + resp_id = f"resp_{request_id}" + msg_id = f"msg_{request_id}" + output_tokens = len(generated_ids) + + output_item = ResponseOutputMessage( + id=msg_id, + type="message", + status="completed", + role="assistant", + content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])], + annotations=[], + ) + usage = _make_usage(input_len, output_tokens) + response = Response( + id=resp_id, + created_at=created_at, + status="completed", + model=model_id, + output=[output_item], + object="response", + tools=[], + parallel_tool_calls=body.get("parallel_tool_calls", False), + tool_choice="auto", + usage=usage, + ) + return JSONResponse(response.model_dump(exclude_none=True)) + + # ----- helpers ----- + + def _validate_request(self, body: dict) -> None: + """Validate a Responses API request.""" + unused = set(body.keys()) & UNUSED_RESPONSE_FIELDS + if unused: + raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") + + def _build_generation_config(self, body: dict, model_generation_config, processor=None): + """Responses API params on top of base config.""" + generation_config = super()._build_generation_config(body, model_generation_config, processor) + + if body.get("max_output_tokens") is not None: + generation_config.max_new_tokens = int(body["max_output_tokens"]) + + return generation_config + + +def _make_usage(input_tokens: int, output_tokens: int) -> ResponseUsage: + return ResponseUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ) diff --git a/src/transformers/cli/serving/server.py b/src/transformers/cli/serving/server.py index e88366cb0f9c..4730509e0dc4 100644 --- a/src/transformers/cli/serving/server.py +++ b/src/transformers/cli/serving/server.py @@ -27,6 +27,7 @@ from ...utils import logging from .chat_completion import ChatCompletionHandler from .model_manager import ModelManager +from .response import ResponseHandler from .utils import X_REQUEST_ID @@ -36,6 +37,7 @@ def build_server( model_manager: ModelManager, chat_handler: ChatCompletionHandler, + response_handler: ResponseHandler, enable_cors: bool = False, ) -> FastAPI: """Build and return a configured FastAPI application. @@ -43,6 +45,7 @@ def build_server( Args: model_manager: Handles model loading, caching, and cleanup. chat_handler: Handles `/v1/chat/completions` requests. + response_handler: Handles `/v1/responses` requests. enable_cors: If `True`, adds permissive CORS middleware (allow all origins). Returns: @@ -83,6 +86,10 @@ async def request_id_middleware(request: Request, call_next): async def chat_completions(request: Request, body: dict): return chat_handler.handle_request(body, request.state.request_id) + @app.post("/v1/responses") + async def responses(request: Request, body: dict): + return response_handler.handle_request(body, request.state.request_id) + @app.get("/v1/models") @app.options("/v1/models") def list_models(): diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index 652014eab9b5..c97c0ab2841d 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -17,11 +17,17 @@ from __future__ import annotations +import asyncio import base64 +import copy import enum +import json import re import tempfile +import threading +from concurrent.futures import Future from io import BytesIO +from queue import Queue from transformers.utils.import_utils import is_openai_available, is_vision_available @@ -89,6 +95,50 @@ def __init__(self, msg: str): self.msg = msg +# --------------------------------------------------------------------------- +# Streaming +# --------------------------------------------------------------------------- + + +class DirectStreamer: + """Streamer that decodes tokens incrementally and pushes text to an asyncio.Queue. + + Uses the Rust ``DecodeStream.step()`` for O(1) per-token decode, unlike + ``TextIteratorStreamer`` which re-decodes the full sequence each time. + + Args: + processor: A HuggingFace processor or tokenizer (must have a ``._tokenizer`` attribute). + loop: The asyncio event loop to push results to. + queue: The asyncio.Queue to push decoded text chunks to. + skip_special_tokens: Whether to skip special tokens during decoding. + """ + + def __init__(self, processor, loop, queue, skip_special_tokens: bool = True): + from tokenizers.decoders import DecodeStream + + self._tokenizer = processor._tokenizer # raw tokenizers.Tokenizer + self._loop = loop + self._queue = queue + self._decode_stream = DecodeStream([], skip_special_tokens) + self._first = True + self.total_tokens = 0 + + def put(self, value) -> None: + if self._first: + self._first = False + return # skip prompt tokens + if len(value.shape) > 1: + value = value[0] + for token_id in value.tolist(): + self.total_tokens += 1 + text = self._decode_stream.step(self._tokenizer, token_id) + if text is not None: + self._loop.call_soon_threadsafe(self._queue.put_nowait, text) + + def end(self) -> None: + self._loop.call_soon_threadsafe(self._queue.put_nowait, None) + + # --------------------------------------------------------------------------- # Torch helpers # --------------------------------------------------------------------------- @@ -107,6 +157,153 @@ def reset_torch_cache() -> None: torch.cuda.empty_cache() +class InferenceThread: + """A single persistent thread that runs all model.generate() calls. + + torch.compile with ``mode="reduce-overhead"`` uses CUDA graphs, which store + state in thread-local storage (TLS). If generate() is called from different + threads (e.g. a new Thread per streaming request), the CUDA graph state is + lost or corrupted — causing silent wrong output or crashes. + + This class ensures all inference runs on the **same thread**, matching what + vLLM does with its engine loop. Both streaming and non-streaming requests + submit work here. + """ + + def __init__(self): + self._queue: Queue = Queue() + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def _run(self): + while True: + fn, args, kwargs, future = self._queue.get() + try: + result = fn(*args, **kwargs) + future.set_result(result) + except Exception as e: + future.set_exception(e) + + def submit(self, fn, *args, **kwargs) -> Future: + """Submit a callable to run on the inference thread. Returns a Future.""" + future: Future = Future() + self._queue.put((fn, args, kwargs, future)) + return future + + +# --------------------------------------------------------------------------- +# Base handler +# --------------------------------------------------------------------------- + + +class BaseHandler: + """Shared logic for chat completion and responses handlers. + + Subclasses implement ``_streaming`` and ``_non_streaming`` for their + specific SSE / JSON formats, plus ``_validate_request``. + """ + + def __init__( + self, + model_manager, + force_model=None, + force_processor=None, + inference_thread=None, + compile=False, + ): + self.model_manager = model_manager + self.force_model = force_model + self.force_processor = force_processor + self._inference_thread = inference_thread or InferenceThread() + self._compile = compile + + @staticmethod + def chunk_to_sse(chunk) -> str: + """Format a pydantic model or string as an SSE ``data:`` line.""" + if isinstance(chunk, str): + return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n" + return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + + def _resolve_model(self, body: dict): + """Apply force_model, load model + processor. Returns (model_id, model, processor).""" + if self.force_model is not None: + body["model"] = self.force_model + + model_id = self.model_manager.process_model_name(body["model"]) + processor_id = self.force_processor or body.get("processor") + model, processor = self.model_manager.load_model_and_processor(model_id, processor_id=processor_id) + + return model_id, model, processor + + def _build_generation_config(self, body: dict, model_generation_config, processor=None): + """Build a GenerationConfig from shared params (temperature, top_p, seed, generation_config JSON). + + Subclasses should call ``super()._build_generation_config(...)`` then apply + endpoint-specific params (``max_tokens``, ``max_output_tokens``, etc.). + """ + from transformers import GenerationConfig + + if body.get("generation_config") is not None: + generation_config = GenerationConfig(**json.loads(body["generation_config"])) + else: + generation_config = copy.deepcopy(model_generation_config) + if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 1024: + generation_config.max_new_tokens = 1024 + + # GGUF models may not have eos/pad token IDs set — sync from processor + if processor is not None: + tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor + if generation_config.eos_token_id is None and hasattr(tokenizer, "eos_token_id"): + generation_config.eos_token_id = tokenizer.eos_token_id + if generation_config.pad_token_id is None and hasattr(tokenizer, "pad_token_id"): + generation_config.pad_token_id = tokenizer.pad_token_id + + if body.get("temperature") is not None: + generation_config.temperature = float(body["temperature"]) + if float(body["temperature"]) == 0.0: + generation_config.do_sample = False + if body.get("top_p") is not None: + generation_config.top_p = float(body["top_p"]) + if body.get("seed") is not None: + set_torch_seed(body["seed"]) + + # --compile flag: use static cache + torch.compile for faster decode + if self._compile and generation_config.cache_implementation is None: + generation_config.cache_implementation = "static" + + return generation_config + + def _start_streaming(self, model, processor, inputs, gen_config): + """Set up DirectStreamer + queue, submit generate to inference thread. + + Returns ``(queue, streamer)`` — caller reads from queue to build SSE events. + """ + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + streamer = DirectStreamer(processor, loop, queue, skip_special_tokens=True) + gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor} + + def _run(): + try: + model.generate(**gen_kwargs) + except Exception as e: + loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e))) + + self._inference_thread.submit(_run) + return queue, streamer + + def _generate_non_streaming(self, model, processor, inputs, gen_config): + """Run generate on the inference thread, decode output. Returns ``(text, input_len, generated_ids)``.""" + future = self._inference_thread.submit( + model.generate, **inputs, generation_config=gen_config, tokenizer=processor + ) + sequences = future.result() + input_len = inputs["input_ids"].shape[-1] + generated_ids = sequences[0, input_len:] + text = processor.decode(generated_ids, skip_special_tokens=True) + return text, input_len, generated_ids + + # --------------------------------------------------------------------------- # Message preprocessing: OpenAI messages → processor-compatible format # --------------------------------------------------------------------------- From 1d5d1cb25fab900a5f66d9d76d3a210087c8195e Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Mar 2026 17:25:44 +0000 Subject: [PATCH 13/64] more tests --- tests/cli/test_serve_refactored.py | 310 ++++++++++++++++++++++++++++- 1 file changed, 299 insertions(+), 11 deletions(-) diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py index 972a66ec1a37..771d936d2a77 100644 --- a/tests/cli/test_serve_refactored.py +++ b/tests/cli/test_serve_refactored.py @@ -25,7 +25,7 @@ import unittest from unittest.mock import MagicMock -from transformers.testing_utils import require_openai, slow +from transformers.testing_utils import require_openai from transformers.utils.import_utils import is_openai_available, is_vision_available @@ -284,8 +284,6 @@ def _make_handler(self): return ChatCompletionHandler(model_manager=MagicMock()) def test_valid_request_passes(self): - from fastapi import HTTPException - handler = self._make_handler() # Should not raise handler._validate_request({"model": "x", "messages": [{"role": "user", "content": "hi"}], "stream": True}) @@ -383,30 +381,32 @@ def test_build_chunk_sse_content(self): sse = handler._build_chunk_sse(request_id="req1", content="hi", model="m") self.assertTrue(sse.startswith("data: ")) self.assertTrue(sse.endswith("\n\n")) - parsed = json.loads(sse[len("data: "):].strip()) + parsed = json.loads(sse[len("data: ") :].strip()) self.assertEqual(parsed["choices"][0]["delta"]["content"], "hi") def test_build_chunk_sse_role(self): handler = self._make_handler() sse = handler._build_chunk_sse(request_id="req1", role="assistant", model="m") - parsed = json.loads(sse[len("data: "):].strip()) + parsed = json.loads(sse[len("data: ") :].strip()) self.assertEqual(parsed["choices"][0]["delta"]["role"], "assistant") self.assertNotIn("content", parsed["choices"][0]["delta"]) def test_build_chunk_sse_finish_reason(self): handler = self._make_handler() sse = handler._build_chunk_sse(request_id="req1", finish_reason="stop", model="m") - parsed = json.loads(sse[len("data: "):].strip()) + parsed = json.loads(sse[len("data: ") :].strip()) self.assertEqual(parsed["choices"][0]["finish_reason"], "stop") def test_chunk_to_sse_string_passthrough(self): - handler = self._make_handler() - result = handler._chunk_to_sse("data: already formatted\n\n") + from transformers.cli.serving.utils import BaseHandler + + result = BaseHandler.chunk_to_sse("data: already formatted\n\n") self.assertEqual(result, "data: already formatted\n\n") def test_chunk_to_sse_wraps_plain_string(self): - handler = self._make_handler() - result = handler._chunk_to_sse("hello") + from transformers.cli.serving.utils import BaseHandler + + result = BaseHandler.chunk_to_sse("hello") self.assertEqual(result, "data: hello\n\n") @@ -421,6 +421,7 @@ class TestAppRoutes(unittest.TestCase): def setUpClass(cls): from transformers.cli.serving.chat_completion import ChatCompletionHandler from transformers.cli.serving.model_manager import ModelManager + from transformers.cli.serving.response import ResponseHandler from transformers.cli.serving.server import build_server cls.model_manager = MagicMock(spec=ModelManager) @@ -428,7 +429,8 @@ def setUpClass(cls): {"id": "test/model", "owned_by": "test", "object": "model", "created": 0} ] cls.chat_handler = MagicMock(spec=ChatCompletionHandler) - cls.app = build_server(cls.model_manager, cls.chat_handler) + cls.response_handler = MagicMock(spec=ResponseHandler) + cls.app = build_server(cls.model_manager, cls.chat_handler, cls.response_handler) def _run(self, coro): return asyncio.get_event_loop().run_until_complete(coro) @@ -560,3 +562,289 @@ def test_multiple_models_on_demand(self): resp_b = self.client.chat.completions.create(model=model_b, messages=prompt) self.assertIn(model_b, resp_b.model) self.assertIsNotNone(resp_b.choices[0].message.content) + + def test_non_streaming_usage(self): + resp = self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}] + ) + self.assertIsNotNone(resp.usage) + self.assertGreater(resp.usage.prompt_tokens, 0) + self.assertGreater(resp.usage.completion_tokens, 0) + self.assertEqual(resp.usage.total_tokens, resp.usage.prompt_tokens + resp.usage.completion_tokens) + + def test_streaming_usage(self): + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}], stream=True, + ) + ) + # Last chunk should have usage + last = chunks[-1] + self.assertIsNotNone(last.usage) + self.assertGreater(last.usage.prompt_tokens, 0) + self.assertGreater(last.usage.completion_tokens, 0) + self.assertEqual(last.usage.total_tokens, last.usage.prompt_tokens + last.usage.completion_tokens) + + +# --------------------------------------------------------------------------- +# 8. Unit tests — Response handler +# --------------------------------------------------------------------------- + + +@require_openai +class TestResponseInputConversion(unittest.TestCase): + def _make_handler(self): + from transformers.cli.serving.response import ResponseHandler + + return ResponseHandler(model_manager=MagicMock()) + + def test_string_input(self): + handler = self._make_handler() + msgs = handler._input_to_messages({"input": "Hello"}) + self.assertEqual(msgs, [{"role": "user", "content": "Hello"}]) + + def test_string_input_with_instructions(self): + handler = self._make_handler() + msgs = handler._input_to_messages({"input": "Hello", "instructions": "Be brief"}) + self.assertEqual(len(msgs), 2) + self.assertEqual(msgs[0], {"role": "system", "content": "Be brief"}) + self.assertEqual(msgs[1], {"role": "user", "content": "Hello"}) + + def test_list_input(self): + handler = self._make_handler() + msgs = handler._input_to_messages( + {"input": [{"role": "user", "content": "A"}, {"role": "assistant", "content": "B"}]} + ) + self.assertEqual(len(msgs), 2) + self.assertEqual(msgs[0]["content"], "A") + + def test_list_input_with_instructions_prepends_system(self): + handler = self._make_handler() + msgs = handler._input_to_messages({"input": [{"role": "user", "content": "Hi"}], "instructions": "Be helpful"}) + self.assertEqual(len(msgs), 2) + self.assertEqual(msgs[0]["role"], "system") + self.assertEqual(msgs[0]["content"], "Be helpful") + + def test_list_input_with_instructions_replaces_existing_system(self): + handler = self._make_handler() + msgs = handler._input_to_messages( + {"input": [{"role": "system", "content": "Old"}, {"role": "user", "content": "Hi"}], "instructions": "New"} + ) + self.assertEqual(len(msgs), 2) + self.assertEqual(msgs[0]["content"], "New") + + def test_dict_input(self): + handler = self._make_handler() + msgs = handler._input_to_messages({"input": {"role": "user", "content": "Test"}}) + self.assertEqual(msgs, [{"role": "user", "content": "Test"}]) + + +@require_openai +class TestResponseValidation(unittest.TestCase): + def _make_handler(self): + from transformers.cli.serving.response import ResponseHandler + + return ResponseHandler(model_manager=MagicMock()) + + def test_unsupported_fields_rejected(self): + from fastapi import HTTPException + + handler = self._make_handler() + with self.assertRaises(HTTPException) as ctx: + handler._validate_request({"model": "x", "input": "hi", "previous_response_id": "abc"}) + self.assertEqual(ctx.exception.status_code, 422) + + def test_valid_request_passes(self): + handler = self._make_handler() + # Should not raise + handler._validate_request({"model": "x", "input": "hi"}) + + +@require_openai +class TestResponseGenerationConfig(unittest.TestCase): + def _make_handler(self): + from transformers.cli.serving.response import ResponseHandler + + return ResponseHandler(model_manager=MagicMock()) + + def test_max_output_tokens(self): + from transformers import GenerationConfig + + result = self._make_handler()._build_generation_config({"max_output_tokens": 42}, GenerationConfig()) + self.assertEqual(result.max_new_tokens, 42) + + def test_default_bumps_short_max_new_tokens(self): + from transformers import GenerationConfig + + result = self._make_handler()._build_generation_config({}, GenerationConfig(max_new_tokens=20)) + self.assertEqual(result.max_new_tokens, 1024) + + +@require_openai +class TestResponseUsage(unittest.TestCase): + def test_make_usage(self): + from transformers.cli.serving.response import _make_usage + + usage = _make_usage(input_tokens=100, output_tokens=50) + self.assertEqual(usage.input_tokens, 100) + self.assertEqual(usage.output_tokens, 50) + self.assertEqual(usage.total_tokens, 150) + self.assertEqual(usage.input_tokens_details.cached_tokens, 0) + self.assertEqual(usage.output_tokens_details.reasoning_tokens, 0) + + def test_usage_in_completed_response(self): + """Usage should serialize correctly inside a Response.""" + from openai.types.responses import Response + + from transformers.cli.serving.response import _make_usage + + usage = _make_usage(10, 5) + response = Response( + id="resp_test", + created_at=0, + status="completed", + model="test", + output=[], + object="response", + tools=[], + parallel_tool_calls=False, + tool_choice="auto", + usage=usage, + ) + dumped = response.model_dump(exclude_none=True) + self.assertEqual(dumped["usage"]["input_tokens"], 10) + self.assertEqual(dumped["usage"]["output_tokens"], 5) + self.assertEqual(dumped["usage"]["total_tokens"], 15) + + +@require_openai +class TestResponseSSEFormat(unittest.TestCase): + def test_sse_format(self): + from openai.types.responses import Response, ResponseCreatedEvent + + from transformers.cli.serving.utils import BaseHandler + + event = ResponseCreatedEvent( + type="response.created", + sequence_number=0, + response=Response( + id="resp_test", + created_at=0, + status="queued", + model="test", + text={"format": {"type": "text"}}, + object="response", + tools=[], + output=[], + parallel_tool_calls=False, + tool_choice="auto", + ), + ) + result = BaseHandler.chunk_to_sse(event) + self.assertTrue(result.startswith("data: ")) + self.assertTrue(result.endswith("\n\n")) + parsed = json.loads(result[len("data: ") :].strip()) + self.assertEqual(parsed["type"], "response.created") + self.assertEqual(parsed["response"]["status"], "queued") + + + +# --------------------------------------------------------------------------- +# 9. Integration tests — Responses API (need GPU + model) +# --------------------------------------------------------------------------- + + +@unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") +@require_openai +class TestResponsesIntegration(unittest.TestCase): + """Integration tests for /v1/responses with a real model.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + PORT = 8878 + + @classmethod + def setUpClass(cls): + from transformers.cli.serve_refactored import Serve + + cls.serve = Serve(port=cls.PORT, non_blocking=True, log_level="warning") + import requests + + for _ in range(30): + try: + if requests.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: + break + except Exception: + pass + time.sleep(2) + + cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + def test_streaming(self): + events = list( + self.client.responses.create( + model=self.MODEL, + input="Say hello", + stream=True, + max_output_tokens=1, + ) + ) + # At least 8 events: created, in_progress, output_item_added, content_part_added, + # delta(s), text_done, content_part_done, output_item_done, completed + self.assertGreaterEqual(len(events), 8) + + # Start markers (fixed order) + self.assertEqual(events[0].type, "response.created") + self.assertEqual(events[1].type, "response.in_progress") + self.assertEqual(events[2].type, "response.output_item.added") + self.assertEqual(events[3].type, "response.content_part.added") + + # At least one delta + self.assertTrue(any(e.type == "response.output_text.delta" for e in events[4:-4])) + + # Closing markers (fixed order from the end) + self.assertEqual(events[-4].type, "response.output_text.done") + self.assertEqual(events[-3].type, "response.content_part.done") + self.assertEqual(events[-2].type, "response.output_item.done") + self.assertEqual(events[-1].type, "response.completed") + + def test_non_streaming(self): + resp = self.client.responses.create( + model=self.MODEL, + input="Say hello", + stream=False, + ) + self.assertEqual(resp.status, "completed") + self.assertTrue(len(resp.output) > 0) + self.assertTrue(len(resp.output[0].content[0].text) > 0) + + def test_non_streaming_usage(self): + resp = self.client.responses.create( + model=self.MODEL, + input="Say hello", + stream=False, + ) + self.assertIsNotNone(resp.usage) + self.assertGreater(resp.usage.input_tokens, 0) + self.assertGreater(resp.usage.output_tokens, 0) + self.assertEqual(resp.usage.total_tokens, resp.usage.input_tokens + resp.usage.output_tokens) + + def test_streaming_usage(self): + events = list( + self.client.responses.create( + model=self.MODEL, + input="Say hello", + stream=True, + max_output_tokens=5, + ) + ) + completed = events[-1] + self.assertEqual(completed.type, "response.completed") + usage = completed.response.usage + self.assertIsNotNone(usage) + self.assertGreater(usage.input_tokens, 0) + self.assertGreater(usage.output_tokens, 0) + self.assertEqual(usage.total_tokens, usage.input_tokens + usage.output_tokens) From 3d64a8cdfbe8f73379920f82b41f6e618b517280 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Mar 2026 17:29:11 +0000 Subject: [PATCH 14/64] add it for now but we will move it --- tests/cli/benchmark_serve.py | 564 +++++++++++++++++++++++++++++++++++ 1 file changed, 564 insertions(+) create mode 100644 tests/cli/benchmark_serve.py diff --git a/tests/cli/benchmark_serve.py b/tests/cli/benchmark_serve.py new file mode 100644 index 000000000000..45ce8be727a8 --- /dev/null +++ b/tests/cli/benchmark_serve.py @@ -0,0 +1,564 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Benchmark prefill and decode throughput for `transformers serve`. + +Tests: +- pp (prefill): sends a large prompt with max_tokens=1. Measures TTFT ≈ pure prefill time. + Default sizes: 256, 1024 tokens. +- tg (decode): sends a 512-token prompt (--tg-prefill) and generates many tokens. + Measures decode throughput after subtracting TTFT. Default sizes: 128, 512 tokens. + +Modes: +- bench: greedy decoding (do_sample=False, temp=0). Deterministic, best for reproducible numbers. +- chat: sampling (do_sample=True, temp=0.7). Simulates real chat usage. + +Recommended benchmarks: + + # HF model — greedy + python tests/cli/benchmark_serve.py --model Qwen/Qwen2.5-7B-Instruct --mode bench + + # HF model — sampling (simulates real chat) + python tests/cli/benchmark_serve.py --model Qwen/Qwen2.5-7B-Instruct --mode chat + + # GGUF model — greedy + python tests/cli/benchmark_serve.py \\ + --model "Qwen/Qwen2.5-7B-Instruct-GGUF/qwen2.5-7b-instruct-fp16-00001-of-00004.gguf --processor Qwen/Qwen2.5-7B-Instruct" \\ + --mode bench + + # GGUF model — sampling + python tests/cli/benchmark_serve.py \\ + --model "Qwen/Qwen2.5-7B-Instruct-GGUF/qwen2.5-7b-instruct-fp16-00001-of-00004.gguf --processor Qwen/Qwen2.5-7B-Instruct" \\ + --mode chat + + # Against an existing server + python tests/cli/benchmark_serve.py --url http://localhost:8000 --processor Qwen/Qwen2.5-7B-Instruct +""" + +import argparse +import json +import os +import statistics +import time + +# Force single GPU — must be set before any CUDA initialization +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +import requests + +from transformers import AutoTokenizer + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_FILLER = ( + "The quick brown fox jumps over the lazy dog. " + "Pack my box with five dozen liquor jugs. " + "How vexingly quick daft zebras jump. " + "Sphinx of black quartz, judge my vow. " +) * 200 + +_TG_PREFILL_DEFAULT = 512 + + +def make_prompt(tokenizer, num_tokens: int) -> str: + """Build a prompt string that tokenizes to exactly `num_tokens` tokens.""" + token_ids = tokenizer.encode(_FILLER, add_special_tokens=False) + if len(token_ids) < num_tokens: + repeats = (num_tokens // len(token_ids)) + 1 + token_ids = (token_ids * repeats)[:num_tokens] + else: + token_ids = token_ids[:num_tokens] + return tokenizer.decode(token_ids) + + +def wait_for_server(base_url: str, timeout: int = 120) -> bool: + """Poll GET /health until 200 or timeout.""" + deadline = time.time() + timeout + while time.time() < deadline: + try: + if requests.get(f"{base_url}/health", timeout=2).status_code == 200: + return True + except requests.ConnectionError: + pass + time.sleep(1) + return False + + +def streaming_chat_completion( + base_url: str, messages: list, max_tokens: int, seed: int, + do_sample: bool = False, cache_implementation: str | None = None, +) -> dict: + """Send a streaming chat completion request. Returns {total, ttft, completion_tokens, text}.""" + gen_cfg = {"max_new_tokens": max_tokens, "min_new_tokens": max_tokens, "do_sample": do_sample} + if do_sample: + gen_cfg["temperature"] = 0.7 + if cache_implementation: + gen_cfg["cache_implementation"] = cache_implementation + payload = { + "messages": messages, + "stream": True, + "seed": seed, + "generation_config": json.dumps(gen_cfg), + } + + t_start = time.perf_counter() + t_first_token = None + completion_tokens = None + text_chunks = [] + + resp = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=True, timeout=300) + resp.raise_for_status() + + for line in resp.iter_lines(decode_unicode=True): + if not line or not line.startswith("data: "): + continue + data_str = line[len("data: "):] + if data_str.strip() == "[DONE]": + break + try: + chunk = json.loads(data_str) + except json.JSONDecodeError: + continue + + choices = chunk.get("choices", []) + if not choices: + continue + + content = choices[0].get("delta", {}).get("content") + if content is not None and content != "": + text_chunks.append(content) + if t_first_token is None: + t_first_token = time.perf_counter() + + if chunk.get("usage"): + completion_tokens = chunk["usage"].get("completion_tokens") + + if choices[0].get("finish_reason") is not None: + break + + t_end = time.perf_counter() + + return { + "total": t_end - t_start, + "ttft": (t_first_token - t_start) if t_first_token else None, + "completion_tokens": completion_tokens, + "text": "".join(text_chunks), + } + + +def streaming_response( + base_url: str, messages: list, max_tokens: int, seed: int, + do_sample: bool = False, cache_implementation: str | None = None, +) -> dict: + """Send a streaming responses API request. Returns {total, ttft, completion_tokens, text}.""" + gen_cfg = {"max_new_tokens": max_tokens, "min_new_tokens": max_tokens, "do_sample": do_sample} + if do_sample: + gen_cfg["temperature"] = 0.7 + if cache_implementation: + gen_cfg["cache_implementation"] = cache_implementation + # Convert messages to Responses API input format + input_messages = messages + payload = { + "input": input_messages, + "stream": True, + "seed": seed, + "generation_config": json.dumps(gen_cfg), + } + + t_start = time.perf_counter() + t_first_token = None + text_chunks = [] + + resp = requests.post(f"{base_url}/v1/responses", json=payload, stream=True, timeout=300) + resp.raise_for_status() + + for line in resp.iter_lines(decode_unicode=True): + if not line or not line.startswith("data: "): + continue + try: + chunk = json.loads(line[len("data: "):]) + except json.JSONDecodeError: + continue + + etype = chunk.get("type") + if etype == "response.output_text.delta": + delta = chunk.get("delta", "") + if delta: + text_chunks.append(delta) + if t_first_token is None: + t_first_token = time.perf_counter() + elif etype == "response.completed": + break + + t_end = time.perf_counter() + text = "".join(text_chunks) + + return { + "total": t_end - t_start, + "ttft": (t_first_token - t_start) if t_first_token else None, + "completion_tokens": len(text_chunks), # approximate — one chunk per streamer push + "text": text, + } + + +def streaming_request( + base_url: str, messages: list, max_tokens: int, seed: int, + do_sample: bool = False, cache_implementation: str | None = None, + endpoint: str = "chat", +) -> dict: + """Dispatch to chat completions or responses API based on endpoint.""" + kw = dict(base_url=base_url, messages=messages, max_tokens=max_tokens, + seed=seed, do_sample=do_sample, cache_implementation=cache_implementation) + if endpoint == "responses": + return streaming_response(**kw) + return streaming_chat_completion(**kw) + + +# --------------------------------------------------------------------------- +# Scenarios +# --------------------------------------------------------------------------- + + +def bench_pp( + base_url: str, tokenizer, pp: int, warmup: int, iterations: int, seed: int, + do_sample: bool = False, cache_implementation: str | None = None, endpoint: str = "chat", +) -> dict: + """Prefill benchmark: large prompt, max_tokens=1. Measures TTFT ≈ pure prefill time.""" + prompt = make_prompt(tokenizer, pp) + messages = [{"role": "user", "content": prompt}] + kw = {"do_sample": do_sample, "cache_implementation": cache_implementation, "endpoint": endpoint} + + for _ in range(warmup): + streaming_request(base_url, messages, max_tokens=1, seed=seed, **kw) + + ttfts = [] + for _ in range(iterations): + r = streaming_request(base_url, messages, max_tokens=1, seed=seed, **kw) + if r["ttft"] is not None: + ttfts.append(r["ttft"]) + + ttft = statistics.median(ttfts) if ttfts else None + tok_s = pp / ttft if ttft and ttft > 0 else None + + return {"test": f"pp{pp}", "tokens": pp, "tok_s": tok_s, "time": ttft} + + +def bench_tg( + base_url: str, tokenizer, tg: int, warmup: int, iterations: int, seed: int, + tg_prefill: int = 512, do_sample: bool = False, cache_implementation: str | None = None, endpoint: str = "chat", +) -> dict: + """Decode benchmark: generate `tg` tokens after a `tg_prefill`-token prompt.""" + prompt = make_prompt(tokenizer, tg_prefill) + messages = [{"role": "user", "content": prompt}] + kw = {"do_sample": do_sample, "cache_implementation": cache_implementation, "endpoint": endpoint} + + for _ in range(warmup): + streaming_request(base_url, messages, max_tokens=tg, seed=seed, **kw) + + decode_times = [] + token_counts = [] + last_text = "" + for _ in range(iterations): + r = streaming_request(base_url, messages, max_tokens=tg, seed=seed, **kw) + if r["ttft"] is not None: + decode_times.append(r["total"] - r["ttft"]) + token_counts.append(r["completion_tokens"] if r["completion_tokens"] is not None else tg) + last_text = r["text"] + + if decode_times: + dt = statistics.median(decode_times) + toks = statistics.median(token_counts) + tok_s = toks / dt if dt > 0 else None + else: + dt = None + toks = tg + tok_s = None + + return {"test": f"tg{tg}", "tokens": int(toks), "tok_s": tok_s, "time": dt, "text": last_text} + + +# --------------------------------------------------------------------------- +# Output +# --------------------------------------------------------------------------- + + +def format_duration(seconds) -> str: + if seconds is None: + return "N/A" + if seconds < 1.0: + return f"{seconds * 1000:.1f}ms" + return f"{seconds:.2f}s" + + +def format_throughput(value) -> str: + if value is None: + return "N/A" + return f"{value:.2f}" + + +_PREVIEW_WIDTH = 120 + + +def truncate_preview(text: str, width: int = _PREVIEW_WIDTH) -> str: + """Single-line preview of generated text.""" + if not text: + return "" + line = text.replace("\n", " ").strip() + if len(line) > width: + return line[:width - 1] + "\u2026" + return line + + +def print_table(rows: list[dict], title: str = "", reference_texts: dict | None = None, is_reference: bool = False) -> None: + """Print results in a bordered table. + + Args: + reference_texts: dict mapping test name (e.g. "tg128") to reference text. + When provided, decode rows show REF/MATCH/MISMATCH. + is_reference: if True, this is the reference table — show REF instead of MATCH. + """ + if not rows: + return + + has_text = any(row.get("text") for row in rows) + has_ref = reference_texts is not None and has_text + + headers = ["test", "tokens", "tok/s", "time"] + align = ["<", ">", ">", ">"] + if has_ref: + headers.append("ref") + align.append("<") + if has_text: + headers.append("preview") + align.append("<") + + formatted_rows = [] + for row in rows: + cells = [ + row["test"], + str(row["tokens"]), + format_throughput(row["tok_s"]), + format_duration(row["time"]), + ] + text = row.get("text", "") + if has_ref: + ref_text = reference_texts.get(row["test"]) + if not text: + cells.append("") + elif ref_text is None: + cells.append("") + elif is_reference: + cells.append("REF") + elif text == ref_text: + cells.append("MATCH") + else: + cells.append("MISMATCH") + if has_text: + cells.append(truncate_preview(text)) + formatted_rows.append(cells) + + widths = [max(len(h), *(len(r[i]) for r in formatted_rows)) for i, h in enumerate(headers)] + + def pad(text, width, a): + return text.ljust(width) if a == "<" else text.rjust(width) + + def make_row(cells): + return "| " + " | ".join(pad(c, w, a) for c, w, a in zip(cells, widths, align)) + " |" + + def make_sep(char="-"): + return "+" + "+".join(char * (w + 2) for w in widths) + "+" + + print() + if title: + print(title) + print(make_sep("-")) + print(make_row(headers)) + print(make_sep("=")) + for r in formatted_rows: + print(make_row(r)) + print(make_sep("-")) + print() + + +# --------------------------------------------------------------------------- +# Server management +# --------------------------------------------------------------------------- + + +def start_server( + model: str, port: int, processor: str | None = None, attn_implementation: str | None = None, + compile: bool = False, +): + """Start a transformers serve instance. Returns the Serve object.""" + from transformers.cli.serve_refactored import Serve + + kwargs = {"force_model": model, "port": port, "non_blocking": True, "log_level": "warning"} + if processor: + kwargs["processor"] = processor + if attn_implementation: + kwargs["attn_implementation"] = attn_implementation + if compile: + kwargs["compile"] = True + return Serve(**kwargs) + + +def parse_model_spec(spec: str) -> dict: + """Parse 'model_id' or 'model_id --processor tokenizer_id'. + + Returns {"model": str, "processor": str | None, "tokenizer": str} + """ + parts = spec.split() + model = parts[0] + processor = None + for i, p in enumerate(parts): + if p == "--processor" and i + 1 < len(parts): + processor = parts[i + 1] + tokenizer_id = processor if processor else model + return {"model": model, "processor": processor, "tokenizer": tokenizer_id} + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark transformers serve (prefill & decode separately)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog="""examples: + python benchmark_serve.py --model Qwen/Qwen2.5-7B-Instruct + python benchmark_serve.py --model "org/model-GGUF/file.gguf --processor org/model" + python benchmark_serve.py --url http://localhost:8000 --processor Qwen/Qwen2.5-7B-Instruct +""", + ) + parser.add_argument("--model", type=str, action="append", dest="models", + help="Model spec (repeatable). For GGUF: 'gguf_id --processor tokenizer_id'") + parser.add_argument("--processor", type=str, default=None, + help="Processor/tokenizer ID for --url mode (default: derived from model)") + parser.add_argument("--port", type=int, default=8642, help="Server port") + parser.add_argument("--url", type=str, default=None, + help="Connect to existing server (skip start/stop)") + parser.add_argument("--warmup", type=int, default=1, help="Warmup iterations (minimum 1)") + parser.add_argument("--iterations", type=int, default=3, help="Measurement iterations") + parser.add_argument("--pp", type=int, nargs="+", default=[256, 1024], help="Prefill token counts") + parser.add_argument("--tg", type=int, nargs="+", default=[128, 512], help="Decode token counts") + parser.add_argument("--tg-prefill", type=int, default=_TG_PREFILL_DEFAULT, + help="Prefill size for decode tests (default: 512)") + parser.add_argument("--attn-impl", type=str, nargs="+", default=["sdpa", "eager", "flash_attention_2"], + help="Attention implementations to benchmark (default: sdpa eager flash_attention_2)") + parser.add_argument("--cache", type=str, default=None, + help="Cache implementation (e.g. 'static' for StaticCache + torch.compile)") + parser.add_argument("--mode", type=str, choices=["bench", "chat"], default="bench", + help="bench: greedy (temp=0). chat: sampling (do_sample=True, temp=0.7)") + parser.add_argument("--endpoint", type=str, choices=["chat", "responses"], default="responses", + help="API endpoint to benchmark (default: responses = /v1/responses)") + parser.add_argument("--seed", type=int, default=42, help="Torch seed") + args = parser.parse_args() + + args.warmup = max(args.warmup, 1) + do_sample = args.mode == "chat" + mode_str = "chat (do_sample=True, temp=0.7)" if do_sample else "bench (greedy, temp=0)" + cache_impl = args.cache + endpoint = args.endpoint + endpoint_path = "/v1/responses" if endpoint == "responses" else "/v1/chat/completions" + + if args.url: + # Against an existing server + base_url = args.url.rstrip("/") + tokenizer_id = args.processor or (args.models[0] if args.models else "Qwen/Qwen2.5-7B-Instruct") + print(f"Using server at {base_url}, endpoint={endpoint_path}, mode={mode_str}") + print(f"Loading tokenizer from {tokenizer_id}...") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + + rows = [] + for pp in args.pp: + print(f" pp{pp}") + rows.append(bench_pp(base_url, tokenizer, pp, args.warmup, args.iterations, args.seed, do_sample=do_sample, cache_implementation=cache_impl, endpoint=endpoint)) + for tg in args.tg: + print(f" tg{tg}") + rows.append(bench_tg(base_url, tokenizer, tg, args.warmup, args.iterations, args.seed, tg_prefill=args.tg_prefill, do_sample=do_sample, cache_implementation=cache_impl, endpoint=endpoint)) + print_table(rows) + + else: + # Start a fresh server per model, benchmark, stop + if not args.models: + args.models = ["Qwen/Qwen2.5-7B-Instruct"] + + for model_str in args.models: + spec = parse_model_spec(model_str) + # Reference texts from the first attn impl (bench mode only) + reference_texts = None + + for attn_impl in args.attn_impl: + print(f"\nStarting server for {spec['model']} (attn={attn_impl})...") + try: + server = start_server(spec["model"], args.port, spec["processor"], attn_implementation=attn_impl, + compile=cache_impl == "static") + except Exception as e: + print(f" ERROR: Failed to start server with attn={attn_impl}: {e}. Skipping.") + continue + + base_url = f"http://localhost:{args.port}" + if not wait_for_server(base_url): + print(" ERROR: Server did not become ready. Skipping.") + server.kill_server() + continue + + tokenizer = AutoTokenizer.from_pretrained(spec["tokenizer"]) + + # Warmup (always dynamic cache — static cache compiles shapes, so a short warmup would break longer requests) + streaming_request(base_url, [{"role": "user", "content": "hi"}], max_tokens=5, seed=args.seed, endpoint=endpoint) + + rows = [] + for pp in args.pp: + print(f" pp{pp}") + rows.append(bench_pp(base_url, tokenizer, pp, args.warmup, args.iterations, args.seed, do_sample=do_sample, cache_implementation=cache_impl, endpoint=endpoint)) + for tg in args.tg: + print(f" tg{tg}") + rows.append(bench_tg(base_url, tokenizer, tg, args.warmup, args.iterations, args.seed, tg_prefill=args.tg_prefill, do_sample=do_sample, cache_implementation=cache_impl, endpoint=endpoint)) + + server.kill_server() + + # Build reference from first attn impl in greedy mode + if not do_sample and reference_texts is None: + reference_texts = { + row["test"]: row["text"] for row in rows if row.get("text") + } + # Pass reference_texts so the first impl shows "REF" in the ref column + print_table(rows, title=f"{spec['model']} | attn={attn_impl} ({mode_str})", + reference_texts=reference_texts if len(args.attn_impl) > 1 else None, + is_reference=True) + else: + print_table(rows, title=f"{spec['model']} | attn={attn_impl} ({mode_str})", + reference_texts=reference_texts if not do_sample else None) + + # Summary: check for mismatches across attn impls (bench mode only) + if not do_sample and reference_texts and len(args.attn_impl) > 1: + print_reference_summary(reference_texts, args.attn_impl[0]) + + +def print_reference_summary(reference_texts: dict[str, str], ref_impl: str) -> None: + """Print a summary noting that outputs are compared against the reference implementation.""" + print(f"Reference comparison: all decode outputs compared against '{ref_impl}'.") + print(f" MATCH = identical text (greedy decoding is deterministic)") + print(f" MISMATCH = text differs (FP divergence across attention kernels — check preview for correctness)") + print() + + +if __name__ == "__main__": + main() From 552603c18b832d4b5f5fde8af0f42fdab51b472a Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Mar 2026 17:34:10 +0000 Subject: [PATCH 15/64] remove cache impl --- tests/cli/benchmark_serve.py | 37 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/tests/cli/benchmark_serve.py b/tests/cli/benchmark_serve.py index 45ce8be727a8..21db7fb7524c 100644 --- a/tests/cli/benchmark_serve.py +++ b/tests/cli/benchmark_serve.py @@ -100,14 +100,13 @@ def wait_for_server(base_url: str, timeout: int = 120) -> bool: def streaming_chat_completion( base_url: str, messages: list, max_tokens: int, seed: int, - do_sample: bool = False, cache_implementation: str | None = None, + do_sample: bool = False, ) -> dict: """Send a streaming chat completion request. Returns {total, ttft, completion_tokens, text}.""" gen_cfg = {"max_new_tokens": max_tokens, "min_new_tokens": max_tokens, "do_sample": do_sample} if do_sample: gen_cfg["temperature"] = 0.7 - if cache_implementation: - gen_cfg["cache_implementation"] = cache_implementation + payload = { "messages": messages, "stream": True, @@ -162,14 +161,13 @@ def streaming_chat_completion( def streaming_response( base_url: str, messages: list, max_tokens: int, seed: int, - do_sample: bool = False, cache_implementation: str | None = None, + do_sample: bool = False, ) -> dict: """Send a streaming responses API request. Returns {total, ttft, completion_tokens, text}.""" gen_cfg = {"max_new_tokens": max_tokens, "min_new_tokens": max_tokens, "do_sample": do_sample} if do_sample: gen_cfg["temperature"] = 0.7 - if cache_implementation: - gen_cfg["cache_implementation"] = cache_implementation + # Convert messages to Responses API input format input_messages = messages payload = { @@ -217,12 +215,12 @@ def streaming_response( def streaming_request( base_url: str, messages: list, max_tokens: int, seed: int, - do_sample: bool = False, cache_implementation: str | None = None, + do_sample: bool = False, endpoint: str = "chat", ) -> dict: """Dispatch to chat completions or responses API based on endpoint.""" kw = dict(base_url=base_url, messages=messages, max_tokens=max_tokens, - seed=seed, do_sample=do_sample, cache_implementation=cache_implementation) + seed=seed, do_sample=do_sample) if endpoint == "responses": return streaming_response(**kw) return streaming_chat_completion(**kw) @@ -235,12 +233,12 @@ def streaming_request( def bench_pp( base_url: str, tokenizer, pp: int, warmup: int, iterations: int, seed: int, - do_sample: bool = False, cache_implementation: str | None = None, endpoint: str = "chat", + do_sample: bool = False, endpoint: str = "chat", ) -> dict: """Prefill benchmark: large prompt, max_tokens=1. Measures TTFT ≈ pure prefill time.""" prompt = make_prompt(tokenizer, pp) messages = [{"role": "user", "content": prompt}] - kw = {"do_sample": do_sample, "cache_implementation": cache_implementation, "endpoint": endpoint} + kw = {"do_sample": do_sample, "endpoint": endpoint} for _ in range(warmup): streaming_request(base_url, messages, max_tokens=1, seed=seed, **kw) @@ -259,12 +257,12 @@ def bench_pp( def bench_tg( base_url: str, tokenizer, tg: int, warmup: int, iterations: int, seed: int, - tg_prefill: int = 512, do_sample: bool = False, cache_implementation: str | None = None, endpoint: str = "chat", + tg_prefill: int = 512, do_sample: bool = False, endpoint: str = "chat", ) -> dict: """Decode benchmark: generate `tg` tokens after a `tg_prefill`-token prompt.""" prompt = make_prompt(tokenizer, tg_prefill) messages = [{"role": "user", "content": prompt}] - kw = {"do_sample": do_sample, "cache_implementation": cache_implementation, "endpoint": endpoint} + kw = {"do_sample": do_sample, "endpoint": endpoint} for _ in range(warmup): streaming_request(base_url, messages, max_tokens=tg, seed=seed, **kw) @@ -461,8 +459,8 @@ def main(): help="Prefill size for decode tests (default: 512)") parser.add_argument("--attn-impl", type=str, nargs="+", default=["sdpa", "eager", "flash_attention_2"], help="Attention implementations to benchmark (default: sdpa eager flash_attention_2)") - parser.add_argument("--cache", type=str, default=None, - help="Cache implementation (e.g. 'static' for StaticCache + torch.compile)") + parser.add_argument("--compile", action="store_true", + help="Enable static cache + torch.compile on the server for faster decode") parser.add_argument("--mode", type=str, choices=["bench", "chat"], default="bench", help="bench: greedy (temp=0). chat: sampling (do_sample=True, temp=0.7)") parser.add_argument("--endpoint", type=str, choices=["chat", "responses"], default="responses", @@ -473,7 +471,6 @@ def main(): args.warmup = max(args.warmup, 1) do_sample = args.mode == "chat" mode_str = "chat (do_sample=True, temp=0.7)" if do_sample else "bench (greedy, temp=0)" - cache_impl = args.cache endpoint = args.endpoint endpoint_path = "/v1/responses" if endpoint == "responses" else "/v1/chat/completions" @@ -488,10 +485,10 @@ def main(): rows = [] for pp in args.pp: print(f" pp{pp}") - rows.append(bench_pp(base_url, tokenizer, pp, args.warmup, args.iterations, args.seed, do_sample=do_sample, cache_implementation=cache_impl, endpoint=endpoint)) + rows.append(bench_pp(base_url, tokenizer, pp, args.warmup, args.iterations, args.seed, do_sample=do_sample, endpoint=endpoint)) for tg in args.tg: print(f" tg{tg}") - rows.append(bench_tg(base_url, tokenizer, tg, args.warmup, args.iterations, args.seed, tg_prefill=args.tg_prefill, do_sample=do_sample, cache_implementation=cache_impl, endpoint=endpoint)) + rows.append(bench_tg(base_url, tokenizer, tg, args.warmup, args.iterations, args.seed, tg_prefill=args.tg_prefill, do_sample=do_sample, endpoint=endpoint)) print_table(rows) else: @@ -508,7 +505,7 @@ def main(): print(f"\nStarting server for {spec['model']} (attn={attn_impl})...") try: server = start_server(spec["model"], args.port, spec["processor"], attn_implementation=attn_impl, - compile=cache_impl == "static") + compile=args.compile) except Exception as e: print(f" ERROR: Failed to start server with attn={attn_impl}: {e}. Skipping.") continue @@ -527,10 +524,10 @@ def main(): rows = [] for pp in args.pp: print(f" pp{pp}") - rows.append(bench_pp(base_url, tokenizer, pp, args.warmup, args.iterations, args.seed, do_sample=do_sample, cache_implementation=cache_impl, endpoint=endpoint)) + rows.append(bench_pp(base_url, tokenizer, pp, args.warmup, args.iterations, args.seed, do_sample=do_sample, endpoint=endpoint)) for tg in args.tg: print(f" tg{tg}") - rows.append(bench_tg(base_url, tokenizer, tg, args.warmup, args.iterations, args.seed, tg_prefill=args.tg_prefill, do_sample=do_sample, cache_implementation=cache_impl, endpoint=endpoint)) + rows.append(bench_tg(base_url, tokenizer, tg, args.warmup, args.iterations, args.seed, tg_prefill=args.tg_prefill, do_sample=do_sample, endpoint=endpoint)) server.kill_server() From 3643ece62928dcb326e2336907d850f6f65c59e4 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 20 Mar 2026 17:34:43 +0000 Subject: [PATCH 16/64] add back load_model --- src/transformers/cli/serve_refactored.py | 5 + src/transformers/cli/serving/model_manager.py | 145 ++++++- src/transformers/cli/serving/server.py | 17 +- src/transformers/cli/serving/utils.py | 109 ++++++ tests/cli/test_serve_refactored.py | 361 +++++++++++++++++- 5 files changed, 619 insertions(+), 18 deletions(-) diff --git a/src/transformers/cli/serve_refactored.py b/src/transformers/cli/serve_refactored.py index c1abf8104cdd..6338c4bbe73b 100644 --- a/src/transformers/cli/serve_refactored.py +++ b/src/transformers/cli/serve_refactored.py @@ -112,6 +112,7 @@ def __init__( force_model=force_model, processor_id=processor, ) + self._model_manager = model_manager # Single persistent thread for all generate() calls — required for # torch.compile + CUDA graphs which use thread-local storage. @@ -152,6 +153,10 @@ def _run(): self._thread = threading.Thread(target=_run, name="uvicorn-thread", daemon=False) self._thread.start() + def reset_loaded_models(self): + """Clear all loaded models from memory.""" + self._model_manager.shutdown() + def kill_server(self): if not self._thread or not self._thread.is_alive(): return diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 224d43c91086..c184d206f0d1 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -17,6 +17,7 @@ from __future__ import annotations +import asyncio import gc import json import threading @@ -30,7 +31,7 @@ from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizerBase from ...utils import logging -from .utils import Modality, reset_torch_cache +from .utils import Modality, make_progress_tqdm_class, reset_torch_cache if TYPE_CHECKING: @@ -128,6 +129,14 @@ def __init__( self.loaded_models: dict[str, TimedModel] = {} + # Thread-safety for concurrent load_model_and_processor calls + self._model_locks: dict[str, threading.Lock] = {} + self._model_locks_guard = threading.Lock() + + # Tracks in-flight loads for fan-out to multiple SSE subscribers (used by load_model_streaming) + self._loading_subscribers: dict[str, list[asyncio.Queue[str | None]]] = {} + self._loading_tasks: dict[str, asyncio.Task] = {} + if force_model is not None: self.load_model_and_processor(self.process_model_name(force_model), processor_id=processor_id) @@ -176,7 +185,7 @@ def _load_processor( except OSError: raise OSError(f"Failed to load processor for {model_id} with AutoProcessor and AutoTokenizer.") - def _load_model(self, model_id_and_revision: str) -> PreTrainedModel: + def _load_model(self, model_id_and_revision: str, tqdm_class=None, progress_callback=None) -> PreTrainedModel: """Load a model. GGUF files are detected by the `.gguf` extension and loaded via llama.cpp.""" import torch @@ -205,34 +214,140 @@ def _load_model(self, model_id_and_revision: str) -> PreTrainedModel: "device_map": self.device, "trust_remote_code": self.trust_remote_code, "quantization_config": self.get_quantization_config(), + "tqdm_class": tqdm_class, } + if progress_callback is not None: + progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "config"}) config = AutoConfig.from_pretrained(model_id, **model_kwargs) architecture = getattr(transformers, config.architectures[0]) return architecture.from_pretrained(model_id, **model_kwargs) def load_model_and_processor( - self, model_id_and_revision: str, processor_id: str | None = None + self, + model_id_and_revision: str, + processor_id: str | None = None, + progress_callback=None, + tqdm_class=None, ) -> tuple[PreTrainedModel, ProcessorMixin | PreTrainedTokenizerFast]: """Load a model (or return it from cache), resetting its inactivity timer. Args: model_id_and_revision: Model ID in ``'model_id@revision'`` format. - processor_id: Optional per-request processor override (takes precedence - over the instance-level ``self.processor_id``). + processor_id: Optional per-request processor override. + progress_callback: If provided, called with dicts like + ``{"status": "loading", "model": ..., "stage": ...}`` during loading. + tqdm_class: Optional tqdm subclass for progress bars during ``from_pretrained``. """ - if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted(): - processor = self._load_processor(model_id_and_revision, processor_id=processor_id) - model = self._load_model(model_id_and_revision) - self.loaded_models[model_id_and_revision] = TimedModel( - model, timeout_seconds=self.model_timeout, processor=processor - ) - else: - self.loaded_models[model_id_and_revision].reset_timer() - model = self.loaded_models[model_id_and_revision].model - processor = self.loaded_models[model_id_and_revision].processor + # Per-model lock prevents duplicate loads when concurrent requests arrive + with self._model_locks_guard: + lock = self._model_locks.setdefault(model_id_and_revision, threading.Lock()) + + with lock: + if ( + model_id_and_revision not in self.loaded_models + or self.loaded_models[model_id_and_revision].is_deleted() + ): + if progress_callback is not None: + progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "processor"}) + processor = self._load_processor(model_id_and_revision, processor_id=processor_id) + model = self._load_model( + model_id_and_revision, tqdm_class=tqdm_class, progress_callback=progress_callback + ) + self.loaded_models[model_id_and_revision] = TimedModel( + model, timeout_seconds=self.model_timeout, processor=processor + ) + if progress_callback is not None: + progress_callback({"status": "ready", "model": model_id_and_revision, "cached": False}) + else: + self.loaded_models[model_id_and_revision].reset_timer() + model = self.loaded_models[model_id_and_revision].model + processor = self.loaded_models[model_id_and_revision].processor + if progress_callback is not None: + progress_callback({"status": "ready", "model": model_id_and_revision, "cached": True}) return model, processor + async def load_model_streaming(self, model_id: str): + """Load a model and stream progress as SSE events. + + Handles three cases: + 1. Model already cached → single ``ready`` event + 2. Load already in progress → join existing subscriber stream + 3. First request → start loading, broadcast to all subscribers + + Yields SSE ``data: ...`` lines. + """ + queue: asyncio.Queue[str | None] = asyncio.Queue() + + # Case 1: already cached + if model_id in self.loaded_models and not self.loaded_models[model_id].is_deleted(): + self.loaded_models[model_id].reset_timer() + yield f"data: {json.dumps({'status': 'ready', 'model': model_id, 'cached': True})}\n\n" + return + + # Case 2: load in progress — join existing subscribers + if model_id in self._loading_tasks: + self._loading_subscribers[model_id].append(queue) + while True: + item = await queue.get() + if item is None: + break + yield item + return + + # Case 3: first request — start the load + self._loading_subscribers[model_id] = [queue] + loop = asyncio.get_running_loop() + + def enqueue(payload: dict): + msg = f"data: {json.dumps(payload)}\n\n" + + def broadcast(): + for q in self._loading_subscribers.get(model_id, []): + q.put_nowait(msg) + + loop.call_soon_threadsafe(broadcast) + + tqdm_class = make_progress_tqdm_class(enqueue, model_id) + + def _tqdm_hook(factory, args, kwargs): + return tqdm_class(*args, **kwargs) + + async def run_load(): + try: + # Install a global tqdm hook so the "Loading weights" bar in + # core_model_loading.py (which uses logging.tqdm) routes through + # our ProgressTqdm. The tqdm_class kwarg only covers download bars. + previous_hook = logging.set_tqdm_hook(_tqdm_hook) + try: + await asyncio.to_thread( + self.load_model_and_processor, + model_id, + progress_callback=enqueue, + tqdm_class=tqdm_class, + ) + finally: + logging.set_tqdm_hook(previous_hook) + except Exception as e: + logger.error(f"Failed to load {model_id}: {e}", exc_info=True) + enqueue({"status": "error", "model": model_id, "message": str(e)}) + finally: + + def _send_sentinel(): + for q in self._loading_subscribers.pop(model_id, []): + q.put_nowait(None) + self._loading_tasks.pop(model_id, None) + + loop.call_soon_threadsafe(_send_sentinel) + + self._loading_tasks[model_id] = asyncio.create_task(run_load()) + + while True: + item = await queue.get() + if item is None: + break + yield item + def shutdown(self) -> None: """Delete all loaded models and free resources.""" for timed in self.loaded_models.values(): diff --git a/src/transformers/cli/serving/server.py b/src/transformers/cli/serving/server.py index 4730509e0dc4..150cec14115f 100644 --- a/src/transformers/cli/serving/server.py +++ b/src/transformers/cli/serving/server.py @@ -22,7 +22,7 @@ from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse from ...utils import logging from .chat_completion import ChatCompletionHandler @@ -90,6 +90,21 @@ async def chat_completions(request: Request, body: dict): async def responses(request: Request, body: dict): return response_handler.handle_request(body, request.state.request_id) + @app.post("/load_model") + async def load_model(body: dict): + from fastapi import HTTPException + + model = body.get("model") + if model is None: + raise HTTPException(status_code=422, detail="Missing `model` field in the request body.") + model_id = model_manager.process_model_name(model) + return StreamingResponse(model_manager.load_model_streaming(model_id), media_type="text/event-stream") + + @app.post("/reset") + def reset(): + model_manager.shutdown() + return JSONResponse({"status": "ok"}) + @app.get("/v1/models") @app.options("/v1/models") def list_models(): diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index c97c0ab2841d..3a3e0a774224 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -95,6 +95,115 @@ def __init__(self, msg: str): self.msg = msg +# --------------------------------------------------------------------------- +# Progress tracking for model loading +# --------------------------------------------------------------------------- + + +class DownloadAggregator: + """Aggregates byte-progress across multiple concurrent download tqdm bars. + + huggingface_hub opens one tqdm bar per file shard. This class tracks them all and emits + a single aggregate ``{"stage": "download", "progress": {...}}`` event whenever any updates. + """ + + def __init__(self, enqueue, model_id: str): + self.enqueue = enqueue + self.model = model_id + self.bars: dict[int, tuple[int, int | None]] = {} + self.last_emitted_current: int | None = None + + def register(self, bar_id: int, total: int | None): + self.bars[bar_id] = (0, total) + self._emit() + + def update(self, bar_id: int, current: int, total: int | None): + self.bars[bar_id] = (current, total) + self._emit() + + def close(self, bar_id: int): + pass # keep the bar so totals remain correct + + def _emit(self): + agg_current = sum(c for c, _ in self.bars.values()) + if agg_current == self.last_emitted_current: + return + self.last_emitted_current = agg_current + totals = [t for _, t in self.bars.values() if t is not None] + agg_total = sum(totals) if totals else None + self.enqueue( + { + "status": "loading", + "model": self.model, + "stage": "download", + "progress": {"current": agg_current, "total": agg_total}, + } + ) + + +def make_progress_tqdm_class(callback, model_id: str): + """Create a tqdm subclass that routes progress to a callback. + + Bars with ``unit="B"`` are download bars — aggregated via ``DownloadAggregator``. + Other bars (e.g. "Loading weights") emit ``weights`` stage events. + """ + from tqdm.auto import tqdm as base_tqdm + + download_aggregator = DownloadAggregator(callback, model_id) + + class ProgressTqdm(base_tqdm): + def __init__(self, *args, **kwargs): + self.sse_unit = kwargs.get("unit") or "it" + kwargs["disable"] = True + super().__init__(*args, **kwargs) + self.n = 0 + self.last_emitted = -1 + if self.sse_unit == "B": + self._bar_id = id(self) + download_aggregator.register(self._bar_id, self.total) + + def update(self, n=1): + if n is None: + n = 1 + self.n += n + if self.sse_unit == "B": + download_aggregator.update(self._bar_id, self.n, self.total) + elif self.n != self.last_emitted: + self.last_emitted = self.n + callback( + { + "status": "loading", + "model": model_id, + "stage": "weights", + "progress": {"current": self.n, "total": self.total}, + } + ) + + def __iter__(self): + for item in self.iterable: + self.n += 1 + if self.sse_unit == "B": + download_aggregator.update(self._bar_id, self.n, self.total) + elif self.n != self.last_emitted: + self.last_emitted = self.n + callback( + { + "status": "loading", + "model": model_id, + "stage": "weights", + "progress": {"current": self.n, "total": self.total}, + } + ) + yield item + + def close(self): + if self.sse_unit == "B": + download_aggregator.close(self._bar_id) + super().close() + + return ProgressTqdm + + # --------------------------------------------------------------------------- # Streaming # --------------------------------------------------------------------------- diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py index 771d936d2a77..b2289175aaaa 100644 --- a/tests/cli/test_serve_refactored.py +++ b/tests/cli/test_serve_refactored.py @@ -532,6 +532,32 @@ def test_streaming(self): text += chunk.choices[0].delta.content self.assertTrue(len(text) > 0) + def test_early_return_due_to_length(self): + """When max_tokens is hit, finish_reason should be 'length'.""" + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Hello, how are you?"}], + stream=True, + max_tokens=3, + ) + ) + last = chunks[-1] + self.assertEqual(last.choices[0].finish_reason, "length") + + def test_continues_until_stop(self): + """When model stops naturally, finish_reason should be 'stop'.""" + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": 'Please only answer with "Hi."'}], + stream=True, + max_tokens=30, + ) + ) + last = chunks[-1] + self.assertEqual(last.choices[0].finish_reason, "stop") + def test_stop_strings(self): resp = self.client.chat.completions.create( model=self.MODEL, messages=[{"role": "user", "content": "Count to 10"}], stop=["5"] @@ -575,7 +601,9 @@ def test_non_streaming_usage(self): def test_streaming_usage(self): chunks = list( self.client.chat.completions.create( - model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}], stream=True, + model=self.MODEL, + messages=[{"role": "user", "content": "Say hello"}], + stream=True, ) ) # Last chunk should have usage @@ -585,6 +613,59 @@ def test_streaming_usage(self): self.assertGreater(last.usage.completion_tokens, 0) self.assertEqual(last.usage.total_tokens, last.usage.prompt_tokens + last.usage.completion_tokens) + def test_concurrent_non_streaming(self): + """Two concurrent non-streaming requests should both complete without interference.""" + import concurrent.futures + + prompts = [ + [{"role": "user", "content": "Say hello"}], + [{"role": "user", "content": "Say goodbye"}], + ] + results = [None, None] + + def request_in_thread(index): + client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + results[index] = client.chat.completions.create(model=self.MODEL, messages=prompts[index]) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(request_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() # re-raise exceptions + + for i in range(2): + self.assertIsNotNone(results[i]) + self.assertIsNotNone(results[i].choices[0].message.content) + self.assertTrue(len(results[i].choices[0].message.content) > 0) + + def test_concurrent_streaming(self): + """Two concurrent streaming requests should both produce complete, non-empty output.""" + import concurrent.futures + + prompts = [ + [{"role": "user", "content": "Say hello"}], + [{"role": "user", "content": "Say goodbye"}], + ] + results = [None, None] + + def stream_in_thread(index): + client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + text = "" + for chunk in client.chat.completions.create(model=self.MODEL, messages=prompts[index], stream=True): + if chunk.choices[0].delta.content: + text += chunk.choices[0].delta.content + results[index] = text + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(stream_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + self.assertIsNotNone(results[i]) + self.assertTrue(len(results[i]) > 0, f"Request {i} produced empty output") + # --------------------------------------------------------------------------- # 8. Unit tests — Response handler @@ -748,7 +829,6 @@ def test_sse_format(self): self.assertEqual(parsed["response"]["status"], "queued") - # --------------------------------------------------------------------------- # 9. Integration tests — Responses API (need GPU + model) # --------------------------------------------------------------------------- @@ -848,3 +928,280 @@ def test_streaming_usage(self): self.assertGreater(usage.input_tokens, 0) self.assertGreater(usage.output_tokens, 0) self.assertEqual(usage.total_tokens, usage.input_tokens + usage.output_tokens) + + +# --------------------------------------------------------------------------- +# 10. Integration tests — /load_model endpoint (need GPU + model) +# --------------------------------------------------------------------------- + + +def _parse_sse_events(response): + """Parse SSE lines from a streaming requests response into a list of dicts.""" + events = [] + for line in response.iter_lines(decode_unicode=True): + if not line or not line.startswith("data: "): + continue + events.append(json.loads(line[len("data: ") :])) + return events + + +@unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") +@require_openai +class TestLoadModel(unittest.TestCase): + """Integration tests for POST /load_model SSE endpoint.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + PORT = 8879 + + @classmethod + def setUpClass(cls): + import requests as req + + from transformers.cli.serve_refactored import Serve + + cls.serve = Serve(port=cls.PORT, non_blocking=True, log_level="warning") + for _ in range(30): + try: + if req.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: + break + except Exception: + pass + time.sleep(2) + cls.base_url = f"http://localhost:{cls.PORT}" + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + def setUp(self): + # Clear model cache so each test starts fresh + self.serve.reset_loaded_models() + + def _load_model(self, model: str): + import requests as req + + resp = req.post(f"{self.base_url}/load_model", json={"model": model}, stream=True, timeout=120) + events = _parse_sse_events(resp) + return resp, events + + def test_load_model_fresh(self): + """POST /load_model returns SSE events ending with ready.""" + response, events = self._load_model(self.MODEL) + + self.assertEqual(response.status_code, 200) + + stages = [e["stage"] for e in events if e["status"] == "loading" and "stage" in e] + self.assertIn("processor", stages) + self.assertIn("weights", stages) + + last = events[-1] + self.assertEqual(last["status"], "ready") + self.assertFalse(last["cached"]) + + for event in events: + self.assertIn("status", event) + self.assertIn("model", event) + + def test_load_model_cached(self): + """Loading an already-loaded model returns a single ready event with cached: true.""" + self._load_model(self.MODEL) + + _, events = self._load_model(self.MODEL) + + ready_events = [e for e in events if e["status"] == "ready"] + self.assertEqual(len(ready_events), 1) + self.assertTrue(ready_events[0]["cached"]) + + loading_events = [e for e in events if e["status"] == "loading"] + self.assertEqual(len(loading_events), 0) + + def test_load_model_error(self): + """Loading a nonexistent model produces an error event.""" + _, events = self._load_model("nonexistent/model-that-does-not-exist") + + error_events = [e for e in events if e["status"] == "error"] + self.assertGreaterEqual(len(error_events), 1) + self.assertIn("message", error_events[0]) + + def test_load_model_missing_field(self): + """POST /load_model with no model field returns 422.""" + import requests as req + + response = req.post(f"{self.base_url}/load_model", json={}, timeout=30) + self.assertEqual(response.status_code, 422) + + def test_load_model_event_schema(self): + """Every event conforms to the expected schema.""" + _, events = self._load_model(self.MODEL) + + for event in events: + self.assertIsInstance(event["status"], str) + self.assertIsInstance(event["model"], str) + + if event["status"] == "loading": + self.assertIn("stage", event) + if event["stage"] in ("download", "weights") and "progress" in event: + progress = event["progress"] + self.assertIn("current", progress) + self.assertIn("total", progress) + self.assertIsInstance(progress["current"], int) + + if event["status"] == "ready": + self.assertIn("cached", event) + self.assertIsInstance(event["cached"], bool) + + def test_load_model_stage_ordering(self): + """Stages appear in the expected order.""" + _, events = self._load_model(self.MODEL) + + stages = [e["stage"] for e in events if e["status"] == "loading" and "stage" in e] + seen = set() + unique_stages = [] + for s in stages: + if s not in seen: + seen.add(s) + unique_stages.append(s) + + expected_order = ["processor", "config", "download", "weights"] + expected_present = [s for s in expected_order if s in unique_stages] + self.assertEqual(unique_stages, expected_present, "Stages appeared out of order") + + def test_concurrent_load_same_model(self): + """Two concurrent /load_model requests both get events and a ready event.""" + import concurrent.futures + + results = [None, None] + + def load_in_thread(index): + import requests as req + + resp = req.post(f"{self.base_url}/load_model", json={"model": self.MODEL}, stream=True, timeout=120) + events = _parse_sse_events(resp) + results[index] = (resp.status_code, events) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(load_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + status_code, events = results[i] + self.assertEqual(status_code, 200, f"Caller {i} got non-200 status") + self.assertTrue(len(events) > 0, f"Caller {i} received no events") + ready_events = [e for e in events if e["status"] == "ready"] + self.assertEqual(len(ready_events), 1, f"Caller {i} should get exactly one ready event") + + def test_concurrent_load_second_caller_gets_cached(self): + """If the first /load_model finishes before the second, the second gets cached: true.""" + _, events1 = self._load_model(self.MODEL) + ready1 = [e for e in events1 if e["status"] == "ready"] + self.assertEqual(len(ready1), 1) + self.assertFalse(ready1[0]["cached"]) + + _, events2 = self._load_model(self.MODEL) + ready2 = [e for e in events2 if e["status"] == "ready"] + self.assertEqual(len(ready2), 1) + self.assertTrue(ready2[0]["cached"]) + + loading2 = [e for e in events2 if e["status"] == "loading"] + self.assertEqual(len(loading2), 0) + + def test_load_model_weights_progress_complete(self): + """Weights progress should go from 1 to total, with total matching across events.""" + _, events = self._load_model(self.MODEL) + + weights_events = [e for e in events if e.get("stage") == "weights" and "progress" in e] + self.assertGreater(len(weights_events), 0, "No weights progress events emitted") + + # All events should have the same total + totals = {e["progress"]["total"] for e in weights_events} + self.assertEqual(len(totals), 1, f"Inconsistent totals: {totals}") + total = totals.pop() + self.assertIsNotNone(total) + self.assertGreater(total, 0) + + # First should be 1, last should be total + self.assertEqual(weights_events[0]["progress"]["current"], 1) + self.assertEqual(weights_events[-1]["progress"]["current"], total) + + # Progress should be monotonically increasing + currents = [e["progress"]["current"] for e in weights_events] + self.assertEqual(currents, sorted(currents)) + + def test_load_model_exactly_one_ready(self): + """A fresh load should produce exactly one ready event as the last event.""" + _, events = self._load_model(self.MODEL) + + ready_events = [e for e in events if e["status"] == "ready"] + self.assertEqual(len(ready_events), 1) + self.assertEqual(events[-1]["status"], "ready") + + def test_load_model_usable_after_load(self): + """After /load_model completes, the model should be usable for inference.""" + self._load_model(self.MODEL) + + client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + resp = client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hi"}], + max_tokens=5, + ) + self.assertIsNotNone(resp.choices[0].message.content) + self.assertTrue(len(resp.choices[0].message.content) > 0) + + def test_load_model_model_field_matches(self): + """The model field in every event should match the canonical model ID.""" + _, events = self._load_model(self.MODEL) + + for event in events: + self.assertTrue( + event["model"].startswith(self.MODEL), + f"Event model '{event['model']}' doesn't match '{self.MODEL}'", + ) + + def test_concurrent_non_streaming(self): + """Two concurrent non-streaming responses requests should both complete.""" + import concurrent.futures + + inputs = ["Say hello", "Say goodbye"] + results = [None, None] + + def request_in_thread(index): + client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(request_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + self.assertIsNotNone(results[i]) + self.assertEqual(results[i].status, "completed") + self.assertTrue(len(results[i].output[0].content[0].text) > 0) + + def test_concurrent_streaming(self): + """Two concurrent streaming responses requests should both produce complete event streams.""" + import concurrent.futures + + inputs = ["Say hello", "Say goodbye"] + results = [None, None] + + def stream_in_thread(index): + client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + events = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True)) + results[index] = events + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(stream_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + types = [e.type for e in results[i]] + self.assertIn("response.created", types, f"Request {i} missing created event") + self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events") + self.assertIn("response.completed", types, f"Request {i} missing completed event") From 12c0f5588c9f2137be24b63b95cbf01a206aa392 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 20 Mar 2026 18:25:08 +0000 Subject: [PATCH 17/64] fix naming --- src/transformers/cli/serving/model_manager.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index c184d206f0d1..23e57d31e5ed 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -267,7 +267,7 @@ def load_model_and_processor( progress_callback({"status": "ready", "model": model_id_and_revision, "cached": True}) return model, processor - async def load_model_streaming(self, model_id: str): + async def load_model_streaming(self, model_id_and_revision: str): """Load a model and stream progress as SSE events. Handles three cases: @@ -277,17 +277,18 @@ async def load_model_streaming(self, model_id: str): Yields SSE ``data: ...`` lines. """ + mid = model_id_and_revision queue: asyncio.Queue[str | None] = asyncio.Queue() # Case 1: already cached - if model_id in self.loaded_models and not self.loaded_models[model_id].is_deleted(): - self.loaded_models[model_id].reset_timer() - yield f"data: {json.dumps({'status': 'ready', 'model': model_id, 'cached': True})}\n\n" + if mid in self.loaded_models and not self.loaded_models[mid].is_deleted(): + self.loaded_models[mid].reset_timer() + yield f"data: {json.dumps({'status': 'ready', 'model': mid, 'cached': True})}\n\n" return # Case 2: load in progress — join existing subscribers - if model_id in self._loading_tasks: - self._loading_subscribers[model_id].append(queue) + if mid in self._loading_tasks: + self._loading_subscribers[mid].append(queue) while True: item = await queue.get() if item is None: @@ -296,19 +297,19 @@ async def load_model_streaming(self, model_id: str): return # Case 3: first request — start the load - self._loading_subscribers[model_id] = [queue] + self._loading_subscribers[mid] = [queue] loop = asyncio.get_running_loop() def enqueue(payload: dict): msg = f"data: {json.dumps(payload)}\n\n" def broadcast(): - for q in self._loading_subscribers.get(model_id, []): + for q in self._loading_subscribers.get(mid, []): q.put_nowait(msg) loop.call_soon_threadsafe(broadcast) - tqdm_class = make_progress_tqdm_class(enqueue, model_id) + tqdm_class = make_progress_tqdm_class(enqueue, mid) def _tqdm_hook(factory, args, kwargs): return tqdm_class(*args, **kwargs) @@ -322,25 +323,25 @@ async def run_load(): try: await asyncio.to_thread( self.load_model_and_processor, - model_id, + mid, progress_callback=enqueue, tqdm_class=tqdm_class, ) finally: logging.set_tqdm_hook(previous_hook) except Exception as e: - logger.error(f"Failed to load {model_id}: {e}", exc_info=True) - enqueue({"status": "error", "model": model_id, "message": str(e)}) + logger.error(f"Failed to load {mid}: {e}", exc_info=True) + enqueue({"status": "error", "model": mid, "message": str(e)}) finally: def _send_sentinel(): - for q in self._loading_subscribers.pop(model_id, []): + for q in self._loading_subscribers.pop(mid, []): q.put_nowait(None) - self._loading_tasks.pop(model_id, None) + self._loading_tasks.pop(mid, None) loop.call_soon_threadsafe(_send_sentinel) - self._loading_tasks[model_id] = asyncio.create_task(run_load()) + self._loading_tasks[mid] = asyncio.create_task(run_load()) while True: item = await queue.get() From d4ffdf41de09747753f851a034ecaa53df26731e Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 20 Mar 2026 18:48:02 +0000 Subject: [PATCH 18/64] add transcription --- src/transformers/cli/serving/server.py | 13 +- src/transformers/cli/serving/transcription.py | 116 ++++++++++++++++ src/transformers/cli/serving/utils.py | 9 +- tests/cli/test_serve_refactored.py | 126 ++++++++++++++++++ 4 files changed, 258 insertions(+), 6 deletions(-) create mode 100644 src/transformers/cli/serving/transcription.py diff --git a/src/transformers/cli/serving/server.py b/src/transformers/cli/serving/server.py index 150cec14115f..fd4f1245b2cd 100644 --- a/src/transformers/cli/serving/server.py +++ b/src/transformers/cli/serving/server.py @@ -28,6 +28,7 @@ from .chat_completion import ChatCompletionHandler from .model_manager import ModelManager from .response import ResponseHandler +from .transcription import TranscriptionHandler from .utils import X_REQUEST_ID @@ -90,6 +91,12 @@ async def chat_completions(request: Request, body: dict): async def responses(request: Request, body: dict): return response_handler.handle_request(body, request.state.request_id) + transcription_handler = TranscriptionHandler(model_manager) + + @app.post("/v1/audio/transcriptions") + async def audio_transcriptions(request: Request): + return await transcription_handler.handle_request(request) + @app.post("/load_model") async def load_model(body: dict): from fastapi import HTTPException @@ -97,8 +104,10 @@ async def load_model(body: dict): model = body.get("model") if model is None: raise HTTPException(status_code=422, detail="Missing `model` field in the request body.") - model_id = model_manager.process_model_name(model) - return StreamingResponse(model_manager.load_model_streaming(model_id), media_type="text/event-stream") + model_id_and_revision = model_manager.process_model_name(model) + return StreamingResponse( + model_manager.load_model_streaming(model_id_and_revision), media_type="text/event-stream" + ) @app.post("/reset") def reset(): diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py new file mode 100644 index 000000000000..134cfa72a721 --- /dev/null +++ b/src/transformers/cli/serving/transcription.py @@ -0,0 +1,116 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Handler for the /v1/audio/transcriptions endpoint. +""" + +from __future__ import annotations + +import io +from threading import Thread +from typing import TYPE_CHECKING + +from fastapi.responses import JSONResponse, StreamingResponse + +from ...utils import logging +from .model_manager import ModelManager +from .utils import DirectStreamer, _StreamError + + +if TYPE_CHECKING: + from fastapi import Request + +logger = logging.get_logger(__name__) + + +class TranscriptionHandler: + """Handler for ``POST /v1/audio/transcriptions``. + + Accepts a multipart/form-data request with an audio file and model name, + runs speech-to-text, and returns an OpenAI-compatible Transcription response. + + Supports streaming (``stream=true`` in form data) which yields text chunks + as SSE events, or non-streaming which returns a single JSON response. + """ + + def __init__(self, model_manager: ModelManager): + self.model_manager = model_manager + + async def handle_request(self, request: Request) -> JSONResponse | StreamingResponse: + """Parse multipart form, run transcription, return result.""" + from transformers.utils.import_utils import is_librosa_available + + if not is_librosa_available(): + raise ImportError("Missing librosa dependency for audio transcription. Install with `pip install librosa`") + + import librosa + + async with request.form() as form: + file_bytes = await form["file"].read() + model = form["model"] + stream = str(form.get("stream", "false")).lower() == "true" + + model_id_and_revision = self.model_manager.process_model_name(model) + audio_model, audio_processor = self.model_manager.load_model_and_processor(model_id_and_revision) + + # Read audio with librosa at the model's expected sampling rate + model_sampling_rate = audio_processor.feature_extractor.sampling_rate + audio_array, _ = librosa.load(io.BytesIO(file_bytes), sr=model_sampling_rate, mono=True) + audio_inputs = audio_processor(audio_array, sampling_rate=model_sampling_rate, return_tensors="pt").to( + audio_model.device + ) + audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype) + + if stream: + return self._streaming(audio_model, audio_processor, audio_inputs) + return self._non_streaming(audio_model, audio_processor, audio_inputs) + + def _non_streaming(self, audio_model, audio_processor, audio_inputs) -> JSONResponse: + from openai.types.audio import Transcription + + generated_ids = audio_model.generate(**audio_inputs) + text = audio_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + return JSONResponse(Transcription(text=text).model_dump(exclude_none=True)) + + def _streaming(self, audio_model, audio_processor, audio_inputs) -> StreamingResponse: + import asyncio + + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + # For processors like WhisperProcessor, the fast tokenizer is at .tokenizer + tokenizer = audio_processor.tokenizer if hasattr(audio_processor, "tokenizer") else audio_processor + streamer = DirectStreamer(tokenizer._tokenizer, loop, queue, skip_special_tokens=True) + gen_kwargs = {**audio_inputs, "streamer": streamer} + + def _run(): + try: + audio_model.generate(**gen_kwargs) + except Exception as e: + loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e))) + + Thread(target=_run, daemon=True).start() + + async def sse_gen(): + while True: + text = await queue.get() + if text is None: + break + if isinstance(text, _StreamError): + yield f'data: {{"error": "{text.msg}"}}\n\n' + return + yield f"data: {text}\n\n" + + return StreamingResponse(sse_gen(), media_type="text/event-stream") diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index 3a3e0a774224..67f20da82cb6 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -216,16 +216,17 @@ class DirectStreamer: ``TextIteratorStreamer`` which re-decodes the full sequence each time. Args: - processor: A HuggingFace processor or tokenizer (must have a ``._tokenizer`` attribute). + tokenizer: The raw ``tokenizers.Tokenizer`` instance (i.e. the Rust tokenizer, + typically accessed as ``processor._tokenizer`` or ``processor.tokenizer._tokenizer``). loop: The asyncio event loop to push results to. queue: The asyncio.Queue to push decoded text chunks to. skip_special_tokens: Whether to skip special tokens during decoding. """ - def __init__(self, processor, loop, queue, skip_special_tokens: bool = True): + def __init__(self, tokenizer, loop, queue, skip_special_tokens: bool = True): from tokenizers.decoders import DecodeStream - self._tokenizer = processor._tokenizer # raw tokenizers.Tokenizer + self._tokenizer = tokenizer self._loop = loop self._queue = queue self._decode_stream = DecodeStream([], skip_special_tokens) @@ -389,7 +390,7 @@ def _start_streaming(self, model, processor, inputs, gen_config): """ loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() - streamer = DirectStreamer(processor, loop, queue, skip_special_tokens=True) + streamer = DirectStreamer(processor._tokenizer, loop, queue, skip_special_tokens=True) gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor} def _run(): diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py index b2289175aaaa..843638982da5 100644 --- a/tests/cli/test_serve_refactored.py +++ b/tests/cli/test_serve_refactored.py @@ -19,6 +19,7 @@ """ import asyncio +import io import json import os import time @@ -1205,3 +1206,128 @@ def stream_in_thread(index): self.assertIn("response.created", types, f"Request {i} missing created event") self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events") self.assertIn("response.completed", types, f"Request {i} missing completed event") + + +# --------------------------------------------------------------------------- +# 11. Integration tests — Transcription API (need GPU + model + librosa) +# --------------------------------------------------------------------------- + + +def _make_test_wav(duration: float = 2.0, sample_rate: int = 16000) -> bytes: + """Create a simple WAV file with a sine wave. Returns raw bytes.""" + import wave + + import numpy as np + + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + audio = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) + buf = io.BytesIO() + with wave.open(buf, "w") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(audio.tobytes()) + return buf.getvalue() + + +@unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") +@require_openai +class TestTranscription(unittest.TestCase): + """Integration tests for POST /v1/audio/transcriptions with whisper-tiny.""" + + MODEL = "openai/whisper-tiny" + PORT = 8880 + + @classmethod + def setUpClass(cls): + import requests as req + + from transformers.cli.serve_refactored import Serve + + cls.serve = Serve(port=cls.PORT, non_blocking=True, log_level="warning") + for _ in range(30): + try: + if req.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: + break + except Exception: + pass + time.sleep(2) + cls.base_url = f"http://localhost:{cls.PORT}" + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + @classmethod + def _get_audio_bytes(cls): + """Download the MLK 'I have a dream' speech sample from HF Hub.""" + if not hasattr(cls, "_audio_bytes"): + from huggingface_hub import hf_hub_download + + path = hf_hub_download("Narsil/asr_dummy", "mlk.flac", repo_type="dataset") + with open(path, "rb") as f: + cls._audio_bytes = f.read() + return cls._audio_bytes + + def test_transcription_returns_text(self): + """POST /v1/audio/transcriptions with real speech returns meaningful transcription.""" + import requests as req + + audio_bytes = self._get_audio_bytes() + resp = req.post( + f"{self.base_url}/v1/audio/transcriptions", + files={"file": ("mlk.flac", audio_bytes, "audio/flac")}, + data={"model": self.MODEL}, + timeout=120, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertIn("text", data) + self.assertIsInstance(data["text"], str) + # Whisper-tiny should recognize at least "dream" from the MLK speech + self.assertIn("dream", data["text"].lower()) + + def test_transcription_openai_client(self): + """Transcription should work via the OpenAI Python client.""" + audio_bytes = self._get_audio_bytes() + client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + result = client.audio.transcriptions.create( + model=self.MODEL, + file=("mlk.flac", audio_bytes), + ) + self.assertIsInstance(result.text, str) + self.assertTrue(len(result.text) > 10) + + def test_transcription_streaming(self): + """Streaming transcription should yield text chunks via SSE.""" + import requests as req + + audio_bytes = self._get_audio_bytes() + resp = req.post( + f"{self.base_url}/v1/audio/transcriptions", + files={"file": ("mlk.flac", audio_bytes, "audio/flac")}, + data={"model": self.MODEL, "stream": "true"}, + stream=True, + timeout=120, + ) + self.assertEqual(resp.status_code, 200) + + chunks = [] + for line in resp.iter_lines(decode_unicode=True): + if line and line.startswith("data: "): + chunks.append(line[len("data: ") :]) + + self.assertGreater(len(chunks), 0, "No streaming chunks received") + full_text = "".join(chunks) + self.assertIn("dream", full_text.lower()) + + def test_transcription_missing_file(self): + """POST without a file should fail.""" + import requests as req + + resp = req.post( + f"{self.base_url}/v1/audio/transcriptions", + data={"model": self.MODEL}, + timeout=30, + ) + self.assertNotEqual(resp.status_code, 200) From 68cd5bc5f3e2ba4233831aea9e108dd7aaac974a Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 23 Mar 2026 18:00:53 +0000 Subject: [PATCH 19/64] tool calls better ! --- .../cli/serving/chat_completion.py | 93 +++- src/transformers/cli/serving/response.py | 172 ++++++-- src/transformers/cli/serving/utils.py | 124 ++++++ tests/cli/test_serve_refactored.py | 401 +++++++++++++++++- 4 files changed, 743 insertions(+), 47 deletions(-) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 88def4c1d0f0..e2b0cc40eb1e 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -31,9 +31,9 @@ from fastapi import HTTPException from fastapi.responses import JSONResponse, StreamingResponse -from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import Choice -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk from openai.types.completion_usage import CompletionUsage @@ -43,8 +43,10 @@ from .utils import ( UNUSED_CHAT_COMPLETION_FIELDS, BaseHandler, + ToolCallParser, TransformersCompletionCreateParamsStreaming, _StreamError, + detect_tool_format, get_processor_inputs_from_messages, ) @@ -86,9 +88,31 @@ def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSO gen_config = self._build_generation_config(body, model.generation_config, processor) + # Detect tool support for the loaded model + # TODO: after tool_call start token, use constrained generation to: + # 1. force generation to pick from the available tool names + # 2. force generation to produce valid JSON matching the tool's parameter schema + tool_format = detect_tool_format(model) if body.get("tools") else None + if body.get("stream"): - return self._streaming(request_id, model, processor, model_id, inputs, gen_config) - return self._non_streaming(request_id, model, processor, model_id, inputs, gen_config) + return self._streaming( + request_id, + model, + processor, + model_id, + inputs, + gen_config, + tool_format=tool_format, + ) + return self._non_streaming( + request_id, + model, + processor, + model_id, + inputs, + gen_config, + tool_format=tool_format, + ) # ----- streaming ----- @@ -100,12 +124,15 @@ def _streaming( model_id: str, inputs: dict[str, torch.Tensor], gen_config: GenerationConfig, + tool_format: dict | None = None, ) -> StreamingResponse: """Stream tokens as SSE via DirectStreamer.""" queue, streamer = self._start_streaming(model, processor, inputs, gen_config) input_len = inputs["input_ids"].shape[-1] + parser = ToolCallParser(tool_format) if tool_format else None async def sse_gen() -> AsyncGenerator[str, None]: + has_tool_calls = False yield self._build_chunk_sse(request_id, role="assistant", model=model_id) while True: @@ -115,9 +142,33 @@ async def sse_gen() -> AsyncGenerator[str, None]: elif isinstance(text, _StreamError): yield f'data: {{"error": "{text.msg}"}}\n\n' return - yield self._build_chunk_sse(request_id, content=text, model=model_id) + + # Tool call parsing: None = normal text, CONSUMED = buffering, else = tool call dict + chunk_kwargs = {"content": text} + if parser is not None and (result := parser.feed(text)) is not None: + if result is ToolCallParser.CONSUMED: + continue + has_tool_calls = True + chunk_kwargs = { + "tool_calls": [ + ChoiceDeltaToolCall( + index=0, + type="function", + id=f"{request_id}_tool_call", + function={"name": result["name"], "arguments": result["arguments"]}, + ) + ] + } + + yield self._build_chunk_sse(request_id, model=model_id, **chunk_kwargs) hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens + if has_tool_calls: + finish_reason = "tool_calls" + elif hit_max: + finish_reason = "length" + else: + finish_reason = "stop" usage = CompletionUsage( prompt_tokens=input_len, completion_tokens=streamer.total_tokens, @@ -125,7 +176,7 @@ async def sse_gen() -> AsyncGenerator[str, None]: ) yield self._build_chunk_sse( request_id, - finish_reason="length" if hit_max else "stop", + finish_reason=finish_reason, model=model_id, usage=usage, ) @@ -142,6 +193,7 @@ def _non_streaming( model_id: str, inputs: dict[str, torch.Tensor], gen_config: GenerationConfig, + tool_format: dict | None = None, ) -> JSONResponse: """Run generation and return a JSONResponse.""" content, input_len, generated_ids = self._generate_non_streaming(model, processor, inputs, gen_config) @@ -153,13 +205,36 @@ def _non_streaming( completion_tokens=completion_tokens, total_tokens=input_len + completion_tokens, ) + + # Parse tool calls from the generated text + tool_calls = None + if tool_format is not None: + parsed = ToolCallParser.parse(content, tool_format) + if parsed is not None: + tool_calls = [ + ChatCompletionMessageToolCall( + id=f"{request_id}_tool_call", + type="function", + function={"name": tc["name"], "arguments": tc["arguments"]}, + ) + for tc in parsed + ] + + if tool_calls is not None: + finish_reason = "tool_calls" + elif hit_max: + finish_reason = "length" + else: + finish_reason = "stop" + return JSONResponse( self._build_completion( request_id, content, model_id, - finish_reason="length" if hit_max else "stop", + finish_reason=finish_reason, usage=usage, + tool_calls=tool_calls, ), media_type="application/json", ) @@ -201,8 +276,10 @@ def _build_completion( model_id: str, finish_reason: str, usage: CompletionUsage | None = None, + tool_calls: list[dict] | None = None, ) -> dict: """Build a non-streaming ChatCompletion response dict.""" + message = ChatCompletionMessage(content=content, role="assistant", tool_calls=tool_calls) result = ChatCompletion( id=request_id, created=int(time.time()), @@ -211,7 +288,7 @@ def _build_completion( choices=[ Choice( index=0, - message=ChatCompletionMessage(content=content, role="assistant"), + message=message, finish_reason=finish_reason, ) ], diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 5bae80d08815..218ade0218e8 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -34,6 +34,8 @@ ResponseError, ResponseErrorEvent, ResponseFailedEvent, + ResponseFunctionCallArgumentsDoneEvent, + ResponseFunctionToolCall, ResponseInProgressEvent, ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, @@ -45,7 +47,7 @@ from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage from ...utils import logging -from .utils import BaseHandler, _StreamError +from .utils import BaseHandler, ToolCallParser, _StreamError, detect_tool_format if TYPE_CHECKING: @@ -84,14 +86,37 @@ def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSO messages = self._input_to_messages(body) inputs = processor.apply_chat_template( - messages, add_generation_prompt=True, return_tensors="pt", return_dict=True + messages, + add_generation_prompt=True, + tools=body.get("tools"), + return_tensors="pt", + return_dict=True, ).to(model.device) gen_config = self._build_generation_config(body, model.generation_config, processor) + tool_format = detect_tool_format(model) if body.get("tools") else None if body.get("stream", True): - return self._streaming(request_id, model, processor, model_id, body, inputs, gen_config) - return self._non_streaming(request_id, model, processor, model_id, body, inputs, gen_config) + return self._streaming( + request_id, + model, + processor, + model_id, + body, + inputs, + gen_config, + tool_format=tool_format, + ) + return self._non_streaming( + request_id, + model, + processor, + model_id, + body, + inputs, + gen_config, + tool_format=tool_format, + ) # ----- input conversion ----- @@ -132,14 +157,15 @@ def _streaming( body: dict, inputs: dict, gen_config: GenerationConfig, + tool_format: dict | None = None, ) -> StreamingResponse: """Generate a streaming Responses API reply (SSE) using DirectStreamer.""" queue, streamer = self._start_streaming(model, processor, inputs, gen_config) input_len = inputs["input_ids"].shape[-1] + parser = ToolCallParser(tool_format) if tool_format else None seq = 0 output_index = 0 - content_index = 0 created_at = time.time() resp_id = f"resp_{request_id}" msg_id = f"msg_{request_id}" @@ -149,15 +175,16 @@ def _streaming( "created_at": created_at, "model": model_id, "object": "response", + # Required by pydantic but not used — echo request config back "tools": [], "parallel_tool_calls": body.get("parallel_tool_calls", False), "tool_choice": "auto", } async def event_stream() -> AsyncGenerator[str, None]: - nonlocal seq + nonlocal seq, output_index - # 1. Created + # 1. Created + In progress yield self.chunk_to_sse( ResponseCreatedEvent( type="response.created", @@ -166,8 +193,6 @@ async def event_stream() -> AsyncGenerator[str, None]: ) ) seq += 1 - - # 2. In progress yield self.chunk_to_sse( ResponseInProgressEvent( type="response.in_progress", @@ -177,7 +202,7 @@ async def event_stream() -> AsyncGenerator[str, None]: ) seq += 1 - # 3. Output item added + # 2. Output item added (message) yield self.chunk_to_sse( ResponseOutputItemAddedEvent( type="response.output_item.added", @@ -194,21 +219,23 @@ async def event_stream() -> AsyncGenerator[str, None]: ) seq += 1 - # 4. Content part added + # 3. Content part added yield self.chunk_to_sse( ResponseContentPartAddedEvent( type="response.content_part.added", item_id=msg_id, sequence_number=seq, output_index=output_index, - content_index=content_index, + content_index=0, part=ResponseOutputText(type="output_text", text="", annotations=[]), ) ) seq += 1 - # 5. Text deltas from DirectStreamer queue + # 4. Stream tokens full_text = "" + tool_calls = [] + while True: text = await queue.get() if text is None: @@ -231,50 +258,96 @@ async def event_stream() -> AsyncGenerator[str, None]: ) return + # Tool call parsing + if parser is not None and (result := parser.feed(text)) is not None: + if result is not ToolCallParser.CONSUMED: + # Emit tool call as a function_call output item + tc_id = f"{request_id}_tool_call" + name = result["name"] + arguments = result["arguments"] + tc_item = ResponseFunctionToolCall( + id=tc_id, + call_id=tc_id, + type="function_call", + name=name, + arguments=arguments, + status="completed", + ) + tool_calls.append(tc_item) + # output[0] = message, output[1..N] = tool calls (required by OpenAI SSE spec) + output_index += 1 + yield self.chunk_to_sse( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=seq, + output_index=output_index, + item=tc_item, + ) + ) + seq += 1 + yield self.chunk_to_sse( + ResponseFunctionCallArgumentsDoneEvent( + type="response.function_call_arguments.done", + sequence_number=seq, + item_id=tc_id, + output_index=output_index, + arguments=arguments, + name=name, + ) + ) + seq += 1 + yield self.chunk_to_sse( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=seq, + output_index=output_index, + item=tc_item, + ) + ) + seq += 1 + continue + full_text += text yield self.chunk_to_sse( ResponseTextDeltaEvent( type="response.output_text.delta", item_id=msg_id, sequence_number=seq, - output_index=output_index, - content_index=content_index, + output_index=0, + content_index=0, delta=text, logprobs=[], ) ) seq += 1 - # 6. Text done + # 5. Close text output + output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[]) yield self.chunk_to_sse( ResponseTextDoneEvent( type="response.output_text.done", item_id=msg_id, sequence_number=seq, - output_index=output_index, + output_index=0, content_index=0, text=full_text, logprobs=[], ) ) seq += 1 - - # 7. Content part done - output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[]) yield self.chunk_to_sse( ResponseContentPartDoneEvent( type="response.content_part.done", item_id=msg_id, sequence_number=seq, - output_index=output_index, - content_index=content_index, + output_index=0, + content_index=0, part=output_text_part, ) ) seq += 1 - # 8. Output item done - output_item = ResponseOutputMessage( + msg_item = ResponseOutputMessage( id=msg_id, type="message", status="completed", @@ -286,19 +359,20 @@ async def event_stream() -> AsyncGenerator[str, None]: ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=seq, - output_index=output_index, - item=output_item, + output_index=0, + item=msg_item, ) ) seq += 1 - # 9. Completed + # 6. Completed + all_output = [msg_item] + list(tool_calls) usage = _make_usage(input_len, streamer.total_tokens) yield self.chunk_to_sse( ResponseCompletedEvent( type="response.completed", sequence_number=seq, - response=Response(**response_base, status="completed", output=[output_item], usage=usage), + response=Response(**response_base, status="completed", output=all_output, usage=usage), ) ) seq += 1 @@ -316,6 +390,7 @@ def _non_streaming( body: dict, inputs: dict, gen_config: GenerationConfig, + tool_format: dict | None = None, ) -> JSONResponse: """Generate a non-streaming Responses API reply (single JSON).""" full_text, input_len, generated_ids = self._generate_non_streaming(model, processor, inputs, gen_config) @@ -325,26 +400,47 @@ def _non_streaming( msg_id = f"msg_{request_id}" output_tokens = len(generated_ids) - output_item = ResponseOutputMessage( - id=msg_id, - type="message", - status="completed", - role="assistant", - content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])], - annotations=[], - ) + output_items = [ + ResponseOutputMessage( + id=msg_id, + type="message", + status="completed", + role="assistant", + content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])], + annotations=[], + ) + ] + + # Parse tool calls from the generated text + if tool_format is not None: + parsed_calls = ToolCallParser.parse(full_text, tool_format) + if parsed_calls is not None: + for i, tc in enumerate(parsed_calls): + tc_id = f"{request_id}_tool_call" + output_items.append( + ResponseFunctionToolCall( + id=tc_id, + call_id=tc_id, + type="function_call", + name=tc["name"], + arguments=tc["arguments"], + status="completed", + ) + ) + usage = _make_usage(input_len, output_tokens) response = Response( id=resp_id, created_at=created_at, status="completed", model=model_id, - output=[output_item], + output=output_items, object="response", + usage=usage, + # Required by pydantic but not used — echo request config back tools=[], parallel_tool_calls=body.get("parallel_tool_calls", False), tool_choice="auto", - usage=usage, ) return JSONResponse(response.model_dump(exclude_none=True)) diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index 67f20da82cb6..5d69e0e47858 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -95,6 +95,130 @@ def __init__(self, msg: str): self.msg = msg +# --------------------------------------------------------------------------- +# Tool call parsing +# --------------------------------------------------------------------------- + +# Model-specific tokens that mark the start/end of a tool call block. +# TODO: extract these from the chat template at runtime instead of hardcoding. +# Qwen/Hermes use /, Mistral uses [TOOL_CALLS], etc. +# The markers are defined in each model's Jinja chat template. +_TOOL_CALL_TOKENS = { + "qwen": { + "start": "", + "end": "", + }, +} + + +def detect_tool_format(model) -> dict | None: + """Return the tool call token format (``{"start": ..., "end": ...}``) if supported, else ``None``.""" + architecture = model.config.architectures[0].lower() + for family in _TOOL_CALL_TOKENS: + if family in architecture: + return _TOOL_CALL_TOKENS[family] + return None + + +class ToolCallParser: + """Parses tool calls from model output. + + The model emits tool calls as structured text between start/end tokens + (e.g. ``{"name": "fn", "arguments": {...}}``). + + **Streaming** (``feed``): buffers tokens between start/end markers, parses + the complete block when the end marker is seen, returns a ``ChoiceDeltaToolCall``. + + **Non-streaming** (``parse``): extracts all tool call blocks from complete text. + + Usage:: + + parser = ToolCallParser(tool_format={"start": ..., "end": ...}) + for text_chunk in streamer: + result = parser.feed(text_chunk) + if result is None: + # Normal text — emit as content + elif result is ToolCallParser.CONSUMED: + # Buffering — skip + else: + # result is a ChoiceDeltaToolCall — emit it + """ + + def __init__(self, tool_format: dict): + self._tokens = tool_format + self._inside = False + self._buffer = "" + + # Sentinel: token was consumed by the parser but produced no output. + CONSUMED = object() + + def feed(self, text: str): + """Feed a text chunk (streaming). + + Returns: + - ``None`` — normal text, not a tool token. Emit as content. + - ``CONSUMED`` — token consumed internally (buffering/markers). Skip. + - A ``ChoiceDeltaToolCall`` — emit as a tool call delta. + """ + if text.strip() == self._tokens["start"]: + self._inside = True + self._buffer = "" + return self.CONSUMED + + if text.strip() == self._tokens["end"]: + self._inside = False + block = self._buffer.strip() + self._buffer = "" + return self._parse_block(block) or self.CONSUMED + + if self._inside: + self._buffer += text + return self.CONSUMED + + return None + + @staticmethod + def _extract_name_and_args(block: str) -> tuple[str, str] | None: + """Extract (name, arguments_json) from a tool call block, or None if invalid.""" + if not block: + return None + parsed = json.loads(block) + name = parsed.get("name") + if name is None: + return None + arguments = parsed.get("arguments", {}) + return name, json.dumps(arguments) + + @staticmethod + def parse(text: str, tool_format: dict) -> list[dict] | None: + """Parse tool calls from complete text. + + Returns a list of ``{"name": str, "arguments": str}`` dicts, or ``None`` if none found. + """ + start, end = tool_format["start"], tool_format["end"] + tool_calls = [] + pos = 0 + while True: + s = text.find(start, pos) + if s < 0: + break + e = text.find(end, s + len(start)) + if e < 0: + break + result = ToolCallParser._extract_name_and_args(text[s + len(start) : e].strip()) + if result is not None: + tool_calls.append({"name": result[0], "arguments": result[1]}) + pos = e + len(end) + return tool_calls if tool_calls else None + + def _parse_block(self, block: str) -> dict | None: + """Parse a buffered tool call block. Returns ``{"name": str, "arguments": str}`` or None.""" + result = self._extract_name_and_args(block) + if result is None: + return None + return {"name": result[0], "arguments": result[1]} + + # --------------------------------------------------------------------------- # Progress tracking for model loading # --------------------------------------------------------------------------- diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py index 843638982da5..b8becd955ca8 100644 --- a/tests/cli/test_serve_refactored.py +++ b/tests/cli/test_serve_refactored.py @@ -412,7 +412,153 @@ def test_chunk_to_sse_wraps_plain_string(self): # --------------------------------------------------------------------------- -# 6. App-level tests with ASGI test client (no real model) +# 6. Unit tests — tool parser +# --------------------------------------------------------------------------- + + +QWEN_TOOL_FORMAT = {"start": "", "end": ""} + + +@require_openai +class TestToolParser(unittest.TestCase): + def test_detect_tool_format_qwen(self): + from transformers.cli.serving.utils import detect_tool_format + + model = MagicMock() + model.config.architectures = ["Qwen2ForCausalLM"] + fmt = detect_tool_format(model) + self.assertEqual(fmt, QWEN_TOOL_FORMAT) + + def test_detect_tool_format_unsupported(self): + from transformers.cli.serving.utils import detect_tool_format + + model = MagicMock() + model.config.architectures = ["LlamaForCausalLM"] + self.assertIsNone(detect_tool_format(model)) + + def test_parser_start_token(self): + from transformers.cli.serving.utils import ToolCallParser + + parser = ToolCallParser(QWEN_TOOL_FORMAT) + result = parser.feed("") + self.assertIs(result, ToolCallParser.CONSUMED) + + def test_parser_end_token(self): + from transformers.cli.serving.utils import ToolCallParser + + parser = ToolCallParser(QWEN_TOOL_FORMAT) + parser.feed("") + result = parser.feed("") + self.assertIs(result, ToolCallParser.CONSUMED) + + def test_parser_buffers_until_end(self): + from transformers.cli.serving.utils import ToolCallParser + + parser = ToolCallParser(QWEN_TOOL_FORMAT) + parser.feed("") + # Intermediate tokens are buffered + result = parser.feed('{"name": "my_tool", "arguments": {"x": 1}}') + self.assertIs(result, ToolCallParser.CONSUMED) + # Tool call is emitted on end token + result = parser.feed("") + self.assertIsNot(result, ToolCallParser.CONSUMED) + self.assertEqual(result["name"], "my_tool") + + def test_parser_normal_text_returns_none(self): + from transformers.cli.serving.utils import ToolCallParser + + parser = ToolCallParser(QWEN_TOOL_FORMAT) + result = parser.feed("Hello world") + self.assertIsNone(result) + + def test_parser_full_flow(self): + """Simulate a complete tool call token sequence.""" + from transformers.cli.serving.utils import ToolCallParser + + parser = ToolCallParser(QWEN_TOOL_FORMAT) + tool_calls = [] + + for token in [ + "", + '{"name": "get_weather",', + ' "arguments": {', + '"city": "Paris"', + "}}", + "\n", + "", + ]: + result = parser.feed(token) + if result is not None and result is not ToolCallParser.CONSUMED: + tool_calls.append(result) + + # Single tool call emitted on with both name and arguments + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertIn("Paris", tool_calls[0]["arguments"]) + + def test_parse_tool_calls_from_text(self): + """Non-streaming tool call parsing from complete text.""" + from transformers.cli.serving.utils import ToolCallParser + + text = '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n' + calls = ToolCallParser.parse(text, QWEN_TOOL_FORMAT) + self.assertIsNotNone(calls) + self.assertEqual(len(calls), 1) + self.assertEqual(calls[0]["name"], "get_weather") + self.assertIn("Paris", calls[0]["arguments"]) + + def test_parse_tool_calls_no_tool_call(self): + """Non-streaming: normal text returns None.""" + from transformers.cli.serving.utils import ToolCallParser + + calls = ToolCallParser.parse("Hello, how can I help?", QWEN_TOOL_FORMAT) + self.assertIsNone(calls) + + def test_parse_multiple_tool_calls(self): + """Non-streaming: multiple tool calls in one response.""" + from transformers.cli.serving.utils import ToolCallParser + + text = ( + '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n\n' + '\n{"name": "get_weather", "arguments": {"city": "London"}}\n' + ) + calls = ToolCallParser.parse(text, QWEN_TOOL_FORMAT) + self.assertIsNotNone(calls) + self.assertEqual(len(calls), 2) + self.assertEqual(calls[0]["name"], "get_weather") + self.assertIn("Paris", calls[0]["arguments"]) + self.assertEqual(calls[1]["name"], "get_weather") + self.assertIn("London", calls[1]["arguments"]) + + def test_feed_multiple_tool_calls(self): + """Streaming: multiple tool calls emitted sequentially.""" + from transformers.cli.serving.utils import ToolCallParser + + parser = ToolCallParser(QWEN_TOOL_FORMAT) + tool_calls = [] + + tokens = [ + "", + '{"name": "get_weather", "arguments": {"city": "Paris"}}', + "", + "", + '{"name": "get_weather", "arguments": {"city": "London"}}', + "", + ] + for token in tokens: + result = parser.feed(token) + if result is not None and result is not ToolCallParser.CONSUMED: + tool_calls.append(result) + + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertIn("Paris", tool_calls[0]["arguments"]) + self.assertEqual(tool_calls[1]["name"], "get_weather") + self.assertIn("London", tool_calls[1]["arguments"]) + + +# --------------------------------------------------------------------------- +# 7. App-level tests with ASGI test client (no real model) # --------------------------------------------------------------------------- @@ -614,6 +760,110 @@ def test_streaming_usage(self): self.assertGreater(last.usage.completion_tokens, 0) self.assertEqual(last.usage.total_tokens, last.usage.prompt_tokens + last.usage.completion_tokens) + def test_tool_call(self): + """Tool calls should be parsed and emitted as ChoiceDeltaToolCall objects.""" + # Qwen2.5-0.5B-Instruct supports tools (Qwen family) + tool_def = { + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + "description": "Get the weather for a city.", + }, + "type": "function", + } + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "What is the weather in Paris?"}], + stream=True, + max_tokens=50, + temperature=0.0, + tools=[tool_def], + ) + ) + + # First chunk should have role="assistant" + self.assertEqual(chunks[0].choices[0].delta.role, "assistant") + + # Model should make a tool call for this prompt + tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls] + self.assertGreater(len(tool_chunks), 0, "Model did not produce a tool call") + + # First tool call delta should have the function name + first_tool = tool_chunks[0].choices[0].delta.tool_calls[0] + self.assertEqual(first_tool.function.name, "get_weather") + + # finish_reason should be "tool_calls" + last = chunks[-1] + self.assertEqual(last.choices[0].finish_reason, "tool_calls") + + # Arguments should be valid JSON with no trailing brace + args_json = first_tool.function.arguments + import json as json_mod + + parsed_args = json_mod.loads(args_json) + self.assertIsInstance(parsed_args, dict) + + def test_tool_call_non_streaming(self): + """Non-streaming tool calls should return tool_calls in the message.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "What is the weather in Paris?"}], + stream=False, + max_tokens=50, + temperature=0.0, + tools=[tool_def], + ) + self.assertEqual(resp.choices[0].finish_reason, "tool_calls") + self.assertIsNotNone(resp.choices[0].message.tool_calls) + tc = resp.choices[0].message.tool_calls[0] + self.assertEqual(tc.function.name, "get_weather") + + import json as json_mod + + parsed_args = json_mod.loads(tc.function.arguments) + self.assertIsInstance(parsed_args, dict) + + def test_tool_call_multi(self): + """Model should be able to call multiple tools when asked.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + # Ask for two cities to encourage multiple tool calls + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "What is the weather in Paris and London?"}], + stream=True, + max_tokens=100, + temperature=0.0, + tools=[tool_def], + ) + ) + tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls] + # Should have two tool calls — one for Paris, one for London + self.assertEqual(len(tool_chunks), 2, f"Expected 2 tool calls, got {len(tool_chunks)}") + cities = {tc.choices[0].delta.tool_calls[0].function.name for tc in tool_chunks} + self.assertEqual(cities, {"get_weather"}) + last = chunks[-1] + self.assertEqual(last.choices[0].finish_reason, "tool_calls") + def test_concurrent_non_streaming(self): """Two concurrent non-streaming requests should both complete without interference.""" import concurrent.futures @@ -930,6 +1180,155 @@ def test_streaming_usage(self): self.assertGreater(usage.output_tokens, 0) self.assertEqual(usage.total_tokens, usage.input_tokens + usage.output_tokens) + def test_tool_call_streaming(self): + """Streaming responses with tools should emit function_call events.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + events = list( + self.client.responses.create( + model=self.MODEL, + input="What is the weather in Paris?", + stream=True, + max_output_tokens=50, + tools=[tool_def], + ) + ) + types = [e.type for e in events] + self.assertIn("response.created", types) + self.assertIn("response.completed", types) + + # Should have function call events + self.assertIn("response.output_item.added", types) + self.assertIn("response.function_call_arguments.done", types) + + # Check the arguments done event + args_done = [e for e in events if e.type == "response.function_call_arguments.done"] + self.assertGreater(len(args_done), 0) + self.assertEqual(args_done[0].name, "get_weather") + + import json as json_mod + + parsed = json_mod.loads(args_done[0].arguments) + self.assertIsInstance(parsed, dict) + + def test_tool_call_non_streaming(self): + """Non-streaming responses with tools should include function_call output items.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + resp = self.client.responses.create( + model=self.MODEL, + input="What is the weather in Paris?", + stream=False, + max_output_tokens=50, + tools=[tool_def], + ) + self.assertEqual(resp.status, "completed") + + # Should have at least message + function_call in output + self.assertGreater(len(resp.output), 1) + fc_items = [o for o in resp.output if o.type == "function_call"] + self.assertGreater(len(fc_items), 0) + self.assertEqual(fc_items[0].name, "get_weather") + + import json as json_mod + + parsed = json_mod.loads(fc_items[0].arguments) + self.assertIsInstance(parsed, dict) + + def test_tool_call_multi(self): + """Model should produce multiple tool calls when asked about two cities.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + events = list( + self.client.responses.create( + model=self.MODEL, + input="What is the weather in Paris and London?", + stream=True, + max_output_tokens=100, + tools=[tool_def], + ) + ) + args_done = [e for e in events if e.type == "response.function_call_arguments.done"] + self.assertEqual(len(args_done), 2, f"Expected 2 tool calls, got {len(args_done)}") + self.assertEqual(events[-1].type, "response.completed") + + def test_multi_turn(self): + """Multi-turn conversation via list input.""" + resp = self.client.responses.create( + model=self.MODEL, + input=[ + {"role": "user", "content": "My name is Alice"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "What is my name?"}, + ], + stream=False, + ) + self.assertEqual(resp.status, "completed") + self.assertIn("Alice", resp.output[0].content[0].text) + + def test_concurrent_non_streaming(self): + """Two concurrent non-streaming responses requests should both complete.""" + import concurrent.futures + + inputs = ["Say hello", "Say goodbye"] + results = [None, None] + + def request_in_thread(index): + client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(request_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + self.assertIsNotNone(results[i]) + self.assertEqual(results[i].status, "completed") + self.assertTrue(len(results[i].output[0].content[0].text) > 0) + + def test_concurrent_streaming(self): + """Two concurrent streaming responses requests should both produce complete event streams.""" + import concurrent.futures + + inputs = ["Say hello", "Say goodbye"] + results = [None, None] + + def stream_in_thread(index): + client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + results[index] = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True)) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(stream_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + types = [e.type for e in results[i]] + self.assertIn("response.created", types, f"Request {i} missing created event") + self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events") + self.assertIn("response.completed", types, f"Request {i} missing completed event") + # --------------------------------------------------------------------------- # 10. Integration tests — /load_model endpoint (need GPU + model) From 6da3f3c11aa60355058f551da4d50f11582fc644 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 24 Mar 2026 10:13:12 +0000 Subject: [PATCH 20/64] vlm support for both response and chat endpoints --- src/transformers/cli/serving/response.py | 10 ++- tests/cli/test_serve_refactored.py | 87 ++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 218ade0218e8..26e213268f92 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -47,7 +47,7 @@ from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage from ...utils import logging -from .utils import BaseHandler, ToolCallParser, _StreamError, detect_tool_format +from .utils import BaseHandler, ToolCallParser, _StreamError, detect_tool_format, get_processor_inputs_from_messages if TYPE_CHECKING: @@ -84,9 +84,15 @@ def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSO model_id, model, processor = self._resolve_model(body) + # Two-step input conversion (chat completions skips step 1 since messages are already standard): + # 1. Normalize Responses API input (string/list/dict + instructions) → standard messages list + # 2. Transform message content for the HF processor (VLM image handling, text joining, etc.) messages = self._input_to_messages(body) + modality = self.model_manager.get_model_modality(model, processor=processor) + processor_inputs = get_processor_inputs_from_messages(messages, modality) + inputs = processor.apply_chat_template( - messages, + processor_inputs, add_generation_prompt=True, tools=body.get("tools"), return_tensors="pt", diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py index b8becd955ca8..9840f15410dd 100644 --- a/tests/cli/test_serve_refactored.py +++ b/tests/cli/test_serve_refactored.py @@ -1629,6 +1629,93 @@ def _make_test_wav(duration: float = 2.0, sample_rate: int = 16000) -> bytes: return buf.getvalue() +# --------------------------------------------------------------------------- +# 12. Integration tests — VLM support (need GPU + model) +# --------------------------------------------------------------------------- + + +def _make_test_image_base64() -> str: + """Create a small red 64x64 PNG as a base64 data URL.""" + import base64 + + from PIL import Image + + img = Image.new("RGB", (64, 64), color="red") + buf = io.BytesIO() + img.save(buf, format="PNG") + return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" + + +@unittest.skipUnless(run_slow and is_vision_available(), "Set RUN_SLOW=1 and install torchvision + PIL") +@require_openai +class TestVLM(unittest.TestCase): + """Integration tests for VLM (vision-language model) support. Requires torchvision.""" + + MODEL = "HuggingFaceTB/SmolVLM-256M-Instruct" + PORT = 8881 + + @classmethod + def setUpClass(cls): + import requests as req + + from transformers.cli.serve_refactored import Serve + + cls.serve = Serve(port=cls.PORT, non_blocking=True, log_level="warning") + for _ in range(60): + try: + if req.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: + break + except Exception: + pass + time.sleep(2) + cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + def test_chat_completion_with_image(self): + """Chat completions should accept image_url content and produce a response.""" + image_url = _make_test_image_base64() + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ], + max_tokens=20, + ) + text = resp.choices[0].message.content + self.assertIsNotNone(text) + self.assertIn("red", text.lower(), f"Expected 'red' in response, got: {text}") + + def test_responses_with_image(self): + """Responses API should accept image_url content and produce a response about the image.""" + image_url = _make_test_image_base64() + resp = self.client.responses.create( + model=self.MODEL, + input=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ], + stream=False, + max_output_tokens=20, + ) + self.assertEqual(resp.status, "completed") + text = resp.output[0].content[0].text + self.assertIn("red", text.lower(), f"Expected 'red' in response, got: {text}") + + @unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") @require_openai class TestTranscription(unittest.TestCase): From a92ebe298ced900aae51ee085c6b4f849a0d0755 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 24 Mar 2026 10:16:44 +0000 Subject: [PATCH 21/64] update bench --- tests/cli/benchmark_serve.py | 5 +- tests/cli/benchmark_serve_load.py | 471 ++++++++++++++++++++++++++++++ 2 files changed, 474 insertions(+), 2 deletions(-) create mode 100644 tests/cli/benchmark_serve_load.py diff --git a/tests/cli/benchmark_serve.py b/tests/cli/benchmark_serve.py index 21db7fb7524c..69386ec5faa7 100644 --- a/tests/cli/benchmark_serve.py +++ b/tests/cli/benchmark_serve.py @@ -52,6 +52,7 @@ import statistics import time + # Force single GPU — must be set before any CUDA initialization os.environ["CUDA_VISIBLE_DEVICES"] = "0" @@ -552,8 +553,8 @@ def main(): def print_reference_summary(reference_texts: dict[str, str], ref_impl: str) -> None: """Print a summary noting that outputs are compared against the reference implementation.""" print(f"Reference comparison: all decode outputs compared against '{ref_impl}'.") - print(f" MATCH = identical text (greedy decoding is deterministic)") - print(f" MISMATCH = text differs (FP divergence across attention kernels — check preview for correctness)") + print(" MATCH = identical text (greedy decoding is deterministic)") + print(" MISMATCH = text differs (FP divergence across attention kernels — check preview for correctness)") print() diff --git a/tests/cli/benchmark_serve_load.py b/tests/cli/benchmark_serve_load.py new file mode 100644 index 000000000000..65c959818bbb --- /dev/null +++ b/tests/cli/benchmark_serve_load.py @@ -0,0 +1,471 @@ +""" +Load test for `transformers serve` — measures throughput and latency under concurrent requests. + +Unlike benchmark_serve.py (single-user perf), this tests server capacity: +- How many tokens/sec can the server sustain under load? +- What's the latency distribution (p50/p90/p99) as concurrency increases? +- Does the server stay stable under pressure? + +Modes: + --max-concurrency N Send requests with up to N in flight at once + --request-rate R Send R requests/sec (Poisson arrival), let them queue naturally + +Examples: + # Sweep concurrency levels (1, 2, 4, 8) + python tests/cli/benchmark_serve_load.py --model Qwen/Qwen2.5-7B-Instruct \\ + --max-concurrency 1 2 4 8 --num-requests 32 + + # Fixed request rate + python tests/cli/benchmark_serve_load.py --model Qwen/Qwen2.5-7B-Instruct \\ + --request-rate 5.0 --num-requests 50 + + # Against an existing server + python tests/cli/benchmark_serve_load.py --url http://localhost:8000 \\ + --processor Qwen/Qwen2.5-7B-Instruct --max-concurrency 1 4 8 +""" + +import argparse +import asyncio +import json +import os +import random +import statistics +import time + + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +import aiohttp + +from transformers import AutoTokenizer + + +# --------------------------------------------------------------------------- +# Prompt generation +# --------------------------------------------------------------------------- + +_FILLER = ( + "The quick brown fox jumps over the lazy dog. " + "Pack my box with five dozen liquor jugs. " + "How vexingly quick daft zebras jump. " + "Sphinx of black quartz, judge my vow. " +) * 200 + + +def make_prompt(tokenizer, num_tokens: int) -> str: + token_ids = tokenizer.encode(_FILLER, add_special_tokens=False)[:num_tokens] + return tokenizer.decode(token_ids) + + +def make_prompts(tokenizer, num_requests: int, prompt_tokens: int, variance: float = 0.2) -> list[str]: + """Generate a list of prompts with some length variance to simulate realistic traffic.""" + prompts = [] + for _ in range(num_requests): + # Vary prompt length by ±variance around the target + length = max(10, int(prompt_tokens * (1.0 + random.uniform(-variance, variance)))) + prompts.append(make_prompt(tokenizer, length)) + return prompts + + +# --------------------------------------------------------------------------- +# Request sender +# --------------------------------------------------------------------------- + + +async def send_request( + session: aiohttp.ClientSession, + base_url: str, + prompt: str, + max_new_tokens: int, + seed: int, + endpoint: str = "responses", +) -> dict: + """Send a single streaming request and collect timing metrics.""" + gen_cfg = {"max_new_tokens": max_new_tokens, "do_sample": False} + + if endpoint == "responses": + url = f"{base_url}/v1/responses" + payload = { + "input": [{"role": "user", "content": prompt}], + "stream": True, + "seed": seed, + "generation_config": json.dumps(gen_cfg), + } + else: + url = f"{base_url}/v1/chat/completions" + payload = { + "messages": [{"role": "user", "content": prompt}], + "stream": True, + "seed": seed, + "generation_config": json.dumps(gen_cfg), + } + + t_start = time.perf_counter() + t_first_token = None + token_times = [] + text_chunks = [] + error = None + + try: + async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=300)) as resp: + if resp.status != 200: + error = f"HTTP {resp.status}" + return _make_result(t_start, error=error) + + async for line in resp.content: + line = line.decode("utf-8").strip() + if not line or not line.startswith("data: "): + continue + data_str = line[len("data: "):] + if data_str.strip() == "[DONE]": + break + try: + chunk = json.loads(data_str) + except json.JSONDecodeError: + continue + + # Extract token content based on endpoint format + has_content = False + if endpoint == "responses": + if chunk.get("type") == "response.output_text.delta": + delta = chunk.get("delta", "") + if delta: + text_chunks.append(delta) + has_content = True + elif chunk.get("type") == "response.completed": + break + else: + choices = chunk.get("choices", []) + if choices: + content = choices[0].get("delta", {}).get("content") + if content is not None and content != "": + text_chunks.append(content) + has_content = True + if choices[0].get("finish_reason") is not None: + break + + if has_content: + now = time.perf_counter() + token_times.append(now) + if t_first_token is None: + t_first_token = now + + except asyncio.TimeoutError: + error = "timeout" + except Exception as e: + error = str(e) + + return _make_result(t_start, t_first_token, token_times, text_chunks, error) + + +def _make_result(t_start, t_first_token=None, token_times=None, text_chunks=None, error=None): + t_end = time.perf_counter() + token_times = token_times or [] + text_chunks = text_chunks or [] + + # Inter-token latencies + itl = [] + for i in range(1, len(token_times)): + itl.append(token_times[i] - token_times[i - 1]) + + return { + "e2e": t_end - t_start, + "ttft": (t_first_token - t_start) if t_first_token else None, + "tpot": statistics.mean(itl) if itl else None, # time per output token + "itl": itl, + "output_tokens": len(text_chunks), + "text": "".join(text_chunks), + "error": error, + } + + +# --------------------------------------------------------------------------- +# Load generators +# --------------------------------------------------------------------------- + + +async def run_concurrency_test( + base_url: str, + prompts: list[str], + max_new_tokens: int, + max_concurrency: int, + seed: int, + endpoint: str, +) -> list[dict]: + """Send all requests with a concurrency limit via semaphore.""" + semaphore = asyncio.Semaphore(max_concurrency) + results = [] + + async def _limited(session, prompt): + async with semaphore: + return await send_request(session, base_url, prompt, max_new_tokens, seed, endpoint) + + async with aiohttp.ClientSession() as session: + tasks = [_limited(session, p) for p in prompts] + results = await asyncio.gather(*tasks) + + return list(results) + + +async def run_rate_test( + base_url: str, + prompts: list[str], + max_new_tokens: int, + request_rate: float, + seed: int, + endpoint: str, +) -> list[dict]: + """Send requests at a target rate using Poisson inter-arrival times.""" + results = [] + tasks = [] + + async with aiohttp.ClientSession() as session: + for i, prompt in enumerate(prompts): + task = asyncio.create_task( + send_request(session, base_url, prompt, max_new_tokens, seed, endpoint) + ) + tasks.append(task) + + # Poisson inter-arrival: exponential delay + if i < len(prompts) - 1: + delay = random.expovariate(request_rate) + await asyncio.sleep(delay) + + results = await asyncio.gather(*tasks) + + return list(results) + + +# --------------------------------------------------------------------------- +# Metrics +# --------------------------------------------------------------------------- + + +def compute_metrics(results: list[dict], duration: float) -> dict: + """Compute aggregate metrics from individual request results.""" + successful = [r for r in results if r["error"] is None] + failed = [r for r in results if r["error"] is not None] + + if not successful: + return {"error": "all requests failed", "failures": len(failed)} + + total_output_tokens = sum(r["output_tokens"] for r in successful) + + e2e_latencies = [r["e2e"] for r in successful] + ttfts = [r["ttft"] for r in successful if r["ttft"] is not None] + tpots = [r["tpot"] for r in successful if r["tpot"] is not None] + + # Flatten all inter-token latencies + all_itl = [] + for r in successful: + all_itl.extend(r["itl"]) + + def percentiles(values): + if not values: + return {} + values = sorted(values) + n = len(values) + return { + "mean": statistics.mean(values), + "median": statistics.median(values), + "p90": values[int(n * 0.9)], + "p99": values[min(int(n * 0.99), n - 1)], + "min": values[0], + "max": values[-1], + } + + return { + "total_requests": len(results), + "successful": len(successful), + "failed": len(failed), + "duration": duration, + "total_output_tokens": total_output_tokens, + "throughput_req_per_sec": len(successful) / duration, + "throughput_tok_per_sec": total_output_tokens / duration, + "e2e_latency": percentiles(e2e_latencies), + "ttft": percentiles(ttfts), + "tpot": percentiles(tpots), + "itl": percentiles(all_itl), + } + + +# --------------------------------------------------------------------------- +# Output +# --------------------------------------------------------------------------- + + +def format_ms(seconds): + if seconds is None: + return "N/A" + return f"{seconds * 1000:.1f}ms" + + +def print_metrics(metrics: dict, label: str): + print(f"\n{'=' * 70}") + print(f" {label}") + print(f"{'=' * 70}") + + if "error" in metrics: + print(f" ERROR: {metrics['error']}") + return + + print(f" Requests: {metrics['successful']} ok / {metrics['failed']} failed / {metrics['total_requests']} total") + print(f" Duration: {metrics['duration']:.1f}s") + print(f" Throughput: {metrics['throughput_req_per_sec']:.2f} req/s, {metrics['throughput_tok_per_sec']:.1f} tok/s") + print(f" Tokens: {metrics['total_output_tokens']} total output") + print() + + headers = ["metric", "mean", "median", "p90", "p99", "min", "max"] + rows = [] + for name in ["e2e_latency", "ttft", "tpot", "itl"]: + p = metrics.get(name, {}) + if not p: + continue + rows.append([ + name.upper().replace("_", " "), + format_ms(p.get("mean")), + format_ms(p.get("median")), + format_ms(p.get("p90")), + format_ms(p.get("p99")), + format_ms(p.get("min")), + format_ms(p.get("max")), + ]) + + if rows: + widths = [max(len(h), *(len(r[i]) for r in rows)) for i, h in enumerate(headers)] + fmt = " " + " | ".join(f"{{:<{w}}}" for w in widths) + sep = " " + "-+-".join("-" * w for w in widths) + print(fmt.format(*headers)) + print(sep) + for row in rows: + print(fmt.format(*row)) + print() + + +# --------------------------------------------------------------------------- +# Server management +# --------------------------------------------------------------------------- + + +def wait_for_server(base_url: str, timeout: int = 120) -> bool: + import requests + + deadline = time.time() + timeout + while time.time() < deadline: + try: + if requests.get(f"{base_url}/health", timeout=2).status_code == 200: + return True + except Exception: + pass + time.sleep(1) + return False + + +def start_server(model: str, port: int, compile: bool = False): + from transformers.cli.serve_refactored import Serve + + kwargs = {"force_model": model, "port": port, "non_blocking": True, "log_level": "warning"} + if compile: + kwargs["compile"] = True + return Serve(**kwargs) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def async_main(args): + base_url = args.url if args.url else f"http://localhost:{args.port}" + server = None + + if not args.url: + print(f"Starting server for {args.model}...") + server = start_server(args.model, args.port, compile=args.compile) + if not wait_for_server(base_url): + print("ERROR: Server did not start") + if server: + server.kill_server() + return + print("Server ready.") + + tokenizer_id = args.processor or args.model + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + + # Generate prompts + prompts = make_prompts(tokenizer, args.num_requests, args.prompt_tokens, variance=args.prompt_variance) + print(f"Generated {len(prompts)} prompts (~{args.prompt_tokens} tokens each, ±{int(args.prompt_variance*100)}%)") + print(f"Max new tokens per request: {args.max_new_tokens}") + print(f"Endpoint: /v1/{args.endpoint}") + + # Warmup — use the longest prompt so compilation covers all shorter sizes + warmup_prompt = max(prompts, key=len) + print(f"Warming up ({args.warmup} requests, longest prompt)...") + async with aiohttp.ClientSession() as session: + for i in range(args.warmup): + await send_request(session, base_url, warmup_prompt, args.max_new_tokens, args.seed, args.endpoint) + print("Warmup done.") + + # Run tests + if args.request_rate: + # Rate-based test + label = f"rate={args.request_rate} req/s, {args.num_requests} requests" + print(f"\nRunning: {label}") + t0 = time.perf_counter() + results = await run_rate_test( + base_url, prompts, args.max_new_tokens, args.request_rate, args.seed, args.endpoint, + ) + duration = time.perf_counter() - t0 + metrics = compute_metrics(results, duration) + print_metrics(metrics, label) + else: + # Concurrency sweep + for concurrency in args.max_concurrency: + label = f"concurrency={concurrency}, {args.num_requests} requests" + print(f"\nRunning: {label}") + t0 = time.perf_counter() + results = await run_concurrency_test( + base_url, prompts, args.max_new_tokens, concurrency, args.seed, args.endpoint, + ) + duration = time.perf_counter() - t0 + metrics = compute_metrics(results, duration) + print_metrics(metrics, label) + + if server: + server.kill_server() + + +def main(): + parser = argparse.ArgumentParser( + description="Load test for transformers serve", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-7B-Instruct") + parser.add_argument("--processor", type=str, default=None) + parser.add_argument("--url", type=str, default=None, help="Existing server URL (skip start/stop)") + parser.add_argument("--port", type=int, default=8642) + parser.add_argument("--compile", action="store_true", help="Enable --compile on the server") + parser.add_argument("--endpoint", type=str, choices=["chat", "responses"], default="responses") + + # Load parameters + parser.add_argument("--max-concurrency", type=int, nargs="+", default=[1, 2, 4], + help="Concurrency levels to sweep (default: 1 2 4)") + parser.add_argument("--request-rate", type=float, default=None, + help="Target request rate (req/s). Uses Poisson arrivals. Overrides --max-concurrency.") + parser.add_argument("--num-requests", type=int, default=16, help="Total requests per test (default: 16)") + + # Prompt parameters + parser.add_argument("--prompt-tokens", type=int, default=256, help="Target prompt length in tokens (default: 256)") + parser.add_argument("--prompt-variance", type=float, default=0.2, + help="Prompt length variance as fraction (default: 0.2 = ±20%%)") + parser.add_argument("--max-new-tokens", type=int, default=128, help="Max tokens to generate per request (default: 128)") + + parser.add_argument("--warmup", type=int, default=2, help="Warmup requests (default: 2)") + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + asyncio.run(async_main(args)) + + +if __name__ == "__main__": + main() From 76a5c836638471f235c6669b89b3ae84307587d5 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 24 Mar 2026 14:57:41 +0000 Subject: [PATCH 22/64] fix vl test --- src/transformers/cli/serving/response.py | 1 + tests/cli/test_serve_refactored.py | 40 +++++++++++------------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 26e213268f92..96313a4e3294 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -97,6 +97,7 @@ def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSO tools=body.get("tools"), return_tensors="pt", return_dict=True, + tokenize=True, ).to(model.device) gen_config = self._build_generation_config(body, model.generation_config, processor) diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py index 9840f15410dd..3b04dafaa6a7 100644 --- a/tests/cli/test_serve_refactored.py +++ b/tests/cli/test_serve_refactored.py @@ -1634,16 +1634,8 @@ def _make_test_wav(duration: float = 2.0, sample_rate: int = 16000) -> bytes: # --------------------------------------------------------------------------- -def _make_test_image_base64() -> str: - """Create a small red 64x64 PNG as a base64 data URL.""" - import base64 - - from PIL import Image - - img = Image.new("RGB", (64, 64), color="red") - buf = io.BytesIO() - img.save(buf, format="PNG") - return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" +# Real image URL for VLM tests (person + dog on a beach) +_DOG_IMAGE_URL = "https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2-VL/demo_small.jpg" @unittest.skipUnless(run_slow and is_vision_available(), "Set RUN_SLOW=1 and install torchvision + PIL") @@ -1675,45 +1667,49 @@ def tearDownClass(cls): cls.serve.kill_server() def test_chat_completion_with_image(self): - """Chat completions should accept image_url content and produce a response.""" - image_url = _make_test_image_base64() + """Chat completions should accept image_url content and produce a meaningful response.""" resp = self.client.chat.completions.create( model=self.MODEL, messages=[ { "role": "user", "content": [ - {"type": "text", "text": "What color is this image?"}, - {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What do you see in this image?"}, + {"type": "image_url", "image_url": {"url": _DOG_IMAGE_URL}}, ], } ], - max_tokens=20, + max_tokens=50, ) text = resp.choices[0].message.content self.assertIsNotNone(text) - self.assertIn("red", text.lower(), f"Expected 'red' in response, got: {text}") + self.assertTrue( + any(word in text.lower() for word in ["dog", "beach", "person"]), + f"Expected dog/beach/person in response, got: {text}", + ) def test_responses_with_image(self): - """Responses API should accept image_url content and produce a response about the image.""" - image_url = _make_test_image_base64() + """Responses API should accept image_url content and produce a meaningful response.""" resp = self.client.responses.create( model=self.MODEL, input=[ { "role": "user", "content": [ - {"type": "text", "text": "What color is this image?"}, - {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What do you see in this image?"}, + {"type": "image_url", "image_url": {"url": _DOG_IMAGE_URL}}, ], } ], stream=False, - max_output_tokens=20, + max_output_tokens=50, ) self.assertEqual(resp.status, "completed") text = resp.output[0].content[0].text - self.assertIn("red", text.lower(), f"Expected 'red' in response, got: {text}") + self.assertTrue( + any(word in text.lower() for word in ["dog", "beach", "person"]), + f"Expected dog/beach/person in response, got: {text}", + ) @unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") From 31e59c3560893ccaa9d0c9889b6279f4c01fdc52 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 26 Mar 2026 17:11:17 +0000 Subject: [PATCH 23/64] first iteration of cb --- src/transformers/cli/serve_refactored.py | 34 +- .../cli/serving/chat_completion.py | 160 ++++--- src/transformers/cli/serving/response.py | 405 ++++++++++-------- src/transformers/cli/serving/server.py | 7 +- src/transformers/cli/serving/transcription.py | 40 +- src/transformers/cli/serving/utils.py | 396 ++++++++++++++--- 6 files changed, 699 insertions(+), 343 deletions(-) diff --git a/src/transformers/cli/serve_refactored.py b/src/transformers/cli/serve_refactored.py index 6338c4bbe73b..725fa41f7fab 100644 --- a/src/transformers/cli/serve_refactored.py +++ b/src/transformers/cli/serve_refactored.py @@ -75,6 +75,12 @@ def __init__( help="Enable static cache + torch.compile for faster decode (~2.6x). First request triggers compilation (~30s)." ), ] = False, + continuous_batching: Annotated[ + bool, + typer.Option( + help="Enable continuous batching with paged attention for higher throughput on concurrent requests." + ), + ] = False, non_blocking: Annotated[ bool, typer.Option(hidden=True, help="Run server in a background thread. Used by tests.") ] = False, @@ -88,7 +94,8 @@ def __init__( from .serving.model_manager import ModelManager from .serving.response import ResponseHandler from .serving.server import build_server - from .serving.utils import InferenceThread + from .serving.transcription import TranscriptionHandler + from .serving.utils import GenerationState # Seed if default_seed is not None: @@ -113,28 +120,33 @@ def __init__( processor_id=processor, ) self._model_manager = model_manager + self._generation_state = GenerationState(continuous_batching=continuous_batching) - # Single persistent thread for all generate() calls — required for - # torch.compile + CUDA graphs which use thread-local storage. - inference_thread = InferenceThread() - - chat_handler = ChatCompletionHandler( + self._chat_handler = ChatCompletionHandler( model_manager=model_manager, + generation_state=self._generation_state, force_model=force_model, force_processor=processor, - inference_thread=inference_thread, compile=compile, ) - response_handler = ResponseHandler( + self._response_handler = ResponseHandler( model_manager=model_manager, + generation_state=self._generation_state, force_model=force_model, force_processor=processor, - inference_thread=inference_thread, compile=compile, ) - app = build_server(model_manager, chat_handler, response_handler=response_handler, enable_cors=enable_cors) + self._transcription_handler = TranscriptionHandler(model_manager, self._generation_state) + + app = build_server( + model_manager, + self._chat_handler, + response_handler=self._response_handler, + transcription_handler=self._transcription_handler, + enable_cors=enable_cors, + ) config = uvicorn.Config(app, host=host, port=port, log_level=log_level) self.server = uvicorn.Server(config) @@ -158,6 +170,8 @@ def reset_loaded_models(self): self._model_manager.shutdown() def kill_server(self): + self._generation_state.shutdown() + self._model_manager.shutdown() if not self._thread or not self._thread.is_alive(): return self.server.should_exit = True diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index e2b0cc40eb1e..ae57207224a8 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -19,6 +19,7 @@ from __future__ import annotations +import asyncio import time from collections.abc import AsyncGenerator from typing import TYPE_CHECKING @@ -42,6 +43,7 @@ from ...utils import logging from .utils import ( UNUSED_CHAT_COMPLETION_FIELDS, + BaseGenerateManager, BaseHandler, ToolCallParser, TransformersCompletionCreateParamsStreaming, @@ -62,7 +64,7 @@ class ChatCompletionHandler(BaseHandler): # ----- entry point ----- - def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: + async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: """Validate the request, load the model, and dispatch to streaming or non-streaming.""" self._validate_request(body) @@ -73,20 +75,35 @@ def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSO return JSONResponse({}, status_code=200) model_id, model, processor = self._resolve_model(body) - modality = self.model_manager.get_model_modality(model, processor=processor) + use_cb = self.generation_state.use_continuous_batching(model, modality) + gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb) processor_inputs = get_processor_inputs_from_messages(messages, modality) - inputs = processor.apply_chat_template( - processor_inputs, - add_generation_prompt=True, - tools=body.get("tools"), - return_tensors="pt", - return_dict=True, - tokenize=True, - ).to(model.device) - - gen_config = self._build_generation_config(body, model.generation_config, processor) + if use_cb: + # CB handles device placement internally — don't create tensors or move + # anything to CUDA here. Pass plain token ID lists only. + inputs = processor.apply_chat_template( + processor_inputs, + add_generation_prompt=True, + tools=body.get("tools"), + return_dict=True, + tokenize=True, + ) + else: + inputs = processor.apply_chat_template( + processor_inputs, + add_generation_prompt=True, + tools=body.get("tools"), + return_tensors="pt", + return_dict=True, + tokenize=True, + ).to(model.device) + + gen_config = self._build_generation_config(body, model.generation_config, processor, use_cb=use_cb) + # TODO: remove when CB supports per-request generation config + if use_cb: + gen_manager.init_cb(model, gen_config) # Detect tool support for the loaded model # TODO: after tool_call start token, use constrained generation to: @@ -102,15 +119,17 @@ def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSO model_id, inputs, gen_config, + gen_manager=gen_manager, tool_format=tool_format, ) - return self._non_streaming( + return await self._non_streaming( request_id, model, processor, model_id, inputs, gen_config, + gen_manager=gen_manager, tool_format=tool_format, ) @@ -124,68 +143,76 @@ def _streaming( model_id: str, inputs: dict[str, torch.Tensor], gen_config: GenerationConfig, + gen_manager: BaseGenerateManager, tool_format: dict | None = None, ) -> StreamingResponse: """Stream tokens as SSE via DirectStreamer.""" - queue, streamer = self._start_streaming(model, processor, inputs, gen_config) - input_len = inputs["input_ids"].shape[-1] + queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id) + input_ids = inputs["input_ids"] + input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] parser = ToolCallParser(tool_format) if tool_format else None async def sse_gen() -> AsyncGenerator[str, None]: has_tool_calls = False - yield self._build_chunk_sse(request_id, role="assistant", model=model_id) - - while True: - text = await queue.get() - if text is None: - break - elif isinstance(text, _StreamError): - yield f'data: {{"error": "{text.msg}"}}\n\n' - return - - # Tool call parsing: None = normal text, CONSUMED = buffering, else = tool call dict - chunk_kwargs = {"content": text} - if parser is not None and (result := parser.feed(text)) is not None: - if result is ToolCallParser.CONSUMED: - continue - has_tool_calls = True - chunk_kwargs = { - "tool_calls": [ - ChoiceDeltaToolCall( - index=0, - type="function", - id=f"{request_id}_tool_call", - function={"name": result["name"], "arguments": result["arguments"]}, - ) - ] - } - - yield self._build_chunk_sse(request_id, model=model_id, **chunk_kwargs) - - hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens - if has_tool_calls: - finish_reason = "tool_calls" - elif hit_max: - finish_reason = "length" - else: - finish_reason = "stop" - usage = CompletionUsage( - prompt_tokens=input_len, - completion_tokens=streamer.total_tokens, - total_tokens=input_len + streamer.total_tokens, - ) - yield self._build_chunk_sse( - request_id, - finish_reason=finish_reason, - model=model_id, - usage=usage, - ) + try: + yield self._build_chunk_sse(request_id, role="assistant", model=model_id) + + while True: + text = await queue.get() + if text is None: + break + elif isinstance(text, _StreamError): + yield f'data: {{"error": "{text.msg}"}}\n\n' + return + + # Tool call parsing: None = normal text, CONSUMED = buffering, else = tool call dict + chunk_kwargs = {"content": text} + if parser is not None and (result := parser.feed(text)) is not None: + if result is ToolCallParser.CONSUMED: + continue + has_tool_calls = True + chunk_kwargs = { + "tool_calls": [ + ChoiceDeltaToolCall( + index=0, + type="function", + id=f"{request_id}_tool_call", + function={"name": result["name"], "arguments": result["arguments"]}, + ) + ] + } + + yield self._build_chunk_sse(request_id, model=model_id, **chunk_kwargs) + + hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens + if has_tool_calls: + finish_reason = "tool_calls" + elif hit_max: + finish_reason = "length" + else: + finish_reason = "stop" + usage = CompletionUsage( + prompt_tokens=input_len, + completion_tokens=streamer.total_tokens, + total_tokens=input_len + streamer.total_tokens, + ) + yield self._build_chunk_sse( + request_id, + finish_reason=finish_reason, + model=model_id, + usage=usage, + ) + except (GeneratorExit, asyncio.CancelledError): + # Client disconnected — abort generation to free GPU. + # Re-raise is mandatory: Python raises RuntimeError if GeneratorExit is swallowed. + streamer.cancel() + raise return StreamingResponse(sse_gen(), media_type="text/event-stream") # ----- non-streaming ----- - def _non_streaming( + async def _non_streaming( self, request_id: str, model: PreTrainedModel, @@ -193,10 +220,11 @@ def _non_streaming( model_id: str, inputs: dict[str, torch.Tensor], gen_config: GenerationConfig, + gen_manager: BaseGenerateManager, tool_format: dict | None = None, ) -> JSONResponse: """Run generation and return a JSONResponse.""" - content, input_len, generated_ids = self._generate_non_streaming(model, processor, inputs, gen_config) + content, input_len, generated_ids = await gen_manager.generate_non_streaming(model, processor, inputs, gen_config, request_id=request_id) hit_max = gen_config.max_new_tokens is not None and len(generated_ids) >= gen_config.max_new_tokens completion_tokens = len(generated_ids) @@ -252,9 +280,9 @@ def _validate_request(self, body: dict) -> None: if unused: raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") - def _build_generation_config(self, body: dict, model_generation_config, processor=None): + def _build_generation_config(self, body: dict, model_generation_config, processor=None, use_cb=False): """Chat Completions params on top of base config.""" - generation_config = super()._build_generation_config(body, model_generation_config, processor) + generation_config = super()._build_generation_config(body, model_generation_config, processor, use_cb=use_cb) if body.get("max_tokens") is not None: generation_config.max_new_tokens = int(body["max_tokens"]) diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 96313a4e3294..da626281d960 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -19,6 +19,7 @@ from __future__ import annotations +import asyncio import time from collections.abc import AsyncGenerator from typing import TYPE_CHECKING @@ -47,7 +48,14 @@ from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage from ...utils import logging -from .utils import BaseHandler, ToolCallParser, _StreamError, detect_tool_format, get_processor_inputs_from_messages +from .utils import ( + BaseGenerateManager, + BaseHandler, + ToolCallParser, + _StreamError, + detect_tool_format, + get_processor_inputs_from_messages, +) if TYPE_CHECKING: @@ -78,29 +86,45 @@ class ResponseHandler(BaseHandler): # ----- entry point ----- - def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: + async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: """Validate, load model, dispatch to streaming or non-streaming.""" self._validate_request(body) model_id, model, processor = self._resolve_model(body) + modality = self.model_manager.get_model_modality(model, processor=processor) + use_cb = self.generation_state.use_continuous_batching(model, modality) + gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb) # Two-step input conversion (chat completions skips step 1 since messages are already standard): # 1. Normalize Responses API input (string/list/dict + instructions) → standard messages list # 2. Transform message content for the HF processor (VLM image handling, text joining, etc.) messages = self._input_to_messages(body) - modality = self.model_manager.get_model_modality(model, processor=processor) processor_inputs = get_processor_inputs_from_messages(messages, modality) - inputs = processor.apply_chat_template( - processor_inputs, - add_generation_prompt=True, - tools=body.get("tools"), - return_tensors="pt", - return_dict=True, - tokenize=True, - ).to(model.device) - - gen_config = self._build_generation_config(body, model.generation_config, processor) + if use_cb: + # CB handles device placement internally — don't create tensors or move + # anything to CUDA here. Pass plain token ID lists only. + inputs = processor.apply_chat_template( + processor_inputs, + add_generation_prompt=True, + tools=body.get("tools"), + return_dict=True, + tokenize=True, + ) + else: + inputs = processor.apply_chat_template( + processor_inputs, + add_generation_prompt=True, + tools=body.get("tools"), + return_tensors="pt", + return_dict=True, + tokenize=True, + ).to(model.device) + + gen_config = self._build_generation_config(body, model.generation_config, processor, use_cb=use_cb) + # TODO: remove when CB supports per-request generation config + if use_cb: + gen_manager.init_cb(model, gen_config) tool_format = detect_tool_format(model) if body.get("tools") else None if body.get("stream", True): @@ -112,9 +136,10 @@ def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSO body, inputs, gen_config, + gen_manager=gen_manager, tool_format=tool_format, ) - return self._non_streaming( + return await self._non_streaming( request_id, model, processor, @@ -122,6 +147,7 @@ def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSO body, inputs, gen_config, + gen_manager=gen_manager, tool_format=tool_format, ) @@ -164,11 +190,13 @@ def _streaming( body: dict, inputs: dict, gen_config: GenerationConfig, + gen_manager: BaseGenerateManager, tool_format: dict | None = None, ) -> StreamingResponse: """Generate a streaming Responses API reply (SSE) using DirectStreamer.""" - queue, streamer = self._start_streaming(model, processor, inputs, gen_config) - input_len = inputs["input_ids"].shape[-1] + queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id) + input_ids = inputs["input_ids"] + input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] parser = ToolCallParser(tool_format) if tool_format else None seq = 0 @@ -191,204 +219,212 @@ def _streaming( async def event_stream() -> AsyncGenerator[str, None]: nonlocal seq, output_index - # 1. Created + In progress - yield self.chunk_to_sse( - ResponseCreatedEvent( - type="response.created", - sequence_number=seq, - response=Response(**response_base, status="queued", output=[]), - ) - ) - seq += 1 - yield self.chunk_to_sse( - ResponseInProgressEvent( - type="response.in_progress", - sequence_number=seq, - response=Response(**response_base, status="in_progress", output=[]), + try: + # 1. Created + In progress + yield self.chunk_to_sse( + ResponseCreatedEvent( + type="response.created", + sequence_number=seq, + response=Response(**response_base, status="queued", output=[]), + ) ) - ) - seq += 1 - - # 2. Output item added (message) - yield self.chunk_to_sse( - ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=seq, - output_index=output_index, - item=ResponseOutputMessage( - id=msg_id, - type="message", - status="in_progress", - role="assistant", - content=[], - ), + seq += 1 + yield self.chunk_to_sse( + ResponseInProgressEvent( + type="response.in_progress", + sequence_number=seq, + response=Response(**response_base, status="in_progress", output=[]), + ) ) - ) - seq += 1 - - # 3. Content part added - yield self.chunk_to_sse( - ResponseContentPartAddedEvent( - type="response.content_part.added", - item_id=msg_id, - sequence_number=seq, - output_index=output_index, - content_index=0, - part=ResponseOutputText(type="output_text", text="", annotations=[]), + seq += 1 + + # 2. Output item added (message) + yield self.chunk_to_sse( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=seq, + output_index=output_index, + item=ResponseOutputMessage( + id=msg_id, + type="message", + status="in_progress", + role="assistant", + content=[], + ), + ) ) - ) - seq += 1 - - # 4. Stream tokens - full_text = "" - tool_calls = [] - - while True: - text = await queue.get() - if text is None: - break - if isinstance(text, _StreamError): - logger.error(f"Exception in response generation: {text.msg}") - yield self.chunk_to_sse(ResponseErrorEvent(type="error", sequence_number=seq, message=text.msg)) - seq += 1 - yield self.chunk_to_sse( - ResponseFailedEvent( - type="response.failed", - sequence_number=seq, - response=Response( - **response_base, - status="failed", - output=[], - error=ResponseError(code="server_error", message=text.msg), - ), - ) + seq += 1 + + # 3. Content part added + yield self.chunk_to_sse( + ResponseContentPartAddedEvent( + type="response.content_part.added", + item_id=msg_id, + sequence_number=seq, + output_index=output_index, + content_index=0, + part=ResponseOutputText(type="output_text", text="", annotations=[]), ) - return - - # Tool call parsing - if parser is not None and (result := parser.feed(text)) is not None: - if result is not ToolCallParser.CONSUMED: - # Emit tool call as a function_call output item - tc_id = f"{request_id}_tool_call" - name = result["name"] - arguments = result["arguments"] - tc_item = ResponseFunctionToolCall( - id=tc_id, - call_id=tc_id, - type="function_call", - name=name, - arguments=arguments, - status="completed", - ) - tool_calls.append(tc_item) - # output[0] = message, output[1..N] = tool calls (required by OpenAI SSE spec) - output_index += 1 + ) + seq += 1 + + # 4. Stream tokens + full_text = "" + tool_calls = [] + + while True: + text = await queue.get() + if text is None: + break + if isinstance(text, _StreamError): + logger.error(f"Exception in response generation: {text.msg}") yield self.chunk_to_sse( - ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=seq, - output_index=output_index, - item=tc_item, - ) + ResponseErrorEvent(type="error", sequence_number=seq, message=text.msg) ) seq += 1 yield self.chunk_to_sse( - ResponseFunctionCallArgumentsDoneEvent( - type="response.function_call_arguments.done", + ResponseFailedEvent( + type="response.failed", sequence_number=seq, - item_id=tc_id, - output_index=output_index, - arguments=arguments, - name=name, + response=Response( + **response_base, + status="failed", + output=[], + error=ResponseError(code="server_error", message=text.msg), + ), ) ) - seq += 1 - yield self.chunk_to_sse( - ResponseOutputItemDoneEvent( - type="response.output_item.done", - sequence_number=seq, - output_index=output_index, - item=tc_item, + return + + # Tool call parsing + if parser is not None and (result := parser.feed(text)) is not None: + if result is not ToolCallParser.CONSUMED: + # Emit tool call as a function_call output item + tc_id = f"{request_id}_tool_call" + name = result["name"] + arguments = result["arguments"] + tc_item = ResponseFunctionToolCall( + id=tc_id, + call_id=tc_id, + type="function_call", + name=name, + arguments=arguments, + status="completed", + ) + tool_calls.append(tc_item) + # output[0] = message, output[1..N] = tool calls (required by OpenAI SSE spec) + output_index += 1 + yield self.chunk_to_sse( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=seq, + output_index=output_index, + item=tc_item, + ) + ) + seq += 1 + yield self.chunk_to_sse( + ResponseFunctionCallArgumentsDoneEvent( + type="response.function_call_arguments.done", + sequence_number=seq, + item_id=tc_id, + output_index=output_index, + arguments=arguments, + name=name, + ) ) + seq += 1 + yield self.chunk_to_sse( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=seq, + output_index=output_index, + item=tc_item, + ) + ) + seq += 1 + continue + + full_text += text + yield self.chunk_to_sse( + ResponseTextDeltaEvent( + type="response.output_text.delta", + item_id=msg_id, + sequence_number=seq, + output_index=0, + content_index=0, + delta=text, + logprobs=[], ) - seq += 1 - continue + ) + seq += 1 - full_text += text + # 5. Close text output + output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[]) yield self.chunk_to_sse( - ResponseTextDeltaEvent( - type="response.output_text.delta", + ResponseTextDoneEvent( + type="response.output_text.done", item_id=msg_id, sequence_number=seq, output_index=0, content_index=0, - delta=text, + text=full_text, logprobs=[], ) ) seq += 1 + yield self.chunk_to_sse( + ResponseContentPartDoneEvent( + type="response.content_part.done", + item_id=msg_id, + sequence_number=seq, + output_index=0, + content_index=0, + part=output_text_part, + ) + ) + seq += 1 - # 5. Close text output - output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[]) - yield self.chunk_to_sse( - ResponseTextDoneEvent( - type="response.output_text.done", - item_id=msg_id, - sequence_number=seq, - output_index=0, - content_index=0, - text=full_text, - logprobs=[], + msg_item = ResponseOutputMessage( + id=msg_id, + type="message", + status="completed", + role="assistant", + content=[output_text_part], + annotations=[], ) - ) - seq += 1 - yield self.chunk_to_sse( - ResponseContentPartDoneEvent( - type="response.content_part.done", - item_id=msg_id, - sequence_number=seq, - output_index=0, - content_index=0, - part=output_text_part, + yield self.chunk_to_sse( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=seq, + output_index=0, + item=msg_item, + ) ) - ) - seq += 1 + seq += 1 - msg_item = ResponseOutputMessage( - id=msg_id, - type="message", - status="completed", - role="assistant", - content=[output_text_part], - annotations=[], - ) - yield self.chunk_to_sse( - ResponseOutputItemDoneEvent( - type="response.output_item.done", - sequence_number=seq, - output_index=0, - item=msg_item, - ) - ) - seq += 1 - - # 6. Completed - all_output = [msg_item] + list(tool_calls) - usage = _make_usage(input_len, streamer.total_tokens) - yield self.chunk_to_sse( - ResponseCompletedEvent( - type="response.completed", - sequence_number=seq, - response=Response(**response_base, status="completed", output=all_output, usage=usage), + # 6. Completed + all_output = [msg_item] + list(tool_calls) + usage = _make_usage(input_len, streamer.total_tokens) + yield self.chunk_to_sse( + ResponseCompletedEvent( + type="response.completed", + sequence_number=seq, + response=Response(**response_base, status="completed", output=all_output, usage=usage), + ) ) - ) - seq += 1 + seq += 1 + except (GeneratorExit, asyncio.CancelledError): + # Client disconnected — abort generation to free GPU. + # Re-raise is mandatory: Python raises RuntimeError if GeneratorExit is swallowed. + streamer.cancel() + raise return StreamingResponse(event_stream(), media_type="text/event-stream") # ----- non-streaming ----- - def _non_streaming( + async def _non_streaming( self, request_id: str, model: PreTrainedModel, @@ -397,10 +433,11 @@ def _non_streaming( body: dict, inputs: dict, gen_config: GenerationConfig, + gen_manager: BaseGenerateManager, tool_format: dict | None = None, ) -> JSONResponse: """Generate a non-streaming Responses API reply (single JSON).""" - full_text, input_len, generated_ids = self._generate_non_streaming(model, processor, inputs, gen_config) + full_text, input_len, generated_ids = await gen_manager.generate_non_streaming(model, processor, inputs, gen_config, request_id=request_id) created_at = time.time() resp_id = f"resp_{request_id}" @@ -459,9 +496,9 @@ def _validate_request(self, body: dict) -> None: if unused: raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") - def _build_generation_config(self, body: dict, model_generation_config, processor=None): + def _build_generation_config(self, body: dict, model_generation_config, processor=None, use_cb=False): """Responses API params on top of base config.""" - generation_config = super()._build_generation_config(body, model_generation_config, processor) + generation_config = super()._build_generation_config(body, model_generation_config, processor, use_cb=use_cb) if body.get("max_output_tokens") is not None: generation_config.max_new_tokens = int(body["max_output_tokens"]) diff --git a/src/transformers/cli/serving/server.py b/src/transformers/cli/serving/server.py index fd4f1245b2cd..7be3bb0c4d61 100644 --- a/src/transformers/cli/serving/server.py +++ b/src/transformers/cli/serving/server.py @@ -39,6 +39,7 @@ def build_server( model_manager: ModelManager, chat_handler: ChatCompletionHandler, response_handler: ResponseHandler, + transcription_handler: TranscriptionHandler, enable_cors: bool = False, ) -> FastAPI: """Build and return a configured FastAPI application. @@ -85,13 +86,11 @@ async def request_id_middleware(request: Request, call_next): @app.post("/v1/chat/completions") async def chat_completions(request: Request, body: dict): - return chat_handler.handle_request(body, request.state.request_id) + return await chat_handler.handle_request(body, request.state.request_id) @app.post("/v1/responses") async def responses(request: Request, body: dict): - return response_handler.handle_request(body, request.state.request_id) - - transcription_handler = TranscriptionHandler(model_manager) + return await response_handler.handle_request(body, request.state.request_id) @app.post("/v1/audio/transcriptions") async def audio_transcriptions(request: Request): diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py index 134cfa72a721..586e4d10d67f 100644 --- a/src/transformers/cli/serving/transcription.py +++ b/src/transformers/cli/serving/transcription.py @@ -18,14 +18,13 @@ from __future__ import annotations import io -from threading import Thread from typing import TYPE_CHECKING from fastapi.responses import JSONResponse, StreamingResponse from ...utils import logging from .model_manager import ModelManager -from .utils import DirectStreamer, _StreamError +from .utils import DirectStreamer, GenerationState, _StreamError if TYPE_CHECKING: @@ -40,12 +39,14 @@ class TranscriptionHandler: Accepts a multipart/form-data request with an audio file and model name, runs speech-to-text, and returns an OpenAI-compatible Transcription response. - Supports streaming (``stream=true`` in form data) which yields text chunks - as SSE events, or non-streaming which returns a single JSON response. + Standalone (does not extend :class:`BaseHandler`) because audio requests use + multipart form data, not JSON bodies, and don't need generation config or + validation. Shares the :class:`GenerationState` for thread safety. """ - def __init__(self, model_manager: ModelManager): + def __init__(self, model_manager: ModelManager, generation_state: GenerationState): self.model_manager = model_manager + self._generation_state = generation_state async def handle_request(self, request: Request) -> JSONResponse | StreamingResponse: """Parse multipart form, run transcription, return result.""" @@ -72,26 +73,33 @@ async def handle_request(self, request: Request) -> JSONResponse | StreamingResp ) audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype) - if stream: - return self._streaming(audio_model, audio_processor, audio_inputs) - return self._non_streaming(audio_model, audio_processor, audio_inputs) + # Transcription uses the per-model InferenceThread (no CB for audio). + gen_manager = self._generation_state.get_manager(model_id_and_revision, use_cb=False) + tokenizer = audio_processor.tokenizer if hasattr(audio_processor, "tokenizer") else audio_processor - def _non_streaming(self, audio_model, audio_processor, audio_inputs) -> JSONResponse: + if stream: + return self._streaming(gen_manager, audio_model, tokenizer, audio_inputs) + return await self._non_streaming(gen_manager, audio_model, audio_processor, audio_inputs) + + async def _non_streaming(self, gen_manager, audio_model, audio_processor, audio_inputs) -> JSONResponse: + # Audio models have different inputs (input_features) and decode (batch_decode) + # than text models, so we use async_submit() directly instead of + # generate_non_streaming(). TODO: add generate_audio_non_streaming() when + # more audio modalities are supported. from openai.types.audio import Transcription - generated_ids = audio_model.generate(**audio_inputs) + generated_ids = await gen_manager.async_submit(audio_model.generate, **audio_inputs) text = audio_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - return JSONResponse(Transcription(text=text).model_dump(exclude_none=True)) - def _streaming(self, audio_model, audio_processor, audio_inputs) -> StreamingResponse: + def _streaming(self, gen_manager, audio_model, tokenizer, audio_inputs) -> StreamingResponse: + # Same as _non_streaming — uses submit() directly because audio inputs + # differ from text. TODO: add generate_audio_streaming() when more audio + # modalities are supported. import asyncio loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() - - # For processors like WhisperProcessor, the fast tokenizer is at .tokenizer - tokenizer = audio_processor.tokenizer if hasattr(audio_processor, "tokenizer") else audio_processor streamer = DirectStreamer(tokenizer._tokenizer, loop, queue, skip_special_tokens=True) gen_kwargs = {**audio_inputs, "streamer": streamer} @@ -101,7 +109,7 @@ def _run(): except Exception as e: loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e))) - Thread(target=_run, daemon=True).start() + gen_manager.submit(_run) async def sse_gen(): while True: diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index 5d69e0e47858..02d54ba2e825 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -25,13 +25,18 @@ import re import tempfile import threading +from abc import ABC, abstractmethod from concurrent.futures import Future from io import BytesIO from queue import Queue +from transformers.utils import logging from transformers.utils.import_utils import is_openai_available, is_vision_available +logger = logging.get_logger(__name__) + + if is_vision_available(): from PIL import Image @@ -95,6 +100,12 @@ def __init__(self, msg: str): self.msg = msg +class _GenerationCancelled(Exception): + """Raised inside ``DirectStreamer.put()`` to abort ``model.generate()``.""" + + + + # --------------------------------------------------------------------------- # Tool call parsing # --------------------------------------------------------------------------- @@ -334,17 +345,12 @@ def close(self): class DirectStreamer: - """Streamer that decodes tokens incrementally and pushes text to an asyncio.Queue. - - Uses the Rust ``DecodeStream.step()`` for O(1) per-token decode, unlike - ``TextIteratorStreamer`` which re-decodes the full sequence each time. + """Streamer for ``model.generate()`` (used by :class:`GenerateManager`). - Args: - tokenizer: The raw ``tokenizers.Tokenizer`` instance (i.e. the Rust tokenizer, - typically accessed as ``processor._tokenizer`` or ``processor.tokenizer._tokenizer``). - loop: The asyncio event loop to push results to. - queue: The asyncio.Queue to push decoded text chunks to. - skip_special_tokens: Whether to skip special tokens during decoding. + Implements the ``put``/``end`` protocol that ``model.generate()`` expects: + generate calls ``put(token_tensor)`` after each decode step, and ``end()`` + when generation is complete. Tokens are decoded incrementally via the Rust + ``DecodeStream`` (O(1) per token) and pushed as text to an asyncio.Queue. """ def __init__(self, tokenizer, loop, queue, skip_special_tokens: bool = True): @@ -355,14 +361,17 @@ def __init__(self, tokenizer, loop, queue, skip_special_tokens: bool = True): self._queue = queue self._decode_stream = DecodeStream([], skip_special_tokens) self._first = True + self._cancelled = threading.Event() self.total_tokens = 0 def put(self, value) -> None: + """Called by ``model.generate()`` after each decode step with new token(s).""" + if self._cancelled.is_set(): + raise _GenerationCancelled() + # The first put() contains the prompt tokens — skip since we only stream generated tokens. if self._first: self._first = False - return # skip prompt tokens - if len(value.shape) > 1: - value = value[0] + return for token_id in value.tolist(): self.total_tokens += 1 text = self._decode_stream.step(self._tokenizer, token_id) @@ -370,8 +379,53 @@ def put(self, value) -> None: self._loop.call_soon_threadsafe(self._queue.put_nowait, text) def end(self) -> None: + """Called by ``model.generate()`` when generation is complete.""" self._loop.call_soon_threadsafe(self._queue.put_nowait, None) + def cancel(self) -> None: + """Signal cancellation. The next ``put()`` call will raise and abort ``model.generate()``.""" + self._cancelled.set() + + +class CBStreamer: + """Streamer for continuous batching (used by :class:`CBGenerateManager`). + + Same ``put``/``end`` protocol as :class:`DirectStreamer`, but called manually + by :class:`CBGenerateManager` instead of by ``model.generate()``: + ``put(output)`` receives a CB ``GenerationOutput``, decodes new tokens, and + pushes text to the asyncio.Queue. ``end()`` signals the stream is complete. + """ + + def __init__(self, cb_manager, request_id, tokenizer, loop, queue): + from tokenizers.decoders import DecodeStream + + self._cb = cb_manager + self._request_id = request_id + self._loop = loop + self._queue = queue + self._tokenizer = tokenizer + self._decode_stream = DecodeStream([], True) + self._prev_len = 0 + self.total_tokens = 0 + + def put(self, output) -> None: + """Decode new tokens from a CB ``GenerationOutput`` and push text to the queue.""" + new_tokens = output.generated_tokens[self._prev_len :] + self._prev_len = len(output.generated_tokens) + for token_id in new_tokens: + self.total_tokens += 1 + text = self._decode_stream.step(self._tokenizer, token_id) + if text is not None: + self._queue.put_nowait(text) + + def end(self) -> None: + """Signal end of stream.""" + self._queue.put_nowait(None) + + def cancel(self) -> None: + """Cancel the CB request.""" + self._cb.cancel_request(self._request_id) + # --------------------------------------------------------------------------- # Torch helpers @@ -392,16 +446,10 @@ def reset_torch_cache() -> None: class InferenceThread: - """A single persistent thread that runs all model.generate() calls. - - torch.compile with ``mode="reduce-overhead"`` uses CUDA graphs, which store - state in thread-local storage (TLS). If generate() is called from different - threads (e.g. a new Thread per streaming request), the CUDA graph state is - lost or corrupted — causing silent wrong output or crashes. + """Persistent thread for ``model.generate()`` calls. - This class ensures all inference runs on the **same thread**, matching what - vLLM does with its engine loop. Both streaming and non-streaming requests - submit work here. + ``torch.compile`` with CUDA graphs stores state in thread-local storage. + All inference must run on the same thread to avoid corrupted graph state. """ def __init__(self): @@ -411,20 +459,265 @@ def __init__(self): def _run(self): while True: - fn, args, kwargs, future = self._queue.get() + fn, args, kwargs, future, loop = self._queue.get() try: result = fn(*args, **kwargs) - future.set_result(result) + if loop is not None: + loop.call_soon_threadsafe(future.set_result, result) + else: + future.set_result(result) except Exception as e: - future.set_exception(e) + if loop is not None: + loop.call_soon_threadsafe(future.set_exception, e) + else: + future.set_exception(e) def submit(self, fn, *args, **kwargs) -> Future: - """Submit a callable to run on the inference thread. Returns a Future.""" + """Submit a callable to the inference thread. Returns a blocking Future.""" future: Future = Future() - self._queue.put((fn, args, kwargs, future)) + self._queue.put((fn, args, kwargs, future, None)) + return future + + def async_submit(self, fn, *args, **kwargs) -> asyncio.Future: + """Submit a callable to the inference thread. Returns an awaitable asyncio.Future.""" + loop = asyncio.get_running_loop() + future = loop.create_future() + self._queue.put((fn, args, kwargs, future, loop)) return future +# --------------------------------------------------------------------------- +# Generation managers +# --------------------------------------------------------------------------- + + +class BaseGenerateManager(ABC): + """Base class for generation managers. + + Subclasses: + - :class:`GenerateManager` — sequential ``model.generate()`` on a persistent thread. + - :class:`CBGenerateManager` — continuous batching with paged attention. + """ + + @abstractmethod + def generate_streaming(self, model, processor, inputs, gen_config, request_id=None): + """Start streaming generation. + + Returns ``(queue, context)`` where *queue* yields ``str | _StreamError | None`` + and *context* exposes ``.total_tokens`` and ``.cancel()``. + """ + + @abstractmethod + def generate_non_streaming(self, model, processor, inputs, gen_config, request_id=None): + """Run generation to completion. Returns ``(text, input_len, generated_ids)``.""" + + @abstractmethod + def stop(self): + """Stop the generation manager and free resources.""" + + +class GenerateManager(BaseGenerateManager): + """Sequential generation via ``model.generate()`` on a persistent thread.""" + + def __init__(self): + self._thread = InferenceThread() + + def generate_streaming(self, model, processor, inputs, gen_config, request_id=None): + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + streamer = DirectStreamer(processor._tokenizer, loop, queue, skip_special_tokens=True) + gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor} + + def _run(): + try: + model.generate(**gen_kwargs) + except _GenerationCancelled: + loop.call_soon_threadsafe(queue.put_nowait, None) + except Exception as e: + loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e))) + + self.submit(_run) + return queue, streamer + + async def generate_non_streaming(self, model, processor, inputs, gen_config, request_id=None): + sequences = await self.async_submit( + model.generate, **inputs, generation_config=gen_config, tokenizer=processor + ) + input_len = inputs["input_ids"].shape[-1] + generated_ids = sequences[0, input_len:] + text = processor.decode(generated_ids, skip_special_tokens=True) + return text, input_len, generated_ids + + def submit(self, fn, *args, **kwargs): + """Submit a callable to the inference thread. Returns a blocking Future.""" + return self._thread.submit(fn, *args, **kwargs) + + def async_submit(self, fn, *args, **kwargs): + """Submit a callable to the inference thread. Returns an awaitable asyncio.Future.""" + return self._thread.async_submit(fn, *args, **kwargs) + + def stop(self): + pass # inference thread is a daemon + + +class CBGenerateManager(BaseGenerateManager): + """Continuous batching generation via paged attention. + + Translates between the handler's text-level asyncio.Queue and CB's + token-level interface. Per-request: ``max_new_tokens``, ``eos_token_id``. + + The CB manager is initialized lazily on the first request via + :meth:`ensure_initialized`, using that request's ``gen_config`` for shared + sampling params (temperature, top_p, do_sample). + + .. todo:: Remove :meth:`init_cb` when CB supports per-request + generation config. At that point, ``gen_config`` can be passed directly + to ``add_request`` and the CB manager no longer needs a shared config. + """ + + def __init__(self): + self._cb = None + + def init_cb(self, model, gen_config): + """Initialize the CB manager on first call with the request's generation config. + + .. todo:: Remove when CB supports per-request generation config. + """ + if self._cb is not None: + return + from transformers import LogitsProcessorList + + self._cb = model.init_continuous_batching(generation_config=gen_config) + # TODO: logits processors should be fixed in CB and correctly applied + self._cb.logit_processor = LogitsProcessorList() + self._cb.start() + + def generate_streaming(self, model, processor, inputs, gen_config, request_id=None): + loop = asyncio.get_running_loop() + text_queue: asyncio.Queue = asyncio.Queue() + + input_ids = inputs["input_ids"] + request_id = self._cb.add_request( + input_ids, + request_id=request_id, + max_new_tokens=gen_config.max_new_tokens, + min_new_tokens=gen_config.min_new_tokens, + streaming=True, + eos_token_id=gen_config.eos_token_id, + ) + streamer = CBStreamer(self._cb, request_id, processor._tokenizer, loop, text_queue) + + # Consume CB outputs and decode tokens into the SSE text queue. + # It's a coroutine on the event loop (via async_request_id_iter) + # to avoid spawning a thread per concurrent request. + async def _read_and_decode(): + try: + async for output in self._cb.async_request_id_iter(request_id): + streamer.put(output) + if output.is_finished(): + break + streamer.end() + except Exception as e: + text_queue.put_nowait(_StreamError(str(e))) + + asyncio.ensure_future(_read_and_decode()) + return text_queue, streamer + + async def generate_non_streaming(self, model, processor, inputs, gen_config, request_id=None): + """Non-streaming CB generation, fully async (no per-request thread). + + Uses ``register_async_future`` — the dispatcher resolves a single + asyncio.Future when the result arrives. No per-request queue, no polling + loop — scales to thousands of concurrent requests with minimal event loop + overhead. + """ + input_ids = inputs["input_ids"] + input_len = len(input_ids) + + # Register future BEFORE add_request to avoid race with fast completion + request_id = request_id or f"cb_{id(inputs)}" + future = self._cb.register_async_future(request_id) + + self._cb.add_request( + input_ids, + request_id=request_id, + max_new_tokens=gen_config.max_new_tokens, + min_new_tokens=gen_config.min_new_tokens, + streaming=False, + eos_token_id=gen_config.eos_token_id, + ) + result = await future + if result is None: + raise RuntimeError(f"CB manager stopped before producing a result for {request_id}") + generated_ids = result.generated_tokens + text = processor.decode(generated_ids, skip_special_tokens=True) + return text, input_len, generated_ids + + + @property + def scheduler(self): + """The CB scheduler (for testing/monitoring).""" + return self._cb.batch_processor.scheduler + + def stop(self): + if self._cb is not None: + self._cb.stop(block=True, timeout=2) + + +# --------------------------------------------------------------------------- +# Generation state (shared across handlers) +# --------------------------------------------------------------------------- + + +class GenerationState: + """Shared generation state across all handlers. + + Manages per-model :class:`GenerateManager` instances (each with its own + :class:`InferenceThread` so different models can run concurrently while + ``torch.compile`` / CUDA graphs require same-model-same-thread) and a + single :class:`CBGenerateManager` for continuous batching. + """ + + def __init__(self, continuous_batching: bool = False): + self._continuous_batching = continuous_batching + self._generate_managers: dict[str, GenerateManager] = {} + self._cb_manager: CBGenerateManager | None = None + self._cb_model_id: str | None = None + + def use_continuous_batching(self, model, modality: Modality) -> bool: + """Check if CB can be used. Logs a warning on fallback.""" + if not self._continuous_batching: + return False + can = hasattr(model, "init_continuous_batching") and modality == Modality.LLM + if not can: + logger.warning_once( + f"{model.__class__.__name__} does not support continuous batching. " + "Falling back to sequential generation." + ) + return can + + def get_manager(self, model_id: str, use_cb: bool) -> BaseGenerateManager: + """Return a per-model generation manager, lazily created on first request.""" + if use_cb: + if self._cb_model_id != model_id: + if self._cb_manager is not None: + self._cb_manager.stop() + self._cb_manager = None + if self._cb_manager is None: + self._cb_manager = CBGenerateManager() + self._cb_model_id = model_id + return self._cb_manager + if model_id not in self._generate_managers: + self._generate_managers[model_id] = GenerateManager() + return self._generate_managers[model_id] + + def shutdown(self): + """Stop any active generation managers.""" + if self._cb_manager is not None: + self._cb_manager.stop() + self._cb_manager = None + + # --------------------------------------------------------------------------- # Base handler # --------------------------------------------------------------------------- @@ -433,22 +726,22 @@ def submit(self, fn, *args, **kwargs) -> Future: class BaseHandler: """Shared logic for chat completion and responses handlers. - Subclasses implement ``_streaming`` and ``_non_streaming`` for their - specific SSE / JSON formats, plus ``_validate_request``. + Provides model resolution, generation config building, and SSE formatting. + Generation is delegated to the shared :class:`GenerationState`. """ def __init__( self, model_manager, + generation_state: GenerationState, force_model=None, force_processor=None, - inference_thread=None, compile=False, ): self.model_manager = model_manager + self.generation_state = generation_state self.force_model = force_model self.force_processor = force_processor - self._inference_thread = inference_thread or InferenceThread() self._compile = compile @staticmethod @@ -459,7 +752,10 @@ def chunk_to_sse(chunk) -> str: return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" def _resolve_model(self, body: dict): - """Apply force_model, load model + processor. Returns (model_id, model, processor).""" + """Apply force_model, load model + processor. + + Returns ``(model_id, model, processor)``. + """ if self.force_model is not None: body["model"] = self.force_model @@ -469,7 +765,7 @@ def _resolve_model(self, body: dict): return model_id, model, processor - def _build_generation_config(self, body: dict, model_generation_config, processor=None): + def _build_generation_config(self, body: dict, model_generation_config, processor=None, use_cb=False): """Build a GenerationConfig from shared params (temperature, top_p, seed, generation_config JSON). Subclasses should call ``super()._build_generation_config(...)`` then apply @@ -505,37 +801,11 @@ def _build_generation_config(self, body: dict, model_generation_config, processo if self._compile and generation_config.cache_implementation is None: generation_config.cache_implementation = "static" - return generation_config - - def _start_streaming(self, model, processor, inputs, gen_config): - """Set up DirectStreamer + queue, submit generate to inference thread. + # CB manages its own paged KV cache + if use_cb: + generation_config.use_cache = False - Returns ``(queue, streamer)`` — caller reads from queue to build SSE events. - """ - loop = asyncio.get_running_loop() - queue: asyncio.Queue = asyncio.Queue() - streamer = DirectStreamer(processor._tokenizer, loop, queue, skip_special_tokens=True) - gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor} - - def _run(): - try: - model.generate(**gen_kwargs) - except Exception as e: - loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e))) - - self._inference_thread.submit(_run) - return queue, streamer - - def _generate_non_streaming(self, model, processor, inputs, gen_config): - """Run generate on the inference thread, decode output. Returns ``(text, input_len, generated_ids)``.""" - future = self._inference_thread.submit( - model.generate, **inputs, generation_config=gen_config, tokenizer=processor - ) - sequences = future.result() - input_len = inputs["input_ids"].shape[-1] - generated_ids = sequences[0, input_len:] - text = processor.decode(generated_ids, skip_special_tokens=True) - return text, input_len, generated_ids + return generation_config # --------------------------------------------------------------------------- From 962d039172a6235b74645e5cee28e8aa4a3d7c50 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 26 Mar 2026 17:13:13 +0000 Subject: [PATCH 24/64] cb tests --- tests/cli/test_serve_refactored.py | 310 +++++++++++++++++++++++++++-- 1 file changed, 298 insertions(+), 12 deletions(-) diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py index 3b04dafaa6a7..5f507c5dbb66 100644 --- a/tests/cli/test_serve_refactored.py +++ b/tests/cli/test_serve_refactored.py @@ -205,7 +205,7 @@ class TestBuildGenerationConfig(unittest.TestCase): def _make_handler(self): from transformers.cli.serving.chat_completion import ChatCompletionHandler - return ChatCompletionHandler(model_manager=MagicMock()) + from transformers.cli.serving.utils import GenerationState; return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_max_tokens(self): from transformers import GenerationConfig @@ -282,7 +282,7 @@ class TestValidation(unittest.TestCase): def _make_handler(self): from transformers.cli.serving.chat_completion import ChatCompletionHandler - return ChatCompletionHandler(model_manager=MagicMock()) + from transformers.cli.serving.utils import GenerationState; return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_valid_request_passes(self): handler = self._make_handler() @@ -375,7 +375,7 @@ class TestChunkSSE(unittest.TestCase): def _make_handler(self): from transformers.cli.serving.chat_completion import ChatCompletionHandler - return ChatCompletionHandler(model_manager=MagicMock()) + from transformers.cli.serving.utils import GenerationState; return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_build_chunk_sse_content(self): handler = self._make_handler() @@ -570,6 +570,7 @@ def setUpClass(cls): from transformers.cli.serving.model_manager import ModelManager from transformers.cli.serving.response import ResponseHandler from transformers.cli.serving.server import build_server + from transformers.cli.serving.transcription import TranscriptionHandler cls.model_manager = MagicMock(spec=ModelManager) cls.model_manager.get_gen_models.return_value = [ @@ -577,7 +578,8 @@ def setUpClass(cls): ] cls.chat_handler = MagicMock(spec=ChatCompletionHandler) cls.response_handler = MagicMock(spec=ResponseHandler) - cls.app = build_server(cls.model_manager, cls.chat_handler, cls.response_handler) + cls.transcription_handler = MagicMock(spec=TranscriptionHandler) + cls.app = build_server(cls.model_manager, cls.chat_handler, cls.response_handler, cls.transcription_handler) def _run(self, coro): return asyncio.get_event_loop().run_until_complete(coro) @@ -646,7 +648,7 @@ class TestChatCompletion(unittest.TestCase): def setUpClass(cls): from transformers.cli.serve_refactored import Serve - cls.serve = Serve(port=cls.PORT, non_blocking=True, log_level="warning") + cls.serve = Serve(port=cls.PORT, non_blocking=True) import requests for _ in range(30): @@ -917,6 +919,36 @@ def stream_in_thread(index): self.assertIsNotNone(results[i]) self.assertTrue(len(results[i]) > 0, f"Request {i} produced empty output") + def test_request_cancellation(self): + """Closing a stream early doesn't crash and the server stays healthy.""" + import requests as req + + with req.post( + f"http://localhost:{self.PORT}/v1/chat/completions", + json={ + "model": self.MODEL, + "stream": True, + "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}], + "max_tokens": 500, + }, + stream=True, + timeout=30, + ) as resp: + self.assertEqual(resp.status_code, 200) + chunks_read = 0 + for _ in resp.iter_lines(): + chunks_read += 1 + if chunks_read >= 3: + break + + # Server should still be healthy and serve subsequent requests + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hi"}], + max_tokens=10, + ) + self.assertIsNotNone(resp.choices[0].message.content) + # --------------------------------------------------------------------------- # 8. Unit tests — Response handler @@ -928,7 +960,7 @@ class TestResponseInputConversion(unittest.TestCase): def _make_handler(self): from transformers.cli.serving.response import ResponseHandler - return ResponseHandler(model_manager=MagicMock()) + from transformers.cli.serving.utils import GenerationState; return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_string_input(self): handler = self._make_handler() @@ -976,7 +1008,7 @@ class TestResponseValidation(unittest.TestCase): def _make_handler(self): from transformers.cli.serving.response import ResponseHandler - return ResponseHandler(model_manager=MagicMock()) + from transformers.cli.serving.utils import GenerationState; return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_unsupported_fields_rejected(self): from fastapi import HTTPException @@ -997,7 +1029,7 @@ class TestResponseGenerationConfig(unittest.TestCase): def _make_handler(self): from transformers.cli.serving.response import ResponseHandler - return ResponseHandler(model_manager=MagicMock()) + from transformers.cli.serving.utils import GenerationState; return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_max_output_tokens(self): from transformers import GenerationConfig @@ -1097,7 +1129,7 @@ class TestResponsesIntegration(unittest.TestCase): def setUpClass(cls): from transformers.cli.serve_refactored import Serve - cls.serve = Serve(port=cls.PORT, non_blocking=True, log_level="warning") + cls.serve = Serve(port=cls.PORT, non_blocking=True) import requests for _ in range(30): @@ -1359,7 +1391,7 @@ def setUpClass(cls): from transformers.cli.serve_refactored import Serve - cls.serve = Serve(port=cls.PORT, non_blocking=True, log_level="warning") + cls.serve = Serve(port=cls.PORT, non_blocking=True) for _ in range(30): try: if req.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: @@ -1652,7 +1684,7 @@ def setUpClass(cls): from transformers.cli.serve_refactored import Serve - cls.serve = Serve(port=cls.PORT, non_blocking=True, log_level="warning") + cls.serve = Serve(port=cls.PORT, non_blocking=True) for _ in range(60): try: if req.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: @@ -1726,7 +1758,7 @@ def setUpClass(cls): from transformers.cli.serve_refactored import Serve - cls.serve = Serve(port=cls.PORT, non_blocking=True, log_level="warning") + cls.serve = Serve(port=cls.PORT, non_blocking=True) for _ in range(30): try: if req.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: @@ -1813,3 +1845,257 @@ def test_transcription_missing_file(self): timeout=30, ) self.assertNotEqual(resp.status_code, 200) + + +# --------------------------------------------------------------------------- +# Continuous Batching integration tests +# --------------------------------------------------------------------------- + + +@unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") +@require_openai +class TestContinuousBatchingChatCompletion(unittest.TestCase): + """Integration tests for /v1/chat/completions with continuous batching.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + PORT = 8891 + + @classmethod + def setUpClass(cls): + from transformers.cli.serve_refactored import Serve + + cls.serve = Serve( + force_model=cls.MODEL, + port=cls.PORT, + device="cuda:0", + continuous_batching=True, + attn_implementation="sdpa", + default_seed=42, + non_blocking=True, + ) + import requests + + for _ in range(30): + try: + if requests.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: + break + except Exception: + pass + time.sleep(2) + + cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + def test_streaming(self): + """Streaming chat completion with CB produces text.""" + text = "" + for chunk in self.client.chat.completions.create( + model=self.MODEL, + messages=[ + {"role": "system", "content": "You are a sports assistant designed to craft sports programs."}, + {"role": "user", "content": "Tell me what you can do."}, + ], + stream=True, + max_tokens=30, + ): + if chunk.choices[0].delta.content: + text += chunk.choices[0].delta.content + self.assertTrue(len(text) > 0) + + def test_non_streaming(self): + """Non-streaming chat completion with CB returns a full response.""" + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hello"}], + max_tokens=20, + ) + self.assertIsNotNone(resp.choices[0].message.content) + self.assertTrue(len(resp.choices[0].message.content) > 0) + + def test_multi_turn(self): + """Multi-turn conversation works with CB.""" + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[ + {"role": "user", "content": "My name is Alice"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "What is my name?"}, + ], + max_tokens=20, + ) + self.assertIn("Alice", resp.choices[0].message.content) + + def test_request_cancellation(self): + """Opening a stream and closing it early triggers CB cancellation.""" + import requests as req + + request_id = "test-cb-cancel" + + # Open a streaming request and close after a few chunks + with req.post( + f"http://localhost:{self.PORT}/v1/chat/completions", + headers={"X-Request-ID": request_id}, + json={ + "model": self.MODEL, + "stream": True, + "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}], + }, + stream=True, + timeout=30, + ) as resp: + self.assertEqual(resp.status_code, 200) + chunks_read = 0 + for _ in resp.iter_lines(): + chunks_read += 1 + if chunks_read >= 3: + break + + # Poll for cancellation in the CB scheduler + scheduler = self.serve._generation_state._cb_manager.scheduler + deadline = time.time() + 8.0 + while time.time() < deadline: + if scheduler.request_is_cancelled(request_id): + break + time.sleep(0.1) + + self.assertTrue( + scheduler.request_is_cancelled(request_id), + f"Request {request_id} not cancelled in scheduler after stream close.", + ) + + # Server should still be healthy and serve subsequent requests + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hi"}], + max_tokens=10, + ) + self.assertIsNotNone(resp.choices[0].message.content) + + +@unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") +@require_openai +class TestContinuousBatchingResponses(unittest.TestCase): + """Integration tests for /v1/responses with continuous batching.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + PORT = 8893 + + @classmethod + def setUpClass(cls): + from transformers.cli.serve_refactored import Serve + + cls.serve = Serve( + force_model=cls.MODEL, + port=cls.PORT, + device="cuda:0", + continuous_batching=True, + attn_implementation="sdpa", + default_seed=42, + non_blocking=True, + ) + import requests + + for _ in range(30): + try: + if requests.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: + break + except Exception: + pass + time.sleep(2) + + cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + def test_streaming(self): + """Streaming response with CB produces text.""" + text = "" + stream = self.client.responses.create( + model=self.MODEL, + input="Say hello in one sentence.", + stream=True, + max_output_tokens=30, + ) + for event in stream: + if event.type == "response.output_text.delta": + text += event.delta + self.assertTrue(len(text) > 0) + + def test_non_streaming(self): + """Non-streaming response with CB returns text.""" + resp = self.client.responses.create( + model=self.MODEL, + input="Say hello in one sentence.", + stream=False, + max_output_tokens=30, + ) + content = resp.output[0].content[0].text + self.assertTrue(len(content) > 0) + + def test_multi_turn(self): + """Multi-turn conversation works with CB via Responses API.""" + resp = self.client.responses.create( + model=self.MODEL, + input=[ + {"role": "user", "content": "My name is Alice"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "What is my name?"}, + ], + stream=False, + max_output_tokens=20, + ) + content = resp.output[0].content[0].text + self.assertIn("Alice", content) + + def test_request_cancellation(self): + """Opening a stream and closing it early triggers CB cancellation.""" + import requests as req + + request_id = "test-cb-resp-cancel" + + with req.post( + f"http://localhost:{self.PORT}/v1/responses", + headers={"X-Request-ID": request_id}, + json={ + "model": self.MODEL, + "stream": True, + "input": "Count slowly so I can cancel you.", + "max_output_tokens": 500, + }, + stream=True, + timeout=30, + ) as resp: + self.assertEqual(resp.status_code, 200) + # Read enough data to ensure CB generation has started, then close. + received = b"" + for chunk in resp.iter_content(chunk_size=512): + received += chunk + if b"output_text.delta" in received: + break + + # Poll for cancellation in the CB scheduler + scheduler = self.serve._generation_state._cb_manager.scheduler + deadline = time.time() + 8.0 + while time.time() < deadline: + if scheduler.request_is_cancelled(request_id): + break + time.sleep(0.1) + + self.assertTrue( + scheduler.request_is_cancelled(request_id), + f"Request {request_id} not cancelled in scheduler after stream close.", + ) + + # Server should still serve subsequent requests + resp = self.client.responses.create( + model=self.MODEL, + input="Say hi", + stream=False, + max_output_tokens=10, + ) + self.assertTrue(len(resp.output[0].content[0].text) > 0) From 13945c1bd0bd1a6f8bdbbf6f1f6508b0244fac67 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 26 Mar 2026 21:05:11 +0000 Subject: [PATCH 25/64] typing + review --- .../cli/serving/chat_completion.py | 138 +++++--- src/transformers/cli/serving/model_manager.py | 59 +++- src/transformers/cli/serving/response.py | 86 +++-- src/transformers/cli/serving/server.py | 4 +- src/transformers/cli/serving/transcription.py | 29 +- src/transformers/cli/serving/utils.py | 326 ++++++++++-------- 6 files changed, 406 insertions(+), 236 deletions(-) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index ae57207224a8..ccc1d9bf078e 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,42 +17,66 @@ Supports streaming (SSE via DirectStreamer) and non-streaming (JSON) responses. """ -from __future__ import annotations - import asyncio import time from collections.abc import AsyncGenerator from typing import TYPE_CHECKING -import torch - - -if TYPE_CHECKING: - from transformers import PreTrainedTokenizerFast, ProcessorMixin - from fastapi import HTTPException from fastapi.responses import JSONResponse, StreamingResponse from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk +from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming from openai.types.completion_usage import CompletionUsage -from transformers import GenerationConfig, PreTrainedModel - from ...utils import logging from .utils import ( - UNUSED_CHAT_COMPLETION_FIELDS, BaseGenerateManager, BaseHandler, ToolCallParser, - TransformersCompletionCreateParamsStreaming, _StreamError, detect_tool_format, - get_processor_inputs_from_messages, ) +if TYPE_CHECKING: + from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin + + +class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): + generation_config: str + processor: str + + +# Fields accepted by the OpenAI schema but not yet supported. +# Receiving these raises an error to avoid silent misbehaviour. +# NOTE: "stop" is NOT in this set — we map it to stop_strings. +UNUSED_CHAT_COMPLETION_FIELDS = { + "audio", + "function_call", + "functions", + "logprobs", + "max_completion_tokens", + "metadata", + "modalities", + "n", + "parallel_tool_calls", + "prediction", + "presence_penalty", + "reasoning_effort", + "response_format", + "service_tier", + "store", + "stream_options", + "tool_choice", + "top_logprobs", + "user", + "web_search_options", +} + + logger = logging.get_logger(__name__) @@ -65,7 +89,15 @@ class ChatCompletionHandler(BaseHandler): # ----- entry point ----- async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: - """Validate the request, load the model, and dispatch to streaming or non-streaming.""" + """Validate the request, load the model, and dispatch to streaming or non-streaming. + + Args: + body (`dict`): The raw JSON request body (OpenAI chat completion format). + request_id (`str`): Unique request identifier (from header or auto-generated). + + Returns: + `StreamingResponse | JSONResponse`: SSE stream or JSON depending on ``body["stream"]``. + """ self._validate_request(body) messages = body["messages"] @@ -78,7 +110,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse modality = self.model_manager.get_model_modality(model, processor=processor) use_cb = self.generation_state.use_continuous_batching(model, modality) gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb) - processor_inputs = get_processor_inputs_from_messages(messages, modality) + processor_inputs = self.get_processor_inputs_from_messages(messages, modality) if use_cb: # CB handles device placement internally — don't create tensors or move @@ -111,7 +143,8 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse # 2. force generation to produce valid JSON matching the tool's parameter schema tool_format = detect_tool_format(model) if body.get("tools") else None - if body.get("stream"): + streaming = body.get("stream") + if streaming: return self._streaming( request_id, model, @@ -122,27 +155,28 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse gen_manager=gen_manager, tool_format=tool_format, ) - return await self._non_streaming( - request_id, - model, - processor, - model_id, - inputs, - gen_config, - gen_manager=gen_manager, - tool_format=tool_format, - ) + else: + return await self._non_streaming( + request_id, + model, + processor, + model_id, + inputs, + gen_config, + gen_manager=gen_manager, + tool_format=tool_format, + ) # ----- streaming ----- def _streaming( self, request_id: str, - model: PreTrainedModel, - processor: ProcessorMixin | PreTrainedTokenizerFast, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", model_id: str, - inputs: dict[str, torch.Tensor], - gen_config: GenerationConfig, + inputs: dict, + gen_config: "GenerationConfig", gen_manager: BaseGenerateManager, tool_format: dict | None = None, ) -> StreamingResponse: @@ -215,11 +249,11 @@ async def sse_gen() -> AsyncGenerator[str, None]: async def _non_streaming( self, request_id: str, - model: PreTrainedModel, - processor: ProcessorMixin | PreTrainedTokenizerFast, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", model_id: str, - inputs: dict[str, torch.Tensor], - gen_config: GenerationConfig, + inputs: dict, + gen_config: "GenerationConfig", gen_manager: BaseGenerateManager, tool_format: dict | None = None, ) -> JSONResponse: @@ -280,8 +314,9 @@ def _validate_request(self, body: dict) -> None: if unused: raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") - def _build_generation_config(self, body: dict, model_generation_config, processor=None, use_cb=False): - """Chat Completions params on top of base config.""" + def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None, use_cb: bool = False): + """Apply Chat Completions params (``max_tokens``, ``frequency_penalty``, ``logit_bias``, + ``stop``) on top of the base generation config.""" generation_config = super()._build_generation_config(body, model_generation_config, processor, use_cb=use_cb) if body.get("max_tokens") is not None: @@ -306,7 +341,19 @@ def _build_completion( usage: CompletionUsage | None = None, tool_calls: list[dict] | None = None, ) -> dict: - """Build a non-streaming ChatCompletion response dict.""" + """Build a non-streaming ChatCompletion response dict. + + Args: + request_id (`str`): Unique request identifier. + content (`str`): The generated text. + model_id (`str`): Model ID to include in the response. + finish_reason (`str`): Why generation stopped (``"stop"``, ``"length"``, ``"tool_calls"``). + usage (`CompletionUsage`, *optional*): Token usage statistics. + tool_calls (`list[dict]`, *optional*): Parsed tool calls, if any. + + Returns: + `dict`: Serialized ``ChatCompletion`` ready for JSON response. + """ message = ChatCompletionMessage(content=content, role="assistant", tool_calls=tool_calls) result = ChatCompletion( id=request_id, @@ -334,7 +381,20 @@ def _build_chunk_sse( tool_calls: list | None = None, usage: CompletionUsage | None = None, ) -> str: - """Build a streaming ChatCompletionChunk and format it as an SSE event.""" + """Build a streaming ``ChatCompletionChunk`` and format it as an SSE ``data:`` line. + + Args: + request_id (`str`): Unique request identifier. + content (`str`, *optional*): Text content delta. + model (`str`, *optional*): Model ID. + role (`str`, *optional*): Role (only sent in the first chunk). + finish_reason (`str`, *optional*): Set on the final chunk. + tool_calls (`list`, *optional*): Tool call deltas. + usage (`CompletionUsage`, *optional*): Token usage (sent with the final chunk). + + Returns: + `str`: A formatted SSE event string. + """ chunk = ChatCompletionChunk( id=request_id, created=int(time.time()), diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 23e57d31e5ed..273b2419c978 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,13 +15,12 @@ Model loading, caching, and lifecycle management. """ -from __future__ import annotations - import asyncio import gc import json import threading from functools import lru_cache +from typing import Callable from typing import TYPE_CHECKING from huggingface_hub import scan_cache_dir @@ -52,9 +51,9 @@ class TimedModel: def __init__( self, - model: PreTrainedModel, + model: "PreTrainedModel", timeout_seconds: int, - processor: ProcessorMixin | PreTrainedTokenizerFast | None = None, + processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None, ): self.model = model self._name_or_path = str(model.name_or_path) @@ -161,7 +160,7 @@ def get_quantization_config(self) -> BitsAndBytesConfig | None: def _load_processor( self, model_id_and_revision: str, processor_id: str | None = None - ) -> ProcessorMixin | PreTrainedTokenizerFast: + ) -> "ProcessorMixin | PreTrainedTokenizerFast": """Load a processor, trying AutoProcessor first then AutoTokenizer. Args: @@ -185,8 +184,17 @@ def _load_processor( except OSError: raise OSError(f"Failed to load processor for {model_id} with AutoProcessor and AutoTokenizer.") - def _load_model(self, model_id_and_revision: str, tqdm_class=None, progress_callback=None) -> PreTrainedModel: - """Load a model. GGUF files are detected by the `.gguf` extension and loaded via llama.cpp.""" + def _load_model(self, model_id_and_revision: str, tqdm_class: type | None = None, progress_callback: Callable | None = None) -> "PreTrainedModel": + """Load a model. GGUF files are detected by the ``.gguf`` extension and loaded via llama.cpp. + + Args: + model_id_and_revision (`str`): Model ID in ``'model_id@revision'`` format. + tqdm_class (*optional*): tqdm subclass for progress bars during ``from_pretrained``. + progress_callback (`callable`, *optional*): Called with progress dicts during loading. + + Returns: + `PreTrainedModel`: The loaded model. + """ import torch from transformers import AutoConfig @@ -227,9 +235,9 @@ def load_model_and_processor( self, model_id_and_revision: str, processor_id: str | None = None, - progress_callback=None, - tqdm_class=None, - ) -> tuple[PreTrainedModel, ProcessorMixin | PreTrainedTokenizerFast]: + progress_callback: Callable | None = None, + tqdm_class: type | None = None, + ) -> "tuple[PreTrainedModel, ProcessorMixin | PreTrainedTokenizerFast]": """Load a model (or return it from cache), resetting its inactivity timer. Args: @@ -275,7 +283,11 @@ async def load_model_streaming(self, model_id_and_revision: str): 2. Load already in progress → join existing subscriber stream 3. First request → start loading, broadcast to all subscribers - Yields SSE ``data: ...`` lines. + Args: + model_id_and_revision (`str`): Model ID in ``'model_id@revision'`` format. + + Yields: + `str`: SSE ``data: ...`` lines with progress updates. """ mid = model_id_and_revision queue: asyncio.Queue[str | None] = asyncio.Queue() @@ -357,9 +369,18 @@ def shutdown(self) -> None: @staticmethod def get_model_modality( - model: PreTrainedModel, processor: ProcessorMixin | PreTrainedTokenizerFast | None = None + model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None ) -> Modality: - """Detect whether a model is an LLM or VLM based on its architecture.""" + """Detect whether a model is an LLM or VLM based on its architecture. + + Args: + model (`PreTrainedModel`): The loaded model. + processor (`ProcessorMixin | PreTrainedTokenizerFast`, *optional*): + If a plain tokenizer (not a multi-modal processor), short-circuits to LLM. + + Returns: + `Modality`: The detected modality (``Modality.LLM`` or ``Modality.VLM``). + """ if processor is not None and isinstance(processor, PreTrainedTokenizerBase): return Modality.LLM @@ -379,7 +400,15 @@ def get_model_modality( @staticmethod @lru_cache def get_gen_models(cache_dir: str | None = None) -> list[dict]: - """List generative models (LLMs and VLMs) available in the HuggingFace cache.""" + """List generative models (LLMs and VLMs) available in the HuggingFace cache. + + Args: + cache_dir (`str`, *optional*): Path to the HuggingFace cache directory. + Defaults to the standard cache location. + + Returns: + `list[dict]`: OpenAI-compatible model list entries with ``id``, ``object``, etc. + """ from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index da626281d960..322a60d8bd61 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,8 +17,6 @@ Supports streaming (SSE) and non-streaming (JSON) responses. """ -from __future__ import annotations - import asyncio import time from collections.abc import AsyncGenerator @@ -54,7 +52,6 @@ ToolCallParser, _StreamError, detect_tool_format, - get_processor_inputs_from_messages, ) @@ -87,7 +84,15 @@ class ResponseHandler(BaseHandler): # ----- entry point ----- async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: - """Validate, load model, dispatch to streaming or non-streaming.""" + """Validate, load model, dispatch to streaming or non-streaming. + + Args: + body (`dict`): The raw JSON request body (OpenAI Responses API format). + request_id (`str`): Unique request identifier (from header or auto-generated). + + Returns: + `StreamingResponse | JSONResponse`: SSE stream or JSON depending on ``body["stream"]``. + """ self._validate_request(body) model_id, model, processor = self._resolve_model(body) @@ -99,7 +104,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse # 1. Normalize Responses API input (string/list/dict + instructions) → standard messages list # 2. Transform message content for the HF processor (VLM image handling, text joining, etc.) messages = self._input_to_messages(body) - processor_inputs = get_processor_inputs_from_messages(messages, modality) + processor_inputs = self.get_processor_inputs_from_messages(messages, modality) if use_cb: # CB handles device placement internally — don't create tensors or move @@ -127,7 +132,8 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse gen_manager.init_cb(model, gen_config) tool_format = detect_tool_format(model) if body.get("tools") else None - if body.get("stream", True): + streaming = body.get("stream", True) + if streaming: return self._streaming( request_id, model, @@ -139,23 +145,34 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse gen_manager=gen_manager, tool_format=tool_format, ) - return await self._non_streaming( - request_id, - model, - processor, - model_id, - body, - inputs, - gen_config, - gen_manager=gen_manager, - tool_format=tool_format, - ) + else: + return await self._non_streaming( + request_id, + model, + processor, + model_id, + body, + inputs, + gen_config, + gen_manager=gen_manager, + tool_format=tool_format, + ) # ----- input conversion ----- @staticmethod def _input_to_messages(body: dict) -> list[dict]: - """Convert the Responses API ``input`` field to a list of chat messages.""" + """Convert the Responses API ``input`` field to a list of chat messages. + + Handles string, list, and dict inputs. If ``instructions`` is provided, it is + prepended as a system message (or replaces an existing one). + + Args: + body (`dict`): The raw request body containing ``input`` and optionally ``instructions``. + + Returns: + `list[dict]`: Standard OpenAI-format chat messages. + """ inp = body["input"] instructions = body.get("instructions") @@ -184,12 +201,12 @@ def _input_to_messages(body: dict) -> list[dict]: def _streaming( self, request_id: str, - model: PreTrainedModel, - processor: ProcessorMixin | PreTrainedTokenizerFast, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", model_id: str, body: dict, inputs: dict, - gen_config: GenerationConfig, + gen_config: "GenerationConfig", gen_manager: BaseGenerateManager, tool_format: dict | None = None, ) -> StreamingResponse: @@ -405,7 +422,7 @@ async def event_stream() -> AsyncGenerator[str, None]: # 6. Completed all_output = [msg_item] + list(tool_calls) - usage = _make_usage(input_len, streamer.total_tokens) + usage = compute_usage(input_len, streamer.total_tokens) yield self.chunk_to_sse( ResponseCompletedEvent( type="response.completed", @@ -427,12 +444,12 @@ async def event_stream() -> AsyncGenerator[str, None]: async def _non_streaming( self, request_id: str, - model: PreTrainedModel, - processor: ProcessorMixin | PreTrainedTokenizerFast, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", model_id: str, body: dict, inputs: dict, - gen_config: GenerationConfig, + gen_config: "GenerationConfig", gen_manager: BaseGenerateManager, tool_format: dict | None = None, ) -> JSONResponse: @@ -472,7 +489,7 @@ async def _non_streaming( ) ) - usage = _make_usage(input_len, output_tokens) + usage = compute_usage(input_len, output_tokens) response = Response( id=resp_id, created_at=created_at, @@ -496,8 +513,8 @@ def _validate_request(self, body: dict) -> None: if unused: raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") - def _build_generation_config(self, body: dict, model_generation_config, processor=None, use_cb=False): - """Responses API params on top of base config.""" + def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None, use_cb: bool = False): + """Apply Responses API params (``max_output_tokens``) on top of the base generation config.""" generation_config = super()._build_generation_config(body, model_generation_config, processor, use_cb=use_cb) if body.get("max_output_tokens") is not None: @@ -506,7 +523,16 @@ def _build_generation_config(self, body: dict, model_generation_config, processo return generation_config -def _make_usage(input_tokens: int, output_tokens: int) -> ResponseUsage: +def compute_usage(input_tokens: int, output_tokens: int) -> ResponseUsage: + """Build a ``ResponseUsage`` object for a Responses API reply. + + Args: + input_tokens (`int`): Number of prompt tokens. + output_tokens (`int`): Number of generated tokens. + + Returns: + `ResponseUsage`: Usage statistics with zero-filled detail fields. + """ return ResponseUsage( input_tokens=input_tokens, output_tokens=output_tokens, diff --git a/src/transformers/cli/serving/server.py b/src/transformers/cli/serving/server.py index 7be3bb0c4d61..587e50981b6e 100644 --- a/src/transformers/cli/serving/server.py +++ b/src/transformers/cli/serving/server.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +15,6 @@ FastAPI app factory. """ -from __future__ import annotations - import uuid from contextlib import asynccontextmanager diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py index 586e4d10d67f..8cf627f5f142 100644 --- a/src/transformers/cli/serving/transcription.py +++ b/src/transformers/cli/serving/transcription.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,20 +15,20 @@ Handler for the /v1/audio/transcriptions endpoint. """ -from __future__ import annotations - import io from typing import TYPE_CHECKING +from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse from ...utils import logging from .model_manager import ModelManager -from .utils import DirectStreamer, GenerationState, _StreamError +from .utils import DirectStreamer, GenerateManager, GenerationState, _StreamError if TYPE_CHECKING: - from fastapi import Request + from transformers import PreTrainedModel, ProcessorMixin + logger = logging.get_logger(__name__) @@ -45,11 +45,24 @@ class TranscriptionHandler: """ def __init__(self, model_manager: ModelManager, generation_state: GenerationState): + """ + Args: + model_manager (`ModelManager`): Handles model loading, caching, and lifecycle. + generation_state (`GenerationState`): Shared generation state for thread safety. + """ self.model_manager = model_manager self._generation_state = generation_state async def handle_request(self, request: Request) -> JSONResponse | StreamingResponse: - """Parse multipart form, run transcription, return result.""" + """Parse multipart form, run transcription, return result. + + Args: + request (`Request`): FastAPI request containing multipart form data with + ``file`` (audio bytes), ``model`` (model ID), and optional ``stream`` flag. + + Returns: + `JSONResponse | StreamingResponse`: Transcription result or SSE stream. + """ from transformers.utils.import_utils import is_librosa_available if not is_librosa_available(): @@ -81,7 +94,7 @@ async def handle_request(self, request: Request) -> JSONResponse | StreamingResp return self._streaming(gen_manager, audio_model, tokenizer, audio_inputs) return await self._non_streaming(gen_manager, audio_model, audio_processor, audio_inputs) - async def _non_streaming(self, gen_manager, audio_model, audio_processor, audio_inputs) -> JSONResponse: + async def _non_streaming(self, gen_manager: GenerateManager, audio_model: "PreTrainedModel", audio_processor: "ProcessorMixin", audio_inputs: dict) -> JSONResponse: # Audio models have different inputs (input_features) and decode (batch_decode) # than text models, so we use async_submit() directly instead of # generate_non_streaming(). TODO: add generate_audio_non_streaming() when @@ -92,7 +105,7 @@ async def _non_streaming(self, gen_manager, audio_model, audio_processor, audio_ text = audio_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return JSONResponse(Transcription(text=text).model_dump(exclude_none=True)) - def _streaming(self, gen_manager, audio_model, tokenizer, audio_inputs) -> StreamingResponse: + def _streaming(self, gen_manager: GenerateManager, audio_model: "PreTrainedModel", tokenizer: "ProcessorMixin", audio_inputs: dict) -> StreamingResponse: # Same as _non_streaming — uses submit() directly because audio inputs # differ from text. TODO: add generate_audio_streaming() when more audio # modalities are supported. diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index 02d54ba2e825..f4f4cbd17aff 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +15,6 @@ Shared types, constants, and utilities for the serving layer. """ -from __future__ import annotations - import asyncio import base64 import copy @@ -27,64 +25,31 @@ import threading from abc import ABC, abstractmethod from concurrent.futures import Future +from typing import Callable from io import BytesIO from queue import Queue +from typing import TYPE_CHECKING from transformers.utils import logging -from transformers.utils.import_utils import is_openai_available, is_vision_available - -logger = logging.get_logger(__name__) +if TYPE_CHECKING: + import pydantic + import tokenizers + import torch -if is_vision_available(): - from PIL import Image + from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin + from transformers.generation.continuous_batching.continuous_api import ContinuousBatchingManager + from transformers.generation.continuous_batching.requests import GenerationOutput -if is_openai_available(): - from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming + from .model_manager import ModelManager - class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): - generation_config: str - processor: str +logger = logging.get_logger(__name__) -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- X_REQUEST_ID = "x-request-id" -# Fields accepted by the OpenAI schema but not yet supported. -# Receiving these raises an error to avoid silent misbehaviour. -# NOTE: "stop" is NOT in this set — we map it to stop_strings. -UNUSED_CHAT_COMPLETION_FIELDS = { - "audio", - "function_call", - "functions", - "logprobs", - "max_completion_tokens", - "metadata", - "modalities", - "n", - "parallel_tool_calls", - "prediction", - "presence_penalty", - "reasoning_effort", - "response_format", - "service_tier", - "store", - "stream_options", - "tool_choice", - "top_logprobs", - "user", - "web_search_options", -} - - -# --------------------------------------------------------------------------- -# Types -# --------------------------------------------------------------------------- - class Modality(enum.Enum): LLM = "LLM" @@ -106,10 +71,6 @@ class _GenerationCancelled(Exception): -# --------------------------------------------------------------------------- -# Tool call parsing -# --------------------------------------------------------------------------- - # Model-specific tokens that mark the start/end of a tool call block. # TODO: extract these from the chat template at runtime instead of hardcoding. # Qwen/Hermes use /, Mistral uses [TOOL_CALLS], etc. @@ -122,8 +83,16 @@ class _GenerationCancelled(Exception): } -def detect_tool_format(model) -> dict | None: - """Return the tool call token format (``{"start": ..., "end": ...}``) if supported, else ``None``.""" +def detect_tool_format(model: "PreTrainedModel") -> dict | None: + """Return the tool call token format for a model, if supported. + + Args: + model (`PreTrainedModel`): The loaded model. + + Returns: + `dict | None`: A dict ``{"start": str, "end": str}`` with the model's tool call + delimiters, or ``None`` if the model family is not recognized. + """ architecture = model.config.architectures[0].lower() for family in _TOOL_CALL_TOKENS: if family in architecture: @@ -230,11 +199,6 @@ def _parse_block(self, block: str) -> dict | None: return {"name": result[0], "arguments": result[1]} -# --------------------------------------------------------------------------- -# Progress tracking for model loading -# --------------------------------------------------------------------------- - - class DownloadAggregator: """Aggregates byte-progress across multiple concurrent download tqdm bars. @@ -242,7 +206,7 @@ class DownloadAggregator: a single aggregate ``{"stage": "download", "progress": {...}}`` event whenever any updates. """ - def __init__(self, enqueue, model_id: str): + def __init__(self, enqueue: Callable, model_id: str): self.enqueue = enqueue self.model = model_id self.bars: dict[int, tuple[int, int | None]] = {} @@ -276,11 +240,19 @@ def _emit(self): ) -def make_progress_tqdm_class(callback, model_id: str): +def make_progress_tqdm_class(callback: Callable, model_id: str): """Create a tqdm subclass that routes progress to a callback. Bars with ``unit="B"`` are download bars — aggregated via ``DownloadAggregator``. Other bars (e.g. "Loading weights") emit ``weights`` stage events. + + Args: + callback (`callable`): Called with a dict payload + ``{"status": "loading", "model": ..., "stage": ..., "progress": ...}``. + model_id (`str`): The model ID (included in progress payloads). + + Returns: + A tqdm subclass that forwards progress to *callback*. """ from tqdm.auto import tqdm as base_tqdm @@ -339,11 +311,6 @@ def close(self): return ProgressTqdm -# --------------------------------------------------------------------------- -# Streaming -# --------------------------------------------------------------------------- - - class DirectStreamer: """Streamer for ``model.generate()`` (used by :class:`GenerateManager`). @@ -353,7 +320,15 @@ class DirectStreamer: ``DecodeStream`` (O(1) per token) and pushed as text to an asyncio.Queue. """ - def __init__(self, tokenizer, loop, queue, skip_special_tokens: bool = True): + def __init__(self, tokenizer: "tokenizers.Tokenizer", loop: asyncio.AbstractEventLoop, queue: asyncio.Queue, skip_special_tokens: bool = True): + """ + Args: + tokenizer: The Rust tokenizer (``tokenizer._tokenizer``). + loop (`asyncio.AbstractEventLoop`): The event loop to push decoded text to. + queue (`asyncio.Queue`): The queue that receives decoded text chunks. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether to strip special tokens during decoding. + """ from tokenizers.decoders import DecodeStream self._tokenizer = tokenizer @@ -364,7 +339,7 @@ def __init__(self, tokenizer, loop, queue, skip_special_tokens: bool = True): self._cancelled = threading.Event() self.total_tokens = 0 - def put(self, value) -> None: + def put(self, value: "torch.Tensor") -> None: """Called by ``model.generate()`` after each decode step with new token(s).""" if self._cancelled.is_set(): raise _GenerationCancelled() @@ -396,7 +371,15 @@ class CBStreamer: pushes text to the asyncio.Queue. ``end()`` signals the stream is complete. """ - def __init__(self, cb_manager, request_id, tokenizer, loop, queue): + def __init__(self, cb_manager: "ContinuousBatchingManager", request_id: str, tokenizer: "tokenizers.Tokenizer", loop: asyncio.AbstractEventLoop, queue: asyncio.Queue): + """ + Args: + cb_manager (`ContinuousBatchingManager`): The CB manager instance. + request_id (`str`): The request ID to track in the CB scheduler. + tokenizer: The Rust tokenizer (``tokenizer._tokenizer``). + loop (`asyncio.AbstractEventLoop`): The event loop to push decoded text to. + queue (`asyncio.Queue`): The queue that receives decoded text chunks. + """ from tokenizers.decoders import DecodeStream self._cb = cb_manager @@ -408,7 +391,7 @@ def __init__(self, cb_manager, request_id, tokenizer, loop, queue): self._prev_len = 0 self.total_tokens = 0 - def put(self, output) -> None: + def put(self, output: "GenerationOutput") -> None: """Decode new tokens from a CB ``GenerationOutput`` and push text to the queue.""" new_tokens = output.generated_tokens[self._prev_len :] self._prev_len = len(output.generated_tokens) @@ -427,11 +410,6 @@ def cancel(self) -> None: self._cb.cancel_request(self._request_id) -# --------------------------------------------------------------------------- -# Torch helpers -# --------------------------------------------------------------------------- - - def set_torch_seed(seed: int) -> None: import torch @@ -486,11 +464,6 @@ def async_submit(self, fn, *args, **kwargs) -> asyncio.Future: return future -# --------------------------------------------------------------------------- -# Generation managers -# --------------------------------------------------------------------------- - - class BaseGenerateManager(ABC): """Base class for generation managers. @@ -500,16 +473,36 @@ class BaseGenerateManager(ABC): """ @abstractmethod - def generate_streaming(self, model, processor, inputs, gen_config, request_id=None): + def generate_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): """Start streaming generation. - Returns ``(queue, context)`` where *queue* yields ``str | _StreamError | None`` - and *context* exposes ``.total_tokens`` and ``.cancel()``. + Args: + model (`PreTrainedModel`): The loaded model. + processor: The processor or tokenizer for decoding. + inputs (`dict`): Tokenized inputs (tensors for sequential, lists for CB). + gen_config (`GenerationConfig`): Generation parameters. + request_id (`str`, *optional*): Unique request identifier. + + Returns: + `tuple[asyncio.Queue, DirectStreamer | CBStreamer]`: A ``(queue, streamer)`` pair + where *queue* yields ``str | _StreamError | None`` and *streamer* exposes + ``.total_tokens`` and ``.cancel()``. """ @abstractmethod - def generate_non_streaming(self, model, processor, inputs, gen_config, request_id=None): - """Run generation to completion. Returns ``(text, input_len, generated_ids)``.""" + def generate_non_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): + """Run generation to completion. + + Args: + model (`PreTrainedModel`): The loaded model. + processor: The processor or tokenizer for decoding. + inputs (`dict`): Tokenized inputs (tensors for sequential, lists for CB). + gen_config (`GenerationConfig`): Generation parameters. + request_id (`str`, *optional*): Unique request identifier. + + Returns: + `tuple[str, int, list[int]]`: ``(text, input_len, generated_ids)``. + """ @abstractmethod def stop(self): @@ -522,7 +515,7 @@ class GenerateManager(BaseGenerateManager): def __init__(self): self._thread = InferenceThread() - def generate_streaming(self, model, processor, inputs, gen_config, request_id=None): + def generate_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() streamer = DirectStreamer(processor._tokenizer, loop, queue, skip_special_tokens=True) @@ -539,7 +532,7 @@ def _run(): self.submit(_run) return queue, streamer - async def generate_non_streaming(self, model, processor, inputs, gen_config, request_id=None): + async def generate_non_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): sequences = await self.async_submit( model.generate, **inputs, generation_config=gen_config, tokenizer=processor ) @@ -578,10 +571,14 @@ class CBGenerateManager(BaseGenerateManager): def __init__(self): self._cb = None - def init_cb(self, model, gen_config): + def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig"): """Initialize the CB manager on first call with the request's generation config. .. todo:: Remove when CB supports per-request generation config. + + Args: + model (`PreTrainedModel`): The loaded model (must support ``init_continuous_batching``). + gen_config (`GenerationConfig`): Generation config used for shared sampling params. """ if self._cb is not None: return @@ -592,7 +589,7 @@ def init_cb(self, model, gen_config): self._cb.logit_processor = LogitsProcessorList() self._cb.start() - def generate_streaming(self, model, processor, inputs, gen_config, request_id=None): + def generate_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): loop = asyncio.get_running_loop() text_queue: asyncio.Queue = asyncio.Queue() @@ -623,7 +620,7 @@ async def _read_and_decode(): asyncio.ensure_future(_read_and_decode()) return text_queue, streamer - async def generate_non_streaming(self, model, processor, inputs, gen_config, request_id=None): + async def generate_non_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): """Non-streaming CB generation, fully async (no per-request thread). Uses ``register_async_future`` — the dispatcher resolves a single @@ -664,11 +661,6 @@ def stop(self): self._cb.stop(block=True, timeout=2) -# --------------------------------------------------------------------------- -# Generation state (shared across handlers) -# --------------------------------------------------------------------------- - - class GenerationState: """Shared generation state across all handlers. @@ -676,6 +668,11 @@ class GenerationState: :class:`InferenceThread` so different models can run concurrently while ``torch.compile`` / CUDA graphs require same-model-same-thread) and a single :class:`CBGenerateManager` for continuous batching. + + Args: + continuous_batching (`bool`, *optional*, defaults to `False`): + Whether to use continuous batching with paged attention instead of + sequential ``model.generate()`` calls. """ def __init__(self, continuous_batching: bool = False): @@ -684,8 +681,16 @@ def __init__(self, continuous_batching: bool = False): self._cb_manager: CBGenerateManager | None = None self._cb_model_id: str | None = None - def use_continuous_batching(self, model, modality: Modality) -> bool: - """Check if CB can be used. Logs a warning on fallback.""" + def use_continuous_batching(self, model: "PreTrainedModel", modality: Modality) -> bool: + """Check if continuous batching can be used for this model and modality. + + Args: + model (`PreTrainedModel`): The loaded model. + modality (`Modality`): The detected model modality (LLM, VLM, etc.). + + Returns: + `bool`: ``True`` if CB is enabled and the model supports it, ``False`` otherwise. + """ if not self._continuous_batching: return False can = hasattr(model, "init_continuous_batching") and modality == Modality.LLM @@ -697,7 +702,15 @@ def use_continuous_batching(self, model, modality: Modality) -> bool: return can def get_manager(self, model_id: str, use_cb: bool) -> BaseGenerateManager: - """Return a per-model generation manager, lazily created on first request.""" + """Return a per-model generation manager, lazily created on first request. + + Args: + model_id (`str`): The model ID in ``'model_id@revision'`` format. + use_cb (`bool`): Whether to return a CB manager or a sequential one. + + Returns: + `BaseGenerateManager`: Either a `GenerateManager` or `CBGenerateManager`. + """ if use_cb: if self._cb_model_id != model_id: if self._cb_manager is not None: @@ -718,25 +731,32 @@ def shutdown(self): self._cb_manager = None -# --------------------------------------------------------------------------- -# Base handler -# --------------------------------------------------------------------------- - - class BaseHandler: """Shared logic for chat completion and responses handlers. Provides model resolution, generation config building, and SSE formatting. Generation is delegated to the shared :class:`GenerationState`. + + Args: + model_manager (`ModelManager`): + Handles model loading, caching, and lifecycle. + generation_state (`GenerationState`): + Shared state managing per-model generation managers. + force_model (`str`, *optional*): + If set, override the ``model`` field in every request with this model ID. + force_processor (`str`, *optional*): + If set, override the processor/tokenizer model ID. + compile (`bool`, *optional*, defaults to `False`): + Enable ``torch.compile`` with static cache for faster decode. """ def __init__( self, - model_manager, + model_manager: "ModelManager", generation_state: GenerationState, - force_model=None, - force_processor=None, - compile=False, + force_model: str | None = None, + force_processor: str | None = None, + compile: bool = False, ): self.model_manager = model_manager self.generation_state = generation_state @@ -745,7 +765,7 @@ def __init__( self._compile = compile @staticmethod - def chunk_to_sse(chunk) -> str: + def chunk_to_sse(chunk: "str | pydantic.BaseModel") -> str: """Format a pydantic model or string as an SSE ``data:`` line.""" if isinstance(chunk, str): return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n" @@ -765,11 +785,26 @@ def _resolve_model(self, body: dict): return model_id, model, processor - def _build_generation_config(self, body: dict, model_generation_config, processor=None, use_cb=False): + def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None, use_cb: bool = False): """Build a GenerationConfig from shared params (temperature, top_p, seed, generation_config JSON). Subclasses should call ``super()._build_generation_config(...)`` then apply endpoint-specific params (``max_tokens``, ``max_output_tokens``, etc.). + + Args: + body (`dict`): + The raw request body. + model_generation_config (`GenerationConfig`): + The model's default generation config (will be deep-copied). + processor (*optional*): + Processor or tokenizer, used to sync ``eos_token_id`` / ``pad_token_id`` + for GGUF models that lack them. + use_cb (`bool`, *optional*, defaults to `False`): + Whether continuous batching is active. If ``True``, disables the model's + internal KV cache (CB manages its own paged cache). + + Returns: + `GenerationConfig`: A new config with request-specific overrides applied. """ from transformers import GenerationConfig @@ -807,42 +842,51 @@ def _build_generation_config(self, body: dict, model_generation_config, processo return generation_config + @staticmethod + def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) -> list[dict]: + """Convert OpenAI-format messages to the format expected by HF processors. -# --------------------------------------------------------------------------- -# Message preprocessing: OpenAI messages → processor-compatible format -# --------------------------------------------------------------------------- + For LLMs, collapses list content blocks into plain text. For VLMs, converts + ``image_url`` content parts (including base64) into ``{"type": "image", "url": ...}`` + entries that HF processors understand. + Args: + messages (`list[dict]`): OpenAI-format chat messages. + modality (`Modality`): Whether the model is an LLM or VLM. -def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) -> list[dict]: - """Convert OpenAI-format messages to the format expected by HF processors.""" - processor_inputs = [] + Returns: + `list[dict]`: Processor-compatible messages. + """ + processor_inputs = [] - for message in messages: - parsed = {"role": message["role"], "content": []} + for message in messages: + parsed = {"role": message["role"], "content": []} - if modality == Modality.LLM: - if isinstance(message["content"], str): - parsed["content"] = message["content"] - elif isinstance(message["content"], list): - texts = [c["text"] for c in message["content"] if c["type"] == "text"] - parsed["content"] = " ".join(texts) + if modality == Modality.LLM: + if isinstance(message["content"], str): + parsed["content"] = message["content"] + elif isinstance(message["content"], list): + texts = [c["text"] for c in message["content"] if c["type"] == "text"] + parsed["content"] = " ".join(texts) - elif modality == Modality.VLM: - if isinstance(message["content"], str): - parsed["content"].append({"type": "text", "text": message["content"]}) - else: - for content in message["content"]: - if content["type"] == "text": - parsed["content"].append(content) - elif content["type"] == "image_url": - url = content["image_url"]["url"] - if "base64" in url: - image_data = re.sub("^data:image/.+;base64,", "", url) - image = Image.open(BytesIO(base64.b64decode(image_data))) - file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) - image.save(file.name) - url = file.name - parsed["content"].append({"type": "image", "url": url}) - - processor_inputs.append(parsed) - return processor_inputs + elif modality == Modality.VLM: + if isinstance(message["content"], str): + parsed["content"].append({"type": "text", "text": message["content"]}) + else: + for content in message["content"]: + if content["type"] == "text": + parsed["content"].append(content) + elif content["type"] == "image_url": + from PIL import Image + + url = content["image_url"]["url"] + if "base64" in url: + image_data = re.sub("^data:image/.+;base64,", "", url) + image = Image.open(BytesIO(base64.b64decode(image_data))) + file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + image.save(file.name) + url = file.name + parsed["content"].append({"type": "image", "url": url}) + + processor_inputs.append(parsed) + return processor_inputs From 4abb194f31d538b8c98b3e60b4d24c6111ce6e5f Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 26 Mar 2026 21:05:21 +0000 Subject: [PATCH 26/64] update test --- tests/cli/test_serve_refactored.py | 42 +++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py index 5f507c5dbb66..473bc243bb25 100644 --- a/tests/cli/test_serve_refactored.py +++ b/tests/cli/test_serve_refactored.py @@ -70,28 +70,36 @@ def test_host_port_blocking(cli): class TestProcessorInputsFromMessages(unittest.TestCase): def test_llm_string_content(self): - from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + from transformers.cli.serving.utils import BaseHandler, Modality + + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages messages = [{"role": "user", "content": "Hello"}] result = get_processor_inputs_from_messages(messages, Modality.LLM) self.assertEqual(result, [{"role": "user", "content": "Hello"}]) def test_llm_list_content_text_only(self): - from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + from transformers.cli.serving.utils import BaseHandler, Modality + + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages messages = [{"role": "user", "content": [{"type": "text", "text": "A"}, {"type": "text", "text": "B"}]}] result = get_processor_inputs_from_messages(messages, Modality.LLM) self.assertEqual(result, [{"role": "user", "content": "A B"}]) def test_vlm_string_content_wrapped(self): - from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + from transformers.cli.serving.utils import BaseHandler, Modality + + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages messages = [{"role": "user", "content": "Hello"}] result = get_processor_inputs_from_messages(messages, Modality.VLM) self.assertEqual(result, [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]) def test_vlm_text_and_image_url(self): - from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + from transformers.cli.serving.utils import BaseHandler, Modality + + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages messages = [ { @@ -109,7 +117,9 @@ def test_vlm_text_and_image_url(self): def test_llm_multi_turn_conversation(self): """Multi-turn conversation with string content should pass through as-is.""" - from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + from transformers.cli.serving.utils import BaseHandler, Modality + + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages messages = [ {"role": "user", "content": "How are you?"}, @@ -124,7 +134,9 @@ def test_llm_multi_turn_conversation(self): def test_llm_list_content_with_type(self): """LLM messages with typed content list should extract text and join.""" - from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + from transformers.cli.serving.utils import BaseHandler, Modality + + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages messages = [ {"role": "user", "content": [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]} @@ -137,7 +149,9 @@ def test_vlm_base64_image_creates_temp_file(self): """Base64 image URLs should be decoded and saved to a temp file.""" import os - from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + from transformers.cli.serving.utils import BaseHandler, Modality + + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages # Minimal valid 1x1 PNG as base64 base64_url = ( @@ -160,7 +174,9 @@ def test_vlm_base64_image_creates_temp_file(self): def test_vlm_multi_turn(self): """VLM multi-turn: string content should be wrapped in text type.""" - from transformers.cli.serving.utils import Modality, get_processor_inputs_from_messages + from transformers.cli.serving.utils import BaseHandler, Modality + + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages messages = [ {"role": "user", "content": "Describe the image"}, @@ -1046,10 +1062,10 @@ def test_default_bumps_short_max_new_tokens(self): @require_openai class TestResponseUsage(unittest.TestCase): - def test_make_usage(self): - from transformers.cli.serving.response import _make_usage + def testcompute_usage(self): + from transformers.cli.serving.response import compute_usage - usage = _make_usage(input_tokens=100, output_tokens=50) + usage = compute_usage(input_tokens=100, output_tokens=50) self.assertEqual(usage.input_tokens, 100) self.assertEqual(usage.output_tokens, 50) self.assertEqual(usage.total_tokens, 150) @@ -1060,9 +1076,9 @@ def test_usage_in_completed_response(self): """Usage should serialize correctly inside a Response.""" from openai.types.responses import Response - from transformers.cli.serving.response import _make_usage + from transformers.cli.serving.response import compute_usage - usage = _make_usage(10, 5) + usage = compute_usage(10, 5) response = Response( id="resp_test", created_at=0, From 165898144a9769e08f8da684ac9d55d9ac1176b2 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 26 Mar 2026 22:23:19 +0000 Subject: [PATCH 27/64] better benchmark --- tests/cli/bench_cb_raw.py | 185 +++++++++++++++++++++++ tests/cli/benchmark_serve.py | 26 +++- tests/cli/benchmark_serve_load.py | 239 +++++++++++++++--------------- 3 files changed, 323 insertions(+), 127 deletions(-) create mode 100644 tests/cli/bench_cb_raw.py diff --git a/tests/cli/bench_cb_raw.py b/tests/cli/bench_cb_raw.py new file mode 100644 index 000000000000..d694c65bfe11 --- /dev/null +++ b/tests/cli/bench_cb_raw.py @@ -0,0 +1,185 @@ +""" +Raw continuous batching benchmark — no HTTP, no serve layer. +2x2 matrix: {non_stream, stream} × {legacy get_result, optimized async}. + +Usage: + CUDA_VISIBLE_DEVICES=0 python tests/cli/bench_cb_raw.py + CUDA_VISIBLE_DEVICES=0 python tests/cli/bench_cb_raw.py --batch 10 50 100 500 1000 2000 +""" + +import argparse +import asyncio +import os +import sys +import time + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, ContinuousBatchingConfig, GenerationConfig + + +def make_prompts(tokenizer, n, target_len=256): + filler = "The quick brown fox jumps over the lazy dog. Pack my box with five dozen liquor jugs. " * 100 + ids = tokenizer.encode(filler, add_special_tokens=False) + return [ids[:max(10, int(target_len * (0.8 + 0.4 * (i % 5) / 4)))] for i in range(n)] + + +# --------------------------------------------------------------------------- +# Non-streaming (CB streaming=False → one output per request when finished) +# --------------------------------------------------------------------------- + + +def bench_ns_get_result(mgr, prompts, max_new_tokens): + """Non-stream + get_result: batch add, poll shared queue.""" + N = len(prompts) + t0 = time.perf_counter() + mgr.add_requests(inputs=prompts, max_new_tokens=max_new_tokens, streaming=False) + total = finished = 0 + while finished < N: + r = mgr.get_result(timeout=1) + if r and r.is_finished(): + total += len(r.generated_tokens) + finished += 1 + return total, time.perf_counter() - t0 + + +async def bench_ns_future(mgr, prompts, max_new_tokens): + """Non-stream + future: one asyncio.Future per request, resolved by dispatcher.""" + t0 = time.perf_counter() + futures = [] + for i, ids in enumerate(prompts): + rid = f"nsf_{time.perf_counter_ns()}_{i}" + future = mgr.register_async_future(rid) + mgr.add_request(ids, request_id=rid, max_new_tokens=max_new_tokens, streaming=False) + futures.append(future) + results = await asyncio.gather(*futures) + return sum(len(r.generated_tokens) for r in results), time.perf_counter() - t0 + + +# --------------------------------------------------------------------------- +# Streaming (CB streaming=True → one output per token per request) +# --------------------------------------------------------------------------- + + +def bench_s_get_result(mgr, prompts, max_new_tokens): + """Stream + get_result: batch add, poll shared queue, skip intermediate outputs.""" + N = len(prompts) + t0 = time.perf_counter() + mgr.add_requests(inputs=prompts, max_new_tokens=max_new_tokens, streaming=True) + total = finished = 0 + while finished < N: + r = mgr.get_result(timeout=1) + if r is not None and r.is_finished(): + total += len(r.generated_tokens) + finished += 1 + return total, time.perf_counter() - t0 + + +async def bench_s_async_iter(mgr, prompts, max_new_tokens): + """Stream + async_request_id_iter: per-request async queue via dispatcher.""" + t0 = time.perf_counter() + rids = [] + for i, ids in enumerate(prompts): + rid = f"sai_{time.perf_counter_ns()}_{i}" + mgr.add_request(ids, request_id=rid, max_new_tokens=max_new_tokens, streaming=True) + rids.append(rid) + + async def consume(rid): + async for output in mgr.async_request_id_iter(rid): + if output.is_finished(): + return len(output.generated_tokens) + return 0 + + results = await asyncio.gather(*[consume(rid) for rid in rids]) + return sum(results), time.perf_counter() - t0 + + +# --------------------------------------------------------------------------- +# Runner +# --------------------------------------------------------------------------- + +METHODS = { + "ns_get_result": ("Non-stream + get_result", lambda mgr, p, m: bench_ns_get_result(mgr, p, m)), + "ns_future": ("Non-stream + future", lambda mgr, p, m: asyncio.run(bench_ns_future(mgr, p, m))), + "s_get_result": ("Stream + get_result", lambda mgr, p, m: bench_s_get_result(mgr, p, m)), + "s_async_iter": ("Stream + async_iter", lambda mgr, p, m: asyncio.run(bench_s_async_iter(mgr, p, m))), +} + + +def main(): + parser = argparse.ArgumentParser(description="Raw CB benchmark (2x2 matrix)") + parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-7B-Instruct") + parser.add_argument("--batch", type=int, nargs="+", default=[10, 50, 100, 500]) + parser.add_argument("--max-new-tokens", type=int, default=64) + parser.add_argument("--prompt-tokens", type=int, default=256) + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--runs", type=int, default=3) + parser.add_argument("--methods", type=str, nargs="+", + default=list(METHODS.keys()), choices=list(METHODS.keys())) + args = parser.parse_args() + + print(f"Model: {args.model}") + print(f"Batch: {args.batch} | Prompt: ~{args.prompt_tokens} tok | Gen: {args.max_new_tokens} tok") + print(f"Warmup: {args.warmup} | Runs: {args.runs} | Methods: {args.methods}") + sys.stdout.flush() + + model = AutoModelForCausalLM.from_pretrained( + args.model, dtype=torch.bfloat16, attn_implementation="flash_attention_3", + ).cuda().eval() + tokenizer = AutoTokenizer.from_pretrained(args.model) + all_prompts = make_prompts(tokenizer, max(args.batch), args.prompt_tokens) + + gen_config = GenerationConfig(max_new_tokens=args.max_new_tokens, do_sample=False) + cb_config = ContinuousBatchingConfig() + + # Header + col_w = 20 + header = f"{'N':>6}" + for m in args.methods: + label = METHODS[m][0] + header += f" | {label:>{col_w}}" + print(f"\n{header}") + print("-" * len(header)) + sys.stdout.flush() + + # Per-batch-size context: each N gets fresh CUDA graph capture + for N in args.batch: + prompts = all_prompts[:N] + + with model.continuous_batching_context_manager( + generation_config=gen_config, continuous_batching_config=cb_config, block=True, timeout=5, + ) as mgr: + # Warmup for this batch size + warmup_prompts = prompts[:min(200, N)] + for _ in range(args.warmup): + bench_ns_get_result(mgr, warmup_prompts, args.max_new_tokens) + + row = f"{N:>6}" + for method_key in args.methods: + _, fn = METHODS[method_key] + best = 0 + for _ in range(args.runs): + tokens, dt = fn(mgr, prompts, args.max_new_tokens) + best = max(best, tokens / dt if dt > 0 else 0) + row += f" | {best:>{col_w - 4}.0f} t/s" + print(row, flush=True) + + # Quality check + print("\n--- Quality check ---") + with model.continuous_batching_context_manager( + generation_config=gen_config, continuous_batching_config=cb_config, block=True, timeout=5, + ) as mgr: + async def check(): + for i in range(3): + rid = f"qc_{i}" + future = mgr.register_async_future(rid) + mgr.add_request(all_prompts[i], request_id=rid, max_new_tokens=args.max_new_tokens, streaming=False) + r = await future + text = tokenizer.decode(r.generated_tokens, skip_special_tokens=True)[:80] + print(f" {r.request_id}: {len(r.generated_tokens)} tokens | {text}") + asyncio.run(check()) + + +if __name__ == "__main__": + main() diff --git a/tests/cli/benchmark_serve.py b/tests/cli/benchmark_serve.py index 69386ec5faa7..7782bd64c6b8 100644 --- a/tests/cli/benchmark_serve.py +++ b/tests/cli/benchmark_serve.py @@ -180,6 +180,7 @@ def streaming_response( t_start = time.perf_counter() t_first_token = None + completion_tokens = None text_chunks = [] resp = requests.post(f"{base_url}/v1/responses", json=payload, stream=True, timeout=300) @@ -201,6 +202,8 @@ def streaming_response( if t_first_token is None: t_first_token = time.perf_counter() elif etype == "response.completed": + usage = chunk.get("response", {}).get("usage", {}) + completion_tokens = usage.get("output_tokens") break t_end = time.perf_counter() @@ -209,7 +212,7 @@ def streaming_response( return { "total": t_end - t_start, "ttft": (t_first_token - t_start) if t_first_token else None, - "completion_tokens": len(text_chunks), # approximate — one chunk per streamer push + "completion_tokens": completion_tokens, "text": text, } @@ -400,7 +403,7 @@ def make_sep(char="-"): def start_server( model: str, port: int, processor: str | None = None, attn_implementation: str | None = None, - compile: bool = False, + compile: bool = False, continuous_batching: bool = False, ): """Start a transformers serve instance. Returns the Serve object.""" from transformers.cli.serve_refactored import Serve @@ -412,6 +415,8 @@ def start_server( kwargs["attn_implementation"] = attn_implementation if compile: kwargs["compile"] = True + if continuous_batching: + kwargs["continuous_batching"] = True return Serve(**kwargs) @@ -462,6 +467,8 @@ def main(): help="Attention implementations to benchmark (default: sdpa eager flash_attention_2)") parser.add_argument("--compile", action="store_true", help="Enable static cache + torch.compile on the server for faster decode") + parser.add_argument("--continuous-batching", action="store_true", + help="Enable continuous batching with paged attention") parser.add_argument("--mode", type=str, choices=["bench", "chat"], default="bench", help="bench: greedy (temp=0). chat: sampling (do_sample=True, temp=0.7)") parser.add_argument("--endpoint", type=str, choices=["chat", "responses"], default="responses", @@ -506,7 +513,7 @@ def main(): print(f"\nStarting server for {spec['model']} (attn={attn_impl})...") try: server = start_server(spec["model"], args.port, spec["processor"], attn_implementation=attn_impl, - compile=args.compile) + compile=args.compile, continuous_batching=args.continuous_batching) except Exception as e: print(f" ERROR: Failed to start server with attn={attn_impl}: {e}. Skipping.") continue @@ -519,8 +526,17 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(spec["tokenizer"]) - # Warmup (always dynamic cache — static cache compiles shapes, so a short warmup would break longer requests) - streaming_request(base_url, [{"role": "user", "content": "hi"}], max_tokens=5, seed=args.seed, endpoint=endpoint) + # Warmup — use non-streaming when compile is on (first compile call takes ~30s, + # streaming warmup can hang waiting for SSE chunks during compilation) + if args.compile: + warmup_prompt = make_prompt(tokenizer, max(args.pp + [args.tg_prefill])) + gen_cfg = {"max_new_tokens": max(args.tg), "do_sample": False} + payload = {"messages": [{"role": "user", "content": warmup_prompt}], "stream": False, + "seed": args.seed, "generation_config": json.dumps(gen_cfg)} + print(" compile warmup (non-streaming, may take ~30s)...") + requests.post(f"{base_url}/v1/chat/completions", json=payload, timeout=120) + else: + streaming_request(base_url, [{"role": "user", "content": "hi"}], max_tokens=5, seed=args.seed, endpoint=endpoint) rows = [] for pp in args.pp: diff --git a/tests/cli/benchmark_serve_load.py b/tests/cli/benchmark_serve_load.py index 65c959818bbb..737af09c239e 100644 --- a/tests/cli/benchmark_serve_load.py +++ b/tests/cli/benchmark_serve_load.py @@ -6,18 +6,16 @@ - What's the latency distribution (p50/p90/p99) as concurrency increases? - Does the server stay stable under pressure? -Modes: - --max-concurrency N Send requests with up to N in flight at once - --request-rate R Send R requests/sec (Poisson arrival), let them queue naturally +Each --max-concurrency value sends that many requests simultaneously. Examples: - # Sweep concurrency levels (1, 2, 4, 8) + # Sweep concurrency levels python tests/cli/benchmark_serve_load.py --model Qwen/Qwen2.5-7B-Instruct \\ - --max-concurrency 1 2 4 8 --num-requests 32 + --max-concurrency 1 4 8 32 --continuous-batching - # Fixed request rate + # 500 concurrent non-streaming requests python tests/cli/benchmark_serve_load.py --model Qwen/Qwen2.5-7B-Instruct \\ - --request-rate 5.0 --num-requests 50 + --max-concurrency 500 --continuous-batching --no-stream # Against an existing server python tests/cli/benchmark_serve_load.py --url http://localhost:8000 \\ @@ -33,7 +31,7 @@ import time -os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") import aiohttp @@ -79,23 +77,27 @@ async def send_request( max_new_tokens: int, seed: int, endpoint: str = "responses", + model: str | None = None, + stream: bool = True, ) -> dict: - """Send a single streaming request and collect timing metrics.""" + """Send a single request and collect timing metrics.""" gen_cfg = {"max_new_tokens": max_new_tokens, "do_sample": False} if endpoint == "responses": url = f"{base_url}/v1/responses" payload = { + "model": model, "input": [{"role": "user", "content": prompt}], - "stream": True, + "stream": stream, "seed": seed, "generation_config": json.dumps(gen_cfg), } else: url = f"{base_url}/v1/chat/completions" payload = { + "model": model, "messages": [{"role": "user", "content": prompt}], - "stream": True, + "stream": stream, "seed": seed, "generation_config": json.dumps(gen_cfg), } @@ -104,61 +106,85 @@ async def send_request( t_first_token = None token_times = [] text_chunks = [] + non_streaming_tokens = 0 error = None try: async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=300)) as resp: if resp.status != 200: - error = f"HTTP {resp.status}" + error = f"HTTP {resp.status}: {await resp.text()}" return _make_result(t_start, error=error) - async for line in resp.content: - line = line.decode("utf-8").strip() - if not line or not line.startswith("data: "): - continue - data_str = line[len("data: "):] - if data_str.strip() == "[DONE]": - break - try: - chunk = json.loads(data_str) - except json.JSONDecodeError: - continue - - # Extract token content based on endpoint format - has_content = False + if not stream: + # Non-streaming: single JSON response — get token count from usage + body = await resp.json() + t_first_token = time.perf_counter() if endpoint == "responses": - if chunk.get("type") == "response.output_text.delta": - delta = chunk.get("delta", "") - if delta: - text_chunks.append(delta) - has_content = True - elif chunk.get("type") == "response.completed": - break + output_tokens = body.get("usage", {}).get("output_tokens", 0) + for item in body.get("output", []): + if item.get("type") == "message": + for part in item.get("content", []): + if part.get("type") == "output_text": + text_chunks.append(part.get("text", "")) else: - choices = chunk.get("choices", []) - if choices: - content = choices[0].get("delta", {}).get("content") - if content is not None and content != "": + output_tokens = body.get("usage", {}).get("completion_tokens", 0) + for choice in body.get("choices", []): + content = choice.get("message", {}).get("content", "") + if content: text_chunks.append(content) - has_content = True - if choices[0].get("finish_reason") is not None: + # Use server-reported token count instead of len(text_chunks) + non_streaming_tokens = output_tokens + token_times.append(t_first_token) + else: + # Streaming: parse SSE events + async for line in resp.content: + line = line.decode("utf-8").strip() + if not line or not line.startswith("data: "): + continue + data_str = line[len("data: "):] + if data_str.strip() == "[DONE]": + break + try: + chunk = json.loads(data_str) + except json.JSONDecodeError: + continue + + # Extract token content based on endpoint format + has_content = False + if endpoint == "responses": + if chunk.get("type") == "response.output_text.delta": + delta = chunk.get("delta", "") + if delta: + text_chunks.append(delta) + has_content = True + elif chunk.get("type") == "response.completed": break - - if has_content: - now = time.perf_counter() - token_times.append(now) - if t_first_token is None: - t_first_token = now + else: + choices = chunk.get("choices", []) + if choices: + content = choices[0].get("delta", {}).get("content") + if content is not None and content != "": + text_chunks.append(content) + has_content = True + if choices[0].get("finish_reason") is not None: + break + + if has_content: + now = time.perf_counter() + token_times.append(now) + if t_first_token is None: + t_first_token = now except asyncio.TimeoutError: error = "timeout" except Exception as e: error = str(e) - return _make_result(t_start, t_first_token, token_times, text_chunks, error) + output_token_count = non_streaming_tokens if not stream else None + return _make_result(t_start, t_first_token, token_times, text_chunks, error, output_token_count=output_token_count) -def _make_result(t_start, t_first_token=None, token_times=None, text_chunks=None, error=None): +def _make_result(t_start, t_first_token=None, token_times=None, text_chunks=None, error=None, output_token_count=None): t_end = time.perf_counter() token_times = token_times or [] text_chunks = text_chunks or [] @@ -173,7 +199,7 @@ def _make_result(t_start, t_first_token=None, token_times=None, text_chunks=None "ttft": (t_first_token - t_start) if t_first_token else None, "tpot": statistics.mean(itl) if itl else None, # time per output token "itl": itl, - "output_tokens": len(text_chunks), + "output_tokens": output_token_count if output_token_count is not None else len(text_chunks), "text": "".join(text_chunks), "error": error, } @@ -188,53 +214,21 @@ async def run_concurrency_test( base_url: str, prompts: list[str], max_new_tokens: int, - max_concurrency: int, seed: int, endpoint: str, + model: str | None = None, + stream: bool = True, ) -> list[dict]: - """Send all requests with a concurrency limit via semaphore.""" - semaphore = asyncio.Semaphore(max_concurrency) - results = [] - - async def _limited(session, prompt): - async with semaphore: - return await send_request(session, base_url, prompt, max_new_tokens, seed, endpoint) - - async with aiohttp.ClientSession() as session: - tasks = [_limited(session, p) for p in prompts] + """Send all prompts concurrently and collect results.""" + connector = aiohttp.TCPConnector(limit=0) + timeout = aiohttp.ClientTimeout(total=600) + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + tasks = [send_request(session, base_url, p, max_new_tokens, seed, endpoint, model=model, stream=stream) for p in prompts] results = await asyncio.gather(*tasks) return list(results) -async def run_rate_test( - base_url: str, - prompts: list[str], - max_new_tokens: int, - request_rate: float, - seed: int, - endpoint: str, -) -> list[dict]: - """Send requests at a target rate using Poisson inter-arrival times.""" - results = [] - tasks = [] - - async with aiohttp.ClientSession() as session: - for i, prompt in enumerate(prompts): - task = asyncio.create_task( - send_request(session, base_url, prompt, max_new_tokens, seed, endpoint) - ) - tasks.append(task) - - # Poisson inter-arrival: exponential delay - if i < len(prompts) - 1: - delay = random.expovariate(request_rate) - await asyncio.sleep(delay) - - results = await asyncio.gather(*tasks) - - return list(results) - # --------------------------------------------------------------------------- # Metrics @@ -361,12 +355,17 @@ def wait_for_server(base_url: str, timeout: int = 120) -> bool: return False -def start_server(model: str, port: int, compile: bool = False): +def start_server(model: str, port: int, compile: bool = False, continuous_batching: bool = False, + attn_implementation: str | None = None): from transformers.cli.serve_refactored import Serve kwargs = {"force_model": model, "port": port, "non_blocking": True, "log_level": "warning"} if compile: kwargs["compile"] = True + if continuous_batching: + kwargs["continuous_batching"] = True + if attn_implementation: + kwargs["attn_implementation"] = attn_implementation return Serve(**kwargs) @@ -379,9 +378,14 @@ async def async_main(args): base_url = args.url if args.url else f"http://localhost:{args.port}" server = None + # Default to flash_attention_3 when using continuous batching + if args.continuous_batching and args.attn_impl is None: + args.attn_impl = "flash_attention_3" + if not args.url: print(f"Starting server for {args.model}...") - server = start_server(args.model, args.port, compile=args.compile) + server = start_server(args.model, args.port, compile=args.compile, continuous_batching=args.continuous_batching, + attn_implementation=args.attn_impl) if not wait_for_server(base_url): print("ERROR: Server did not start") if server: @@ -392,44 +396,34 @@ async def async_main(args): tokenizer_id = args.processor or args.model tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) - # Generate prompts - prompts = make_prompts(tokenizer, args.num_requests, args.prompt_tokens, variance=args.prompt_variance) - print(f"Generated {len(prompts)} prompts (~{args.prompt_tokens} tokens each, ±{int(args.prompt_variance*100)}%)") + num_requests = max(args.max_concurrency) + prompts = make_prompts(tokenizer, num_requests, args.prompt_tokens, variance=args.prompt_variance) + print(f"Generated {num_requests} prompts (~{args.prompt_tokens} tokens each, ±{int(args.prompt_variance*100)}%)") print(f"Max new tokens per request: {args.max_new_tokens}") - print(f"Endpoint: /v1/{args.endpoint}") - - # Warmup — use the longest prompt so compilation covers all shorter sizes - warmup_prompt = max(prompts, key=len) - print(f"Warming up ({args.warmup} requests, longest prompt)...") - async with aiohttp.ClientSession() as session: - for i in range(args.warmup): - await send_request(session, base_url, warmup_prompt, args.max_new_tokens, args.seed, args.endpoint) + print(f"Endpoint: /v1/{args.endpoint} ({'streaming' if args.stream else 'non-streaming'})") + + # Warmup — small batch to warm CUDA graphs and JIT kernels + warmup_size = min(16, num_requests) + warmup_prompts = prompts[:warmup_size] + print(f"Warming up ({args.warmup}x {warmup_size} requests)...") + for _ in range(args.warmup): + await run_concurrency_test( + base_url, warmup_prompts, args.max_new_tokens, args.seed, args.endpoint, model=args.model, stream=args.stream, + ) print("Warmup done.") - # Run tests - if args.request_rate: - # Rate-based test - label = f"rate={args.request_rate} req/s, {args.num_requests} requests" + # Run tests — one round per concurrency level + for concurrency in args.max_concurrency: + test_prompts = prompts[:concurrency] + label = f"{concurrency} concurrent requests" print(f"\nRunning: {label}") t0 = time.perf_counter() - results = await run_rate_test( - base_url, prompts, args.max_new_tokens, args.request_rate, args.seed, args.endpoint, + results = await run_concurrency_test( + base_url, test_prompts, args.max_new_tokens, args.seed, args.endpoint, model=args.model, stream=args.stream, ) duration = time.perf_counter() - t0 metrics = compute_metrics(results, duration) print_metrics(metrics, label) - else: - # Concurrency sweep - for concurrency in args.max_concurrency: - label = f"concurrency={concurrency}, {args.num_requests} requests" - print(f"\nRunning: {label}") - t0 = time.perf_counter() - results = await run_concurrency_test( - base_url, prompts, args.max_new_tokens, concurrency, args.seed, args.endpoint, - ) - duration = time.perf_counter() - t0 - metrics = compute_metrics(results, duration) - print_metrics(metrics, label) if server: server.kill_server() @@ -445,14 +439,14 @@ def main(): parser.add_argument("--url", type=str, default=None, help="Existing server URL (skip start/stop)") parser.add_argument("--port", type=int, default=8642) parser.add_argument("--compile", action="store_true", help="Enable --compile on the server") + parser.add_argument("--continuous-batching", action="store_true", help="Enable continuous batching on the server") + parser.add_argument("--attn-impl", type=str, default=None, help="Attention implementation (e.g. flash_attention_3)") parser.add_argument("--endpoint", type=str, choices=["chat", "responses"], default="responses") + parser.add_argument("--no-stream", action="store_true", help="Use non-streaming requests") # Load parameters parser.add_argument("--max-concurrency", type=int, nargs="+", default=[1, 2, 4], - help="Concurrency levels to sweep (default: 1 2 4)") - parser.add_argument("--request-rate", type=float, default=None, - help="Target request rate (req/s). Uses Poisson arrivals. Overrides --max-concurrency.") - parser.add_argument("--num-requests", type=int, default=16, help="Total requests per test (default: 16)") + help="Number of concurrent requests to send (default: 1 2 4)") # Prompt parameters parser.add_argument("--prompt-tokens", type=int, default=256, help="Target prompt length in tokens (default: 256)") @@ -463,6 +457,7 @@ def main(): parser.add_argument("--warmup", type=int, default=2, help="Warmup requests (default: 2)") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() + args.stream = not args.no_stream asyncio.run(async_main(args)) From 720ecdbda14b31ff85d30a7eae461f55e8e0caf6 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 10:19:09 +0000 Subject: [PATCH 28/64] better stream --- src/transformers/cli/serving/utils.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index f4f4cbd17aff..c053d8ce31a7 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -604,20 +604,18 @@ def generate_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixi ) streamer = CBStreamer(self._cb, request_id, processor._tokenizer, loop, text_queue) - # Consume CB outputs and decode tokens into the SSE text queue. - # It's a coroutine on the event loop (via async_request_id_iter) - # to avoid spawning a thread per concurrent request. - async def _read_and_decode(): + # Register a direct callback: the dispatcher calls this on the event loop + # with each GenerationOutput. This decodes tokens and pushes text straight + # to the SSE text_queue — no intermediate async queue or coroutine needed. + def _on_output(output): try: - async for output in self._cb.async_request_id_iter(request_id): - streamer.put(output) - if output.is_finished(): - break - streamer.end() + streamer.put(output) + if output.is_finished(): + streamer.end() except Exception as e: text_queue.put_nowait(_StreamError(str(e))) - asyncio.ensure_future(_read_and_decode()) + self._cb.register_streaming_callback(request_id, _on_output) return text_queue, streamer async def generate_non_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): From 442463570c5127a9fc6caac803a20c1c2e2ec14b Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 15:42:35 +0000 Subject: [PATCH 29/64] update bench --- src/transformers/cli/serving/utils.py | 2 - tests/cli/bench_cb_raw.py | 81 ++++++++++++++++----------- tests/cli/benchmark_serve.py | 6 +- tests/cli/benchmark_serve_load.py | 18 ++++-- 4 files changed, 65 insertions(+), 42 deletions(-) diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index c053d8ce31a7..2b3e1913eace 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -598,7 +598,6 @@ def generate_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixi input_ids, request_id=request_id, max_new_tokens=gen_config.max_new_tokens, - min_new_tokens=gen_config.min_new_tokens, streaming=True, eos_token_id=gen_config.eos_token_id, ) @@ -637,7 +636,6 @@ async def generate_non_streaming(self, model: "PreTrainedModel", processor: "Pro input_ids, request_id=request_id, max_new_tokens=gen_config.max_new_tokens, - min_new_tokens=gen_config.min_new_tokens, streaming=False, eos_token_id=gen_config.eos_token_id, ) diff --git a/tests/cli/bench_cb_raw.py b/tests/cli/bench_cb_raw.py index d694c65bfe11..40ce3f4799a1 100644 --- a/tests/cli/bench_cb_raw.py +++ b/tests/cli/bench_cb_raw.py @@ -38,23 +38,30 @@ def bench_ns_get_result(mgr, prompts, max_new_tokens): total = finished = 0 while finished < N: r = mgr.get_result(timeout=1) - if r and r.is_finished(): + if r is not None and r.is_finished(): total += len(r.generated_tokens) finished += 1 return total, time.perf_counter() - t0 -async def bench_ns_future(mgr, prompts, max_new_tokens): - """Non-stream + future: one asyncio.Future per request, resolved by dispatcher.""" +async def bench_ns_handler(mgr, prompts, max_new_tokens): + """Non-stream + handler: register_result_handler per request, resolve future on finish.""" + loop = asyncio.get_running_loop() t0 = time.perf_counter() futures = [] for i, ids in enumerate(prompts): - rid = f"nsf_{time.perf_counter_ns()}_{i}" - future = mgr.register_async_future(rid) + rid = f"nsh_{time.perf_counter_ns()}_{i}" + future = loop.create_future() + + def _on_result(output, fut=future): + if not fut.done() and output.is_finished(): + fut.set_result(len(output.generated_tokens)) + + mgr.register_result_handler(rid, _on_result) mgr.add_request(ids, request_id=rid, max_new_tokens=max_new_tokens, streaming=False) futures.append(future) results = await asyncio.gather(*futures) - return sum(len(r.generated_tokens) for r in results), time.perf_counter() - t0 + return sum(results), time.perf_counter() - t0 # --------------------------------------------------------------------------- @@ -76,22 +83,24 @@ def bench_s_get_result(mgr, prompts, max_new_tokens): return total, time.perf_counter() - t0 -async def bench_s_async_iter(mgr, prompts, max_new_tokens): - """Stream + async_request_id_iter: per-request async queue via dispatcher.""" +async def bench_s_handler(mgr, prompts, max_new_tokens): + """Stream + handler: register_result_handler per request, await future on finish.""" + loop = asyncio.get_running_loop() t0 = time.perf_counter() - rids = [] + futures = [] for i, ids in enumerate(prompts): - rid = f"sai_{time.perf_counter_ns()}_{i}" - mgr.add_request(ids, request_id=rid, max_new_tokens=max_new_tokens, streaming=True) - rids.append(rid) + rid = f"sh_{time.perf_counter_ns()}_{i}" + future = loop.create_future() - async def consume(rid): - async for output in mgr.async_request_id_iter(rid): - if output.is_finished(): - return len(output.generated_tokens) - return 0 + def _on_output(output, fut=future): + if not fut.done() and output.is_finished(): + fut.set_result(len(output.generated_tokens)) - results = await asyncio.gather(*[consume(rid) for rid in rids]) + mgr.register_result_handler(rid, _on_output) + mgr.add_request(ids, request_id=rid, max_new_tokens=max_new_tokens, streaming=True) + futures.append(future) + + results = await asyncio.gather(*futures) return sum(results), time.perf_counter() - t0 @@ -101,9 +110,9 @@ async def consume(rid): METHODS = { "ns_get_result": ("Non-stream + get_result", lambda mgr, p, m: bench_ns_get_result(mgr, p, m)), - "ns_future": ("Non-stream + future", lambda mgr, p, m: asyncio.run(bench_ns_future(mgr, p, m))), + "ns_handler": ("Non-stream + handler", lambda mgr, p, m: asyncio.run(bench_ns_handler(mgr, p, m))), "s_get_result": ("Stream + get_result", lambda mgr, p, m: bench_s_get_result(mgr, p, m)), - "s_async_iter": ("Stream + async_iter", lambda mgr, p, m: asyncio.run(bench_s_async_iter(mgr, p, m))), + "s_handler": ("Stream + handler", lambda mgr, p, m: asyncio.run(bench_s_handler(mgr, p, m))), } @@ -147,22 +156,23 @@ def main(): for N in args.batch: prompts = all_prompts[:N] - with model.continuous_batching_context_manager( - generation_config=gen_config, continuous_batching_config=cb_config, block=True, timeout=5, - ) as mgr: - # Warmup for this batch size - warmup_prompts = prompts[:min(200, N)] - for _ in range(args.warmup): - bench_ns_get_result(mgr, warmup_prompts, args.max_new_tokens) - - row = f"{N:>6}" - for method_key in args.methods: - _, fn = METHODS[method_key] + row = f"{N:>6}" + for method_key in args.methods: + _, fn = METHODS[method_key] + # Fresh CB context per method — each gets its own CUDA graph cache + with model.continuous_batching_context_manager( + generation_config=gen_config, continuous_batching_config=cb_config, block=True, timeout=5, + ) as mgr: + # Warmup with the same method being tested + warmup_prompts = prompts[:min(200, N)] + for _ in range(args.warmup): + fn(mgr, warmup_prompts, args.max_new_tokens) + # Measured runs best = 0 for _ in range(args.runs): tokens, dt = fn(mgr, prompts, args.max_new_tokens) best = max(best, tokens / dt if dt > 0 else 0) - row += f" | {best:>{col_w - 4}.0f} t/s" + row += f" | {best:>{col_w - 4}.0f} t/s" print(row, flush=True) # Quality check @@ -171,9 +181,14 @@ def main(): generation_config=gen_config, continuous_batching_config=cb_config, block=True, timeout=5, ) as mgr: async def check(): + loop = asyncio.get_running_loop() for i in range(3): rid = f"qc_{i}" - future = mgr.register_async_future(rid) + future = loop.create_future() + def _on_qc(output, fut=future): + if not fut.done(): + fut.set_result(output) + mgr.register_result_handler(rid, _on_qc) mgr.add_request(all_prompts[i], request_id=rid, max_new_tokens=args.max_new_tokens, streaming=False) r = await future text = tokenizer.decode(r.generated_tokens, skip_special_tokens=True)[:80] diff --git a/tests/cli/benchmark_serve.py b/tests/cli/benchmark_serve.py index 7782bd64c6b8..9461d902f8e6 100644 --- a/tests/cli/benchmark_serve.py +++ b/tests/cli/benchmark_serve.py @@ -104,7 +104,7 @@ def streaming_chat_completion( do_sample: bool = False, ) -> dict: """Send a streaming chat completion request. Returns {total, ttft, completion_tokens, text}.""" - gen_cfg = {"max_new_tokens": max_tokens, "min_new_tokens": max_tokens, "do_sample": do_sample} + gen_cfg = {"max_new_tokens": max_tokens, "do_sample": do_sample, "eos_token_id": -1} if do_sample: gen_cfg["temperature"] = 0.7 @@ -165,7 +165,7 @@ def streaming_response( do_sample: bool = False, ) -> dict: """Send a streaming responses API request. Returns {total, ttft, completion_tokens, text}.""" - gen_cfg = {"max_new_tokens": max_tokens, "min_new_tokens": max_tokens, "do_sample": do_sample} + gen_cfg = {"max_new_tokens": max_tokens, "do_sample": do_sample, "eos_token_id": -1} if do_sample: gen_cfg["temperature"] = 0.7 @@ -530,7 +530,7 @@ def main(): # streaming warmup can hang waiting for SSE chunks during compilation) if args.compile: warmup_prompt = make_prompt(tokenizer, max(args.pp + [args.tg_prefill])) - gen_cfg = {"max_new_tokens": max(args.tg), "do_sample": False} + gen_cfg = {"max_new_tokens": max(args.tg), "do_sample": False, "eos_token_id": -1} payload = {"messages": [{"role": "user", "content": warmup_prompt}], "stream": False, "seed": args.seed, "generation_config": json.dumps(gen_cfg)} print(" compile warmup (non-streaming, may take ~30s)...") diff --git a/tests/cli/benchmark_serve_load.py b/tests/cli/benchmark_serve_load.py index 737af09c239e..68e3af3d7ff4 100644 --- a/tests/cli/benchmark_serve_load.py +++ b/tests/cli/benchmark_serve_load.py @@ -81,7 +81,9 @@ async def send_request( stream: bool = True, ) -> dict: """Send a single request and collect timing metrics.""" - gen_cfg = {"max_new_tokens": max_new_tokens, "do_sample": False} + # eos_token_id=-1 forces exact max_new_tokens generation (no early stopping) + # for consistent benchmarking + gen_cfg = {"max_new_tokens": max_new_tokens, "do_sample": False, "eos_token_id": -1} if endpoint == "responses": url = f"{base_url}/v1/responses" @@ -110,7 +112,7 @@ async def send_request( error = None try: - async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=300)) as resp: + async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=600)) as resp: if resp.status != 200: error = f"HTTP {resp.status}: {await resp.text()}" return _make_result(t_start, error=error) @@ -237,11 +239,14 @@ async def run_concurrency_test( def compute_metrics(results: list[dict], duration: float) -> dict: """Compute aggregate metrics from individual request results.""" + from collections import Counter + successful = [r for r in results if r["error"] is None] failed = [r for r in results if r["error"] is not None] + error_summary = Counter(r["error"] for r in failed) if not successful: - return {"error": "all requests failed", "failures": len(failed)} + return {"error": "all requests failed", "failures": len(failed), "error_summary": error_summary} total_output_tokens = sum(r["output_tokens"] for r in successful) @@ -280,6 +285,7 @@ def percentiles(values): "ttft": percentiles(ttfts), "tpot": percentiles(tpots), "itl": percentiles(all_itl), + "error_summary": error_summary, } @@ -304,6 +310,9 @@ def print_metrics(metrics: dict, label: str): return print(f" Requests: {metrics['successful']} ok / {metrics['failed']} failed / {metrics['total_requests']} total") + if metrics.get("error_summary"): + for err, count in metrics["error_summary"].most_common(5): + print(f" - {count}x: {err}") print(f" Duration: {metrics['duration']:.1f}s") print(f" Throughput: {metrics['throughput_req_per_sec']:.2f} req/s, {metrics['throughput_tok_per_sec']:.1f} tok/s") print(f" Tokens: {metrics['total_output_tokens']} total output") @@ -402,7 +411,8 @@ async def async_main(args): print(f"Max new tokens per request: {args.max_new_tokens}") print(f"Endpoint: /v1/{args.endpoint} ({'streaming' if args.stream else 'non-streaming'})") - # Warmup — small batch to warm CUDA graphs and JIT kernels + # Warmup — ramp up to full batch size so CUDA graphs are compiled for + # the batch shapes the scheduler will use under load (~100+ active requests). warmup_size = min(16, num_requests) warmup_prompts = prompts[:warmup_size] print(f"Warming up ({args.warmup}x {warmup_size} requests)...") From 7d0cd776d94b85fc2b7daf1fb9ddf67c22137366 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 15:42:49 +0000 Subject: [PATCH 30/64] fix --- src/transformers/cli/serving/utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index 2b3e1913eace..5d44538b7196 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -614,23 +614,28 @@ def _on_output(output): except Exception as e: text_queue.put_nowait(_StreamError(str(e))) - self._cb.register_streaming_callback(request_id, _on_output) + self._cb.register_result_handler(request_id, _on_output) return text_queue, streamer async def generate_non_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): """Non-streaming CB generation, fully async (no per-request thread). - Uses ``register_async_future`` — the dispatcher resolves a single - asyncio.Future when the result arrives. No per-request queue, no polling - loop — scales to thousands of concurrent requests with minimal event loop - overhead. + Registers a handler that resolves an asyncio.Future when the result arrives. + No per-request queue, no polling — just one ``await`` per request. """ input_ids = inputs["input_ids"] input_len = len(input_ids) # Register future BEFORE add_request to avoid race with fast completion request_id = request_id or f"cb_{id(inputs)}" - future = self._cb.register_async_future(request_id) + loop = asyncio.get_running_loop() + future = loop.create_future() + + def _on_result(result): + if not future.done(): + future.set_result(result) + + self._cb.register_result_handler(request_id, _on_result) self._cb.add_request( input_ids, From 533233ccfc935c2e30e4159408c370993c64bea0 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 15:52:28 +0000 Subject: [PATCH 31/64] serve refactored --- src/transformers/cli/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/cli/transformers.py b/src/transformers/cli/transformers.py index cefee1ca97c8..6ae79d99c74e 100644 --- a/src/transformers/cli/transformers.py +++ b/src/transformers/cli/transformers.py @@ -18,7 +18,7 @@ from transformers.cli.add_new_model_like import add_new_model_like from transformers.cli.chat import Chat from transformers.cli.download import download -from transformers.cli.serve import Serve +from transformers.cli.serve_refactored import Serve from transformers.cli.system import env, version From 880e6e07717a17b1a3ddd9ff6725cb32f8c97d25 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 15:54:00 +0000 Subject: [PATCH 32/64] merge --- .../continuous_batching/continuous_api.py | 139 +++++++++++++++--- 1 file changed, 119 insertions(+), 20 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index ee7bcc4ddde8..c81885ace5a4 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import gc import queue import threading @@ -95,6 +96,7 @@ def __init__( model_device: torch.device, model_dtype: torch.dtype, scheduler: Scheduler, + deliver_outputs: callable, ) -> None: """Initialize the continuous batch processor. @@ -103,11 +105,14 @@ def __init__( config: The model configuration generation_config: The generation configuration input_queue: Queue for incoming requests - output_queue: Queue for outgoing results + output_queue: Queue for outgoing results (used by ``get_result()`` callers) stop_event: Event to signal processing should stop model_device: Device for model inputs/outputs model_dtype: Data type for model inputs/outputs scheduler: The [`Scheduler`] to use + deliver_outputs: Called with a list of ``GenerationOutput`` at the end of each + generation step. Provided by the manager to route results to registered + handlers or fall back to the output_queue. """ self.cache = cache self.config = config @@ -118,6 +123,7 @@ def __init__( self.model_device = model_device self.model_dtype = model_dtype self.scheduler = scheduler + self._deliver_outputs = deliver_outputs # Generation-related attributes self.do_sample = getattr(generation_config, "do_sample", True) @@ -327,17 +333,12 @@ def prepare_next_batch(self) -> bool: self.metrics.record_kv_cache_memory_metrics(self.cache) return True - @traced - def _maybe_send_output(self, state: RequestState) -> None: - """Send output to the queue based on streaming mode and request state.""" - if state.streaming or state.status == RequestStatus.FINISHED: - self.output_queue.put(state.to_generation_output()) - @traced def update_batch(self) -> None: """Update request states based on generated tokens.""" requests_in_batch, new_tokens, logprobs = self.inputs_and_outputs.prepare_batch_update() current_logits_index = 0 + pending_outputs = [] for future_state in requests_in_batch: state = future_state.state # Early return if the request is finished @@ -367,7 +368,8 @@ def update_batch(self) -> None: self.metrics.record_request_completion(state.created_time, state.request_id) self.scheduler.finish_request(state.request_id) self.scheduler.block_new_requests = False - self._maybe_send_output(state) + if state.streaming or state.status == RequestStatus.FINISHED: + pending_outputs.append(state.to_generation_output()) # Otherwise, the request is still prefilling, but the prefill has been split elif state.status == RequestStatus.PREFILLING: self.cache.mark_shareable_blocks_as_complete(state, future_state.complete_blocks) @@ -398,6 +400,10 @@ def update_batch(self) -> None: with maybe_stream: self.cache.copy_cache(copy_source, copy_destination) + # Deliver outputs after all GPU work is done to minimize GIL contention + if pending_outputs: + self._deliver_outputs(pending_outputs) + @traced def has_pending_requests(self) -> bool: """Check if there are any active or waiting requests.""" @@ -596,7 +602,14 @@ def __init__( self._use_prefix_sharing = self.continuous_batching_config.allow_block_sharing self.input_queue = queue.Queue(maxsize=self.continuous_batching_config.max_queue_size) + self._has_new_requests = threading.Event() self.output_queue = queue.Queue() + # Per-request result handlers: request_id → (callback, event_loop). + # Registered via register_result_handler(). The generation thread delivers + # outputs directly via call_soon_threadsafe — no dispatcher thread needed. + # Unhandled results fall back to the output_queue for get_result() callers. + self._result_handlers: dict[str, tuple[callable, asyncio.AbstractEventLoop]] = {} + self._result_handlers_lock = threading.Lock() self.stop_event = threading.Event() self.batch_processor: ContinuousBatchProcessor | None = None self._generation_thread = None @@ -624,6 +637,43 @@ def __init__( self.kv_padding_interval_size = self.continuous_batching_config.kv_padding_interval_size self.max_cached_graphs = self.continuous_batching_config.max_cached_graphs + # Log probability generation is not supported yet (TODO) + if self.log_prob_generation: + raise NotImplementedError("log_prob_generation is not supported yet") + + def _register_handler(self, request_id: str, callback: callable, loop: asyncio.AbstractEventLoop) -> None: + """Register a result handler for a request.""" + with self._result_handlers_lock: + self._result_handlers[request_id] = (callback, loop) + + def _unregister_handler(self, request_id: str) -> None: + """Remove a result handler for a request.""" + with self._result_handlers_lock: + self._result_handlers.pop(request_id, None) + + def _deliver_outputs(self, outputs: list[GenerationOutput]) -> None: + """Route outputs directly from the generation thread to registered handlers. + + Called by ``update_batch`` at the end of each generation step. Results with + registered handlers are batched per event loop and delivered via a single + ``call_soon_threadsafe``. Unhandled results fall back to the output_queue. + """ + deliveries: dict[asyncio.AbstractEventLoop, list[tuple[callable, object]]] = {} + with self._result_handlers_lock: + for output in outputs: + entry = self._result_handlers.get(output.request_id) + if entry is not None: + callback, loop = entry + deliveries.setdefault(loop, []).append((callback, output)) + else: + self.output_queue.put(output) + + for loop, items in deliveries.items(): + def _deliver_batch(batch_items=items): + for cb, res in batch_items: + cb(res) + loop.call_soon_threadsafe(_deliver_batch) + @traced def start(self) -> None: """Start the background generation thread.""" @@ -638,6 +688,7 @@ def is_running(self) -> bool: """Check if the background generation thread is running.""" return self._generation_thread is not None and self._generation_thread.is_alive() + # NOTE: don't forget to update `continuous_batching_context_manager` when changing this method's definition def stop(self, block: bool = True, timeout: float | None = None, keep_for_next_session: bool = False) -> None: """Signal the background thread to stop. @@ -721,7 +772,6 @@ def add_request( max_new_tokens = self.generation_config.max_new_tokens if max_new_tokens is None else max_new_tokens eos_token_id = self.generation_config.eos_token_id if eos_token_id is None else eos_token_id - # NOTE: do we want to handle a case when the user wants token ids returned instead of decoded text? state = RequestState( request_id=request_id, initial_tokens=list(input_ids), @@ -732,8 +782,8 @@ def add_request( streaming=streaming, ) - # Use block=True with timeout to handle backpressure if queue is full - self.input_queue.put(state, block=True, timeout=10) # XXX: pass timeout as fn arg? + self.input_queue.put(state, block=True, timeout=10) + self._has_new_requests.set() return request_id def add_requests( @@ -774,16 +824,16 @@ def get_result(self, request_id: str | None = None, timeout: float | None = None """Retrieve one result from the output queue. Args: - timeout: Maximum time to wait for a result + request_id: If set, only return results matching this ID (others are requeued). + timeout: Maximum time to wait for a result. Returns: - Optional[GenerationOutput]: The result data or None if timeout + Optional[GenerationOutput]: The result data or None if timeout. """ if self._generation_thread is None and self.output_queue.empty(): return None try: result = self.output_queue.get(block=True, timeout=timeout) - # NOTE: requeue logic here if request_id is not None and result.request_id != request_id: self.output_queue.put(result) return None @@ -798,16 +848,41 @@ def __iter__(self): if result is not None: yield result - # FIXME: stop iteration when request status is finished? def request_id_iter(self, request_id: str) -> Generator[GenerationOutput]: - """Iterate over results matching a specific request id as they become available.""" - request_cancelled = False - while self._generation_thread is not None and self._generation_thread.is_alive() and not request_cancelled: + """Iterate over results matching a specific request id (blocking). + + Uses the shared output queue with requeue. For high-concurrency serving, + use :meth:`register_result_handler` instead. + """ + while self._generation_thread is not None and self._generation_thread.is_alive(): result = self.get_result(request_id=request_id, timeout=0.1) if result is not None: yield result - if self.batch_processor is not None: - request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id) + if result.is_finished(): + return + + + def register_result_handler(self, request_id: str, callback: callable) -> None: + """Register a callback for result delivery (streaming or non-streaming). + + The callback is invoked on the event loop via ``call_soon_threadsafe`` + each time a result is produced for this request. For streaming requests, + this happens on every token; for non-streaming, only on completion. + + The handler is automatically cleaned up when the request finishes. + + Args: + request_id (`str`): The request ID to receive outputs for. + callback (`callable`): Called with a ``GenerationOutput`` for each result. + """ + loop = asyncio.get_running_loop() + + def _auto_cleanup(result): + callback(result) + if result.is_finished(): + self._unregister_handler(request_id) + + self._register_handler(request_id, _auto_cleanup, loop) @traced def _generation_step(self) -> None: @@ -862,6 +937,27 @@ def _run_generation_loop(self) -> None: batch_processor = self._create_batch_processor() # Start the generation loop + scheduler = SCHEDULER_MAPPING.get(self.continuous_batching_config.scheduler, None) + if scheduler is None: + logger.warning( + f"Scheduler '{self.continuous_batching_config.scheduler}' not found. Defaulting to FIFO." + ) + scheduler = FIFOScheduler + + t1 = perf_counter() + batch_processor = ContinuousBatchProcessor( + cache=paged_attention_cache, + config=self.model.config, + generation_config=self.generation_config, + continuous_batching_config=self.continuous_batching_config, + input_queue=self.input_queue, + output_queue=self.output_queue, + stop_event=self.stop_event, + model_device=self.model.device, + model_dtype=self.model.dtype, + scheduler=scheduler(paged_attention_cache), + deliver_outputs=self._deliver_outputs, + ) self.batch_processor = batch_processor self.current_batch = 0 @@ -892,6 +988,9 @@ def _run_generation_loop(self) -> None: def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor) -> None: # Loop body ends if there is no requests in the batch if not batch_processor.prepare_next_batch(): + # Wait for new requests instead of busy-spinning. + self._has_new_requests.wait(timeout=0.1) + self._has_new_requests.clear() return self._generation_step() batch_processor.update_batch() From 4aa7fec0168e7a9c32b065b12c35ce86bb232949 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 16:07:03 +0000 Subject: [PATCH 33/64] update --- .../generation/continuous_batching/continuous_api.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index c81885ace5a4..56f58b2f68b7 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -110,9 +110,7 @@ def __init__( model_device: Device for model inputs/outputs model_dtype: Data type for model inputs/outputs scheduler: The [`Scheduler`] to use - deliver_outputs: Called with a list of ``GenerationOutput`` at the end of each - generation step. Provided by the manager to route results to registered - handlers or fall back to the output_queue. + deliver_outputs: Callback that receives a list of ``GenerationOutput`` after each step. """ self.cache = cache self.config = config @@ -604,10 +602,6 @@ def __init__( self.input_queue = queue.Queue(maxsize=self.continuous_batching_config.max_queue_size) self._has_new_requests = threading.Event() self.output_queue = queue.Queue() - # Per-request result handlers: request_id → (callback, event_loop). - # Registered via register_result_handler(). The generation thread delivers - # outputs directly via call_soon_threadsafe — no dispatcher thread needed. - # Unhandled results fall back to the output_queue for get_result() callers. self._result_handlers: dict[str, tuple[callable, asyncio.AbstractEventLoop]] = {} self._result_handlers_lock = threading.Lock() self.stop_event = threading.Event() From 3ab4e09241ef816a5a6df5ee4107948633dad661 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 16:13:06 +0000 Subject: [PATCH 34/64] fix --- .../continuous_batching/continuous_api.py | 27 +++---------------- 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 56f58b2f68b7..2e5697515125 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -663,9 +663,11 @@ def _deliver_outputs(self, outputs: list[GenerationOutput]) -> None: self.output_queue.put(output) for loop, items in deliveries.items(): + def _deliver_batch(batch_items=items): for cb, res in batch_items: cb(res) + loop.call_soon_threadsafe(_deliver_batch) @traced @@ -682,7 +684,6 @@ def is_running(self) -> bool: """Check if the background generation thread is running.""" return self._generation_thread is not None and self._generation_thread.is_alive() - # NOTE: don't forget to update `continuous_batching_context_manager` when changing this method's definition def stop(self, block: bool = True, timeout: float | None = None, keep_for_next_session: bool = False) -> None: """Signal the background thread to stop. @@ -855,7 +856,6 @@ def request_id_iter(self, request_id: str) -> Generator[GenerationOutput]: if result.is_finished(): return - def register_result_handler(self, request_id: str, callback: callable) -> None: """Register a callback for result delivery (streaming or non-streaming). @@ -915,6 +915,7 @@ def _create_batch_processor(self) -> ContinuousBatchProcessor: model_device=self.model.device, model_dtype=self.model.dtype, scheduler=scheduler(paged_attention_cache), + deliver_outputs=self._deliver_outputs, ) return batch_processor @@ -930,28 +931,6 @@ def _run_generation_loop(self) -> None: else: batch_processor = self._create_batch_processor() - # Start the generation loop - scheduler = SCHEDULER_MAPPING.get(self.continuous_batching_config.scheduler, None) - if scheduler is None: - logger.warning( - f"Scheduler '{self.continuous_batching_config.scheduler}' not found. Defaulting to FIFO." - ) - scheduler = FIFOScheduler - - t1 = perf_counter() - batch_processor = ContinuousBatchProcessor( - cache=paged_attention_cache, - config=self.model.config, - generation_config=self.generation_config, - continuous_batching_config=self.continuous_batching_config, - input_queue=self.input_queue, - output_queue=self.output_queue, - stop_event=self.stop_event, - model_device=self.model.device, - model_dtype=self.model.dtype, - scheduler=scheduler(paged_attention_cache), - deliver_outputs=self._deliver_outputs, - ) self.batch_processor = batch_processor self.current_batch = 0 From 06bacbbf7502f7634c39b5a19e1fc9ec1444aa28 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 16:19:09 +0000 Subject: [PATCH 35/64] style --- .../generation/continuous_batching/continuous_api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 2e5697515125..a4f7fe343c0b 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -767,6 +767,7 @@ def add_request( max_new_tokens = self.generation_config.max_new_tokens if max_new_tokens is None else max_new_tokens eos_token_id = self.generation_config.eos_token_id if eos_token_id is None else eos_token_id + # NOTE: do we want to handle a case when the user wants token ids returned instead of decoded text? state = RequestState( request_id=request_id, initial_tokens=list(input_ids), @@ -777,6 +778,7 @@ def add_request( streaming=streaming, ) + # Use block=True with timeout to handle backpressure if queue is full self.input_queue.put(state, block=True, timeout=10) self._has_new_requests.set() return request_id @@ -931,6 +933,7 @@ def _run_generation_loop(self) -> None: else: batch_processor = self._create_batch_processor() + # Start the generation loop self.batch_processor = batch_processor self.current_batch = 0 From ef106187865443ac971e2efc6f1a32f7f23e7b95 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 16:33:18 +0000 Subject: [PATCH 36/64] simpler --- .../continuous_batching/continuous_api.py | 40 +++++-------------- 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index a4f7fe343c0b..5f5a5ea6fd0b 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -110,7 +110,7 @@ def __init__( model_device: Device for model inputs/outputs model_dtype: Data type for model inputs/outputs scheduler: The [`Scheduler`] to use - deliver_outputs: Callback that receives a list of ``GenerationOutput`` after each step. + deliver_outputs: Callback that receives a single ``GenerationOutput``. """ self.cache = cache self.config = config @@ -336,7 +336,6 @@ def update_batch(self) -> None: """Update request states based on generated tokens.""" requests_in_batch, new_tokens, logprobs = self.inputs_and_outputs.prepare_batch_update() current_logits_index = 0 - pending_outputs = [] for future_state in requests_in_batch: state = future_state.state # Early return if the request is finished @@ -367,7 +366,7 @@ def update_batch(self) -> None: self.scheduler.finish_request(state.request_id) self.scheduler.block_new_requests = False if state.streaming or state.status == RequestStatus.FINISHED: - pending_outputs.append(state.to_generation_output()) + self._deliver_outputs(state.to_generation_output()) # Otherwise, the request is still prefilling, but the prefill has been split elif state.status == RequestStatus.PREFILLING: self.cache.mark_shareable_blocks_as_complete(state, future_state.complete_blocks) @@ -398,10 +397,6 @@ def update_batch(self) -> None: with maybe_stream: self.cache.copy_cache(copy_source, copy_destination) - # Deliver outputs after all GPU work is done to minimize GIL contention - if pending_outputs: - self._deliver_outputs(pending_outputs) - @traced def has_pending_requests(self) -> bool: """Check if there are any active or waiting requests.""" @@ -645,30 +640,15 @@ def _unregister_handler(self, request_id: str) -> None: with self._result_handlers_lock: self._result_handlers.pop(request_id, None) - def _deliver_outputs(self, outputs: list[GenerationOutput]) -> None: - """Route outputs directly from the generation thread to registered handlers. - - Called by ``update_batch`` at the end of each generation step. Results with - registered handlers are batched per event loop and delivered via a single - ``call_soon_threadsafe``. Unhandled results fall back to the output_queue. - """ - deliveries: dict[asyncio.AbstractEventLoop, list[tuple[callable, object]]] = {} + def _deliver_outputs(self, output: GenerationOutput) -> None: + """Route a single output to its registered handler or the output_queue.""" with self._result_handlers_lock: - for output in outputs: - entry = self._result_handlers.get(output.request_id) - if entry is not None: - callback, loop = entry - deliveries.setdefault(loop, []).append((callback, output)) - else: - self.output_queue.put(output) - - for loop, items in deliveries.items(): - - def _deliver_batch(batch_items=items): - for cb, res in batch_items: - cb(res) - - loop.call_soon_threadsafe(_deliver_batch) + entry = self._result_handlers.get(output.request_id) + if entry is not None: + callback, loop = entry + loop.call_soon_threadsafe(callback, output) + else: + self.output_queue.put(output) @traced def start(self) -> None: From 09d5fe172ee7ffe453a6bf3533cf53d87ab6527c Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 16:33:23 +0000 Subject: [PATCH 37/64] style --- .../generation/continuous_batching/continuous_api.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 5f5a5ea6fd0b..a14c3ca0d895 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -96,7 +96,7 @@ def __init__( model_device: torch.device, model_dtype: torch.dtype, scheduler: Scheduler, - deliver_outputs: callable, + deliver_output: callable, ) -> None: """Initialize the continuous batch processor. @@ -110,7 +110,7 @@ def __init__( model_device: Device for model inputs/outputs model_dtype: Data type for model inputs/outputs scheduler: The [`Scheduler`] to use - deliver_outputs: Callback that receives a single ``GenerationOutput``. + deliver_output: Callback that receives a single ``GenerationOutput``. """ self.cache = cache self.config = config @@ -121,7 +121,7 @@ def __init__( self.model_device = model_device self.model_dtype = model_dtype self.scheduler = scheduler - self._deliver_outputs = deliver_outputs + self._deliver_output = deliver_output # Generation-related attributes self.do_sample = getattr(generation_config, "do_sample", True) @@ -366,7 +366,7 @@ def update_batch(self) -> None: self.scheduler.finish_request(state.request_id) self.scheduler.block_new_requests = False if state.streaming or state.status == RequestStatus.FINISHED: - self._deliver_outputs(state.to_generation_output()) + self._deliver_output(state.to_generation_output()) # Otherwise, the request is still prefilling, but the prefill has been split elif state.status == RequestStatus.PREFILLING: self.cache.mark_shareable_blocks_as_complete(state, future_state.complete_blocks) @@ -640,7 +640,7 @@ def _unregister_handler(self, request_id: str) -> None: with self._result_handlers_lock: self._result_handlers.pop(request_id, None) - def _deliver_outputs(self, output: GenerationOutput) -> None: + def _deliver_output(self, output: GenerationOutput) -> None: """Route a single output to its registered handler or the output_queue.""" with self._result_handlers_lock: entry = self._result_handlers.get(output.request_id) @@ -897,7 +897,7 @@ def _create_batch_processor(self) -> ContinuousBatchProcessor: model_device=self.model.device, model_dtype=self.model.dtype, scheduler=scheduler(paged_attention_cache), - deliver_outputs=self._deliver_outputs, + deliver_output=self._deliver_output, ) return batch_processor From 96b6b8bde595c87c60b7970203152a40805c9ca5 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 17:25:35 +0000 Subject: [PATCH 38/64] update warmup --- tests/cli/benchmark_serve_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cli/benchmark_serve_load.py b/tests/cli/benchmark_serve_load.py index 68e3af3d7ff4..9297a171903e 100644 --- a/tests/cli/benchmark_serve_load.py +++ b/tests/cli/benchmark_serve_load.py @@ -413,7 +413,7 @@ async def async_main(args): # Warmup — ramp up to full batch size so CUDA graphs are compiled for # the batch shapes the scheduler will use under load (~100+ active requests). - warmup_size = min(16, num_requests) + warmup_size = min(200, num_requests) warmup_prompts = prompts[:warmup_size] print(f"Warming up ({args.warmup}x {warmup_size} requests)...") for _ in range(args.warmup): From 07ecd2a613ce83d9c205ec4217204d2093c4f363 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 18:06:20 +0000 Subject: [PATCH 39/64] remove llamacpp integration for now --- src/transformers/cli/chat.py | 7 ---- src/transformers/cli/serve_refactored.py | 7 ---- .../cli/serving/chat_completion.py | 8 ++--- src/transformers/cli/serving/model_manager.py | 34 +++---------------- src/transformers/cli/serving/response.py | 7 ++-- src/transformers/cli/serving/utils.py | 20 ++--------- 6 files changed, 15 insertions(+), 68 deletions(-) diff --git a/src/transformers/cli/chat.py b/src/transformers/cli/chat.py index fda3c330834b..f38305550de7 100644 --- a/src/transformers/cli/chat.py +++ b/src/transformers/cli/chat.py @@ -332,10 +332,6 @@ def __init__( help="Path to a local generation config file or to a HuggingFace repo containing a `generation_config.json` file. Other generation settings passed as CLI arguments will be applied on top of this generation config." ), ] = None, - processor: Annotated[ - str | None, - typer.Option(help="Processor/tokenizer model ID. Needed for GGUF models whose repos don't include tokenizer files."), - ] = None, ) -> None: """Chat with a model from the command line.""" self.base_url = base_url @@ -354,7 +350,6 @@ def __init__( config.update(**parse_generate_flags(generate_flags)) self.config = config - self.processor = processor self.settings = {"base_url": base_url, "model_id": model_id, "config": self.config.to_dict()} # User settings @@ -563,8 +558,6 @@ async def _inner_run(self): "generation_config": config.to_json_string(), "model": self.model_id, } - if self.processor: - extra_body["processor"] = self.processor stream = client.chat_completion( chat, diff --git a/src/transformers/cli/serve_refactored.py b/src/transformers/cli/serve_refactored.py index 725fa41f7fab..4936b589d9a0 100644 --- a/src/transformers/cli/serve_refactored.py +++ b/src/transformers/cli/serve_refactored.py @@ -56,10 +56,6 @@ def __init__( str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.") ] = None, trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False, - # TODO: auto-detect processor from GGUF base_model metadata so this flag isn't needed - processor: Annotated[ - str | None, typer.Option(help="Processor/tokenizer model ID. Needed for GGUF models.") - ] = None, model_timeout: Annotated[ int, typer.Option(help="Seconds before idle model is unloaded. Ignored when model is set.") ] = 300, @@ -117,7 +113,6 @@ def __init__( quantization=quantization, model_timeout=model_timeout, force_model=force_model, - processor_id=processor, ) self._model_manager = model_manager self._generation_state = GenerationState(continuous_batching=continuous_batching) @@ -126,7 +121,6 @@ def __init__( model_manager=model_manager, generation_state=self._generation_state, force_model=force_model, - force_processor=processor, compile=compile, ) @@ -134,7 +128,6 @@ def __init__( model_manager=model_manager, generation_state=self._generation_state, force_model=force_model, - force_processor=processor, compile=compile, ) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index ccc1d9bf078e..1b79f6c51ad4 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -18,6 +18,7 @@ """ import asyncio +import json import time from collections.abc import AsyncGenerator from typing import TYPE_CHECKING @@ -47,7 +48,6 @@ class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): generation_config: str - processor: str # Fields accepted by the OpenAI schema but not yet supported. @@ -132,7 +132,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse tokenize=True, ).to(model.device) - gen_config = self._build_generation_config(body, model.generation_config, processor, use_cb=use_cb) + gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) # TODO: remove when CB supports per-request generation config if use_cb: gen_manager.init_cb(model, gen_config) @@ -314,10 +314,10 @@ def _validate_request(self, body: dict) -> None: if unused: raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") - def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None, use_cb: bool = False): + def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): """Apply Chat Completions params (``max_tokens``, ``frequency_penalty``, ``logit_bias``, ``stop``) on top of the base generation config.""" - generation_config = super()._build_generation_config(body, model_generation_config, processor, use_cb=use_cb) + generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) if body.get("max_tokens") is not None: generation_config.max_new_tokens = int(body["max_tokens"]) diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 273b2419c978..f44e8fa6bbfa 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -103,8 +103,6 @@ class ModelManager: quantization: Quantization method ("bnb-4bit" or "bnb-8bit"). model_timeout: Seconds before an idle model is unloaded. -1 disables. force_model: If set, preload this model at init time. - processor_id: Override processor/tokenizer model ID. Needed for GGUF models - whose repos don't include tokenizer files. """ def __init__( @@ -116,8 +114,6 @@ def __init__( quantization: str | None = None, model_timeout: int = 300, force_model: str | None = None, - # TODO: auto-detect from GGUF base_model metadata - processor_id: str | None = None, ): self.device = device self.dtype = dtype @@ -137,7 +133,7 @@ def __init__( self._loading_tasks: dict[str, asyncio.Task] = {} if force_model is not None: - self.load_model_and_processor(self.process_model_name(force_model), processor_id=processor_id) + self.load_model_and_processor(self.process_model_name(force_model)) @staticmethod def process_model_name(model_id: str) -> str: @@ -159,21 +155,16 @@ def get_quantization_config(self) -> BitsAndBytesConfig | None: return None def _load_processor( - self, model_id_and_revision: str, processor_id: str | None = None + self, model_id_and_revision: str ) -> "ProcessorMixin | PreTrainedTokenizerFast": """Load a processor, trying AutoProcessor first then AutoTokenizer. Args: model_id_and_revision: Model ID in ``'model_id@revision'`` format. - processor_id: Override processor/tokenizer ID (e.g. for GGUF models). - Falls back to ``model_id``. """ from transformers import AutoProcessor - if processor_id: - model_id, revision = processor_id, "main" - else: - model_id, revision = model_id_and_revision.split("@", 1) + model_id, revision = model_id_and_revision.split("@", 1) try: return AutoProcessor.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code) except OSError: @@ -185,7 +176,7 @@ def _load_processor( raise OSError(f"Failed to load processor for {model_id} with AutoProcessor and AutoTokenizer.") def _load_model(self, model_id_and_revision: str, tqdm_class: type | None = None, progress_callback: Callable | None = None) -> "PreTrainedModel": - """Load a model. GGUF files are detected by the ``.gguf`` extension and loaded via llama.cpp. + """Load a model. Args: model_id_and_revision (`str`): Model ID in ``'model_id@revision'`` format. @@ -201,19 +192,6 @@ def _load_model(self, model_id_and_revision: str, tqdm_class: type | None = None model_id, revision = model_id_and_revision.split("@", 1) - if model_id.endswith(".gguf"): - from llama_cpp_transformers import LlamaCppTransformersModel - - flash_attn = True if self.attn_implementation == "flash_attention_2" else "auto" - return LlamaCppTransformersModel.from_pretrained( - model_id, - revision=revision, - n_gpu_layers=-1, - flash_attn=flash_attn, - n_ctx=8192, - n_batch=2048, - ) - dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype) model_kwargs = { "revision": revision, @@ -234,7 +212,6 @@ def _load_model(self, model_id_and_revision: str, tqdm_class: type | None = None def load_model_and_processor( self, model_id_and_revision: str, - processor_id: str | None = None, progress_callback: Callable | None = None, tqdm_class: type | None = None, ) -> "tuple[PreTrainedModel, ProcessorMixin | PreTrainedTokenizerFast]": @@ -242,7 +219,6 @@ def load_model_and_processor( Args: model_id_and_revision: Model ID in ``'model_id@revision'`` format. - processor_id: Optional per-request processor override. progress_callback: If provided, called with dicts like ``{"status": "loading", "model": ..., "stage": ...}`` during loading. tqdm_class: Optional tqdm subclass for progress bars during ``from_pretrained``. @@ -258,7 +234,7 @@ def load_model_and_processor( ): if progress_callback is not None: progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "processor"}) - processor = self._load_processor(model_id_and_revision, processor_id=processor_id) + processor = self._load_processor(model_id_and_revision) model = self._load_model( model_id_and_revision, tqdm_class=tqdm_class, progress_callback=progress_callback ) diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 322a60d8bd61..5273b9841d17 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -18,6 +18,7 @@ """ import asyncio +import json import time from collections.abc import AsyncGenerator from typing import TYPE_CHECKING @@ -126,7 +127,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse tokenize=True, ).to(model.device) - gen_config = self._build_generation_config(body, model.generation_config, processor, use_cb=use_cb) + gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) # TODO: remove when CB supports per-request generation config if use_cb: gen_manager.init_cb(model, gen_config) @@ -513,9 +514,9 @@ def _validate_request(self, body: dict) -> None: if unused: raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") - def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None, use_cb: bool = False): + def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): """Apply Responses API params (``max_output_tokens``) on top of the base generation config.""" - generation_config = super()._build_generation_config(body, model_generation_config, processor, use_cb=use_cb) + generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) if body.get("max_output_tokens") is not None: generation_config.max_new_tokens = int(body["max_output_tokens"]) diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index 5d44538b7196..427fef08cf53 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -745,8 +745,6 @@ class BaseHandler: Shared state managing per-model generation managers. force_model (`str`, *optional*): If set, override the ``model`` field in every request with this model ID. - force_processor (`str`, *optional*): - If set, override the processor/tokenizer model ID. compile (`bool`, *optional*, defaults to `False`): Enable ``torch.compile`` with static cache for faster decode. """ @@ -756,13 +754,11 @@ def __init__( model_manager: "ModelManager", generation_state: GenerationState, force_model: str | None = None, - force_processor: str | None = None, compile: bool = False, ): self.model_manager = model_manager self.generation_state = generation_state self.force_model = force_model - self.force_processor = force_processor self._compile = compile @staticmethod @@ -781,12 +777,11 @@ def _resolve_model(self, body: dict): body["model"] = self.force_model model_id = self.model_manager.process_model_name(body["model"]) - processor_id = self.force_processor or body.get("processor") - model, processor = self.model_manager.load_model_and_processor(model_id, processor_id=processor_id) + model, processor = self.model_manager.load_model_and_processor(model_id) return model_id, model, processor - def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None, use_cb: bool = False): + def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): """Build a GenerationConfig from shared params (temperature, top_p, seed, generation_config JSON). Subclasses should call ``super()._build_generation_config(...)`` then apply @@ -797,9 +792,6 @@ def _build_generation_config(self, body: dict, model_generation_config: "Generat The raw request body. model_generation_config (`GenerationConfig`): The model's default generation config (will be deep-copied). - processor (*optional*): - Processor or tokenizer, used to sync ``eos_token_id`` / ``pad_token_id`` - for GGUF models that lack them. use_cb (`bool`, *optional*, defaults to `False`): Whether continuous batching is active. If ``True``, disables the model's internal KV cache (CB manages its own paged cache). @@ -816,14 +808,6 @@ def _build_generation_config(self, body: dict, model_generation_config: "Generat if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 1024: generation_config.max_new_tokens = 1024 - # GGUF models may not have eos/pad token IDs set — sync from processor - if processor is not None: - tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor - if generation_config.eos_token_id is None and hasattr(tokenizer, "eos_token_id"): - generation_config.eos_token_id = tokenizer.eos_token_id - if generation_config.pad_token_id is None and hasattr(tokenizer, "pad_token_id"): - generation_config.pad_token_id = tokenizer.pad_token_id - if body.get("temperature") is not None: generation_config.temperature = float(body["temperature"]) if float(body["temperature"]) == 0.0: From fad7c25c28bfd51c6df46dd6a35a0687ab60d31e Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 18:07:46 +0000 Subject: [PATCH 40/64] styke --- src/transformers/cli/serving/chat_completion.py | 1 - src/transformers/cli/serving/model_manager.py | 2 +- src/transformers/cli/serving/response.py | 1 - src/transformers/cli/serving/utils.py | 2 +- tests/cli/bench_cb_raw.py | 2 ++ tests/cli/benchmark_serve.py | 4 ++-- tests/cli/benchmark_serve_load.py | 2 +- tests/cli/test_serve_refactored.py | 12 ++++++++---- 8 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 1b79f6c51ad4..fae22d16d4ea 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -18,7 +18,6 @@ """ import asyncio -import json import time from collections.abc import AsyncGenerator from typing import TYPE_CHECKING diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index f44e8fa6bbfa..12da9f243364 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -19,8 +19,8 @@ import gc import json import threading +from collections.abc import Callable from functools import lru_cache -from typing import Callable from typing import TYPE_CHECKING from huggingface_hub import scan_cache_dir diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 5273b9841d17..7f07a020987a 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -18,7 +18,6 @@ """ import asyncio -import json import time from collections.abc import AsyncGenerator from typing import TYPE_CHECKING diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index 427fef08cf53..0898230c0a96 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -24,8 +24,8 @@ import tempfile import threading from abc import ABC, abstractmethod +from collections.abc import Callable from concurrent.futures import Future -from typing import Callable from io import BytesIO from queue import Queue from typing import TYPE_CHECKING diff --git a/tests/cli/bench_cb_raw.py b/tests/cli/bench_cb_raw.py index 40ce3f4799a1..e9879044fc2a 100644 --- a/tests/cli/bench_cb_raw.py +++ b/tests/cli/bench_cb_raw.py @@ -13,9 +13,11 @@ import sys import time + os.environ["CUDA_VISIBLE_DEVICES"] = "0" import torch + from transformers import AutoModelForCausalLM, AutoTokenizer, ContinuousBatchingConfig, GenerationConfig diff --git a/tests/cli/benchmark_serve.py b/tests/cli/benchmark_serve.py index 9461d902f8e6..17a76316e706 100644 --- a/tests/cli/benchmark_serve.py +++ b/tests/cli/benchmark_serve.py @@ -223,8 +223,8 @@ def streaming_request( endpoint: str = "chat", ) -> dict: """Dispatch to chat completions or responses API based on endpoint.""" - kw = dict(base_url=base_url, messages=messages, max_tokens=max_tokens, - seed=seed, do_sample=do_sample) + kw = {"base_url": base_url, "messages": messages, "max_tokens": max_tokens, + "seed": seed, "do_sample": do_sample} if endpoint == "responses": return streaming_response(**kw) return streaming_chat_completion(**kw) diff --git a/tests/cli/benchmark_serve_load.py b/tests/cli/benchmark_serve_load.py index 9297a171903e..c25057f8bce1 100644 --- a/tests/cli/benchmark_serve_load.py +++ b/tests/cli/benchmark_serve_load.py @@ -359,7 +359,7 @@ def wait_for_server(base_url: str, timeout: int = 120) -> bool: if requests.get(f"{base_url}/health", timeout=2).status_code == 200: return True except Exception: - pass + pass # noqa: S110 time.sleep(1) return False diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py index 473bc243bb25..9b8bf429dca9 100644 --- a/tests/cli/test_serve_refactored.py +++ b/tests/cli/test_serve_refactored.py @@ -220,8 +220,9 @@ def test_lists_only_generative_models(self): class TestBuildGenerationConfig(unittest.TestCase): def _make_handler(self): from transformers.cli.serving.chat_completion import ChatCompletionHandler + from transformers.cli.serving.utils import GenerationState - from transformers.cli.serving.utils import GenerationState; return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) + return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_max_tokens(self): from transformers import GenerationConfig @@ -297,8 +298,9 @@ def test_user_max_tokens_overrides_default(self): class TestValidation(unittest.TestCase): def _make_handler(self): from transformers.cli.serving.chat_completion import ChatCompletionHandler + from transformers.cli.serving.utils import GenerationState - from transformers.cli.serving.utils import GenerationState; return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) + return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_valid_request_passes(self): handler = self._make_handler() @@ -390,8 +392,9 @@ def test_timeout_zero_no_delete(self): class TestChunkSSE(unittest.TestCase): def _make_handler(self): from transformers.cli.serving.chat_completion import ChatCompletionHandler + from transformers.cli.serving.utils import GenerationState - from transformers.cli.serving.utils import GenerationState; return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) + return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_build_chunk_sse_content(self): handler = self._make_handler() @@ -975,8 +978,9 @@ def test_request_cancellation(self): class TestResponseInputConversion(unittest.TestCase): def _make_handler(self): from transformers.cli.serving.response import ResponseHandler + from transformers.cli.serving.utils import GenerationState - from transformers.cli.serving.utils import GenerationState; return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) + return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_string_input(self): handler = self._make_handler() From feed4cbf7459bec2987f1ec9eba971a7480ba3b2 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 18:07:52 +0000 Subject: [PATCH 41/64] styke --- tests/cli/test_serve_refactored.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py index 9b8bf429dca9..221aa3812d90 100644 --- a/tests/cli/test_serve_refactored.py +++ b/tests/cli/test_serve_refactored.py @@ -1027,8 +1027,9 @@ def test_dict_input(self): class TestResponseValidation(unittest.TestCase): def _make_handler(self): from transformers.cli.serving.response import ResponseHandler + from transformers.cli.serving.utils import GenerationState - from transformers.cli.serving.utils import GenerationState; return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) + return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_unsupported_fields_rejected(self): from fastapi import HTTPException From abd40872d5f1a08bc283427cddd616b3a0414ced Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 18:10:25 +0000 Subject: [PATCH 42/64] style again --- .../cli/serving/chat_completion.py | 4 +- src/transformers/cli/serving/model_manager.py | 8 +- src/transformers/cli/serving/response.py | 4 +- src/transformers/cli/serving/transcription.py | 16 +- src/transformers/cli/serving/utils.py | 74 +++++- tests/cli/bench_cb_raw.py | 39 +++- tests/cli/benchmark_serve.py | 215 ++++++++++++++---- tests/cli/benchmark_serve_load.py | 96 +++++--- tests/cli/test_serve_refactored.py | 17 +- 9 files changed, 356 insertions(+), 117 deletions(-) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index fae22d16d4ea..730f1e04c283 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -257,7 +257,9 @@ async def _non_streaming( tool_format: dict | None = None, ) -> JSONResponse: """Run generation and return a JSONResponse.""" - content, input_len, generated_ids = await gen_manager.generate_non_streaming(model, processor, inputs, gen_config, request_id=request_id) + content, input_len, generated_ids = await gen_manager.generate_non_streaming( + model, processor, inputs, gen_config, request_id=request_id + ) hit_max = gen_config.max_new_tokens is not None and len(generated_ids) >= gen_config.max_new_tokens completion_tokens = len(generated_ids) diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 12da9f243364..397650f935f2 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -154,9 +154,7 @@ def get_quantization_config(self) -> BitsAndBytesConfig | None: return BitsAndBytesConfig(load_in_8bit=True) return None - def _load_processor( - self, model_id_and_revision: str - ) -> "ProcessorMixin | PreTrainedTokenizerFast": + def _load_processor(self, model_id_and_revision: str) -> "ProcessorMixin | PreTrainedTokenizerFast": """Load a processor, trying AutoProcessor first then AutoTokenizer. Args: @@ -175,7 +173,9 @@ def _load_processor( except OSError: raise OSError(f"Failed to load processor for {model_id} with AutoProcessor and AutoTokenizer.") - def _load_model(self, model_id_and_revision: str, tqdm_class: type | None = None, progress_callback: Callable | None = None) -> "PreTrainedModel": + def _load_model( + self, model_id_and_revision: str, tqdm_class: type | None = None, progress_callback: Callable | None = None + ) -> "PreTrainedModel": """Load a model. Args: diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 7f07a020987a..b24aabbc9cb0 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -454,7 +454,9 @@ async def _non_streaming( tool_format: dict | None = None, ) -> JSONResponse: """Generate a non-streaming Responses API reply (single JSON).""" - full_text, input_len, generated_ids = await gen_manager.generate_non_streaming(model, processor, inputs, gen_config, request_id=request_id) + full_text, input_len, generated_ids = await gen_manager.generate_non_streaming( + model, processor, inputs, gen_config, request_id=request_id + ) created_at = time.time() resp_id = f"resp_{request_id}" diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py index 8cf627f5f142..889b3eaaba5f 100644 --- a/src/transformers/cli/serving/transcription.py +++ b/src/transformers/cli/serving/transcription.py @@ -94,7 +94,13 @@ async def handle_request(self, request: Request) -> JSONResponse | StreamingResp return self._streaming(gen_manager, audio_model, tokenizer, audio_inputs) return await self._non_streaming(gen_manager, audio_model, audio_processor, audio_inputs) - async def _non_streaming(self, gen_manager: GenerateManager, audio_model: "PreTrainedModel", audio_processor: "ProcessorMixin", audio_inputs: dict) -> JSONResponse: + async def _non_streaming( + self, + gen_manager: GenerateManager, + audio_model: "PreTrainedModel", + audio_processor: "ProcessorMixin", + audio_inputs: dict, + ) -> JSONResponse: # Audio models have different inputs (input_features) and decode (batch_decode) # than text models, so we use async_submit() directly instead of # generate_non_streaming(). TODO: add generate_audio_non_streaming() when @@ -105,7 +111,13 @@ async def _non_streaming(self, gen_manager: GenerateManager, audio_model: "PreTr text = audio_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return JSONResponse(Transcription(text=text).model_dump(exclude_none=True)) - def _streaming(self, gen_manager: GenerateManager, audio_model: "PreTrainedModel", tokenizer: "ProcessorMixin", audio_inputs: dict) -> StreamingResponse: + def _streaming( + self, + gen_manager: GenerateManager, + audio_model: "PreTrainedModel", + tokenizer: "ProcessorMixin", + audio_inputs: dict, + ) -> StreamingResponse: # Same as _non_streaming — uses submit() directly because audio inputs # differ from text. TODO: add generate_audio_streaming() when more audio # modalities are supported. diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index 0898230c0a96..5599ae450c70 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -69,8 +69,6 @@ class _GenerationCancelled(Exception): """Raised inside ``DirectStreamer.put()`` to abort ``model.generate()``.""" - - # Model-specific tokens that mark the start/end of a tool call block. # TODO: extract these from the chat template at runtime instead of hardcoding. # Qwen/Hermes use /, Mistral uses [TOOL_CALLS], etc. @@ -320,7 +318,13 @@ class DirectStreamer: ``DecodeStream`` (O(1) per token) and pushed as text to an asyncio.Queue. """ - def __init__(self, tokenizer: "tokenizers.Tokenizer", loop: asyncio.AbstractEventLoop, queue: asyncio.Queue, skip_special_tokens: bool = True): + def __init__( + self, + tokenizer: "tokenizers.Tokenizer", + loop: asyncio.AbstractEventLoop, + queue: asyncio.Queue, + skip_special_tokens: bool = True, + ): """ Args: tokenizer: The Rust tokenizer (``tokenizer._tokenizer``). @@ -371,7 +375,14 @@ class CBStreamer: pushes text to the asyncio.Queue. ``end()`` signals the stream is complete. """ - def __init__(self, cb_manager: "ContinuousBatchingManager", request_id: str, tokenizer: "tokenizers.Tokenizer", loop: asyncio.AbstractEventLoop, queue: asyncio.Queue): + def __init__( + self, + cb_manager: "ContinuousBatchingManager", + request_id: str, + tokenizer: "tokenizers.Tokenizer", + loop: asyncio.AbstractEventLoop, + queue: asyncio.Queue, + ): """ Args: cb_manager (`ContinuousBatchingManager`): The CB manager instance. @@ -473,7 +484,14 @@ class BaseGenerateManager(ABC): """ @abstractmethod - def generate_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): + def generate_streaming( + self, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + inputs: dict, + gen_config: "GenerationConfig", + request_id: str | None = None, + ): """Start streaming generation. Args: @@ -490,7 +508,14 @@ def generate_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixi """ @abstractmethod - def generate_non_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): + def generate_non_streaming( + self, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + inputs: dict, + gen_config: "GenerationConfig", + request_id: str | None = None, + ): """Run generation to completion. Args: @@ -515,7 +540,14 @@ class GenerateManager(BaseGenerateManager): def __init__(self): self._thread = InferenceThread() - def generate_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): + def generate_streaming( + self, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + inputs: dict, + gen_config: "GenerationConfig", + request_id: str | None = None, + ): loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() streamer = DirectStreamer(processor._tokenizer, loop, queue, skip_special_tokens=True) @@ -532,7 +564,14 @@ def _run(): self.submit(_run) return queue, streamer - async def generate_non_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): + async def generate_non_streaming( + self, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + inputs: dict, + gen_config: "GenerationConfig", + request_id: str | None = None, + ): sequences = await self.async_submit( model.generate, **inputs, generation_config=gen_config, tokenizer=processor ) @@ -589,7 +628,14 @@ def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig"): self._cb.logit_processor = LogitsProcessorList() self._cb.start() - def generate_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): + def generate_streaming( + self, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + inputs: dict, + gen_config: "GenerationConfig", + request_id: str | None = None, + ): loop = asyncio.get_running_loop() text_queue: asyncio.Queue = asyncio.Queue() @@ -617,7 +663,14 @@ def _on_output(output): self._cb.register_result_handler(request_id, _on_output) return text_queue, streamer - async def generate_non_streaming(self, model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", request_id: str | None = None): + async def generate_non_streaming( + self, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + inputs: dict, + gen_config: "GenerationConfig", + request_id: str | None = None, + ): """Non-streaming CB generation, fully async (no per-request thread). Registers a handler that resolves an asyncio.Future when the result arrives. @@ -651,7 +704,6 @@ def _on_result(result): text = processor.decode(generated_ids, skip_special_tokens=True) return text, input_len, generated_ids - @property def scheduler(self): """The CB scheduler (for testing/monitoring).""" diff --git a/tests/cli/bench_cb_raw.py b/tests/cli/bench_cb_raw.py index e9879044fc2a..15c61de16ec6 100644 --- a/tests/cli/bench_cb_raw.py +++ b/tests/cli/bench_cb_raw.py @@ -24,7 +24,7 @@ def make_prompts(tokenizer, n, target_len=256): filler = "The quick brown fox jumps over the lazy dog. Pack my box with five dozen liquor jugs. " * 100 ids = tokenizer.encode(filler, add_special_tokens=False) - return [ids[:max(10, int(target_len * (0.8 + 0.4 * (i % 5) / 4)))] for i in range(n)] + return [ids[: max(10, int(target_len * (0.8 + 0.4 * (i % 5) / 4)))] for i in range(n)] # --------------------------------------------------------------------------- @@ -112,9 +112,9 @@ def _on_output(output, fut=future): METHODS = { "ns_get_result": ("Non-stream + get_result", lambda mgr, p, m: bench_ns_get_result(mgr, p, m)), - "ns_handler": ("Non-stream + handler", lambda mgr, p, m: asyncio.run(bench_ns_handler(mgr, p, m))), - "s_get_result": ("Stream + get_result", lambda mgr, p, m: bench_s_get_result(mgr, p, m)), - "s_handler": ("Stream + handler", lambda mgr, p, m: asyncio.run(bench_s_handler(mgr, p, m))), + "ns_handler": ("Non-stream + handler", lambda mgr, p, m: asyncio.run(bench_ns_handler(mgr, p, m))), + "s_get_result": ("Stream + get_result", lambda mgr, p, m: bench_s_get_result(mgr, p, m)), + "s_handler": ("Stream + handler", lambda mgr, p, m: asyncio.run(bench_s_handler(mgr, p, m))), } @@ -126,8 +126,7 @@ def main(): parser.add_argument("--prompt-tokens", type=int, default=256) parser.add_argument("--warmup", type=int, default=2) parser.add_argument("--runs", type=int, default=3) - parser.add_argument("--methods", type=str, nargs="+", - default=list(METHODS.keys()), choices=list(METHODS.keys())) + parser.add_argument("--methods", type=str, nargs="+", default=list(METHODS.keys()), choices=list(METHODS.keys())) args = parser.parse_args() print(f"Model: {args.model}") @@ -135,9 +134,15 @@ def main(): print(f"Warmup: {args.warmup} | Runs: {args.runs} | Methods: {args.methods}") sys.stdout.flush() - model = AutoModelForCausalLM.from_pretrained( - args.model, dtype=torch.bfloat16, attn_implementation="flash_attention_3", - ).cuda().eval() + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + dtype=torch.bfloat16, + attn_implementation="flash_attention_3", + ) + .cuda() + .eval() + ) tokenizer = AutoTokenizer.from_pretrained(args.model) all_prompts = make_prompts(tokenizer, max(args.batch), args.prompt_tokens) @@ -163,10 +168,13 @@ def main(): _, fn = METHODS[method_key] # Fresh CB context per method — each gets its own CUDA graph cache with model.continuous_batching_context_manager( - generation_config=gen_config, continuous_batching_config=cb_config, block=True, timeout=5, + generation_config=gen_config, + continuous_batching_config=cb_config, + block=True, + timeout=5, ) as mgr: # Warmup with the same method being tested - warmup_prompts = prompts[:min(200, N)] + warmup_prompts = prompts[: min(200, N)] for _ in range(args.warmup): fn(mgr, warmup_prompts, args.max_new_tokens) # Measured runs @@ -180,21 +188,28 @@ def main(): # Quality check print("\n--- Quality check ---") with model.continuous_batching_context_manager( - generation_config=gen_config, continuous_batching_config=cb_config, block=True, timeout=5, + generation_config=gen_config, + continuous_batching_config=cb_config, + block=True, + timeout=5, ) as mgr: + async def check(): loop = asyncio.get_running_loop() for i in range(3): rid = f"qc_{i}" future = loop.create_future() + def _on_qc(output, fut=future): if not fut.done(): fut.set_result(output) + mgr.register_result_handler(rid, _on_qc) mgr.add_request(all_prompts[i], request_id=rid, max_new_tokens=args.max_new_tokens, streaming=False) r = await future text = tokenizer.decode(r.generated_tokens, skip_special_tokens=True)[:80] print(f" {r.request_id}: {len(r.generated_tokens)} tokens | {text}") + asyncio.run(check()) diff --git a/tests/cli/benchmark_serve.py b/tests/cli/benchmark_serve.py index 17a76316e706..ecd371694d20 100644 --- a/tests/cli/benchmark_serve.py +++ b/tests/cli/benchmark_serve.py @@ -100,7 +100,10 @@ def wait_for_server(base_url: str, timeout: int = 120) -> bool: def streaming_chat_completion( - base_url: str, messages: list, max_tokens: int, seed: int, + base_url: str, + messages: list, + max_tokens: int, + seed: int, do_sample: bool = False, ) -> dict: """Send a streaming chat completion request. Returns {total, ttft, completion_tokens, text}.""" @@ -126,7 +129,7 @@ def streaming_chat_completion( for line in resp.iter_lines(decode_unicode=True): if not line or not line.startswith("data: "): continue - data_str = line[len("data: "):] + data_str = line[len("data: ") :] if data_str.strip() == "[DONE]": break try: @@ -161,7 +164,10 @@ def streaming_chat_completion( def streaming_response( - base_url: str, messages: list, max_tokens: int, seed: int, + base_url: str, + messages: list, + max_tokens: int, + seed: int, do_sample: bool = False, ) -> dict: """Send a streaming responses API request. Returns {total, ttft, completion_tokens, text}.""" @@ -190,7 +196,7 @@ def streaming_response( if not line or not line.startswith("data: "): continue try: - chunk = json.loads(line[len("data: "):]) + chunk = json.loads(line[len("data: ") :]) except json.JSONDecodeError: continue @@ -218,13 +224,15 @@ def streaming_response( def streaming_request( - base_url: str, messages: list, max_tokens: int, seed: int, + base_url: str, + messages: list, + max_tokens: int, + seed: int, do_sample: bool = False, endpoint: str = "chat", ) -> dict: """Dispatch to chat completions or responses API based on endpoint.""" - kw = {"base_url": base_url, "messages": messages, "max_tokens": max_tokens, - "seed": seed, "do_sample": do_sample} + kw = {"base_url": base_url, "messages": messages, "max_tokens": max_tokens, "seed": seed, "do_sample": do_sample} if endpoint == "responses": return streaming_response(**kw) return streaming_chat_completion(**kw) @@ -236,8 +244,14 @@ def streaming_request( def bench_pp( - base_url: str, tokenizer, pp: int, warmup: int, iterations: int, seed: int, - do_sample: bool = False, endpoint: str = "chat", + base_url: str, + tokenizer, + pp: int, + warmup: int, + iterations: int, + seed: int, + do_sample: bool = False, + endpoint: str = "chat", ) -> dict: """Prefill benchmark: large prompt, max_tokens=1. Measures TTFT ≈ pure prefill time.""" prompt = make_prompt(tokenizer, pp) @@ -260,8 +274,15 @@ def bench_pp( def bench_tg( - base_url: str, tokenizer, tg: int, warmup: int, iterations: int, seed: int, - tg_prefill: int = 512, do_sample: bool = False, endpoint: str = "chat", + base_url: str, + tokenizer, + tg: int, + warmup: int, + iterations: int, + seed: int, + tg_prefill: int = 512, + do_sample: bool = False, + endpoint: str = "chat", ) -> dict: """Decode benchmark: generate `tg` tokens after a `tg_prefill`-token prompt.""" prompt = make_prompt(tokenizer, tg_prefill) @@ -321,11 +342,13 @@ def truncate_preview(text: str, width: int = _PREVIEW_WIDTH) -> str: return "" line = text.replace("\n", " ").strip() if len(line) > width: - return line[:width - 1] + "\u2026" + return line[: width - 1] + "\u2026" return line -def print_table(rows: list[dict], title: str = "", reference_texts: dict | None = None, is_reference: bool = False) -> None: +def print_table( + rows: list[dict], title: str = "", reference_texts: dict | None = None, is_reference: bool = False +) -> None: """Print results in a bordered table. Args: @@ -402,8 +425,12 @@ def make_sep(char="-"): def start_server( - model: str, port: int, processor: str | None = None, attn_implementation: str | None = None, - compile: bool = False, continuous_batching: bool = False, + model: str, + port: int, + processor: str | None = None, + attn_implementation: str | None = None, + compile: bool = False, + continuous_batching: bool = False, ): """Start a transformers serve instance. Returns the Serve object.""" from transformers.cli.serve_refactored import Serve @@ -450,29 +477,55 @@ def main(): python benchmark_serve.py --url http://localhost:8000 --processor Qwen/Qwen2.5-7B-Instruct """, ) - parser.add_argument("--model", type=str, action="append", dest="models", - help="Model spec (repeatable). For GGUF: 'gguf_id --processor tokenizer_id'") - parser.add_argument("--processor", type=str, default=None, - help="Processor/tokenizer ID for --url mode (default: derived from model)") + parser.add_argument( + "--model", + type=str, + action="append", + dest="models", + help="Model spec (repeatable). For GGUF: 'gguf_id --processor tokenizer_id'", + ) + parser.add_argument( + "--processor", + type=str, + default=None, + help="Processor/tokenizer ID for --url mode (default: derived from model)", + ) parser.add_argument("--port", type=int, default=8642, help="Server port") - parser.add_argument("--url", type=str, default=None, - help="Connect to existing server (skip start/stop)") + parser.add_argument("--url", type=str, default=None, help="Connect to existing server (skip start/stop)") parser.add_argument("--warmup", type=int, default=1, help="Warmup iterations (minimum 1)") parser.add_argument("--iterations", type=int, default=3, help="Measurement iterations") parser.add_argument("--pp", type=int, nargs="+", default=[256, 1024], help="Prefill token counts") parser.add_argument("--tg", type=int, nargs="+", default=[128, 512], help="Decode token counts") - parser.add_argument("--tg-prefill", type=int, default=_TG_PREFILL_DEFAULT, - help="Prefill size for decode tests (default: 512)") - parser.add_argument("--attn-impl", type=str, nargs="+", default=["sdpa", "eager", "flash_attention_2"], - help="Attention implementations to benchmark (default: sdpa eager flash_attention_2)") - parser.add_argument("--compile", action="store_true", - help="Enable static cache + torch.compile on the server for faster decode") - parser.add_argument("--continuous-batching", action="store_true", - help="Enable continuous batching with paged attention") - parser.add_argument("--mode", type=str, choices=["bench", "chat"], default="bench", - help="bench: greedy (temp=0). chat: sampling (do_sample=True, temp=0.7)") - parser.add_argument("--endpoint", type=str, choices=["chat", "responses"], default="responses", - help="API endpoint to benchmark (default: responses = /v1/responses)") + parser.add_argument( + "--tg-prefill", type=int, default=_TG_PREFILL_DEFAULT, help="Prefill size for decode tests (default: 512)" + ) + parser.add_argument( + "--attn-impl", + type=str, + nargs="+", + default=["sdpa", "eager", "flash_attention_2"], + help="Attention implementations to benchmark (default: sdpa eager flash_attention_2)", + ) + parser.add_argument( + "--compile", action="store_true", help="Enable static cache + torch.compile on the server for faster decode" + ) + parser.add_argument( + "--continuous-batching", action="store_true", help="Enable continuous batching with paged attention" + ) + parser.add_argument( + "--mode", + type=str, + choices=["bench", "chat"], + default="bench", + help="bench: greedy (temp=0). chat: sampling (do_sample=True, temp=0.7)", + ) + parser.add_argument( + "--endpoint", + type=str, + choices=["chat", "responses"], + default="responses", + help="API endpoint to benchmark (default: responses = /v1/responses)", + ) parser.add_argument("--seed", type=int, default=42, help="Torch seed") args = parser.parse_args() @@ -493,10 +546,33 @@ def main(): rows = [] for pp in args.pp: print(f" pp{pp}") - rows.append(bench_pp(base_url, tokenizer, pp, args.warmup, args.iterations, args.seed, do_sample=do_sample, endpoint=endpoint)) + rows.append( + bench_pp( + base_url, + tokenizer, + pp, + args.warmup, + args.iterations, + args.seed, + do_sample=do_sample, + endpoint=endpoint, + ) + ) for tg in args.tg: print(f" tg{tg}") - rows.append(bench_tg(base_url, tokenizer, tg, args.warmup, args.iterations, args.seed, tg_prefill=args.tg_prefill, do_sample=do_sample, endpoint=endpoint)) + rows.append( + bench_tg( + base_url, + tokenizer, + tg, + args.warmup, + args.iterations, + args.seed, + tg_prefill=args.tg_prefill, + do_sample=do_sample, + endpoint=endpoint, + ) + ) print_table(rows) else: @@ -512,8 +588,14 @@ def main(): for attn_impl in args.attn_impl: print(f"\nStarting server for {spec['model']} (attn={attn_impl})...") try: - server = start_server(spec["model"], args.port, spec["processor"], attn_implementation=attn_impl, - compile=args.compile, continuous_batching=args.continuous_batching) + server = start_server( + spec["model"], + args.port, + spec["processor"], + attn_implementation=attn_impl, + compile=args.compile, + continuous_batching=args.continuous_batching, + ) except Exception as e: print(f" ERROR: Failed to start server with attn={attn_impl}: {e}. Skipping.") continue @@ -531,35 +613,68 @@ def main(): if args.compile: warmup_prompt = make_prompt(tokenizer, max(args.pp + [args.tg_prefill])) gen_cfg = {"max_new_tokens": max(args.tg), "do_sample": False, "eos_token_id": -1} - payload = {"messages": [{"role": "user", "content": warmup_prompt}], "stream": False, - "seed": args.seed, "generation_config": json.dumps(gen_cfg)} + payload = { + "messages": [{"role": "user", "content": warmup_prompt}], + "stream": False, + "seed": args.seed, + "generation_config": json.dumps(gen_cfg), + } print(" compile warmup (non-streaming, may take ~30s)...") requests.post(f"{base_url}/v1/chat/completions", json=payload, timeout=120) else: - streaming_request(base_url, [{"role": "user", "content": "hi"}], max_tokens=5, seed=args.seed, endpoint=endpoint) + streaming_request( + base_url, [{"role": "user", "content": "hi"}], max_tokens=5, seed=args.seed, endpoint=endpoint + ) rows = [] for pp in args.pp: print(f" pp{pp}") - rows.append(bench_pp(base_url, tokenizer, pp, args.warmup, args.iterations, args.seed, do_sample=do_sample, endpoint=endpoint)) + rows.append( + bench_pp( + base_url, + tokenizer, + pp, + args.warmup, + args.iterations, + args.seed, + do_sample=do_sample, + endpoint=endpoint, + ) + ) for tg in args.tg: print(f" tg{tg}") - rows.append(bench_tg(base_url, tokenizer, tg, args.warmup, args.iterations, args.seed, tg_prefill=args.tg_prefill, do_sample=do_sample, endpoint=endpoint)) + rows.append( + bench_tg( + base_url, + tokenizer, + tg, + args.warmup, + args.iterations, + args.seed, + tg_prefill=args.tg_prefill, + do_sample=do_sample, + endpoint=endpoint, + ) + ) server.kill_server() # Build reference from first attn impl in greedy mode if not do_sample and reference_texts is None: - reference_texts = { - row["test"]: row["text"] for row in rows if row.get("text") - } + reference_texts = {row["test"]: row["text"] for row in rows if row.get("text")} # Pass reference_texts so the first impl shows "REF" in the ref column - print_table(rows, title=f"{spec['model']} | attn={attn_impl} ({mode_str})", - reference_texts=reference_texts if len(args.attn_impl) > 1 else None, - is_reference=True) + print_table( + rows, + title=f"{spec['model']} | attn={attn_impl} ({mode_str})", + reference_texts=reference_texts if len(args.attn_impl) > 1 else None, + is_reference=True, + ) else: - print_table(rows, title=f"{spec['model']} | attn={attn_impl} ({mode_str})", - reference_texts=reference_texts if not do_sample else None) + print_table( + rows, + title=f"{spec['model']} | attn={attn_impl} ({mode_str})", + reference_texts=reference_texts if not do_sample else None, + ) # Summary: check for mismatches across attn impls (bench mode only) if not do_sample and reference_texts and len(args.attn_impl) > 1: diff --git a/tests/cli/benchmark_serve_load.py b/tests/cli/benchmark_serve_load.py index c25057f8bce1..bedb5bdc1e6c 100644 --- a/tests/cli/benchmark_serve_load.py +++ b/tests/cli/benchmark_serve_load.py @@ -143,7 +143,7 @@ async def send_request( line = line.decode("utf-8").strip() if not line or not line.startswith("data: "): continue - data_str = line[len("data: "):] + data_str = line[len("data: ") :] if data_str.strip() == "[DONE]": break try: @@ -225,13 +225,15 @@ async def run_concurrency_test( connector = aiohttp.TCPConnector(limit=0) timeout = aiohttp.ClientTimeout(total=600) async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: - tasks = [send_request(session, base_url, p, max_new_tokens, seed, endpoint, model=model, stream=stream) for p in prompts] + tasks = [ + send_request(session, base_url, p, max_new_tokens, seed, endpoint, model=model, stream=stream) + for p in prompts + ] results = await asyncio.gather(*tasks) return list(results) - # --------------------------------------------------------------------------- # Metrics # --------------------------------------------------------------------------- @@ -309,12 +311,16 @@ def print_metrics(metrics: dict, label: str): print(f" ERROR: {metrics['error']}") return - print(f" Requests: {metrics['successful']} ok / {metrics['failed']} failed / {metrics['total_requests']} total") + print( + f" Requests: {metrics['successful']} ok / {metrics['failed']} failed / {metrics['total_requests']} total" + ) if metrics.get("error_summary"): for err, count in metrics["error_summary"].most_common(5): print(f" - {count}x: {err}") print(f" Duration: {metrics['duration']:.1f}s") - print(f" Throughput: {metrics['throughput_req_per_sec']:.2f} req/s, {metrics['throughput_tok_per_sec']:.1f} tok/s") + print( + f" Throughput: {metrics['throughput_req_per_sec']:.2f} req/s, {metrics['throughput_tok_per_sec']:.1f} tok/s" + ) print(f" Tokens: {metrics['total_output_tokens']} total output") print() @@ -324,15 +330,17 @@ def print_metrics(metrics: dict, label: str): p = metrics.get(name, {}) if not p: continue - rows.append([ - name.upper().replace("_", " "), - format_ms(p.get("mean")), - format_ms(p.get("median")), - format_ms(p.get("p90")), - format_ms(p.get("p99")), - format_ms(p.get("min")), - format_ms(p.get("max")), - ]) + rows.append( + [ + name.upper().replace("_", " "), + format_ms(p.get("mean")), + format_ms(p.get("median")), + format_ms(p.get("p90")), + format_ms(p.get("p99")), + format_ms(p.get("min")), + format_ms(p.get("max")), + ] + ) if rows: widths = [max(len(h), *(len(r[i]) for r in rows)) for i, h in enumerate(headers)] @@ -359,13 +367,18 @@ def wait_for_server(base_url: str, timeout: int = 120) -> bool: if requests.get(f"{base_url}/health", timeout=2).status_code == 200: return True except Exception: - pass # noqa: S110 + continue time.sleep(1) return False -def start_server(model: str, port: int, compile: bool = False, continuous_batching: bool = False, - attn_implementation: str | None = None): +def start_server( + model: str, + port: int, + compile: bool = False, + continuous_batching: bool = False, + attn_implementation: str | None = None, +): from transformers.cli.serve_refactored import Serve kwargs = {"force_model": model, "port": port, "non_blocking": True, "log_level": "warning"} @@ -393,8 +406,13 @@ async def async_main(args): if not args.url: print(f"Starting server for {args.model}...") - server = start_server(args.model, args.port, compile=args.compile, continuous_batching=args.continuous_batching, - attn_implementation=args.attn_impl) + server = start_server( + args.model, + args.port, + compile=args.compile, + continuous_batching=args.continuous_batching, + attn_implementation=args.attn_impl, + ) if not wait_for_server(base_url): print("ERROR: Server did not start") if server: @@ -407,7 +425,7 @@ async def async_main(args): num_requests = max(args.max_concurrency) prompts = make_prompts(tokenizer, num_requests, args.prompt_tokens, variance=args.prompt_variance) - print(f"Generated {num_requests} prompts (~{args.prompt_tokens} tokens each, ±{int(args.prompt_variance*100)}%)") + print(f"Generated {num_requests} prompts (~{args.prompt_tokens} tokens each, ±{int(args.prompt_variance * 100)}%)") print(f"Max new tokens per request: {args.max_new_tokens}") print(f"Endpoint: /v1/{args.endpoint} ({'streaming' if args.stream else 'non-streaming'})") @@ -418,7 +436,13 @@ async def async_main(args): print(f"Warming up ({args.warmup}x {warmup_size} requests)...") for _ in range(args.warmup): await run_concurrency_test( - base_url, warmup_prompts, args.max_new_tokens, args.seed, args.endpoint, model=args.model, stream=args.stream, + base_url, + warmup_prompts, + args.max_new_tokens, + args.seed, + args.endpoint, + model=args.model, + stream=args.stream, ) print("Warmup done.") @@ -429,7 +453,13 @@ async def async_main(args): print(f"\nRunning: {label}") t0 = time.perf_counter() results = await run_concurrency_test( - base_url, test_prompts, args.max_new_tokens, args.seed, args.endpoint, model=args.model, stream=args.stream, + base_url, + test_prompts, + args.max_new_tokens, + args.seed, + args.endpoint, + model=args.model, + stream=args.stream, ) duration = time.perf_counter() - t0 metrics = compute_metrics(results, duration) @@ -450,19 +480,29 @@ def main(): parser.add_argument("--port", type=int, default=8642) parser.add_argument("--compile", action="store_true", help="Enable --compile on the server") parser.add_argument("--continuous-batching", action="store_true", help="Enable continuous batching on the server") - parser.add_argument("--attn-impl", type=str, default=None, help="Attention implementation (e.g. flash_attention_3)") + parser.add_argument( + "--attn-impl", type=str, default=None, help="Attention implementation (e.g. flash_attention_3)" + ) parser.add_argument("--endpoint", type=str, choices=["chat", "responses"], default="responses") parser.add_argument("--no-stream", action="store_true", help="Use non-streaming requests") # Load parameters - parser.add_argument("--max-concurrency", type=int, nargs="+", default=[1, 2, 4], - help="Number of concurrent requests to send (default: 1 2 4)") + parser.add_argument( + "--max-concurrency", + type=int, + nargs="+", + default=[1, 2, 4], + help="Number of concurrent requests to send (default: 1 2 4)", + ) # Prompt parameters parser.add_argument("--prompt-tokens", type=int, default=256, help="Target prompt length in tokens (default: 256)") - parser.add_argument("--prompt-variance", type=float, default=0.2, - help="Prompt length variance as fraction (default: 0.2 = ±20%%)") - parser.add_argument("--max-new-tokens", type=int, default=128, help="Max tokens to generate per request (default: 128)") + parser.add_argument( + "--prompt-variance", type=float, default=0.2, help="Prompt length variance as fraction (default: 0.2 = ±20%%)" + ) + parser.add_argument( + "--max-new-tokens", type=int, default=128, help="Max tokens to generate per request (default: 128)" + ) parser.add_argument("--warmup", type=int, default=2, help="Warmup requests (default: 2)") parser.add_argument("--seed", type=int, default=42) diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py index 221aa3812d90..cae52ea4595d 100644 --- a/tests/cli/test_serve_refactored.py +++ b/tests/cli/test_serve_refactored.py @@ -675,7 +675,7 @@ def setUpClass(cls): if requests.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: break except Exception: - pass + continue time.sleep(2) cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") @@ -1049,8 +1049,9 @@ def test_valid_request_passes(self): class TestResponseGenerationConfig(unittest.TestCase): def _make_handler(self): from transformers.cli.serving.response import ResponseHandler + from transformers.cli.serving.utils import GenerationState - from transformers.cli.serving.utils import GenerationState; return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) + return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_max_output_tokens(self): from transformers import GenerationConfig @@ -1158,7 +1159,7 @@ def setUpClass(cls): if requests.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: break except Exception: - pass + continue time.sleep(2) cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") @@ -1418,7 +1419,7 @@ def setUpClass(cls): if req.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: break except Exception: - pass + continue time.sleep(2) cls.base_url = f"http://localhost:{cls.PORT}" @@ -1711,7 +1712,7 @@ def setUpClass(cls): if req.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: break except Exception: - pass + continue time.sleep(2) cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") @@ -1785,7 +1786,7 @@ def setUpClass(cls): if req.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: break except Exception: - pass + continue time.sleep(2) cls.base_url = f"http://localhost:{cls.PORT}" @@ -1901,7 +1902,7 @@ def setUpClass(cls): if requests.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: break except Exception: - pass + continue time.sleep(2) cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") @@ -2024,7 +2025,7 @@ def setUpClass(cls): if requests.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: break except Exception: - pass + continue time.sleep(2) cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") From d550b9b90060e21dadca655b07b894645500956d Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 27 Mar 2026 18:15:00 +0000 Subject: [PATCH 43/64] remove annoattion --- src/transformers/cli/serve_refactored.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/cli/serve_refactored.py b/src/transformers/cli/serve_refactored.py index 4936b589d9a0..5b2364c26c33 100644 --- a/src/transformers/cli/serve_refactored.py +++ b/src/transformers/cli/serve_refactored.py @@ -15,8 +15,6 @@ CLI entry point for `transformers serve`. """ -from __future__ import annotations - import asyncio import threading from typing import Annotated From ac0d6a1c1fe99dd8d5a7bdffe5abd107e32671d1 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 30 Mar 2026 12:11:27 +0000 Subject: [PATCH 44/64] review ! --- .../continuous_batching/continuous_api.py | 82 +++++++++---------- src/transformers/utils/metrics.py | 3 +- tests/generation/test_continuous_batching.py | 69 +++++++++++++++- 3 files changed, 109 insertions(+), 45 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index a14c3ca0d895..d3b568de8d54 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -78,6 +78,30 @@ def _get_logits_processor(self, generation_config: GenerationConfig) -> LogitsPr pass +class OutputRouter: + """Dedicated object for routing generation outputs to the right destination. + + When an async handler is registered for a request, the output is forwarded + to that handler via ``call_soon_threadsafe``. Otherwise the output is placed + on the shared ``output_queue``. + """ + + def __init__(self) -> None: + self.output_queue = queue.Queue() + self.result_handlers: dict[str, tuple[callable, asyncio.AbstractEventLoop]] = {} + self._lock = threading.Lock() + + def deliver(self, output: GenerationOutput) -> None: + """Route a single output to its registered handler or the output_queue.""" + with self._lock: + entry = self.result_handlers.get(output.request_id) + if entry is not None: + callback, loop = entry + loop.call_soon_threadsafe(callback, output) + else: + self.output_queue.put(output) + + # Continuous Batch Processor (Internal Logic) @attach_tracer() class ContinuousBatchProcessor: @@ -91,12 +115,11 @@ def __init__( generation_config: GenerationConfig, continuous_batching_config: ContinuousBatchingConfig, input_queue: queue.Queue, - output_queue: queue.Queue, + output_router: OutputRouter, stop_event: threading.Event, model_device: torch.device, model_dtype: torch.dtype, scheduler: Scheduler, - deliver_output: callable, ) -> None: """Initialize the continuous batch processor. @@ -105,23 +128,21 @@ def __init__( config: The model configuration generation_config: The generation configuration input_queue: Queue for incoming requests - output_queue: Queue for outgoing results (used by ``get_result()`` callers) + output_router: An [`OutputRouter`] object that routes outputs to handlers or the output queue. stop_event: Event to signal processing should stop model_device: Device for model inputs/outputs model_dtype: Data type for model inputs/outputs scheduler: The [`Scheduler`] to use - deliver_output: Callback that receives a single ``GenerationOutput``. """ self.cache = cache self.config = config self.cb_config = continuous_batching_config self.input_queue = input_queue - self.output_queue = output_queue + self.output_router = output_router self.stop_event = stop_event self.model_device = model_device self.model_dtype = model_dtype self.scheduler = scheduler - self._deliver_output = deliver_output # Generation-related attributes self.do_sample = getattr(generation_config, "do_sample", True) @@ -181,7 +202,7 @@ def __init__( def __repr__(self) -> str: return ( - f"ContinuousBatchProcessor(input_queue={self.input_queue}, output_queue={self.output_queue}, " + f"ContinuousBatchProcessor(input_queue={self.input_queue}, " f"active_requests={self.scheduler.active_requests}, waiting_requests={self.scheduler.waiting_requests})" + self.inputs_and_outputs.get_model_kwargs().__repr__() ) @@ -254,7 +275,7 @@ def _handle_request_error(self, error: Exception, state: RequestState) -> None: state.generated_tokens = [] self.metrics.record_request_completion(state.created_time, state.request_id) - self.output_queue.put(state.to_generation_output()) + self.output_router.deliver(state.to_generation_output()) # TODO: there should be a way to choose the offloading policy: biggest request, oldest request, etc. # Including a policy to not allow offloading and crashing the generation @@ -366,7 +387,7 @@ def update_batch(self) -> None: self.scheduler.finish_request(state.request_id) self.scheduler.block_new_requests = False if state.streaming or state.status == RequestStatus.FINISHED: - self._deliver_output(state.to_generation_output()) + self.output_router.deliver(state.to_generation_output()) # Otherwise, the request is still prefilling, but the prefill has been split elif state.status == RequestStatus.PREFILLING: self.cache.mark_shareable_blocks_as_complete(state, future_state.complete_blocks) @@ -596,9 +617,7 @@ def __init__( self.input_queue = queue.Queue(maxsize=self.continuous_batching_config.max_queue_size) self._has_new_requests = threading.Event() - self.output_queue = queue.Queue() - self._result_handlers: dict[str, tuple[callable, asyncio.AbstractEventLoop]] = {} - self._result_handlers_lock = threading.Lock() + self.output_router = OutputRouter() self.stop_event = threading.Event() self.batch_processor: ContinuousBatchProcessor | None = None self._generation_thread = None @@ -626,30 +645,6 @@ def __init__( self.kv_padding_interval_size = self.continuous_batching_config.kv_padding_interval_size self.max_cached_graphs = self.continuous_batching_config.max_cached_graphs - # Log probability generation is not supported yet (TODO) - if self.log_prob_generation: - raise NotImplementedError("log_prob_generation is not supported yet") - - def _register_handler(self, request_id: str, callback: callable, loop: asyncio.AbstractEventLoop) -> None: - """Register a result handler for a request.""" - with self._result_handlers_lock: - self._result_handlers[request_id] = (callback, loop) - - def _unregister_handler(self, request_id: str) -> None: - """Remove a result handler for a request.""" - with self._result_handlers_lock: - self._result_handlers.pop(request_id, None) - - def _deliver_output(self, output: GenerationOutput) -> None: - """Route a single output to its registered handler or the output_queue.""" - with self._result_handlers_lock: - entry = self._result_handlers.get(output.request_id) - if entry is not None: - callback, loop = entry - loop.call_soon_threadsafe(callback, output) - else: - self.output_queue.put(output) - @traced def start(self) -> None: """Start the background generation thread.""" @@ -807,12 +802,12 @@ def get_result(self, request_id: str | None = None, timeout: float | None = None Returns: Optional[GenerationOutput]: The result data or None if timeout. """ - if self._generation_thread is None and self.output_queue.empty(): + if self._generation_thread is None and self.output_router.output_queue.empty(): return None try: - result = self.output_queue.get(block=True, timeout=timeout) + result = self.output_router.output_queue.get(block=True, timeout=timeout) if request_id is not None and result.request_id != request_id: - self.output_queue.put(result) + self.output_router.output_queue.put(result) return None return result except queue.Empty: @@ -856,9 +851,11 @@ def register_result_handler(self, request_id: str, callback: callable) -> None: def _auto_cleanup(result): callback(result) if result.is_finished(): - self._unregister_handler(request_id) + with self.output_router._lock: + self.output_router.result_handlers.pop(request_id, None) - self._register_handler(request_id, _auto_cleanup, loop) + with self.output_router._lock: + self.output_router.result_handlers[request_id] = (_auto_cleanup, loop) @traced def _generation_step(self) -> None: @@ -892,12 +889,11 @@ def _create_batch_processor(self) -> ContinuousBatchProcessor: generation_config=self.generation_config, continuous_batching_config=self.continuous_batching_config, input_queue=self.input_queue, - output_queue=self.output_queue, + output_router=self.output_router, stop_event=self.stop_event, model_device=self.model.device, model_dtype=self.model.dtype, scheduler=scheduler(paged_attention_cache), - deliver_output=self._deliver_output, ) return batch_processor diff --git a/src/transformers/utils/metrics.py b/src/transformers/utils/metrics.py index 998595dd94ba..87c457d7f8ef 100644 --- a/src/transformers/utils/metrics.py +++ b/src/transformers/utils/metrics.py @@ -297,7 +297,8 @@ def record_batch_metrics(self, requests_in_batch: list) -> None: decode_tokens = 0 prefill_tokens = 0 - for state in requests_in_batch: + for request in requests_in_batch: + state = request.state if state.status == RequestStatus.DECODING: decode_tokens += 1 elif state.status in [RequestStatus.PREFILLING, RequestStatus.PREFILLING_SPLIT]: diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index e89d4eaefd61..89857a220111 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -39,8 +39,9 @@ group_layers_by_attn_type, ) from transformers.generation.continuous_batching.cache_manager import FullAttentionCacheAllocator -from transformers.generation.continuous_batching.continuous_api import ContinuousBatchProcessor +from transformers.generation.continuous_batching.continuous_api import ContinuousBatchProcessor, OutputRouter from transformers.generation.continuous_batching.input_outputs import build_attention_mask +from transformers.generation.continuous_batching.requests import GenerationOutput, RequestStatus from transformers.testing_utils import ( Expectations, require_deterministic_for_xpu, @@ -420,6 +421,30 @@ def test_continuous_batching_no_accelerators(self) -> None: self.assertIsNotNone(output.generated_tokens) self.assertGreater(len(output.generated_tokens), 0) + def test_output_router_deliver_to_queue(self): + """Test that OutputRouter.deliver places outputs on the queue when no handler is registered.""" + router = OutputRouter() + output = GenerationOutput(request_id="req_0", status=RequestStatus.FINISHED) + router.deliver(output) + result = router.output_queue.get_nowait() + self.assertEqual(result.request_id, "req_0") + self.assertTrue(router.output_queue.empty()) + + def test_output_router_deliver_to_handler(self): + """Test that OutputRouter.deliver forwards to a registered handler instead of the queue.""" + router = OutputRouter() + received = [] + loop = unittest.mock.Mock() + + with router._lock: + router.result_handlers["req_0"] = (lambda out: received.append(out), loop) + + output = GenerationOutput(request_id="req_0", status=RequestStatus.DECODING) + router.deliver(output) + + loop.call_soon_threadsafe.assert_called_once() + self.assertTrue(router.output_queue.empty()) + @require_torch_accelerator class ContinuousBatchingWithAcceleratorTest(unittest.TestCase): @@ -806,6 +831,48 @@ def test_non_streaming_request(self) -> None: def test_streaming_and_non_streaming_requests_can_alternate(self) -> None: self._test_streaming_or_not_request(with_streaming=True, with_non_streaming=True) + def test_register_result_handler(self) -> None: + """Test that register_result_handler receives streaming outputs through the OutputRouter.""" + import asyncio + + model_id = "Qwen/Qwen2.5-0.5B-Instruct" + max_new_tokens = 3 + + tokenizer, model = get_tokenizer_and_model(model_id, "sdpa", torch_device) + manager = model.init_continuous_batching() + manager.logit_processor = LogitsProcessorList() + manager.start() + + user_messages = ["What is the Transformers library known for?"] + inputs = get_generation_inputs(user_messages, tokenizer, for_continuous_batching=True)[0] + + async def collect_results(): + results = [] + future = asyncio.get_running_loop().create_future() + + def on_result(output): + results.append(output) + if output.is_finished(): + future.set_result(True) + + request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=True) + manager.register_result_handler(request_id, on_result) + + await asyncio.wait_for(future, timeout=30) + return results + + results = asyncio.run(collect_results()) + + # Streaming via handler: incremental token count, same as request_id_iter + self.assertEqual(len(results[0].generated_tokens), 1) + self.assertEqual(len(results[1].generated_tokens), 2) + self.assertEqual(len(results[2].generated_tokens), 3) + self.assertTrue(results[-1].is_finished()) + # Queue should be empty — everything went through the handler + self.assertTrue(manager.output_router.output_queue.empty()) + + manager.stop(block=True) + # -----------------------------------------Misc. tests----------------------------------------- # # Various tests that don't fit into the other categories # # --------------------------------------------------------------------------------------------- # From 9d52002b37f24ac42c1dd4d432285b99184b62aa Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 30 Mar 2026 12:13:45 +0000 Subject: [PATCH 45/64] style --- .../generation/continuous_batching/continuous_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index d3b568de8d54..e2d985d60906 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -17,7 +17,7 @@ import queue import threading from abc import abstractmethod -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager, nullcontext from math import ceil from time import perf_counter @@ -88,7 +88,7 @@ class OutputRouter: def __init__(self) -> None: self.output_queue = queue.Queue() - self.result_handlers: dict[str, tuple[callable, asyncio.AbstractEventLoop]] = {} + self.result_handlers: dict[str, tuple[Callable, asyncio.AbstractEventLoop]] = {} self._lock = threading.Lock() def deliver(self, output: GenerationOutput) -> None: @@ -833,7 +833,7 @@ def request_id_iter(self, request_id: str) -> Generator[GenerationOutput]: if result.is_finished(): return - def register_result_handler(self, request_id: str, callback: callable) -> None: + def register_result_handler(self, request_id: str, callback: Callable) -> None: """Register a callback for result delivery (streaming or non-streaming). The callback is invoked on the event loop via ``call_soon_threadsafe`` From c48aec3ac9648736fa62522463b1cfa840678db7 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 30 Mar 2026 15:36:48 +0000 Subject: [PATCH 46/64] much cleaner --- src/transformers/cli/serve_refactored.py | 46 ++----- .../cli/serving/chat_completion.py | 55 +++----- src/transformers/cli/serving/model_manager.py | 98 +++++++++----- src/transformers/cli/serving/response.py | 58 ++++---- src/transformers/cli/serving/transcription.py | 70 +++++++--- src/transformers/cli/serving/utils.py | 128 ++++++++++-------- 6 files changed, 234 insertions(+), 221 deletions(-) diff --git a/src/transformers/cli/serve_refactored.py b/src/transformers/cli/serve_refactored.py index 5b2364c26c33..20bac9b5d8f2 100644 --- a/src/transformers/cli/serve_refactored.py +++ b/src/transformers/cli/serve_refactored.py @@ -42,39 +42,30 @@ class Serve: def __init__( self, - # TODO: maybe rename it to model ? force_model: Annotated[str | None, typer.Argument(help="Model to preload and use for all requests.")] = None, # Model options - device: Annotated[str, typer.Option(help="Device for inference; defaults to 'auto'.")] = "auto", - dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto", + continuous_batching: Annotated[ + bool, typer.Option(help="Enable continuous batching with paged attention for higher throughput.") + ] = False, attn_implementation: Annotated[ str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).") ] = None, + compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False, quantization: Annotated[ str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.") ] = None, + device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto", + dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto", trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False, model_timeout: Annotated[ - int, typer.Option(help="Seconds before idle model is unloaded. Ignored when model is set.") + int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.") ] = 300, # Server options host: Annotated[str, typer.Option(help="Server listen address.")] = "localhost", port: Annotated[int, typer.Option(help="Server listen port.")] = 8000, enable_cors: Annotated[bool, typer.Option(help="Enable permissive CORS.")] = False, - log_level: Annotated[str, typer.Option(help="Logging level (e.g. 'info', 'warning').")] = "info", + log_level: Annotated[str, typer.Option(help="Logging level (e.g. 'info', 'warning').")] = "warning", default_seed: Annotated[int | None, typer.Option(help="Default torch seed.")] = None, - compile: Annotated[ - bool, - typer.Option( - help="Enable static cache + torch.compile for faster decode (~2.6x). First request triggers compilation (~30s)." - ), - ] = False, - continuous_batching: Annotated[ - bool, - typer.Option( - help="Enable continuous batching with paged attention for higher throughput on concurrent requests." - ), - ] = False, non_blocking: Annotated[ bool, typer.Option(hidden=True, help="Run server in a background thread. Used by tests.") ] = False, @@ -99,11 +90,7 @@ def __init__( transformers_logger = logging.get_logger("transformers") transformers_logger.setLevel(logging.log_levels[log_level.lower()]) - # Preloaded models should never be auto-unloaded - if force_model: - model_timeout = -1 - - model_manager = ModelManager( + self._model_manager = ModelManager( device=device, dtype=dtype, trust_remote_code=trust_remote_code, @@ -112,27 +99,22 @@ def __init__( model_timeout=model_timeout, force_model=force_model, ) - self._model_manager = model_manager - self._generation_state = GenerationState(continuous_batching=continuous_batching) + self._generation_state = GenerationState(continuous_batching=continuous_batching, compile=compile) self._chat_handler = ChatCompletionHandler( - model_manager=model_manager, + model_manager=self._model_manager, generation_state=self._generation_state, - force_model=force_model, - compile=compile, ) self._response_handler = ResponseHandler( - model_manager=model_manager, + model_manager=self._model_manager, generation_state=self._generation_state, - force_model=force_model, - compile=compile, ) - self._transcription_handler = TranscriptionHandler(model_manager, self._generation_state) + self._transcription_handler = TranscriptionHandler(self._model_manager, self._generation_state) app = build_server( - model_manager, + self._model_manager, self._chat_handler, response_handler=self._response_handler, transcription_handler=self._transcription_handler, diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 730f1e04c283..c933b23980b9 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -22,7 +22,6 @@ from collections.abc import AsyncGenerator from typing import TYPE_CHECKING -from fastapi import HTTPException from fastapi.responses import JSONResponse, StreamingResponse from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import Choice @@ -85,6 +84,9 @@ class ChatCompletionHandler(BaseHandler): Supports both streaming (SSE) and non-streaming (JSON) responses. """ + _valid_params_class = TransformersCompletionCreateParamsStreaming + _unused_fields = UNUSED_CHAT_COMPLETION_FIELDS + # ----- entry point ----- async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: @@ -99,37 +101,22 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse """ self._validate_request(body) - messages = body["messages"] - - # HACK: tiny-agents sends requests ending with assistant message — skip - if messages and messages[-1]["role"] == "assistant": - return JSONResponse({}, status_code=200) - model_id, model, processor = self._resolve_model(body) modality = self.model_manager.get_model_modality(model, processor=processor) use_cb = self.generation_state.use_continuous_batching(model, modality) gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb) - processor_inputs = self.get_processor_inputs_from_messages(messages, modality) - - if use_cb: - # CB handles device placement internally — don't create tensors or move - # anything to CUDA here. Pass plain token ID lists only. - inputs = processor.apply_chat_template( - processor_inputs, - add_generation_prompt=True, - tools=body.get("tools"), - return_dict=True, - tokenize=True, - ) - else: - inputs = processor.apply_chat_template( - processor_inputs, - add_generation_prompt=True, - tools=body.get("tools"), - return_tensors="pt", - return_dict=True, - tokenize=True, - ).to(model.device) + processor_inputs = self.get_processor_inputs_from_messages(body["messages"], modality) + + inputs = processor.apply_chat_template( + processor_inputs, + add_generation_prompt=True, + tools=body.get("tools"), + return_tensors=None if use_cb else "pt", + return_dict=True, + tokenize=True, + ) + if not use_cb: + inputs = inputs.to(model.device) gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) # TODO: remove when CB supports per-request generation config @@ -182,6 +169,7 @@ def _streaming( """Stream tokens as SSE via DirectStreamer.""" queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id) input_ids = inputs["input_ids"] + # CB returns plain lists, regular path returns tensors input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] parser = ToolCallParser(tool_format) if tool_format else None @@ -304,17 +292,6 @@ async def _non_streaming( # ----- helpers ----- - def _validate_request(self, body: dict) -> None: - """Validate a chat completion request. Raises HTTPException if invalid.""" - input_keys = set(body.keys()) - unexpected = input_keys - TransformersCompletionCreateParamsStreaming.__mutable_keys__ - if unexpected: - raise HTTPException(status_code=422, detail=f"Unexpected keys in the request: {unexpected}") - - unused = input_keys & UNUSED_CHAT_COMPLETION_FIELDS - if unused: - raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") - def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): """Apply Chat Completions params (``max_tokens``, ``frequency_penalty``, ``logit_bias``, ``stop``) on top of the base generation config.""" diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 397650f935f2..fc97dd992a4f 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -27,7 +27,7 @@ from tqdm import tqdm import transformers -from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizerBase +from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase from ...utils import logging from .utils import Modality, make_progress_tqdm_class, reset_torch_cache @@ -41,12 +41,13 @@ class TimedModel: - """Wraps a model + processor and auto-deletes them after a period of inactivity. + """Wraps a model + processor and auto-unloads them after a period of inactivity. Args: model: The loaded model. - timeout_seconds: Seconds of inactivity before auto-deletion. Use -1 to disable. + timeout_seconds: Seconds of inactivity before auto-unload. Use -1 to disable. processor: The associated processor or tokenizer. + on_unload: Optional callback invoked after the model is unloaded from memory. """ def __init__( @@ -54,11 +55,13 @@ def __init__( model: "PreTrainedModel", timeout_seconds: int, processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None, + on_unload: "Callable | None" = None, ): self.model = model self._name_or_path = str(model.name_or_path) self.processor = processor self.timeout_seconds = timeout_seconds + self._on_unload = on_unload self._timer = threading.Timer(self.timeout_seconds, self._timeout_reached) self._timer.start() @@ -78,16 +81,14 @@ def delete_model(self) -> None: gc.collect() reset_torch_cache() self._timer.cancel() + if self._on_unload is not None: + self._on_unload() def _timeout_reached(self) -> None: if self.timeout_seconds > 0: self.delete_model() logger.info(f"{self._name_or_path} was removed from memory after {self.timeout_seconds}s of inactivity") - def is_deleted(self) -> bool: - """Check if the model has been deleted (by timeout or manually).""" - return not hasattr(self, "model") or self.model is None - class ModelManager: """Loads, caches, and manages the lifecycle of models. @@ -115,13 +116,6 @@ def __init__( model_timeout: int = 300, force_model: str | None = None, ): - self.device = device - self.dtype = dtype - self.trust_remote_code = trust_remote_code - self.attn_implementation = attn_implementation - self.quantization = quantization - self.model_timeout = model_timeout - self.loaded_models: dict[str, TimedModel] = {} # Thread-safety for concurrent load_model_and_processor calls @@ -132,9 +126,50 @@ def __init__( self._loading_subscribers: dict[str, list[asyncio.Queue[str | None]]] = {} self._loading_tasks: dict[str, asyncio.Task] = {} + # Convert numeric device strings (e.g. "0") to int so device_map works correctly + self.device = int(device) if device.isdigit() else device + self.dtype = self._resolve_dtype(dtype) + self.trust_remote_code = trust_remote_code + self.attn_implementation = attn_implementation + self.quantization = quantization + self.model_timeout = model_timeout + self.force_model = force_model + + self._validate_args() + + # Preloaded models should never be auto-unloaded + if force_model is not None: + self.model_timeout = -1 + + # Preload the forced model after all state is initialized if force_model is not None: self.load_model_and_processor(self.process_model_name(force_model)) + @staticmethod + def _resolve_dtype(dtype: str | None): + import torch + + if dtype in ("auto", None): + return dtype + resolved = getattr(torch, dtype, None) + if not isinstance(resolved, torch.dtype): + raise ValueError( + f"Unsupported dtype: '{dtype}'. Must be 'auto' or a valid torch dtype (e.g. 'float16', 'bfloat16')." + ) + return resolved + + def _validate_args(self): + if self.quantization is not None and self.quantization not in ("bnb-4bit", "bnb-8bit"): + raise ValueError( + f"Unsupported quantization method: '{self.quantization}'. Must be 'bnb-4bit' or 'bnb-8bit'." + ) + VALID_ATTN_IMPLEMENTATIONS = {"eager", "sdpa", "flash_attention_2", "flash_attention_3", "flex_attention"} + if self.attn_implementation is not None and self.attn_implementation not in VALID_ATTN_IMPLEMENTATIONS: + raise ValueError( + f"Unsupported attention implementation: '{self.attn_implementation}'. " + f"Must be one of {VALID_ATTN_IMPLEMENTATIONS}." + ) + @staticmethod def process_model_name(model_id: str) -> str: """Canonicalize to `'model_id@revision'` format. Defaults to `@main`.""" @@ -155,7 +190,7 @@ def get_quantization_config(self) -> BitsAndBytesConfig | None: return None def _load_processor(self, model_id_and_revision: str) -> "ProcessorMixin | PreTrainedTokenizerFast": - """Load a processor, trying AutoProcessor first then AutoTokenizer. + """Load a processor for the given model. Args: model_id_and_revision: Model ID in ``'model_id@revision'`` format. @@ -163,15 +198,7 @@ def _load_processor(self, model_id_and_revision: str) -> "ProcessorMixin | PreTr from transformers import AutoProcessor model_id, revision = model_id_and_revision.split("@", 1) - try: - return AutoProcessor.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code) - except OSError: - try: - return AutoTokenizer.from_pretrained( - model_id, revision=revision, trust_remote_code=self.trust_remote_code - ) - except OSError: - raise OSError(f"Failed to load processor for {model_id} with AutoProcessor and AutoTokenizer.") + return AutoProcessor.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code) def _load_model( self, model_id_and_revision: str, tqdm_class: type | None = None, progress_callback: Callable | None = None @@ -181,22 +208,19 @@ def _load_model( Args: model_id_and_revision (`str`): Model ID in ``'model_id@revision'`` format. tqdm_class (*optional*): tqdm subclass for progress bars during ``from_pretrained``. - progress_callback (`callable`, *optional*): Called with progress dicts during loading. + progress_callback (`Callable`, *optional*): Called with progress dicts during loading. Returns: `PreTrainedModel`: The loaded model. """ - import torch - from transformers import AutoConfig model_id, revision = model_id_and_revision.split("@", 1) - dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype) model_kwargs = { "revision": revision, "attn_implementation": self.attn_implementation, - "dtype": dtype, + "dtype": self.dtype, "device_map": self.device, "trust_remote_code": self.trust_remote_code, "quantization_config": self.get_quantization_config(), @@ -207,6 +231,7 @@ def _load_model( progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "config"}) config = AutoConfig.from_pretrained(model_id, **model_kwargs) architecture = getattr(transformers, config.architectures[0]) + return architecture.from_pretrained(model_id, **model_kwargs) def load_model_and_processor( @@ -228,10 +253,7 @@ def load_model_and_processor( lock = self._model_locks.setdefault(model_id_and_revision, threading.Lock()) with lock: - if ( - model_id_and_revision not in self.loaded_models - or self.loaded_models[model_id_and_revision].is_deleted() - ): + if model_id_and_revision not in self.loaded_models: if progress_callback is not None: progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "processor"}) processor = self._load_processor(model_id_and_revision) @@ -239,7 +261,10 @@ def load_model_and_processor( model_id_and_revision, tqdm_class=tqdm_class, progress_callback=progress_callback ) self.loaded_models[model_id_and_revision] = TimedModel( - model, timeout_seconds=self.model_timeout, processor=processor + model, + timeout_seconds=self.model_timeout, + processor=processor, + on_unload=lambda key=model_id_and_revision: self.loaded_models.pop(key, None), ) if progress_callback is not None: progress_callback({"status": "ready", "model": model_id_and_revision, "cached": False}) @@ -269,7 +294,7 @@ async def load_model_streaming(self, model_id_and_revision: str): queue: asyncio.Queue[str | None] = asyncio.Queue() # Case 1: already cached - if mid in self.loaded_models and not self.loaded_models[mid].is_deleted(): + if mid in self.loaded_models: self.loaded_models[mid].reset_timer() yield f"data: {json.dumps({'status': 'ready', 'model': mid, 'cached': True})}\n\n" return @@ -339,9 +364,8 @@ def _send_sentinel(): def shutdown(self) -> None: """Delete all loaded models and free resources.""" - for timed in self.loaded_models.values(): + for timed in list(self.loaded_models.values()): timed.delete_model() - self.loaded_models.clear() @staticmethod def get_model_modality( diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index b24aabbc9cb0..5a290bc082c7 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -43,6 +43,7 @@ ResponseTextDeltaEvent, ResponseTextDoneEvent, ) +from openai.types.responses.response_create_params import ResponseCreateParamsStreaming from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage from ...utils import logging @@ -61,6 +62,11 @@ logger = logging.get_logger(__name__) + +class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): + generation_config: str + + UNUSED_RESPONSE_FIELDS = { "background", "include", @@ -81,6 +87,9 @@ class ResponseHandler(BaseHandler): """Handler for the ``/v1/responses`` endpoint.""" + _valid_params_class = TransformersResponseCreateParamsStreaming + _unused_fields = UNUSED_RESPONSE_FIELDS + # ----- entry point ----- async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: @@ -106,25 +115,16 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse messages = self._input_to_messages(body) processor_inputs = self.get_processor_inputs_from_messages(messages, modality) - if use_cb: - # CB handles device placement internally — don't create tensors or move - # anything to CUDA here. Pass plain token ID lists only. - inputs = processor.apply_chat_template( - processor_inputs, - add_generation_prompt=True, - tools=body.get("tools"), - return_dict=True, - tokenize=True, - ) - else: - inputs = processor.apply_chat_template( - processor_inputs, - add_generation_prompt=True, - tools=body.get("tools"), - return_tensors="pt", - return_dict=True, - tokenize=True, - ).to(model.device) + inputs = processor.apply_chat_template( + processor_inputs, + add_generation_prompt=True, + tools=body.get("tools"), + return_tensors=None if use_cb else "pt", + return_dict=True, + tokenize=True, + ) + if not use_cb: + inputs = inputs.to(model.device) gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) # TODO: remove when CB supports per-request generation config @@ -213,6 +213,7 @@ def _streaming( """Generate a streaming Responses API reply (SSE) using DirectStreamer.""" queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id) input_ids = inputs["input_ids"] + # CB returns plain lists, regular path returns tensors input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] parser = ToolCallParser(tool_format) if tool_format else None @@ -458,14 +459,9 @@ async def _non_streaming( model, processor, inputs, gen_config, request_id=request_id ) - created_at = time.time() - resp_id = f"resp_{request_id}" - msg_id = f"msg_{request_id}" - output_tokens = len(generated_ids) - output_items = [ ResponseOutputMessage( - id=msg_id, + id=f"msg_{request_id}", type="message", status="completed", role="assistant", @@ -491,10 +487,10 @@ async def _non_streaming( ) ) - usage = compute_usage(input_len, output_tokens) + usage = compute_usage(input_len, len(generated_ids)) response = Response( - id=resp_id, - created_at=created_at, + id=f"resp_{request_id}", + created_at=time.time(), status="completed", model=model_id, output=output_items, @@ -509,12 +505,6 @@ async def _non_streaming( # ----- helpers ----- - def _validate_request(self, body: dict) -> None: - """Validate a Responses API request.""" - unused = set(body.keys()) & UNUSED_RESPONSE_FIELDS - if unused: - raise HTTPException(status_code=422, detail=f"Unsupported fields in the request: {unused}") - def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): """Apply Responses API params (``max_output_tokens``) on top of the base generation config.""" generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py index 889b3eaaba5f..0d4a8adaa027 100644 --- a/src/transformers/cli/serving/transcription.py +++ b/src/transformers/cli/serving/transcription.py @@ -18,8 +18,9 @@ import io from typing import TYPE_CHECKING -from fastapi import Request +from fastapi import HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse +from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase from ...utils import logging from .model_manager import ModelManager @@ -33,6 +34,21 @@ logger = logging.get_logger(__name__) +class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False): + stream: bool + + +UNUSED_TRANSCRIPTION_FIELDS = { + "chunking_strategy", + "include", + "language", + "prompt", + "response_format", + "temperature", + "timestamp_granularities", +} + + class TranscriptionHandler: """Handler for ``POST /v1/audio/transcriptions``. @@ -51,7 +67,16 @@ def __init__(self, model_manager: ModelManager, generation_state: GenerationStat generation_state (`GenerationState`): Shared generation state for thread safety. """ self.model_manager = model_manager - self._generation_state = generation_state + self.generation_state = generation_state + + def _validate_request(self, form_keys: set[str]) -> None: + """Validate transcription request fields.""" + unexpected = form_keys - TransformersTranscriptionCreateParams.__mutable_keys__ + if unexpected: + raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}") + unused = form_keys & UNUSED_TRANSCRIPTION_FIELDS + if unused: + logger.warning_once(f"Ignoring unsupported fields in the request: {unused}") async def handle_request(self, request: Request) -> JSONResponse | StreamingResponse: """Parse multipart form, run transcription, return result. @@ -68,31 +93,35 @@ async def handle_request(self, request: Request) -> JSONResponse | StreamingResp if not is_librosa_available(): raise ImportError("Missing librosa dependency for audio transcription. Install with `pip install librosa`") - import librosa - async with request.form() as form: + self._validate_request(set(form.keys())) file_bytes = await form["file"].read() model = form["model"] stream = str(form.get("stream", "false")).lower() == "true" model_id_and_revision = self.model_manager.process_model_name(model) audio_model, audio_processor = self.model_manager.load_model_and_processor(model_id_and_revision) + gen_manager = self.generation_state.get_manager(model_id_and_revision) + audio_inputs = self._prepare_audio_inputs(file_bytes, audio_processor, audio_model) + + if stream: + return self._streaming(gen_manager, audio_model, audio_processor, audio_inputs) + return await self._non_streaming(gen_manager, audio_model, audio_processor, audio_inputs) - # Read audio with librosa at the model's expected sampling rate - model_sampling_rate = audio_processor.feature_extractor.sampling_rate - audio_array, _ = librosa.load(io.BytesIO(file_bytes), sr=model_sampling_rate, mono=True) - audio_inputs = audio_processor(audio_array, sampling_rate=model_sampling_rate, return_tensors="pt").to( + @staticmethod + def _prepare_audio_inputs( + file_bytes: bytes, audio_processor: "ProcessorMixin", audio_model: "PreTrainedModel" + ) -> dict: + """Load audio bytes and convert to model inputs.""" + import librosa + + sampling_rate = audio_processor.feature_extractor.sampling_rate + audio_array, _ = librosa.load(io.BytesIO(file_bytes), sr=sampling_rate, mono=True) + audio_inputs = audio_processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").to( audio_model.device ) audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype) - - # Transcription uses the per-model InferenceThread (no CB for audio). - gen_manager = self._generation_state.get_manager(model_id_and_revision, use_cb=False) - tokenizer = audio_processor.tokenizer if hasattr(audio_processor, "tokenizer") else audio_processor - - if stream: - return self._streaming(gen_manager, audio_model, tokenizer, audio_inputs) - return await self._non_streaming(gen_manager, audio_model, audio_processor, audio_inputs) + return audio_inputs async def _non_streaming( self, @@ -103,8 +132,7 @@ async def _non_streaming( ) -> JSONResponse: # Audio models have different inputs (input_features) and decode (batch_decode) # than text models, so we use async_submit() directly instead of - # generate_non_streaming(). TODO: add generate_audio_non_streaming() when - # more audio modalities are supported. + # generate_non_streaming() from openai.types.audio import Transcription generated_ids = await gen_manager.async_submit(audio_model.generate, **audio_inputs) @@ -115,14 +143,14 @@ def _streaming( self, gen_manager: GenerateManager, audio_model: "PreTrainedModel", - tokenizer: "ProcessorMixin", + audio_processor: "ProcessorMixin", audio_inputs: dict, ) -> StreamingResponse: # Same as _non_streaming — uses submit() directly because audio inputs - # differ from text. TODO: add generate_audio_streaming() when more audio - # modalities are supported. + # differ from text. import asyncio + tokenizer = audio_processor.tokenizer if hasattr(audio_processor, "tokenizer") else audio_processor loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() streamer = DirectStreamer(tokenizer._tokenizer, loop, queue, skip_special_tokens=True) diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index 5599ae450c70..d46ca261cce3 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -41,6 +41,7 @@ from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin from transformers.generation.continuous_batching.continuous_api import ContinuousBatchingManager from transformers.generation.continuous_batching.requests import GenerationOutput + from transformers.generation.continuous_batching.scheduler import Scheduler from .model_manager import ModelManager @@ -130,7 +131,7 @@ def __init__(self, tool_format: dict): # Sentinel: token was consumed by the parser but produced no output. CONSUMED = object() - def feed(self, text: str): + def feed(self, text: str) -> object | dict | None: """Feed a text chunk (streaming). Returns: @@ -210,18 +211,20 @@ def __init__(self, enqueue: Callable, model_id: str): self.bars: dict[int, tuple[int, int | None]] = {} self.last_emitted_current: int | None = None - def register(self, bar_id: int, total: int | None): + def register(self, bar_id: int, total: int | None) -> None: + """Register a new download bar with its total byte count.""" self.bars[bar_id] = (0, total) self._emit() - def update(self, bar_id: int, current: int, total: int | None): + def update(self, bar_id: int, current: int, total: int | None) -> None: + """Update a bar's current byte count and emit aggregate progress.""" self.bars[bar_id] = (current, total) self._emit() - def close(self, bar_id: int): + def close(self, bar_id: int) -> None: pass # keep the bar so totals remain correct - def _emit(self): + def _emit(self) -> None: agg_current = sum(c for c, _ in self.bars.values()) if agg_current == self.last_emitted_current: return @@ -238,7 +241,7 @@ def _emit(self): ) -def make_progress_tqdm_class(callback: Callable, model_id: str): +def make_progress_tqdm_class(callback: Callable, model_id: str) -> type: """Create a tqdm subclass that routes progress to a callback. Bars with ``unit="B"`` are download bars — aggregated via ``DownloadAggregator``. @@ -422,12 +425,14 @@ def cancel(self) -> None: def set_torch_seed(seed: int) -> None: + """Set the PyTorch random seed for reproducible generation.""" import torch torch.manual_seed(seed) def reset_torch_cache() -> None: + """Empty the CUDA cache if a GPU is available.""" import torch if torch.cuda.is_available(): @@ -446,7 +451,7 @@ def __init__(self): self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() - def _run(self): + def _run(self) -> None: while True: fn, args, kwargs, future, loop = self._queue.get() try: @@ -490,8 +495,8 @@ def generate_streaming( processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", - request_id: str | None = None, - ): + request_id: str, + ) -> tuple[asyncio.Queue, "DirectStreamer | CBStreamer"]: """Start streaming generation. Args: @@ -499,7 +504,7 @@ def generate_streaming( processor: The processor or tokenizer for decoding. inputs (`dict`): Tokenized inputs (tensors for sequential, lists for CB). gen_config (`GenerationConfig`): Generation parameters. - request_id (`str`, *optional*): Unique request identifier. + request_id (`str`): Unique request identifier. Returns: `tuple[asyncio.Queue, DirectStreamer | CBStreamer]`: A ``(queue, streamer)`` pair @@ -514,8 +519,8 @@ def generate_non_streaming( processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", - request_id: str | None = None, - ): + request_id: str, + ) -> tuple[str, int, list[int]]: """Run generation to completion. Args: @@ -523,14 +528,14 @@ def generate_non_streaming( processor: The processor or tokenizer for decoding. inputs (`dict`): Tokenized inputs (tensors for sequential, lists for CB). gen_config (`GenerationConfig`): Generation parameters. - request_id (`str`, *optional*): Unique request identifier. + request_id (`str`): Unique request identifier. Returns: `tuple[str, int, list[int]]`: ``(text, input_len, generated_ids)``. """ @abstractmethod - def stop(self): + def stop(self) -> None: """Stop the generation manager and free resources.""" @@ -546,14 +551,15 @@ def generate_streaming( processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", - request_id: str | None = None, - ): + request_id: str, + ) -> tuple[asyncio.Queue, DirectStreamer]: + """Start streaming generation via ``model.generate()`` on the inference thread.""" loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() streamer = DirectStreamer(processor._tokenizer, loop, queue, skip_special_tokens=True) gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor} - def _run(): + def _run() -> None: try: model.generate(**gen_kwargs) except _GenerationCancelled: @@ -570,8 +576,9 @@ async def generate_non_streaming( processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", - request_id: str | None = None, - ): + request_id: str, + ) -> tuple[str, int, "torch.Tensor"]: + """Run generation to completion via ``model.generate()`` on the inference thread.""" sequences = await self.async_submit( model.generate, **inputs, generation_config=gen_config, tokenizer=processor ) @@ -580,15 +587,15 @@ async def generate_non_streaming( text = processor.decode(generated_ids, skip_special_tokens=True) return text, input_len, generated_ids - def submit(self, fn, *args, **kwargs): + def submit(self, fn: Callable, *args, **kwargs) -> Future: """Submit a callable to the inference thread. Returns a blocking Future.""" return self._thread.submit(fn, *args, **kwargs) - def async_submit(self, fn, *args, **kwargs): + def async_submit(self, fn: Callable, *args, **kwargs) -> asyncio.Future: """Submit a callable to the inference thread. Returns an awaitable asyncio.Future.""" return self._thread.async_submit(fn, *args, **kwargs) - def stop(self): + def stop(self) -> None: pass # inference thread is a daemon @@ -610,7 +617,7 @@ class CBGenerateManager(BaseGenerateManager): def __init__(self): self._cb = None - def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig"): + def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> None: """Initialize the CB manager on first call with the request's generation config. .. todo:: Remove when CB supports per-request generation config. @@ -621,11 +628,8 @@ def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig"): """ if self._cb is not None: return - from transformers import LogitsProcessorList self._cb = model.init_continuous_batching(generation_config=gen_config) - # TODO: logits processors should be fixed in CB and correctly applied - self._cb.logit_processor = LogitsProcessorList() self._cb.start() def generate_streaming( @@ -634,8 +638,9 @@ def generate_streaming( processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", - request_id: str | None = None, - ): + request_id: str, + ) -> tuple[asyncio.Queue, CBStreamer]: + """Start streaming CB generation. Registers a per-request output handler.""" loop = asyncio.get_running_loop() text_queue: asyncio.Queue = asyncio.Queue() @@ -643,15 +648,14 @@ def generate_streaming( request_id = self._cb.add_request( input_ids, request_id=request_id, - max_new_tokens=gen_config.max_new_tokens, streaming=True, + max_new_tokens=gen_config.max_new_tokens, eos_token_id=gen_config.eos_token_id, ) streamer = CBStreamer(self._cb, request_id, processor._tokenizer, loop, text_queue) - # Register a direct callback: the dispatcher calls this on the event loop - # with each GenerationOutput. This decodes tokens and pushes text straight - # to the SSE text_queue — no intermediate async queue or coroutine needed. + # Register a direct callback: the dispatcher calls this on the event loop with each GenerationOutput. + # This decodes tokens and pushes text straight to the SSE text_queue def _on_output(output): try: streamer.put(output) @@ -669,18 +673,13 @@ async def generate_non_streaming( processor: "ProcessorMixin | PreTrainedTokenizerFast", inputs: dict, gen_config: "GenerationConfig", - request_id: str | None = None, - ): - """Non-streaming CB generation, fully async (no per-request thread). - - Registers a handler that resolves an asyncio.Future when the result arrives. - No per-request queue, no polling — just one ``await`` per request. - """ + request_id: str, + ) -> tuple[str, int, list[int]]: + """Run non-streaming CB generation. Registers a handler that resolves an asyncio.Future on completion.""" input_ids = inputs["input_ids"] input_len = len(input_ids) # Register future BEFORE add_request to avoid race with fast completion - request_id = request_id or f"cb_{id(inputs)}" loop = asyncio.get_running_loop() future = loop.create_future() @@ -705,11 +704,11 @@ def _on_result(result): return text, input_len, generated_ids @property - def scheduler(self): + def scheduler(self) -> "Scheduler": """The CB scheduler (for testing/monitoring).""" return self._cb.batch_processor.scheduler - def stop(self): + def stop(self) -> None: if self._cb is not None: self._cb.stop(block=True, timeout=2) @@ -728,8 +727,9 @@ class GenerationState: sequential ``model.generate()`` calls. """ - def __init__(self, continuous_batching: bool = False): + def __init__(self, continuous_batching: bool = False, compile: bool = False): self._continuous_batching = continuous_batching + self._compile = compile self._generate_managers: dict[str, GenerateManager] = {} self._cb_manager: CBGenerateManager | None = None self._cb_model_id: str | None = None @@ -754,7 +754,7 @@ def use_continuous_batching(self, model: "PreTrainedModel", modality: Modality) ) return can - def get_manager(self, model_id: str, use_cb: bool) -> BaseGenerateManager: + def get_manager(self, model_id: str, use_cb: bool = False) -> BaseGenerateManager: """Return a per-model generation manager, lazily created on first request. Args: @@ -777,7 +777,7 @@ def get_manager(self, model_id: str, use_cb: bool) -> BaseGenerateManager: self._generate_managers[model_id] = GenerateManager() return self._generate_managers[model_id] - def shutdown(self): + def shutdown(self) -> None: """Stop any active generation managers.""" if self._cb_manager is not None: self._cb_manager.stop() @@ -795,23 +795,31 @@ class BaseHandler: Handles model loading, caching, and lifecycle. generation_state (`GenerationState`): Shared state managing per-model generation managers. - force_model (`str`, *optional*): - If set, override the ``model`` field in every request with this model ID. - compile (`bool`, *optional*, defaults to `False`): - Enable ``torch.compile`` with static cache for faster decode. """ + _valid_params_class: type | None = None + _unused_fields: set[str] = set() + def __init__( self, model_manager: "ModelManager", generation_state: GenerationState, - force_model: str | None = None, - compile: bool = False, ): self.model_manager = model_manager self.generation_state = generation_state - self.force_model = force_model - self._compile = compile + + def _validate_request(self, body: dict) -> None: + """Validate request fields against the handler's params class and unused fields.""" + from fastapi import HTTPException + + input_keys = set(body.keys()) + if self._valid_params_class is not None: + unexpected = input_keys - self._valid_params_class.__mutable_keys__ + if unexpected: + raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}") + unused = input_keys & self._unused_fields + if unused: + logger.warning_once(f"Ignoring unsupported fields in the request: {unused}") @staticmethod def chunk_to_sse(chunk: "str | pydantic.BaseModel") -> str: @@ -820,20 +828,22 @@ def chunk_to_sse(chunk: "str | pydantic.BaseModel") -> str: return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n" return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" - def _resolve_model(self, body: dict): + def _resolve_model(self, body: dict) -> tuple[str, "PreTrainedModel", "ProcessorMixin | PreTrainedTokenizerFast"]: """Apply force_model, load model + processor. Returns ``(model_id, model, processor)``. """ - if self.force_model is not None: - body["model"] = self.force_model + if self.model_manager.force_model is not None: + body["model"] = self.model_manager.force_model model_id = self.model_manager.process_model_name(body["model"]) model, processor = self.model_manager.load_model_and_processor(model_id) return model_id, model, processor - def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): + def _build_generation_config( + self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False + ) -> "GenerationConfig": """Build a GenerationConfig from shared params (temperature, top_p, seed, generation_config JSON). Subclasses should call ``super()._build_generation_config(...)`` then apply @@ -870,13 +880,15 @@ def _build_generation_config(self, body: dict, model_generation_config: "Generat set_torch_seed(body["seed"]) # --compile flag: use static cache + torch.compile for faster decode - if self._compile and generation_config.cache_implementation is None: + if self.generation_state._compile and generation_config.cache_implementation is None: generation_config.cache_implementation = "static" # CB manages its own paged KV cache if use_cb: generation_config.use_cache = False + # TODO: add prefix caching for the non-CB path (reuse KV cache across multi-turn conversations) + return generation_config @staticmethod From b13dacc12264cbbf668c6fee2680f7e5a7f599e7 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 30 Mar 2026 15:43:44 +0000 Subject: [PATCH 47/64] renamed --- src/transformers/cli/serve.py | 2280 +--------------------- src/transformers/cli/serve_refactored.py | 165 -- 2 files changed, 76 insertions(+), 2369 deletions(-) delete mode 100644 src/transformers/cli/serve_refactored.py diff --git a/src/transformers/cli/serve.py b/src/transformers/cli/serve.py index 874008bfb209..20bac9b5d8f2 100644 --- a/src/transformers/cli/serve.py +++ b/src/transformers/cli/serve.py @@ -11,2283 +11,155 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +CLI entry point for `transformers serve`. +""" + import asyncio -import base64 -import copy -import enum -import gc -import io -import json -import re -import tempfile import threading -import time -import uuid -from collections.abc import Callable, Generator, Iterable -from contextlib import asynccontextmanager -from functools import lru_cache -from io import BytesIO -from threading import Thread -from typing import TYPE_CHECKING, Annotated, Optional, TypedDict, Union +from typing import Annotated import typer -from huggingface_hub import scan_cache_dir -from tokenizers.decoders import DecodeStream -from tqdm import tqdm -from tqdm.auto import tqdm as base_tqdm -import transformers -from transformers import AutoTokenizer, BitsAndBytesConfig, GenerationConfig, PreTrainedTokenizerBase -from transformers.generation import ( - LogitsProcessorList, - TextIteratorStreamer, -) from transformers.utils import logging from transformers.utils.import_utils import ( is_fastapi_available, - is_librosa_available, is_openai_available, is_pydantic_available, is_uvicorn_available, - is_vision_available, ) +from .serving.utils import set_torch_seed -if TYPE_CHECKING: - from transformers import ( - PreTrainedModel, - PreTrainedTokenizerFast, - ProcessorMixin, - ) - - from ..generation.continuous_batching import ContinuousBatchingManager - - -if is_librosa_available(): - import librosa - -if is_vision_available(): - from PIL import Image serve_dependencies_available = ( is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available() ) -if serve_dependencies_available: - import uvicorn - from fastapi import FastAPI, HTTPException - from fastapi.middleware.cors import CORSMiddleware - from fastapi.responses import JSONResponse, StreamingResponse - from openai.types.audio.transcription import Transcription - from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase - from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam - from openai.types.chat.chat_completion import Choice - from openai.types.chat.chat_completion_chunk import ( - ChatCompletionChunk, - ChoiceDelta, - ChoiceDeltaToolCall, - ChoiceDeltaToolCallFunction, - ) - from openai.types.chat.chat_completion_chunk import ( - Choice as ChoiceChunk, - ) - from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming - from openai.types.responses import ( - Response, - ResponseCompletedEvent, - ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent, - ResponseCreatedEvent, - ResponseError, - ResponseErrorEvent, - ResponseFailedEvent, - ResponseInProgressEvent, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponseOutputMessage, - ResponseOutputText, - ResponseTextDeltaEvent, - ResponseTextDoneEvent, - ) - from openai.types.responses.response_create_params import ResponseCreateParamsStreaming - from pydantic import BaseModel, TypeAdapter, ValidationError - - # Expand OpenAI's request input types with an optional `generation_config` field - class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): - """ - OpenAI's ResponseCreateParamsStreaming with an additional field for the generation config (as a json string). - """ - - generation_config: str - - class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): - """ - OpenAI's CompletionCreateParamsStreaming with additional fields for the generation config (as a json string) and passing the request_id - """ - - generation_config: str - - class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False): - """ - OpenAI's TranscriptionCreateParamsBase with an additional field for the generation config (as a json string). - """ - - file: bytes # Overwritten -- pydantic isn't happy with `typing.IO[bytes]`, present in the original type - generation_config: str - stream: bool = False - - # Contrarily to OpenAI's output types, input types are `TypedDict`, which don't have built-in validation. - response_validator = TypeAdapter(TransformersResponseCreateParamsStreaming) - completion_validator = TypeAdapter(TransformersCompletionCreateParamsStreaming) - transcription_validator = TypeAdapter(TransformersTranscriptionCreateParams) - - # Define request fields that are not yet used in `transformers serve`. Receiving these fields will raise an - # HTTPException. - UNUSED_RESPONSE_FIELDS = { - "background", - "include", - "max_tool_calls", - "previous_response_id", - "prompt", - "reasoning", - "service_tier", - "store", - "text", - "tool_choice", - "top_logprobs", - "truncation", - "user", - } - - UNUSED_CHAT_COMPLETION_FIELDS = { - "audio", - "function_call", - "functions", - "logprobs", - "max_completion_tokens", - "metadata", - "modalities", - "n", - "parallel_tool_calls", - "prediction", - "presence_penalty", - "reasoning_effort", - "response_format", - "service_tier", - "stop", - "store", - "stream_options", - "tool_choice", - "top_logprobs", - "user", - "web_search_options", - } - UNUSED_TRANSCRIPTION_FIELDS = { - "chunking_strategy", - "include", - "language", - "prompt", - "response_format", - "timestamp_granularities", - } - logger = logging.get_logger(__name__) -# Possible tokens that indicate the start/end of a tool call -# TODO (joao, matt): streamline tool token detection logic -_TOOL_CALL_TOKENS = { - "qwen": { - "start": "", - "end": "", - }, -} -_MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys()) - -X_REQUEST_ID = "x-request-id" - - -def set_torch_seed(_seed): - import torch - - torch.manual_seed(_seed) - - -def reset_torch_cache(): - import torch - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def torch_ones_like(_input_tensor): - import torch - - return torch.ones_like(_input_tensor) - - -class Modality(enum.Enum): - LLM = "LLM" - VLM = "VLM" - STT = "STT" - TTS = "TTS" - - -def create_generation_config_from_req( - req: dict, - model_generation_config: GenerationConfig, - **kwargs, -) -> GenerationConfig: - """ - Creates a generation config from the parameters of the request. If a generation config is passed in the request, - it will be used as a baseline for parameterization. Otherwise, we will use the model's default generation config. - Other parameters in the request will be applied on top of the baseline. - - Args: - req (`dict`): - The request which may optionally contain generation parameters. - model_generation_config (`GenerationConfig`): - The model's default generation config. - kwargs (`dict`): - Additional parameters to set in the generation config. - - Returns: - The prepared `GenerationConfig` object. - """ - # If there is a generation config in the request, it is a json string serialization from a `GenerationConfig` - # object. For simplicity, flags set here take precedence over all other flags. - if req.get("generation_config") is not None: - generation_config = GenerationConfig(**json.loads(req["generation_config"])) - else: - generation_config = copy.deepcopy(model_generation_config) - - non_standard_kwargs = generation_config.update(**kwargs) - # Set extra kwargs that are not in the `GenerationConfig` class (e.g. continuous batching flags) - for k, v in non_standard_kwargs.items(): - if v is not None: - setattr(generation_config, k, v) - - # Response-specific parameters - if req.get("max_output_tokens") is not None: - generation_config.max_new_tokens = int(req["max_output_tokens"]) - - # Completion-specific parameters - if req.get("max_tokens") is not None: - generation_config.max_new_tokens = int(req["max_tokens"]) - if req.get("frequency_penalty") is not None: - generation_config.repetition_penalty = float(req["frequency_penalty"]) - if req.get("logit_bias") is not None: - generation_config.sequence_bias = req["logit_bias"] - if req.get("stop") is not None: - generation_config.stop_strings = req["stop"] - if req.get("temperature") is not None: - generation_config.temperature = float(req["temperature"]) - if float(req["temperature"]) == 0.0: - generation_config.do_sample = False - if req.get("top_p") is not None: - generation_config.top_p = float(req["top_p"]) - if req.get("seed") is not None: - set_torch_seed(req["seed"]) - - return generation_config - - -class DownloadAggregator: - """Aggregates byte-progress across multiple concurrent download tqdm bars into a single SSE stream. - - huggingface_hub opens one tqdm bar per file shard. This class tracks them all and emits - a single aggregate ``{"stage": "download", "progress": {"current": ..., "total": ...}}`` - event whenever any bar updates. - """ - - def __init__(self, enqueue: Callable[[dict], None], model_id_and_revision: str): - self.enqueue = enqueue - self.model = model_id_and_revision - self.bars: dict[int, tuple[int, int | None]] = {} # id -> (current, total) - self.last_emitted_current: int | None = None - - def register(self, bar_id: int, total: int | None): - self.bars[bar_id] = (0, total) - self._emit() - - def update(self, bar_id: int, current: int, total: int | None): - self.bars[bar_id] = (current, total) - self._emit() - - def close(self, bar_id: int): - pass # keep the bar in _bars so totals remain correct - - def _emit(self): - agg_current = sum(c for c, _ in self.bars.values()) - if agg_current == self.last_emitted_current: - return - self.last_emitted_current = agg_current - totals = [t for _, t in self.bars.values() if t is not None] - agg_total = sum(totals) if totals else None - self.enqueue( - { - "status": "loading", - "model": self.model, - "stage": "download", - "progress": {"current": agg_current, "total": agg_total}, - } - ) - - -class DownloadProxy: - """ - Leverages the DownloadAggregator in order to have a coherent tqdm wrapper. - """ - - def __init__(self, wrapped_bar, download_aggregator): - self.wrapped_bar = wrapped_bar - self.bar_id = id(wrapped_bar) - self.download_aggregator = download_aggregator - - self.n = 0 - self.total = wrapped_bar.total - - def __getattr__(self, name): - return getattr(self.wrapped_bar, name) - - def update(self, n=1): - if n is None: - n = 1 - - self.n += n - self.download_aggregator.update(self.bar_id, self.n, getattr(self.wrapped_bar, "total", self.total)) - - return self.wrapped_bar.update(n) - - def close(self): - self.download_aggregator.close(self.bar_id) - return self.wrapped_bar.close() - - def __enter__(self): - self.wrapped_bar.__enter__() - return self - - def __exit__(self, *a): - return self.wrapped_bar.__exit__(*a) - - def __iter__(self): - count = 0 - for item in self.wrapped_bar: - count += 1 - self.download_aggregator.update(self.bar_id, count, getattr(self.wrapped_bar, "total", self.total)) - - yield item - - -class WeightsProxy: - """ - Wraps the weight-loading tqdm bar to have finer control over how we emit them to the clinet. - """ - - def __init__(self, wrapped_bar, _callable, model_id_and_revision): - self.wrapped_bar = wrapped_bar - self.last_emitted = -1 - self.callable = _callable - self.model_id_and_revision = model_id_and_revision - - self.n = 0 - self.total = wrapped_bar.total - - def __getattr__(self, name): - return getattr(self.wrapped_bar, name) - - def _emit(self): - if self.n == self.last_emitted: - return - self.last_emitted = self.n - - self.callable( - { - "status": "loading", - "model": self.model_id_and_revision, - "stage": "weights", - "progress": { - "current": self.n, - "total": getattr(self.wrapped_bar, "total", self.total), - }, - } - ) - - def update(self, n=1): - if n is None: - n = 1 - - self.n += n - self._emit() - - return self.wrapped_bar.update(n) - - def close(self): - return self.wrapped_bar.close() - - def __enter__(self): - self.wrapped_bar.__enter__() - return self - - def __exit__(self, *a): - return self.wrapped_bar.__exit__(*a) - - def __iter__(self): - for item in self.wrapped_bar: - self.n += 1 - self._emit() - - yield item - - -def set_tqdm_class(callback, mid): - download_aggregator = DownloadAggregator(callback, mid) - - class ProgressTqdm(base_tqdm): - """tqdm subclass that routes progress to the correct SSE stage. - - Bars with ``unit="B"`` are download bars (one per file shard) — they are - aggregated into a single ``download`` stage stream via ``_DownloadAggregator``. - All other bars are weight-loading bars emitted as ``weights`` stage events. - """ - - def __init__(self, *args, **kwargs): - self.sse_unit = kwargs.get("unit") or "it" - kwargs["disable"] = True # suppress server-side display - super().__init__(*args, **kwargs) - self.n = 0 - self.last_emitted = -1 - if self.sse_unit == "B": - self._bar_id = id(self) - download_aggregator.register(self._bar_id, self.total) - - def update(self, n=1): - if n is None: - n = 1 - - self.n += n - - if self.sse_unit == "B": - download_aggregator.update(self._bar_id, self.n, self.total) - elif self.n != self.last_emitted: - self.last_emitted = self.n - - callback( - { - "status": "loading", - "model": mid, - "stage": "weights", - "progress": {"current": self.n, "total": self.total}, - } - ) - - def close(self): - if self.sse_unit == "B": - download_aggregator.close(self._bar_id) - super().close() - - return ProgressTqdm - - -class ToolState: - """Lightweight class to keep track of the tool call state.""" - - def __init__(self): - self.reset() - - def reset(self): - """Reset the tool call state (assumes we're outside a tool call).""" - self.inside_tool_call = False - self.has_tool_name_defined = False - self.arg_nesting_level = 0 - self.buffer = "" - - -class TimedModel: - """ - A class that holds a PreTrainedModel instance and its associated processor. - Automatically deletes the instances after a specified timeout. - """ - - def __init__( - self, - model: "PreTrainedModel", - timeout_seconds: int, - processor: Union["ProcessorMixin", "PreTrainedTokenizerFast"] | None = None, - ): - self.model = model - self._name_or_path = str(model.name_or_path) - self.processor = processor - self.timeout_seconds = timeout_seconds - self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached) - self._timer.start() - - def reset_timer(self): - """Reset the timer for the deletion of the instances.""" - self._timer.cancel() - self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached) - self._timer.start() - - def delete_model(self): - """Delete the wrapped model and processor and clean up resources.""" - if hasattr(self, "model") and self.model is not None: - del self.model - del self.processor - self.model = None - self.processor = None - gc.collect() - - # Clear CUDA cache if available - reset_torch_cache() - - # XXX: in case we manually delete the model, like on server shutdown - self._timer.cancel() - - def timeout_reached(self): - if self.timeout_seconds > 0: - self.delete_model() - logger.warning( - f"{self._name_or_path} was removed from memory after {self.timeout_seconds} seconds of inactivity" - ) - - def is_deleted(self): - """Check if the instances have been deleted.""" - return not hasattr(self, "model") or self.model is None - class Serve: - # Defining a class to help with internal state but in practice it's just a method to call - # TODO: refactor into a proper module with helpers + 1 main method def __init__( self, + force_model: Annotated[str | None, typer.Argument(help="Model to preload and use for all requests.")] = None, + # Model options continuous_batching: Annotated[ - bool | None, - typer.Option(help="Whether to use continuous batching for chat completions."), - ] = None, - device: Annotated[ - str, - typer.Option( - help="Device to use for inference; will default to `auto` and place the model on an accelerator if available." - ), - ] = "auto", - dtype: Annotated[ - str | None, - typer.Option( - help="Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype will be automatically derived from the model's weights." - ), - ] = "auto", - trust_remote_code: Annotated[ - bool, typer.Option(help="Whether to trust remote code when loading a model.") + bool, typer.Option(help="Enable continuous batching with paged attention for higher throughput.") ] = False, attn_implementation: Annotated[ - str | None, - typer.Option(help="Which attention implementation to use."), + str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).") ] = None, + compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False, quantization: Annotated[ - str | None, - typer.Option(help="Which quantization method to use. choices: 'bnb-4bit', 'bnb-8bit'"), + str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.") ] = None, - host: Annotated[str, typer.Option(help="Interface the server will listen to.")] = "localhost", - port: Annotated[int, typer.Option(help="Port the server will listen to.")] = 8000, + device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto", + dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto", + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False, model_timeout: Annotated[ - int, typer.Option(help="Time in seconds after which a model will be removed from memory.") + int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.") ] = 300, - log_level: Annotated[ - str, typer.Option(help="Logging level as a string. Example: 'info' or 'warning'.") - ] = "warning", - default_seed: Annotated[ - int | None, typer.Option(help="The default seed for torch, should be an integer.") - ] = None, - enable_cors: Annotated[ - bool, - typer.Option( - help="Whether to enable CORS. Some apps that make requests from external domains (e.g. Cursor) require CORS to be enabled." - ), - ] = False, - input_validation: Annotated[bool, typer.Option(help="Whether to turn on strict input validation.")] = False, - force_model: Annotated[ - str | None, - typer.Option( - help="Name of the model to be forced on all requests. This is useful for testing Apps that don't allow changing models in the request." - ), - ] = None, + # Server options + host: Annotated[str, typer.Option(help="Server listen address.")] = "localhost", + port: Annotated[int, typer.Option(help="Server listen port.")] = 8000, + enable_cors: Annotated[bool, typer.Option(help="Enable permissive CORS.")] = False, + log_level: Annotated[str, typer.Option(help="Logging level (e.g. 'info', 'warning').")] = "warning", + default_seed: Annotated[int | None, typer.Option(help="Default torch seed.")] = None, non_blocking: Annotated[ - bool, typer.Option(hidden=True, help="Whether to run the server in a separate thread.") + bool, typer.Option(hidden=True, help="Run server in a background thread. Used by tests.") ] = False, ) -> None: if not serve_dependencies_available: - raise ImportError( - "Missing dependencies for the serving CLI. Please install with `pip install transformers[serving]`" - ) + raise ImportError("Missing dependencies for serving. Install with `pip install transformers[serving]`") - # Save input arguments - self.attn_implementation = attn_implementation - self.continuous_batching = continuous_batching - self.device = device - self.quantization = quantization + import uvicorn - self.dtype = dtype - self.trust_remote_code = trust_remote_code - self.host = host - self.port = port - self.model_timeout = model_timeout - self.log_level = log_level - self.default_seed = default_seed - self.enable_cors = enable_cors - self.input_validation = input_validation - self.force_model = force_model - self.non_blocking = non_blocking + from .serving.chat_completion import ChatCompletionHandler + from .serving.model_manager import ModelManager + from .serving.response import ResponseHandler + from .serving.server import build_server + from .serving.transcription import TranscriptionHandler + from .serving.utils import GenerationState # Seed if default_seed is not None: set_torch_seed(default_seed) - # Set up logging + # Logging transformers_logger = logging.get_logger("transformers") transformers_logger.setLevel(logging.log_levels[log_level.lower()]) - cb_logger = logging.get_logger("transformers.generation.continuous_batching") - cb_logger.setLevel(logging.log_levels[log_level.lower()]) - - # Internal state: - # 1. Tracks models in memory, to prevent reloading the model unnecessarily - self.loaded_models: dict[str, TimedModel] = {} - self.running_continuous_batching_manager: ContinuousBatchingManager | None = None - - # Tracks in-flight model loads for fan-out to multiple SSE subscribers - self.loading_subscribers: dict[str, list[asyncio.Queue[str | None]]] = {} - self.loading_tasks: dict[str, asyncio.Task] = {} - - # Thread-safety for load_model_and_processor / load_audio_model_and_processor - self.model_locks: dict[str, threading.Lock] = {} - self.model_locks_guard = threading.Lock() - - # 2. preserves information about the last call and last KV cache, to determine whether we can reuse the KV - # cache and avoid re-running prefill - self.last_messages = None - self.last_kv_cache = None - self.last_model = None - - if self.model_timeout is None: - self.model_timeout = -1 if self.force_model else 300 - - if self.force_model: - model_id_and_revision = self.process_model_name(self.force_model) - self.last_model = model_id_and_revision - self.load_model_and_processor(model_id_and_revision) - - @asynccontextmanager - async def lifespan(app: FastAPI): - yield - self.reset_loaded_models() - - app = FastAPI(lifespan=lifespan) - - # Some apps that make requests from external domains (e.g. Cursor) require CORS to be enabled. However, for - # security purposes, it's disabled by default - if self.enable_cors: - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - logger.warning_once( - "CORS allow origin is set to `*`. This is not recommended for production environments." - ) - - from fastapi import Request - - @app.post("/v1/chat/completions") - def chat_completion(request: Request, body: dict): - self.validate_chat_completion_request(request=body) - - logger.warning(f"[Request received] Model: {body['model']}, CB: {self.continuous_batching}") - - if self.continuous_batching: - return self.continuous_batching_chat_completion(body, request.state.request_id) - else: - return self.generate_chat_completion(body) - - @app.post("/v1/responses") - def responses(request: dict): - self.validate_response_request(request=request) - - logger.warning(f"[Request received] Model: {request['model']}, CB: {self.continuous_batching}") - - # Support non-streaming mode when `stream=false` is provided - stream = request.get("stream", True) - if not stream: - response_obj = self.generate_response_non_streaming(request) - return JSONResponse(response_obj) - - output = self.generate_response(request) - return StreamingResponse(output, media_type="text/event-stream") - - @app.post("/v1/audio/transcriptions") - async def audio_transcriptions(request: Request): - # Parses the multipart/form-data request into the request format used by other endpoints - async with request.form() as form: - parsed_request = TransformersTranscriptionCreateParams( - file=await form["file"].read(), - model=form["model"], - # TODO: add other fields - ) - logger.debug( - f"Received file: {form['file'].filename}; MIME type: {form['file'].content_type}; " - f"size: {form['file'].size / 1024:.2f} KiB" - ) - self.validate_transcription_request(request=parsed_request) - - output = self.generate_transcription(parsed_request) - return StreamingResponse(output, media_type="text/event-stream") - - @app.options("/v1/models") - @app.get("/v1/models") - def get_all_models(): - return JSONResponse({"object": "list", "data": self.get_gen_models()}) - - @app.get("/health") - def healthcheck(): - return JSONResponse({"status": "ok"}) - - @app.post("/load_model") - async def load_model(body: dict): - model = body.get("model") - if model is None: - raise HTTPException(status_code=422, detail="Missing `model` field in the request body.") - - model_id_and_revision = self.process_model_name(model) - - async def event_publisher(): - queue: asyncio.Queue[str | None] = asyncio.Queue() - model_loaded = model_id_and_revision in self.loaded_models - model_not_deleted = model_loaded and not self.loaded_models[model_id_and_revision].is_deleted() - - # Case 1: Model already cached - if model_loaded and model_not_deleted: - self.loaded_models[model_id_and_revision].reset_timer() - yield f"data: {json.dumps({'status': 'ready', 'model': model_id_and_revision, 'cached': True})}\n\n" - return - - # Case 2: Load already happening, join existing subscribers - if model_id_and_revision in self.loading_tasks: - self.loading_subscribers[model_id_and_revision].append(queue) - - while True: - item = await queue.get() - if item is None: - break - yield item - return - - # Case 3: First request — start the load task - self.loading_subscribers[model_id_and_revision] = [queue] - loop = asyncio.get_running_loop() - - def enqueue(payload: dict): - msg = f"data: {json.dumps(payload)}\n\n" - - def broadcast(): - for q in self.loading_subscribers.get(model_id_and_revision, []): - q.put_nowait(msg) - - loop.call_soon_threadsafe(broadcast) - - download_aggregator = DownloadAggregator(enqueue, model_id_and_revision) - - def streaming_tqdm_hook(factory, args, kwargs): - bar = factory(*args, **kwargs) - desc = kwargs.get("desc") or getattr(bar, "desc", None) or "" - unit = kwargs.get("unit") or getattr(bar, "unit", "it") - total = getattr(bar, "total", kwargs.get("total")) - - # Only forward byte-progress bars (file downloads) — skip noise like "Fetching N files" - if unit == "B": - download_aggregator.register(id(bar), total) - return DownloadProxy(bar, download_aggregator=download_aggregator) - - # "Loading weights" bar — emit as stage: "weights" with item-count progress - if desc == "Loading weights": - return WeightsProxy(bar, enqueue, model_id_and_revision) - - # Other non-byte, non-weights bars (noise) — return unmodified - return bar - - async def run_load(): - previous_hook = logging.set_tqdm_hook(streaming_tqdm_hook) - try: - await asyncio.to_thread(self.load_model_and_processor, model_id_and_revision, enqueue) - except Exception as e: - logger.error(f"Failed to load {model_id_and_revision}: {e}", exc_info=True) - enqueue({"status": "error", "model": model_id_and_revision, "message": str(e)}) - finally: - logging.set_tqdm_hook(previous_hook) - - def _send_sentinel(): - for q in self.loading_subscribers.pop(model_id_and_revision, []): - q.put_nowait(None) - self.loading_tasks.pop(model_id_and_revision, None) - - loop.call_soon_threadsafe(_send_sentinel) + self._model_manager = ModelManager( + device=device, + dtype=dtype, + trust_remote_code=trust_remote_code, + attn_implementation=attn_implementation, + quantization=quantization, + model_timeout=model_timeout, + force_model=force_model, + ) + self._generation_state = GenerationState(continuous_batching=continuous_batching, compile=compile) - self.loading_tasks[model_id_and_revision] = asyncio.create_task(run_load()) + self._chat_handler = ChatCompletionHandler( + model_manager=self._model_manager, + generation_state=self._generation_state, + ) - while True: - item = await queue.get() - if item is None: - break - yield item + self._response_handler = ResponseHandler( + model_manager=self._model_manager, + generation_state=self._generation_state, + ) - return StreamingResponse(event_publisher(), media_type="text/event-stream") + self._transcription_handler = TranscriptionHandler(self._model_manager, self._generation_state) - @app.middleware("http") - async def get_or_set_request_id(request: Request, call_next): - request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) - request.state.request_id = request_id - response = await call_next(request) - response.headers[X_REQUEST_ID] = request_id - return response + app = build_server( + self._model_manager, + self._chat_handler, + response_handler=self._response_handler, + transcription_handler=self._transcription_handler, + enable_cors=enable_cors, + ) - config = uvicorn.Config(app, host=self.host, port=self.port, log_level="info") + config = uvicorn.Config(app, host=host, port=port, log_level=log_level) self.server = uvicorn.Server(config) - if self.non_blocking: + if non_blocking: self.start_server() else: self.server.run() def start_server(self): def _run(): - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - # serve() is a coroutine; it exits when server.should_exit becomes True - self._loop.run_until_complete(self.server.serve()) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.server.serve()) self._thread = threading.Thread(target=_run, name="uvicorn-thread", daemon=False) self._thread.start() - def kill_server(self): - if not self._thread: - raise ValueError("The server cannot be killed as it was not launched in a separate thread.") - - if not self._thread.is_alive(): - raise ValueError("The server is already killed.") - - self.server.should_exit = True - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=2) - def reset_loaded_models(self): - """ - Resets all loaded models. - """ - if self.running_continuous_batching_manager is not None: - logger.warning("Resetting the continuous batching manager.") - self.running_continuous_batching_manager.stop(block=True, timeout=2) - self.running_continuous_batching_manager = None - for model in list(self.loaded_models.values()): - model.delete_model() - self.last_model = None - - def _validate_request( - self, - request: dict, - schema: TypedDict, - validator: "TypeAdapter", - unused_fields: set, - ): - """ - Validates the request against the schema, and checks for unexpected keys. - - Args: - request (`dict`): - The request to validate. - schema (`TypedDict`): - The schema of the request to validate. It is a `TypedDict` definition. - validator (`TypeAdapter`): - The validator to use to validate the request. Built from `schema`. - unused_fields (`set`): - Fields accepted by `schema`, but not used in `transformers serve`. - - Raises: - HTTPException: If the request is invalid or contains unexpected or unused fields. - """ - logger.debug(f"Validating request: {request}") - - # Validate unexpected keys -- Pydantic doesn't validate extra keys in the request. - input_keys = set(request.keys()) - possible_keys = schema.__mutable_keys__ - unexpected_keys = input_keys - possible_keys - if unexpected_keys: - logger.error(f"Unexpected keys in the request: {unexpected_keys}") - raise HTTPException(status_code=422, detail=f"Unexpected keys in the request: {unexpected_keys}") - - if self.input_validation: - # Validate expected keys - try: - validator.validate_python(request) - except ValidationError as e: - logger.error(f"Validation error: {e.errors()}") - raise HTTPException(status_code=422, detail=e.errors()) - - # Validate unused fields - unused_fields_in_request = input_keys & unused_fields - if unused_fields_in_request: - logger.error(f"Unused fields in the request: {unused_fields_in_request}") - raise HTTPException( - status_code=422, detail=f"Unused fields in the request: {unused_fields_in_request}" - ) - - def validate_response_request(self, request: dict): - self._validate_request( - request=request, - schema=TransformersResponseCreateParamsStreaming, - validator=response_validator, - unused_fields=UNUSED_RESPONSE_FIELDS, - ) - - def validate_chat_completion_request(self, request: dict): - self._validate_request( - request=request, - schema=TransformersCompletionCreateParamsStreaming, - validator=completion_validator, - unused_fields=UNUSED_CHAT_COMPLETION_FIELDS, - ) - - def validate_transcription_request(self, request: dict): - self._validate_request( - request=request, - schema=TransformersTranscriptionCreateParams, - validator=transcription_validator, - unused_fields=UNUSED_TRANSCRIPTION_FIELDS, - ) - - def build_chat_completion_chunk( - self, - request_id: str = "", - content: int | None = None, - model: str | None = None, - role: str | None = None, - finish_reason: str | None = None, - tool_calls: list["ChoiceDeltaToolCall"] | None = None, - decode_stream: DecodeStream | None = None, - tokenizer: Optional["PreTrainedTokenizerFast"] = None, - ) -> "ChatCompletionChunk": - """ - Builds a chunk of a streaming OpenAI Chat Completion response. - - IMPORTANT: The serialized chunk won't contain empty fields (fields with `None`). Some downstream apps, - like Cursor, assume that when the field exists, it has data. - - Args: - request_id (`str`): - The request ID. - content (`str`, *optional*): - Content of the response from the model. - model (`str`, *optional*): - The model that generated the content. - role (`str`, *optional*): - The role of the next content, until a new role is defined. - finish_reason (`str`, *optional*): - The reason the generation by the model has finished. - tool_calls (`list[ChoiceDeltaToolCall]`, *optional*): - Data about the tool calls, when they are triggered. - - Returns: - `str`: The built chunk, a string containing a JSON string with the payload. - """ - if decode_stream is not None and content is not None and tokenizer is not None: - content = decode_stream.step(tokenizer._tokenizer, content) - - chunk = ChatCompletionChunk( - id=request_id, - created=int(time.time()), - model=model, - choices=[ - ChoiceChunk( - delta=ChoiceDelta( - content=content, - role=role, - tool_calls=tool_calls, - ), - index=0, - finish_reason=finish_reason, - ) - ], - system_fingerprint="", - object="chat.completion.chunk", - ) - - return chunk - - @staticmethod - def chunk_to_sse_element(chunk: "ChatCompletionChunk | BaseModel") -> str: - """ - Builds an event of a streaming OpenAI Response model or a ChatCompletion chunk. - - IMPORTANT: The serialized chunk won't contain empty fields (fields with `None`). Some downstream apps, - like Cursor, assume that when the field exists, it has data. - - Args: - chunk (`BaseModel` or `ChatCompletionChunk`): - The response to build an event from. One of the multiple OpenAI Response output types - - Returns: - `str`: The built chunk, a string containing a JSON string with the payload. - """ - if isinstance(chunk, str): - # Error paths may yield pre-formatted strings — pass them through as-is. - return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n" - return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" - - @staticmethod - @lru_cache - def get_gen_models(cache_dir: str | None = None) -> list[dict[str, any]]: - """ - List LLMs and VLMs in the cache. - """ - from transformers.models.auto.modeling_auto import ( - MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, - MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, - ) - - generative_models = [] - - logger.warning("Scanning the cache directory for LLMs and VLMs.") - for repo in tqdm(scan_cache_dir(cache_dir).repos): - if repo.repo_type != "model": - continue - - refs = repo.refs - for ref, revision_info in refs.items(): - files = revision_info.files - config_path = next((f.file_path for f in files if f.file_name == "config.json"), None) - - if not config_path: - continue - - config = json.loads(config_path.open().read()) - - if not (isinstance(config, dict) and "architectures" in config): - continue - - architectures = config["architectures"] - llms = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values() - vlms = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() - - if any(arch for arch in architectures if arch in [*llms, *vlms]): - author = repo.repo_id.split("/") if "/" in repo.repo_id else "" - repo_handle = repo.repo_id + (f"@{ref}" if ref != "main" else "") - generative_models.append( - { - "owned_by": author, - "id": repo_handle, - "object": "model", - "created": repo.last_modified, - } - ) - - return generative_models - - def continuous_batching_chat_completion(self, req: dict, request_id: str) -> "StreamingResponse | JSONResponse": - """ - Generates an OpenAI Chat Completion using continuous batching. - - Args: - req (`dict`): The request to generate an OpenAI Chat Completion for. - - Returns: - `Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks. - """ + """Clear all loaded models from memory.""" + self._model_manager.shutdown() - model_id_and_revision = self.process_model_name(req["model"]) - must_discard_cache = model_id_and_revision != self.last_model - - self.last_model = model_id_and_revision - - # When switching models, terminate a continuous batching manager if it is running. - if must_discard_cache: - if self.running_continuous_batching_manager is not None: - self.running_continuous_batching_manager.stop(block=True, timeout=2) - self.running_continuous_batching_manager = None - - model, processor = self.load_model_and_processor(model_id_and_revision) - - # Continuous batching only supports text-only models - if self.get_model_modality(model, processor=processor) != Modality.LLM: - logger.warning_once( - "Continuous batching is not supported for non-text-only models. Falling back to regular generate." - ) - return self.generate_chat_completion(req) - - tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor - - generation_config = create_generation_config_from_req( - req, - model_generation_config=model.generation_config, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - use_cache=False, - do_sample=False, - scheduler="fifo", - ) - - if self.running_continuous_batching_manager is None: - self.running_continuous_batching_manager = model.init_continuous_batching( - generation_config=generation_config - ) - - # TODO (Joao, Lysandre): the logits processors should be fixed in continuous batching and correctly applied in non-cb - self.running_continuous_batching_manager.logit_processor = LogitsProcessorList() - self.running_continuous_batching_manager.start() - - # TODO (Joao, Lysandre): this should also work with tool support - modality = self.get_model_modality(model, processor=processor) - processor_inputs = self.get_processor_inputs_from_inbound_messages(req["messages"], modality) - - inputs = processor.apply_chat_template( - processor_inputs, - add_generation_prompt=True, - tools=req.get("tools"), - return_tensors="pt", - return_dict=True, - tokenize=True, - ) - inputs = inputs["input_ids"][0].to(model.device) - - def stream_chat_completion(request_id, decode_stream): - from ..generation.continuous_batching import RequestStatus - - try: - # Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit - # they come from the assistant. - yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision) - - n_tokens_generated = 0 - for result in self.running_continuous_batching_manager.request_id_iter(request_id): - n_tokens_generated += 1 - - # Always yield the token content (even for the final FINISHED token) - if result.generated_tokens: - token_id = result.generated_tokens[-1] - yield self.build_chat_completion_chunk( - request_id=request_id, - content=token_id, - model=model_id_and_revision, - decode_stream=decode_stream, - tokenizer=tokenizer, - ) - - if result.status == RequestStatus.FINISHED: - generated_all_tokens = ( - generation_config.max_new_tokens is not None - and n_tokens_generated >= generation_config.max_new_tokens - ) - - # If the tokenizer has an eos_token, we can have a more robust check. - if hasattr(tokenizer, "eos_token"): - final_token_is_eos = result == tokenizer.eos_token - generated_all_tokens = generated_all_tokens and not final_token_is_eos - - reason = "length" if generated_all_tokens else "stop" - - yield self.build_chat_completion_chunk( - request_id, - finish_reason=reason, - model=model_id_and_revision, - ) - break - - except Exception as e: - logger.error(str(e)) - self.running_continuous_batching_manager.cancel_request(request_id) - yield f'data: {{"error": "{str(e)}"}}' - - def buffer_chat_completion(_request_id): - result = None - while self.running_continuous_batching_manager.is_running() and result is None: - result = self.running_continuous_batching_manager.get_result(request_id=_request_id, timeout=1) - - content = tokenizer.decode(result.generated_tokens) - - chat_completion_result = ChatCompletion( - id=_request_id, - created=int(time.time()), - object="chat.completion", - model=model_id_and_revision, - choices=[ - Choice( - # TODO check the index - index=0, - message=ChatCompletionMessage(content=content, role="assistant"), - finish_reason="stop", - ) - ], - # TODO implement function calling - # TODO implement usage - ) - - return chat_completion_result - - async def cancellation_wrapper_stream(_request_id): - # Enables cancellation in an async context - try: - decode_stream = DecodeStream(inputs.tolist(), False) - for _chunk in stream_chat_completion(_request_id, decode_stream): - yield self.chunk_to_sse_element(_chunk) - await asyncio.sleep(0) - except asyncio.CancelledError: - self.running_continuous_batching_manager.cancel_request(_request_id) - logger.warning(f"Request {_request_id} was cancelled.") - - def cancellation_wrapper_buffer(_request_id): - # Enables cancellation in an async context - try: - return buffer_chat_completion(_request_id) - except asyncio.CancelledError: - self.running_continuous_batching_manager.cancel_request(_request_id) - logger.warning(f"Request {_request_id} was cancelled.") - - request_id = self.running_continuous_batching_manager.add_request( - inputs, request_id=request_id, max_new_tokens=generation_config.max_new_tokens, streaming=req.get("stream") - ) - - if req.get("stream"): - return StreamingResponse(cancellation_wrapper_stream(request_id), media_type="text/event-stream") - else: - chunk = cancellation_wrapper_buffer(request_id) - json_chunk = chunk.model_dump_json(exclude_none=True) - return JSONResponse(json_chunk, media_type="application/json") - - @staticmethod - def get_model_modality(model: "PreTrainedModel", processor=None) -> Modality: - if processor is not None: - if isinstance(processor, PreTrainedTokenizerBase): - return Modality.LLM - - from transformers.models.auto.modeling_auto import ( - MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, - MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, - ) - - model_classname = model.__class__.__name__ - if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values(): - modality = Modality.VLM - elif model_classname in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): - modality = Modality.LLM - else: - raise ValueError(f"Unknown modality: {model_classname}") - - return modality - - @staticmethod - def get_processor_inputs_from_inbound_messages(messages, modality: Modality): - processor_inputs = [] - - for message in messages: - parsed_message = {"role": message["role"], "content": []} - - if modality == Modality.LLM: - # Input: `content` is a string or a list of dictionaries with a "text" key. - # Output: `content` is a string. - if isinstance(message["content"], str): - parsed_content = message["content"] - elif isinstance(message["content"], list): - parsed_content = [] - for content in message["content"]: - if content["type"] == "text": - parsed_content.append(content["text"]) - parsed_content = " ".join(parsed_content) - parsed_message["content"] = parsed_content - - elif modality == Modality.VLM: - # Input: `content` is a string or a list of dictionaries with a "type" key (possible types: "text", - # "image_url"). - # Output: `content` is a list of dictionaries with a "type" key - if isinstance(message["content"], str): - parsed_message["content"].append({"type": "text", "text": message["content"]}) - else: - for content in message["content"]: - if content["type"] == "text": - parsed_message["content"].append(content) - elif content["type"] == "image_url": - if "base64" in content["image_url"]["url"]: - image_data = re.sub("^data:image/.+;base64,", "", content["image_url"]["url"]) - image = Image.open(BytesIO(base64.b64decode(image_data))) - - file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) - url = file.name - - image.save(file.name) - else: - url = content["image_url"]["url"] - - parsed_message["content"].append({"type": "image", "url": url}) - processor_inputs.append(parsed_message) - return processor_inputs - - def generate_chat_completion(self, req: dict) -> "StreamingResponse | JSONResponse": - """ - Generates an OpenAI Chat Completion using `generate`. - - Args: - req (`dict`): The request to generate an OpenAI Chat Completion for. - - Returns: - `Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks. - """ - - # TODO: This should throw an error in case the specified model in the request is different to the forced model. - if self.force_model is not None: - req["model"] = self.force_model - - messages: Iterable[ChatCompletionMessageParam] = req["messages"] - - # HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a - # request whose last message is from the assistant. - if messages[-1]["role"] == "assistant": + def kill_server(self): + self._generation_state.shutdown() + self._model_manager.shutdown() + if not self._thread or not self._thread.is_alive(): return - - model_id_and_revision = self.process_model_name(req["model"]) - must_discard_cache = model_id_and_revision != self.last_model - - self.last_model = model_id_and_revision - model, processor = self.load_model_and_processor(model_id_and_revision) - - modality = self.get_model_modality(model, processor=processor) - processor_inputs = self.get_processor_inputs_from_inbound_messages(messages, modality) - - # ====== TOOL PREPROCESSING LOGIC ====== - tool_model_family = None - for supported_model_families in _MODELS_WITH_TOOL_SUPPORT: - if supported_model_families in model.config.architectures[0].lower(): - tool_model_family = supported_model_families - break - # TODO: trigger 2 constrained generations after the tool call start token is emitted: - # 1. force generation to pick from the tool names - # 2. force generation to pick from that tool's arguments - # ====== END OF TOOL PREPROCESSING LOGIC ====== - - inputs = processor.apply_chat_template( - processor_inputs, - add_generation_prompt=True, - tools=req.get("tools"), - return_tensors="pt", - return_dict=True, - tokenize=True, - ) - inputs = inputs.to(model.device) - request_id = req.get("request_id", "req_0") - - # Temporary hack for GPTOSS 1: don't filter special tokens - skip_special_tokens = True - if "gptoss" in model.config.architectures[0].lower(): - skip_special_tokens = False - - generation_streamer = TextIteratorStreamer( - processor, - skip_special_tokens=skip_special_tokens, - skip_prompt=True, - ) - - if self.is_continuation(req) and not must_discard_cache: - seq_len = self.last_kv_cache.get_seq_length() - if inputs["input_ids"].shape[-1] > seq_len: - last_kv_cache = self.last_kv_cache - else: - last_kv_cache = None - else: - seq_len = inputs["input_ids"].shape[-1] - last_kv_cache = None - - generation_config = create_generation_config_from_req( - req, - model_generation_config=model.generation_config, - ) - - generation_kwargs = { - **inputs, - "streamer": generation_streamer, - "generation_config": generation_config, - "return_dict_in_generate": True, - "past_key_values": last_kv_cache, - } - - def stream_chat_completion(streamer, _request_id): - # Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output - # classes and piping the reasoning trace into a new field - filter_cot = False - cot_trace_end = None - if "gptoss" in model.config.architectures[0].lower(): - filter_cot = True - cot_trace_end = "<|channel|>final<|message|>" - - # Thin wrapper to save the KV cache after generation - def generate_with_cache(**kwargs): - generate_output = model.generate(**kwargs) - self.last_kv_cache = generate_output.past_key_values - - thread = Thread(target=generate_with_cache, kwargs=generation_kwargs) - results = "" - - try: - thread.start() - tool_state = ToolState() - - # Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit - # they come from the assistant. - yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision) - - result = "" - n_tokens_generated = 0 - - for result in streamer: - n_tokens_generated += 1 - - # Temporary hack for GPT-OSS 3: don't emit the final "<|return|>" - if "gptoss" in model.config.architectures[0].lower(): - result = result.removesuffix("<|return|>") - results += result - - # (related to temporary hack 2) - if filter_cot: - if cot_trace_end in results: # end of reasoning trace observed -> stop filtering - filter_cot = False - continue - else: - continue - - # ====== TOOL CALL LOGIC ====== - if tool_model_family is not None: - # Start of a tool call: reset state variables, set `inside_tool_call` - if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]: - tool_state.inside_tool_call = True - continue - - # End of tool call: reset `inside_tool_call`, emit a `finish_reason` - if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]: - tool_state.reset() - yield self.build_chat_completion_chunk( - request_id=_request_id, - role=None, - finish_reason="tool_calls", - model=model_id_and_revision, - ) - - continue - # Inside a tool call - if tool_state.inside_tool_call: - tool_state.buffer += result - - # First step: extract the tool name (may need several tokens, and we can't emit a delta - # until we have the full name) - if not tool_state.has_tool_name_defined: - tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer) - if tool_name is None: - continue - else: - tool_name = tool_name.group(1) - tool_state.has_tool_name_defined = True - tool = ChoiceDeltaToolCall( - function=ChoiceDeltaToolCallFunction(name=tool_name), - index=0, - type="function", - id=_request_id + "_tool_call", # Only the first tool call delta has an id - ) - - # Second step: extract tool arguments. The tool arguments can be seen as a json string - # within the tool json string. We emit a delta for the arguments. - else: - # Empty text: skip - if result == "": - continue - # Until we see the `"arguments": {` in the buffer, we skip - # TODO: other models will likely need more elaborate processing here - if '"arguments": {' not in tool_state.buffer: - continue - - # Handle nesting. We want to exclude the last } from the emitted arguments (it's - # closing the outermost nesting level, outside the arguments block) - tool_state.arg_nesting_level += result.count("{") - tool_state.arg_nesting_level -= result.count("}") - if tool_state.arg_nesting_level < 0: - result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}" - - tool = ChoiceDeltaToolCall( - function=ChoiceDeltaToolCallFunction(arguments=result), - index=0, - type="function", - ) - - yield self.build_chat_completion_chunk( - request_id=_request_id, - role=None, - tool_calls=[tool], - model=model_id_and_revision, - ) - continue - # ====== END OF TOOL CALL LOGIC ====== - - # All non-tool related tokens are emitted as assistant messages. Empty text is skipped. - if result != "": - yield self.build_chat_completion_chunk( - _request_id, content=result, model=model_id_and_revision - ) - - generated_all_tokens = ( - generation_config.max_new_tokens is not None - and n_tokens_generated >= generation_config.max_new_tokens - ) - - # If the tokenizer has an eos_token, we can have a more robust check. - if hasattr(streamer.tokenizer, "eos_token"): - final_token_is_eos = result == streamer.tokenizer.eos_token - generated_all_tokens = generated_all_tokens and not final_token_is_eos - - reason = "length" if generated_all_tokens else "stop" - - yield self.build_chat_completion_chunk(_request_id, finish_reason=reason, model=model_id_and_revision) - - thread.join() - except Exception as e: - logger.error(str(e)) - yield f'data: {{"error": "{str(e)}"}}' - - finally: - thread.join() - - if req.get("stream"): - return StreamingResponse( - map(self.chunk_to_sse_element, stream_chat_completion(generation_streamer, request_id)), - media_type="text/event-stream", - ) - else: - content = [] - finish_reason = "stop" - - generator = stream_chat_completion(generation_streamer, request_id) - usage = None - - for chunk in generator: - choice = chunk.choices[0] - if getattr(choice.delta, "content", None): - content.append(choice.delta.content) - if choice.finish_reason: - finish_reason = choice.finish_reason - if getattr(chunk, "usage", None): - usage = chunk.usage - - chat_completion_result = ChatCompletion( - id=request_id, - created=int(time.time()), - object="chat.completion", - model=model_id_and_revision, - choices=[ - Choice( - # TODO check the index - index=0, - message=ChatCompletionMessage(content="".join(content), role="assistant"), - finish_reason=finish_reason, - ) - ], - # TODO implement function calling - usage=usage, - ) - - result = chat_completion_result.model_dump(exclude_none=True) - - return JSONResponse(result, media_type="application/json") - - def generate_response(self, req: dict) -> Generator[str, None, None]: - """ - Generates an OpenAI Response using `generate`. - - Args: - req (`dict`): The request to generate an OpenAI Response for. - - Returns: - `Generator[str, None, None]`: A generator that yields the OpenAI Response events. - """ - # TODO -- Implement non-streaming mode - model_id_and_revision = self.process_model_name(req["model"]) - must_discard_cache = model_id_and_revision != self.last_model - self.last_model = model_id_and_revision - model, processor = self.load_model_and_processor(model_id_and_revision) - - if isinstance(req["input"], str): - inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] - inputs.append({"role": "user", "content": req["input"]}) - elif isinstance(req["input"], list): - if "instructions" in req: - if req["input"][0]["role"] != "system": - inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]] - else: - inputs = req["input"] - inputs[0]["content"] = req["instructions"] - else: - inputs = req["input"] - elif isinstance(req["input"], dict): - inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] - inputs.append(req["input"]) - else: - raise TypeError("inputs should be a list, dict, or str") - - inputs = processor.apply_chat_template( - inputs, add_generation_prompt=True, return_tensors="pt", return_dict=True - )["input_ids"] - inputs = inputs.to(model.device) - request_id = req.get("previous_response_id", "req_0") - - # Temporary hack for GPT-OSS 1: don't filter special tokens - skip_special_tokens = True - if "gptoss" in model.config.architectures[0].lower(): - skip_special_tokens = False - - generation_streamer = TextIteratorStreamer( - processor, - skip_special_tokens=skip_special_tokens, - skip_prompt=True, - ) - generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) - - last_kv_cache = None - if self.is_continuation(req) and not must_discard_cache: - seq_len = self.last_kv_cache.get_seq_length() - if inputs.shape[-1] > seq_len: - last_kv_cache = self.last_kv_cache - - generation_kwargs = { - "inputs": inputs, - "attention_mask": torch_ones_like(inputs), - "streamer": generation_streamer, - "generation_config": generation_config, - "return_dict_in_generate": True, - "past_key_values": last_kv_cache, - } - - def stream_response(streamer, _request_id): - # Temporary hack for GPT-OSS 2: filter out the CoT tokens. Full solution here implies defining new output - # classes and piping the reasoning trace into a new field - filter_cot = False - cot_trace_end = None - if "gptoss" in model.config.architectures[0].lower(): - filter_cot = True - cot_trace_end = "<|channel|>final<|message|>" - - # Thin wrapper to save the KV cache after generation - def generate_with_cache(**kwargs): - generate_output = model.generate(**kwargs) - self.last_kv_cache = generate_output.past_key_values - - thread = Thread(target=generate_with_cache, kwargs=generation_kwargs) - sequence_number = 0 - output_index = 0 - content_index = 0 - - try: - thread.start() - created_at = time.time() # the spec expects a unix timestamp in seconds - - # We start by acknowledging the request (the request has `status="queued"`), and then by moving it to - # in progress (`status="in_progress"`) - response_created = ResponseCreatedEvent( - type="response.created", - sequence_number=sequence_number, - response=Response( - id=f"resp_{request_id}", - created_at=created_at, - status="queued", - model=model_id_and_revision, - instructions=req.get("instructions"), - text={"format": {"type": "text"}}, - object="response", - tools=[], - output=[], - parallel_tool_calls=req.get("parallel_tool_calls", False), - tool_choice="auto", - metadata=req.get("metadata"), - ), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_created) - - response_in_progress = ResponseInProgressEvent( - type="response.in_progress", - sequence_number=sequence_number, - response=Response( - id=f"resp_{request_id}", - created_at=created_at, - status="in_progress", - model=model_id_and_revision, - instructions=req.get("instructions"), - text={"format": {"type": "text"}}, - object="response", - tools=[], - output=[], - parallel_tool_calls=req.get("parallel_tool_calls", False), - tool_choice="auto", - metadata=req.get("metadata"), - ), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_in_progress) - - # Start the output item. Emit the assistant role to start the stream. Other chunks won't have a role, - # as it is implicit - response_output_item_added = ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=sequence_number, - output_index=output_index, - item=ResponseOutputMessage( - id=f"msg_{request_id}", type="message", status="in_progress", role="assistant", content=[] - ), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_output_item_added) - - # Start the content part of the event - response_content_part_added = ResponseContentPartAddedEvent( - type="response.content_part.added", - item_id=f"msg_{request_id}", - sequence_number=sequence_number, - output_index=output_index, - content_index=content_index, - part=ResponseOutputText(type="output_text", text="", annotations=[]), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_content_part_added) - - # Stream the actual generated text - results = "" - for result in streamer: - # Temporary hack for GPTOS 3: don't emit the final "<|return|>" - if "gptoss" in model.config.architectures[0].lower(): - result = result.removesuffix("<|return|>") - results += result - - # (related to temporary hack 2) - if filter_cot: - if cot_trace_end in results: # end of reasoning trace observed -> stop filtering - filter_cot = False - results = "" # reset the results -> results will now track the final response - continue - else: - response_output_text_delta = ResponseTextDeltaEvent( - type="response.output_text.delta", - item_id=f"msg_{request_id}", - sequence_number=sequence_number, - output_index=output_index, - content_index=content_index, - delta=result, - logprobs=[], - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_output_text_delta) - else: - # Normal path: emit token deltas when not filtering CoT - if result: - response_output_text_delta = ResponseTextDeltaEvent( - type="response.output_text.delta", - item_id=f"msg_{request_id}", - sequence_number=sequence_number, - output_index=output_index, - content_index=content_index, - delta=result, - logprobs=[], - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_output_text_delta) - - # Signal the end of the text generation - response_output_text_done = ResponseTextDoneEvent( - type="response.output_text.done", - item_id=f"msg_{request_id}", - sequence_number=sequence_number, - output_index=output_index, - content_index=0, - text=results, - logprobs=[], - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_output_text_done) - - # Complete the content part - response_content_part_done = ResponseContentPartDoneEvent( - type="response.content_part.done", - item_id=f"msg_{request_id}", - sequence_number=sequence_number, - output_index=output_index, - content_index=content_index, - part=ResponseOutputText(type="output_text", text=response_output_text_done.text, annotations=[]), - ) - sequence_number += 1 - content_index += 1 - yield self.chunk_to_sse_element(response_content_part_done) - - # Complete the output item - response_output_item_done = ResponseOutputItemDoneEvent( - type="response.output_item.done", - sequence_number=sequence_number, - output_index=output_index, - item=ResponseOutputMessage( - id=f"msg_{request_id}", - type="message", - status="completed", - role="assistant", - content=[response_content_part_done.part], - annotations=[], - ), - ) - sequence_number += 1 - output_index += 1 - yield self.chunk_to_sse_element(response_output_item_done) - - # Finally, Complete the event - response_completed = ResponseCompletedEvent( - type="response.completed", - sequence_number=sequence_number, - response=Response( - id=f"resp_{request_id}", - created_at=created_at, - status="completed", - model=model_id_and_revision, - instructions=req.get("instructions"), - text={"format": {"type": "text"}}, - output=[response_output_item_done.item], - object="response", - tools=[], - parallel_tool_calls=req.get("parallel_tool_calls", False), - tool_choice="auto", - metadata=req.get("metadata"), - ), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_completed) - - thread.join() - except Exception as e: - logger.error(f"Exception in response generation: {str(e)}") - error_event = ResponseErrorEvent( - type="error", - sequence_number=sequence_number, - message=str(e), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(error_event) - - response_failed = ResponseFailedEvent( - type="response.failed", - sequence_number=sequence_number, - response=Response( - id=f"resp_{request_id}", - created_at=created_at, - status="failed", - model=model_id_and_revision, - instructions=req.get("instructions"), - text={"format": {"type": "text"}}, - output=[], - object="response", - tools=[], - parallel_tool_calls=False, - tool_choice="auto", - metadata=req.get("metadata"), - error=ResponseError( - code="server_error", - message=str(e), - ), - ), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_failed) - - finally: - thread.join() - - return stream_response(generation_streamer, request_id) - - def generate_response_non_streaming(self, req: dict) -> dict: - """ - Generates an OpenAI Response in non-streaming mode (single JSON payload). - - Args: - req (`dict`): The request to generate an OpenAI Response for. - - Returns: - `dict`: The OpenAI `Response` serialized as a dict. - """ - model_id_and_revision = self.process_model_name(req["model"]) - must_discard_cache = model_id_and_revision != self.last_model - self.last_model = model_id_and_revision - model, processor = self.load_model_and_processor(model_id_and_revision) - - if isinstance(req["input"], str): - inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] - inputs.append({"role": "user", "content": req["input"]}) - elif isinstance(req["input"], list): - if "instructions" in req: - if req["input"][0]["role"] != "system": - inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]] - else: - inputs = req["input"] - inputs[0]["content"] = req["instructions"] - else: - inputs = req["input"] - elif isinstance(req["input"], dict): - inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] - inputs.append(req["input"]) - else: - raise ValueError("inputs should be a list, dict, or str") - - inputs = processor.apply_chat_template( - inputs, add_generation_prompt=True, return_tensors="pt", return_dict=True - )["input_ids"] - inputs = inputs.to(model.device) - request_id = req.get("previous_response_id", "req_0") - - # Temporary hack for GPTOSS 1: don't filter special tokens - skip_special_tokens = True - if "gptoss" in model.config.architectures[0].lower(): - skip_special_tokens = False - - generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) - - last_kv_cache = None - if self.is_continuation(req) and not must_discard_cache: - seq_len = self.last_kv_cache.get_seq_length() - if inputs.shape[-1] > seq_len: - last_kv_cache = self.last_kv_cache - - generate_output = model.generate( - inputs=inputs, - attention_mask=torch_ones_like(inputs), - generation_config=generation_config, - return_dict_in_generate=True, - past_key_values=last_kv_cache, - ) - # save KV cache - self.last_kv_cache = generate_output.past_key_values - - # Decode full text - full_text = processor.batch_decode(generate_output.sequences, skip_special_tokens=skip_special_tokens)[0] - - created_at = time.time() - response_output_item = ResponseOutputMessage( - id=f"msg_{request_id}", - type="message", - status="completed", - role="assistant", - content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])], - annotations=[], - ) - response_completed = Response( - id=f"resp_{request_id}", - created_at=created_at, - status="completed", - model=model_id_and_revision, - instructions=req.get("instructions"), - text={"format": {"type": "text"}}, - output=[response_output_item], - object="response", - tools=[], - parallel_tool_calls=req.get("parallel_tool_calls", False), - tool_choice="auto", - metadata=req.get("metadata"), - ) - return response_completed.model_dump(exclude_none=True) - - def generate_transcription(self, req: dict) -> Generator[str, None, None]: - """ - Generates an OpenAI Transcription using the audio file. - - Args: - req (`dict`): The request containing the audio file and model information. - - Returns: - `Generator[str, None, None]`: A generator that yields the transcription result. - """ - # TODO: implement streaming transcription (currently, it's not streaming) - if not is_librosa_available(): - raise ImportError( - "Missing librosa dependency for audio transcription. Please install with `pip install librosa`" - ) - model_id_and_revision = self.process_model_name(req["model"]) - audio_model, audio_processor = self.load_audio_model_and_processor(model_id_and_revision) - - generation_streamer = TextIteratorStreamer( - audio_processor.tokenizer, skip_special_tokens=True, skip_prompt=True - ) - generation_config = create_generation_config_from_req( - req, model_generation_config=audio_model.generation_config - ) - - # Read the binary audio file using librosa - model_sampling_rate = audio_processor.feature_extractor.sampling_rate - audio_bytes = io.BytesIO(req["file"]) - audio_array, _ = librosa.load(audio_bytes, sr=model_sampling_rate, mono=True) - audio_inputs = audio_processor(audio_array, sampling_rate=model_sampling_rate, return_tensors="pt").to( - audio_model.device - ) - audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype) - - generation_kwargs = { - "streamer": generation_streamer, - "generation_config": generation_config, - "return_dict_in_generate": True, - } - - def _generate_transcription(): - generated_ids = audio_model.generate(**audio_inputs, **generation_kwargs) - transcription_text = audio_processor.batch_decode(generated_ids.sequences, skip_special_tokens=True)[0] - transcription = Transcription(text=transcription_text) - yield f"{transcription.model_dump_json(exclude_none=True)}" - - return _generate_transcription() - - def is_continuation(self, req: dict) -> bool: - """ - Determines whether the current request is a continuation of the last request. In other words, if it is the - same chat session. - - Args: - req (`dict`): The request to check. - - Returns: - `True` if the request is a continuation of the last request, `False` otherwise. - """ - messages = req.get("messages") or req.get("input") # ChatCompletion and Response have different fields - req_continues_last_messages = True - - # No cached messages: this is a new request - if self.last_messages is None: - req_continues_last_messages = False - # The new request has no new rounds of conversation: this is a new request - elif len(self.last_messages) >= len(messages): - req_continues_last_messages = False - # Otherwise, check that the last messages are a subset of the new request - else: - for i in range(len(self.last_messages)): - if self.last_messages[i] != messages[i]: - req_continues_last_messages = False - break - - self.last_messages = messages - return req_continues_last_messages - - def get_quantization_config(self) -> BitsAndBytesConfig | None: - """ - Returns the quantization config for the given CLI arguments. - - Returns: - `Optional[BitsAndBytesConfig]`: The quantization config. - """ - if self.quantization == "bnb-4bit": - quantization_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_use_double_quant=True, - ) - elif self.quantization == "bnb-8bit": - quantization_config = BitsAndBytesConfig(load_in_8bit=True) - else: - quantization_config = None - - if quantization_config is not None: - logger.warning(f"Quantization applied with the following config: {quantization_config}") - - return quantization_config - - def process_model_name(self, model_id: str) -> str: - """ - Applies the `force_model` CLI argument and canonicalizes the model name to the format "model_id@revision". - If the model_id DOESN'T contain an @, it defaults to "model_id@main". - - Args: - model_id (`str`): The model ID. - - Returns: - `str`: The canonicalized model name to be used - """ - if self.force_model is not None: - model_id = self.force_model - if "@" in model_id: - return model_id - return f"{model_id}@main" - - def _load_model_and_data_processor( - self, model_id_and_revision: str, progress_callback: Callable[[dict], None] | None = None - ): - """ - Generic method to load a model and a data processor from a model ID and revision, making use of the serve CLI - arguments. - - Args: - model_id_and_revision (`str`): - The model ID and revision to load. - model_cls (`type[PreTrainedModel]`): - The model class to load. - - Returns: - `tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]`: The loaded model and - data processor (tokenizer, audio processor, etc.). - """ - import torch - - from transformers import AutoConfig, AutoProcessor - - callback = progress_callback - mid = model_id_and_revision - - tqdm_class = set_tqdm_class(callback, mid) if callback is not None else None - - def emit_progress(stage: str): - if progress_callback is None: - return - progress_callback({"status": "loading", "model": model_id_and_revision, "stage": stage}) - - logger.warning(f"Loading {model_id_and_revision}") - emit_progress("processor") - - if "@" in model_id_and_revision: - model_id, revision = model_id_and_revision.split("@", 1) - else: - model_id, revision = model_id_and_revision, "main" - - try: - data_processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - trust_remote_code=self.trust_remote_code, - ) - except OSError: - try: - data_processor = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - trust_remote_code=self.trust_remote_code, - ) - except OSError: - raise OSError("Failed to load processor with `AutoProcessor` and `AutoTokenizer`.") - - # processor done — move on to config - dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype) - quantization_config = self.get_quantization_config() - - model_kwargs = { - "revision": revision, - "attn_implementation": self.attn_implementation, - "dtype": dtype, - "device_map": self.device, - "trust_remote_code": self.trust_remote_code, - "quantization_config": quantization_config, - } - - emit_progress("config") - config = AutoConfig.from_pretrained(model_id, **model_kwargs) - architecture = getattr(transformers, config.architectures[0]) - # weights stage events are emitted by _WeightsTqdm (and download by _DownloadAggregator) - model = architecture.from_pretrained(model_id, tqdm_class=tqdm_class, **model_kwargs) - - has_default_max_length = ( - model.generation_config.max_new_tokens is None and model.generation_config.max_length == 20 - ) - has_short_max_new_tokens = ( - model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 1024 - ) - if has_default_max_length or has_short_max_new_tokens: - model.generation_config.max_new_tokens = 1024 - - return model, data_processor - - def load_model_and_processor( - self, model_id_and_revision: str, progress_callback: Callable[[dict], None] | None = None - ) -> tuple["PreTrainedModel", "PreTrainedTokenizerFast"]: - """ - Loads the text model and processor from the given model ID and revision into the ServeCommand instance. - - Args: - model_id_and_revision (`str`): - The model ID and revision to load. - - Returns: - `tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and processor. - """ - with self.model_locks_guard: - lock = self.model_locks.setdefault(model_id_and_revision, threading.Lock()) - - with lock: - if ( - model_id_and_revision not in self.loaded_models - or self.loaded_models[model_id_and_revision].is_deleted() - ): - model, processor = self._load_model_and_data_processor( - model_id_and_revision, progress_callback=progress_callback - ) - self.loaded_models[model_id_and_revision] = TimedModel( - model, - timeout_seconds=self.model_timeout, - processor=processor, - ) - if progress_callback is not None: - progress_callback({"status": "ready", "model": model_id_and_revision, "cached": False}) - else: - self.loaded_models[model_id_and_revision].reset_timer() - model = self.loaded_models[model_id_and_revision].model - processor = self.loaded_models[model_id_and_revision].processor - if progress_callback is not None: - progress_callback({"status": "ready", "model": model_id_and_revision, "cached": True}) - - return model, processor - - def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple["PreTrainedModel", "ProcessorMixin"]: - """ - Loads the audio model and processor from the given model ID and revision into the ServeCommand instance. - - Args: - model_id_and_revision (`str`): - The model ID and revision to load. - - Returns: - `tuple[PreTrainedModel, ProcessorMixin]`: The loaded audio model and processor. - """ - with self.model_locks_guard: - lock = self.model_locks.setdefault(model_id_and_revision, threading.Lock()) - - with lock: - if ( - model_id_and_revision not in self.loaded_models - or self.loaded_models[model_id_and_revision].is_deleted() - ): - logger.warning(f"Loading model into cache: {model_id_and_revision}") - audio_model, audio_processor = self._load_model_and_data_processor(model_id_and_revision) - self.loaded_models[model_id_and_revision] = TimedModel( - audio_model, - timeout_seconds=self.model_timeout, - processor=audio_processor, - ) - else: - self.loaded_models[model_id_and_revision].reset_timer() - audio_model = self.loaded_models[model_id_and_revision].model - audio_processor = self.loaded_models[model_id_and_revision].processor - - return audio_model, audio_processor + self.server.should_exit = True + self._thread.join(timeout=2) -# set docstring separately to make it look nice (Typer doesn't play well with the class command) Serve.__doc__ = """ Run a FastAPI server to serve models on-demand with an OpenAI compatible API. - Models will be loaded and unloaded automatically based on usage and a timeout. \b -The server will expose the following endpoints: - - POST /v1/chat/completions: Generates chat completions. - - POST /v1/responses: Generates responses. - - POST /v1/audio/transcriptions: Generates transcriptions from audio. - - GET /v1/models: Lists available models for 3rd party tools. +Endpoints: + POST /v1/chat/completions — Chat completions (streaming + non-streaming). + GET /v1/models — Lists available models. + GET /health — Health check. -Requires FastAPI and Uvicorn to be installed. +Requires FastAPI and Uvicorn: pip install transformers[serving] """ - -if __name__ == "__main__": - serve = Serve() diff --git a/src/transformers/cli/serve_refactored.py b/src/transformers/cli/serve_refactored.py deleted file mode 100644 index 20bac9b5d8f2..000000000000 --- a/src/transformers/cli/serve_refactored.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -CLI entry point for `transformers serve`. -""" - -import asyncio -import threading -from typing import Annotated - -import typer - -from transformers.utils import logging -from transformers.utils.import_utils import ( - is_fastapi_available, - is_openai_available, - is_pydantic_available, - is_uvicorn_available, -) - -from .serving.utils import set_torch_seed - - -serve_dependencies_available = ( - is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available() -) - -logger = logging.get_logger(__name__) - - -class Serve: - def __init__( - self, - force_model: Annotated[str | None, typer.Argument(help="Model to preload and use for all requests.")] = None, - # Model options - continuous_batching: Annotated[ - bool, typer.Option(help="Enable continuous batching with paged attention for higher throughput.") - ] = False, - attn_implementation: Annotated[ - str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).") - ] = None, - compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False, - quantization: Annotated[ - str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.") - ] = None, - device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto", - dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto", - trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False, - model_timeout: Annotated[ - int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.") - ] = 300, - # Server options - host: Annotated[str, typer.Option(help="Server listen address.")] = "localhost", - port: Annotated[int, typer.Option(help="Server listen port.")] = 8000, - enable_cors: Annotated[bool, typer.Option(help="Enable permissive CORS.")] = False, - log_level: Annotated[str, typer.Option(help="Logging level (e.g. 'info', 'warning').")] = "warning", - default_seed: Annotated[int | None, typer.Option(help="Default torch seed.")] = None, - non_blocking: Annotated[ - bool, typer.Option(hidden=True, help="Run server in a background thread. Used by tests.") - ] = False, - ) -> None: - if not serve_dependencies_available: - raise ImportError("Missing dependencies for serving. Install with `pip install transformers[serving]`") - - import uvicorn - - from .serving.chat_completion import ChatCompletionHandler - from .serving.model_manager import ModelManager - from .serving.response import ResponseHandler - from .serving.server import build_server - from .serving.transcription import TranscriptionHandler - from .serving.utils import GenerationState - - # Seed - if default_seed is not None: - set_torch_seed(default_seed) - - # Logging - transformers_logger = logging.get_logger("transformers") - transformers_logger.setLevel(logging.log_levels[log_level.lower()]) - - self._model_manager = ModelManager( - device=device, - dtype=dtype, - trust_remote_code=trust_remote_code, - attn_implementation=attn_implementation, - quantization=quantization, - model_timeout=model_timeout, - force_model=force_model, - ) - self._generation_state = GenerationState(continuous_batching=continuous_batching, compile=compile) - - self._chat_handler = ChatCompletionHandler( - model_manager=self._model_manager, - generation_state=self._generation_state, - ) - - self._response_handler = ResponseHandler( - model_manager=self._model_manager, - generation_state=self._generation_state, - ) - - self._transcription_handler = TranscriptionHandler(self._model_manager, self._generation_state) - - app = build_server( - self._model_manager, - self._chat_handler, - response_handler=self._response_handler, - transcription_handler=self._transcription_handler, - enable_cors=enable_cors, - ) - - config = uvicorn.Config(app, host=host, port=port, log_level=log_level) - self.server = uvicorn.Server(config) - - if non_blocking: - self.start_server() - else: - self.server.run() - - def start_server(self): - def _run(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self.server.serve()) - - self._thread = threading.Thread(target=_run, name="uvicorn-thread", daemon=False) - self._thread.start() - - def reset_loaded_models(self): - """Clear all loaded models from memory.""" - self._model_manager.shutdown() - - def kill_server(self): - self._generation_state.shutdown() - self._model_manager.shutdown() - if not self._thread or not self._thread.is_alive(): - return - self.server.should_exit = True - self._thread.join(timeout=2) - - -Serve.__doc__ = """ -Run a FastAPI server to serve models on-demand with an OpenAI compatible API. -Models will be loaded and unloaded automatically based on usage and a timeout. - -\b -Endpoints: - POST /v1/chat/completions — Chat completions (streaming + non-streaming). - GET /v1/models — Lists available models. - GET /health — Health check. - -Requires FastAPI and Uvicorn: pip install transformers[serving] -""" From 7855606b1300642200e543c503e81f9eae2e9a31 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 30 Mar 2026 15:50:19 +0000 Subject: [PATCH 48/64] remove bench for now --- tests/cli/bench_cb_raw.py | 217 ---------- tests/cli/benchmark_serve.py | 693 ------------------------------ tests/cli/benchmark_serve_load.py | 516 ---------------------- 3 files changed, 1426 deletions(-) delete mode 100644 tests/cli/bench_cb_raw.py delete mode 100644 tests/cli/benchmark_serve.py delete mode 100644 tests/cli/benchmark_serve_load.py diff --git a/tests/cli/bench_cb_raw.py b/tests/cli/bench_cb_raw.py deleted file mode 100644 index 15c61de16ec6..000000000000 --- a/tests/cli/bench_cb_raw.py +++ /dev/null @@ -1,217 +0,0 @@ -""" -Raw continuous batching benchmark — no HTTP, no serve layer. -2x2 matrix: {non_stream, stream} × {legacy get_result, optimized async}. - -Usage: - CUDA_VISIBLE_DEVICES=0 python tests/cli/bench_cb_raw.py - CUDA_VISIBLE_DEVICES=0 python tests/cli/bench_cb_raw.py --batch 10 50 100 500 1000 2000 -""" - -import argparse -import asyncio -import os -import sys -import time - - -os.environ["CUDA_VISIBLE_DEVICES"] = "0" - -import torch - -from transformers import AutoModelForCausalLM, AutoTokenizer, ContinuousBatchingConfig, GenerationConfig - - -def make_prompts(tokenizer, n, target_len=256): - filler = "The quick brown fox jumps over the lazy dog. Pack my box with five dozen liquor jugs. " * 100 - ids = tokenizer.encode(filler, add_special_tokens=False) - return [ids[: max(10, int(target_len * (0.8 + 0.4 * (i % 5) / 4)))] for i in range(n)] - - -# --------------------------------------------------------------------------- -# Non-streaming (CB streaming=False → one output per request when finished) -# --------------------------------------------------------------------------- - - -def bench_ns_get_result(mgr, prompts, max_new_tokens): - """Non-stream + get_result: batch add, poll shared queue.""" - N = len(prompts) - t0 = time.perf_counter() - mgr.add_requests(inputs=prompts, max_new_tokens=max_new_tokens, streaming=False) - total = finished = 0 - while finished < N: - r = mgr.get_result(timeout=1) - if r is not None and r.is_finished(): - total += len(r.generated_tokens) - finished += 1 - return total, time.perf_counter() - t0 - - -async def bench_ns_handler(mgr, prompts, max_new_tokens): - """Non-stream + handler: register_result_handler per request, resolve future on finish.""" - loop = asyncio.get_running_loop() - t0 = time.perf_counter() - futures = [] - for i, ids in enumerate(prompts): - rid = f"nsh_{time.perf_counter_ns()}_{i}" - future = loop.create_future() - - def _on_result(output, fut=future): - if not fut.done() and output.is_finished(): - fut.set_result(len(output.generated_tokens)) - - mgr.register_result_handler(rid, _on_result) - mgr.add_request(ids, request_id=rid, max_new_tokens=max_new_tokens, streaming=False) - futures.append(future) - results = await asyncio.gather(*futures) - return sum(results), time.perf_counter() - t0 - - -# --------------------------------------------------------------------------- -# Streaming (CB streaming=True → one output per token per request) -# --------------------------------------------------------------------------- - - -def bench_s_get_result(mgr, prompts, max_new_tokens): - """Stream + get_result: batch add, poll shared queue, skip intermediate outputs.""" - N = len(prompts) - t0 = time.perf_counter() - mgr.add_requests(inputs=prompts, max_new_tokens=max_new_tokens, streaming=True) - total = finished = 0 - while finished < N: - r = mgr.get_result(timeout=1) - if r is not None and r.is_finished(): - total += len(r.generated_tokens) - finished += 1 - return total, time.perf_counter() - t0 - - -async def bench_s_handler(mgr, prompts, max_new_tokens): - """Stream + handler: register_result_handler per request, await future on finish.""" - loop = asyncio.get_running_loop() - t0 = time.perf_counter() - futures = [] - for i, ids in enumerate(prompts): - rid = f"sh_{time.perf_counter_ns()}_{i}" - future = loop.create_future() - - def _on_output(output, fut=future): - if not fut.done() and output.is_finished(): - fut.set_result(len(output.generated_tokens)) - - mgr.register_result_handler(rid, _on_output) - mgr.add_request(ids, request_id=rid, max_new_tokens=max_new_tokens, streaming=True) - futures.append(future) - - results = await asyncio.gather(*futures) - return sum(results), time.perf_counter() - t0 - - -# --------------------------------------------------------------------------- -# Runner -# --------------------------------------------------------------------------- - -METHODS = { - "ns_get_result": ("Non-stream + get_result", lambda mgr, p, m: bench_ns_get_result(mgr, p, m)), - "ns_handler": ("Non-stream + handler", lambda mgr, p, m: asyncio.run(bench_ns_handler(mgr, p, m))), - "s_get_result": ("Stream + get_result", lambda mgr, p, m: bench_s_get_result(mgr, p, m)), - "s_handler": ("Stream + handler", lambda mgr, p, m: asyncio.run(bench_s_handler(mgr, p, m))), -} - - -def main(): - parser = argparse.ArgumentParser(description="Raw CB benchmark (2x2 matrix)") - parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-7B-Instruct") - parser.add_argument("--batch", type=int, nargs="+", default=[10, 50, 100, 500]) - parser.add_argument("--max-new-tokens", type=int, default=64) - parser.add_argument("--prompt-tokens", type=int, default=256) - parser.add_argument("--warmup", type=int, default=2) - parser.add_argument("--runs", type=int, default=3) - parser.add_argument("--methods", type=str, nargs="+", default=list(METHODS.keys()), choices=list(METHODS.keys())) - args = parser.parse_args() - - print(f"Model: {args.model}") - print(f"Batch: {args.batch} | Prompt: ~{args.prompt_tokens} tok | Gen: {args.max_new_tokens} tok") - print(f"Warmup: {args.warmup} | Runs: {args.runs} | Methods: {args.methods}") - sys.stdout.flush() - - model = ( - AutoModelForCausalLM.from_pretrained( - args.model, - dtype=torch.bfloat16, - attn_implementation="flash_attention_3", - ) - .cuda() - .eval() - ) - tokenizer = AutoTokenizer.from_pretrained(args.model) - all_prompts = make_prompts(tokenizer, max(args.batch), args.prompt_tokens) - - gen_config = GenerationConfig(max_new_tokens=args.max_new_tokens, do_sample=False) - cb_config = ContinuousBatchingConfig() - - # Header - col_w = 20 - header = f"{'N':>6}" - for m in args.methods: - label = METHODS[m][0] - header += f" | {label:>{col_w}}" - print(f"\n{header}") - print("-" * len(header)) - sys.stdout.flush() - - # Per-batch-size context: each N gets fresh CUDA graph capture - for N in args.batch: - prompts = all_prompts[:N] - - row = f"{N:>6}" - for method_key in args.methods: - _, fn = METHODS[method_key] - # Fresh CB context per method — each gets its own CUDA graph cache - with model.continuous_batching_context_manager( - generation_config=gen_config, - continuous_batching_config=cb_config, - block=True, - timeout=5, - ) as mgr: - # Warmup with the same method being tested - warmup_prompts = prompts[: min(200, N)] - for _ in range(args.warmup): - fn(mgr, warmup_prompts, args.max_new_tokens) - # Measured runs - best = 0 - for _ in range(args.runs): - tokens, dt = fn(mgr, prompts, args.max_new_tokens) - best = max(best, tokens / dt if dt > 0 else 0) - row += f" | {best:>{col_w - 4}.0f} t/s" - print(row, flush=True) - - # Quality check - print("\n--- Quality check ---") - with model.continuous_batching_context_manager( - generation_config=gen_config, - continuous_batching_config=cb_config, - block=True, - timeout=5, - ) as mgr: - - async def check(): - loop = asyncio.get_running_loop() - for i in range(3): - rid = f"qc_{i}" - future = loop.create_future() - - def _on_qc(output, fut=future): - if not fut.done(): - fut.set_result(output) - - mgr.register_result_handler(rid, _on_qc) - mgr.add_request(all_prompts[i], request_id=rid, max_new_tokens=args.max_new_tokens, streaming=False) - r = await future - text = tokenizer.decode(r.generated_tokens, skip_special_tokens=True)[:80] - print(f" {r.request_id}: {len(r.generated_tokens)} tokens | {text}") - - asyncio.run(check()) - - -if __name__ == "__main__": - main() diff --git a/tests/cli/benchmark_serve.py b/tests/cli/benchmark_serve.py deleted file mode 100644 index ecd371694d20..000000000000 --- a/tests/cli/benchmark_serve.py +++ /dev/null @@ -1,693 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Benchmark prefill and decode throughput for `transformers serve`. - -Tests: -- pp (prefill): sends a large prompt with max_tokens=1. Measures TTFT ≈ pure prefill time. - Default sizes: 256, 1024 tokens. -- tg (decode): sends a 512-token prompt (--tg-prefill) and generates many tokens. - Measures decode throughput after subtracting TTFT. Default sizes: 128, 512 tokens. - -Modes: -- bench: greedy decoding (do_sample=False, temp=0). Deterministic, best for reproducible numbers. -- chat: sampling (do_sample=True, temp=0.7). Simulates real chat usage. - -Recommended benchmarks: - - # HF model — greedy - python tests/cli/benchmark_serve.py --model Qwen/Qwen2.5-7B-Instruct --mode bench - - # HF model — sampling (simulates real chat) - python tests/cli/benchmark_serve.py --model Qwen/Qwen2.5-7B-Instruct --mode chat - - # GGUF model — greedy - python tests/cli/benchmark_serve.py \\ - --model "Qwen/Qwen2.5-7B-Instruct-GGUF/qwen2.5-7b-instruct-fp16-00001-of-00004.gguf --processor Qwen/Qwen2.5-7B-Instruct" \\ - --mode bench - - # GGUF model — sampling - python tests/cli/benchmark_serve.py \\ - --model "Qwen/Qwen2.5-7B-Instruct-GGUF/qwen2.5-7b-instruct-fp16-00001-of-00004.gguf --processor Qwen/Qwen2.5-7B-Instruct" \\ - --mode chat - - # Against an existing server - python tests/cli/benchmark_serve.py --url http://localhost:8000 --processor Qwen/Qwen2.5-7B-Instruct -""" - -import argparse -import json -import os -import statistics -import time - - -# Force single GPU — must be set before any CUDA initialization -os.environ["CUDA_VISIBLE_DEVICES"] = "0" - -import requests - -from transformers import AutoTokenizer - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -_FILLER = ( - "The quick brown fox jumps over the lazy dog. " - "Pack my box with five dozen liquor jugs. " - "How vexingly quick daft zebras jump. " - "Sphinx of black quartz, judge my vow. " -) * 200 - -_TG_PREFILL_DEFAULT = 512 - - -def make_prompt(tokenizer, num_tokens: int) -> str: - """Build a prompt string that tokenizes to exactly `num_tokens` tokens.""" - token_ids = tokenizer.encode(_FILLER, add_special_tokens=False) - if len(token_ids) < num_tokens: - repeats = (num_tokens // len(token_ids)) + 1 - token_ids = (token_ids * repeats)[:num_tokens] - else: - token_ids = token_ids[:num_tokens] - return tokenizer.decode(token_ids) - - -def wait_for_server(base_url: str, timeout: int = 120) -> bool: - """Poll GET /health until 200 or timeout.""" - deadline = time.time() + timeout - while time.time() < deadline: - try: - if requests.get(f"{base_url}/health", timeout=2).status_code == 200: - return True - except requests.ConnectionError: - pass - time.sleep(1) - return False - - -def streaming_chat_completion( - base_url: str, - messages: list, - max_tokens: int, - seed: int, - do_sample: bool = False, -) -> dict: - """Send a streaming chat completion request. Returns {total, ttft, completion_tokens, text}.""" - gen_cfg = {"max_new_tokens": max_tokens, "do_sample": do_sample, "eos_token_id": -1} - if do_sample: - gen_cfg["temperature"] = 0.7 - - payload = { - "messages": messages, - "stream": True, - "seed": seed, - "generation_config": json.dumps(gen_cfg), - } - - t_start = time.perf_counter() - t_first_token = None - completion_tokens = None - text_chunks = [] - - resp = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=True, timeout=300) - resp.raise_for_status() - - for line in resp.iter_lines(decode_unicode=True): - if not line or not line.startswith("data: "): - continue - data_str = line[len("data: ") :] - if data_str.strip() == "[DONE]": - break - try: - chunk = json.loads(data_str) - except json.JSONDecodeError: - continue - - choices = chunk.get("choices", []) - if not choices: - continue - - content = choices[0].get("delta", {}).get("content") - if content is not None and content != "": - text_chunks.append(content) - if t_first_token is None: - t_first_token = time.perf_counter() - - if chunk.get("usage"): - completion_tokens = chunk["usage"].get("completion_tokens") - - if choices[0].get("finish_reason") is not None: - break - - t_end = time.perf_counter() - - return { - "total": t_end - t_start, - "ttft": (t_first_token - t_start) if t_first_token else None, - "completion_tokens": completion_tokens, - "text": "".join(text_chunks), - } - - -def streaming_response( - base_url: str, - messages: list, - max_tokens: int, - seed: int, - do_sample: bool = False, -) -> dict: - """Send a streaming responses API request. Returns {total, ttft, completion_tokens, text}.""" - gen_cfg = {"max_new_tokens": max_tokens, "do_sample": do_sample, "eos_token_id": -1} - if do_sample: - gen_cfg["temperature"] = 0.7 - - # Convert messages to Responses API input format - input_messages = messages - payload = { - "input": input_messages, - "stream": True, - "seed": seed, - "generation_config": json.dumps(gen_cfg), - } - - t_start = time.perf_counter() - t_first_token = None - completion_tokens = None - text_chunks = [] - - resp = requests.post(f"{base_url}/v1/responses", json=payload, stream=True, timeout=300) - resp.raise_for_status() - - for line in resp.iter_lines(decode_unicode=True): - if not line or not line.startswith("data: "): - continue - try: - chunk = json.loads(line[len("data: ") :]) - except json.JSONDecodeError: - continue - - etype = chunk.get("type") - if etype == "response.output_text.delta": - delta = chunk.get("delta", "") - if delta: - text_chunks.append(delta) - if t_first_token is None: - t_first_token = time.perf_counter() - elif etype == "response.completed": - usage = chunk.get("response", {}).get("usage", {}) - completion_tokens = usage.get("output_tokens") - break - - t_end = time.perf_counter() - text = "".join(text_chunks) - - return { - "total": t_end - t_start, - "ttft": (t_first_token - t_start) if t_first_token else None, - "completion_tokens": completion_tokens, - "text": text, - } - - -def streaming_request( - base_url: str, - messages: list, - max_tokens: int, - seed: int, - do_sample: bool = False, - endpoint: str = "chat", -) -> dict: - """Dispatch to chat completions or responses API based on endpoint.""" - kw = {"base_url": base_url, "messages": messages, "max_tokens": max_tokens, "seed": seed, "do_sample": do_sample} - if endpoint == "responses": - return streaming_response(**kw) - return streaming_chat_completion(**kw) - - -# --------------------------------------------------------------------------- -# Scenarios -# --------------------------------------------------------------------------- - - -def bench_pp( - base_url: str, - tokenizer, - pp: int, - warmup: int, - iterations: int, - seed: int, - do_sample: bool = False, - endpoint: str = "chat", -) -> dict: - """Prefill benchmark: large prompt, max_tokens=1. Measures TTFT ≈ pure prefill time.""" - prompt = make_prompt(tokenizer, pp) - messages = [{"role": "user", "content": prompt}] - kw = {"do_sample": do_sample, "endpoint": endpoint} - - for _ in range(warmup): - streaming_request(base_url, messages, max_tokens=1, seed=seed, **kw) - - ttfts = [] - for _ in range(iterations): - r = streaming_request(base_url, messages, max_tokens=1, seed=seed, **kw) - if r["ttft"] is not None: - ttfts.append(r["ttft"]) - - ttft = statistics.median(ttfts) if ttfts else None - tok_s = pp / ttft if ttft and ttft > 0 else None - - return {"test": f"pp{pp}", "tokens": pp, "tok_s": tok_s, "time": ttft} - - -def bench_tg( - base_url: str, - tokenizer, - tg: int, - warmup: int, - iterations: int, - seed: int, - tg_prefill: int = 512, - do_sample: bool = False, - endpoint: str = "chat", -) -> dict: - """Decode benchmark: generate `tg` tokens after a `tg_prefill`-token prompt.""" - prompt = make_prompt(tokenizer, tg_prefill) - messages = [{"role": "user", "content": prompt}] - kw = {"do_sample": do_sample, "endpoint": endpoint} - - for _ in range(warmup): - streaming_request(base_url, messages, max_tokens=tg, seed=seed, **kw) - - decode_times = [] - token_counts = [] - last_text = "" - for _ in range(iterations): - r = streaming_request(base_url, messages, max_tokens=tg, seed=seed, **kw) - if r["ttft"] is not None: - decode_times.append(r["total"] - r["ttft"]) - token_counts.append(r["completion_tokens"] if r["completion_tokens"] is not None else tg) - last_text = r["text"] - - if decode_times: - dt = statistics.median(decode_times) - toks = statistics.median(token_counts) - tok_s = toks / dt if dt > 0 else None - else: - dt = None - toks = tg - tok_s = None - - return {"test": f"tg{tg}", "tokens": int(toks), "tok_s": tok_s, "time": dt, "text": last_text} - - -# --------------------------------------------------------------------------- -# Output -# --------------------------------------------------------------------------- - - -def format_duration(seconds) -> str: - if seconds is None: - return "N/A" - if seconds < 1.0: - return f"{seconds * 1000:.1f}ms" - return f"{seconds:.2f}s" - - -def format_throughput(value) -> str: - if value is None: - return "N/A" - return f"{value:.2f}" - - -_PREVIEW_WIDTH = 120 - - -def truncate_preview(text: str, width: int = _PREVIEW_WIDTH) -> str: - """Single-line preview of generated text.""" - if not text: - return "" - line = text.replace("\n", " ").strip() - if len(line) > width: - return line[: width - 1] + "\u2026" - return line - - -def print_table( - rows: list[dict], title: str = "", reference_texts: dict | None = None, is_reference: bool = False -) -> None: - """Print results in a bordered table. - - Args: - reference_texts: dict mapping test name (e.g. "tg128") to reference text. - When provided, decode rows show REF/MATCH/MISMATCH. - is_reference: if True, this is the reference table — show REF instead of MATCH. - """ - if not rows: - return - - has_text = any(row.get("text") for row in rows) - has_ref = reference_texts is not None and has_text - - headers = ["test", "tokens", "tok/s", "time"] - align = ["<", ">", ">", ">"] - if has_ref: - headers.append("ref") - align.append("<") - if has_text: - headers.append("preview") - align.append("<") - - formatted_rows = [] - for row in rows: - cells = [ - row["test"], - str(row["tokens"]), - format_throughput(row["tok_s"]), - format_duration(row["time"]), - ] - text = row.get("text", "") - if has_ref: - ref_text = reference_texts.get(row["test"]) - if not text: - cells.append("") - elif ref_text is None: - cells.append("") - elif is_reference: - cells.append("REF") - elif text == ref_text: - cells.append("MATCH") - else: - cells.append("MISMATCH") - if has_text: - cells.append(truncate_preview(text)) - formatted_rows.append(cells) - - widths = [max(len(h), *(len(r[i]) for r in formatted_rows)) for i, h in enumerate(headers)] - - def pad(text, width, a): - return text.ljust(width) if a == "<" else text.rjust(width) - - def make_row(cells): - return "| " + " | ".join(pad(c, w, a) for c, w, a in zip(cells, widths, align)) + " |" - - def make_sep(char="-"): - return "+" + "+".join(char * (w + 2) for w in widths) + "+" - - print() - if title: - print(title) - print(make_sep("-")) - print(make_row(headers)) - print(make_sep("=")) - for r in formatted_rows: - print(make_row(r)) - print(make_sep("-")) - print() - - -# --------------------------------------------------------------------------- -# Server management -# --------------------------------------------------------------------------- - - -def start_server( - model: str, - port: int, - processor: str | None = None, - attn_implementation: str | None = None, - compile: bool = False, - continuous_batching: bool = False, -): - """Start a transformers serve instance. Returns the Serve object.""" - from transformers.cli.serve_refactored import Serve - - kwargs = {"force_model": model, "port": port, "non_blocking": True, "log_level": "warning"} - if processor: - kwargs["processor"] = processor - if attn_implementation: - kwargs["attn_implementation"] = attn_implementation - if compile: - kwargs["compile"] = True - if continuous_batching: - kwargs["continuous_batching"] = True - return Serve(**kwargs) - - -def parse_model_spec(spec: str) -> dict: - """Parse 'model_id' or 'model_id --processor tokenizer_id'. - - Returns {"model": str, "processor": str | None, "tokenizer": str} - """ - parts = spec.split() - model = parts[0] - processor = None - for i, p in enumerate(parts): - if p == "--processor" and i + 1 < len(parts): - processor = parts[i + 1] - tokenizer_id = processor if processor else model - return {"model": model, "processor": processor, "tokenizer": tokenizer_id} - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - - -def main(): - parser = argparse.ArgumentParser( - description="Benchmark transformers serve (prefill & decode separately)", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog="""examples: - python benchmark_serve.py --model Qwen/Qwen2.5-7B-Instruct - python benchmark_serve.py --model "org/model-GGUF/file.gguf --processor org/model" - python benchmark_serve.py --url http://localhost:8000 --processor Qwen/Qwen2.5-7B-Instruct -""", - ) - parser.add_argument( - "--model", - type=str, - action="append", - dest="models", - help="Model spec (repeatable). For GGUF: 'gguf_id --processor tokenizer_id'", - ) - parser.add_argument( - "--processor", - type=str, - default=None, - help="Processor/tokenizer ID for --url mode (default: derived from model)", - ) - parser.add_argument("--port", type=int, default=8642, help="Server port") - parser.add_argument("--url", type=str, default=None, help="Connect to existing server (skip start/stop)") - parser.add_argument("--warmup", type=int, default=1, help="Warmup iterations (minimum 1)") - parser.add_argument("--iterations", type=int, default=3, help="Measurement iterations") - parser.add_argument("--pp", type=int, nargs="+", default=[256, 1024], help="Prefill token counts") - parser.add_argument("--tg", type=int, nargs="+", default=[128, 512], help="Decode token counts") - parser.add_argument( - "--tg-prefill", type=int, default=_TG_PREFILL_DEFAULT, help="Prefill size for decode tests (default: 512)" - ) - parser.add_argument( - "--attn-impl", - type=str, - nargs="+", - default=["sdpa", "eager", "flash_attention_2"], - help="Attention implementations to benchmark (default: sdpa eager flash_attention_2)", - ) - parser.add_argument( - "--compile", action="store_true", help="Enable static cache + torch.compile on the server for faster decode" - ) - parser.add_argument( - "--continuous-batching", action="store_true", help="Enable continuous batching with paged attention" - ) - parser.add_argument( - "--mode", - type=str, - choices=["bench", "chat"], - default="bench", - help="bench: greedy (temp=0). chat: sampling (do_sample=True, temp=0.7)", - ) - parser.add_argument( - "--endpoint", - type=str, - choices=["chat", "responses"], - default="responses", - help="API endpoint to benchmark (default: responses = /v1/responses)", - ) - parser.add_argument("--seed", type=int, default=42, help="Torch seed") - args = parser.parse_args() - - args.warmup = max(args.warmup, 1) - do_sample = args.mode == "chat" - mode_str = "chat (do_sample=True, temp=0.7)" if do_sample else "bench (greedy, temp=0)" - endpoint = args.endpoint - endpoint_path = "/v1/responses" if endpoint == "responses" else "/v1/chat/completions" - - if args.url: - # Against an existing server - base_url = args.url.rstrip("/") - tokenizer_id = args.processor or (args.models[0] if args.models else "Qwen/Qwen2.5-7B-Instruct") - print(f"Using server at {base_url}, endpoint={endpoint_path}, mode={mode_str}") - print(f"Loading tokenizer from {tokenizer_id}...") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) - - rows = [] - for pp in args.pp: - print(f" pp{pp}") - rows.append( - bench_pp( - base_url, - tokenizer, - pp, - args.warmup, - args.iterations, - args.seed, - do_sample=do_sample, - endpoint=endpoint, - ) - ) - for tg in args.tg: - print(f" tg{tg}") - rows.append( - bench_tg( - base_url, - tokenizer, - tg, - args.warmup, - args.iterations, - args.seed, - tg_prefill=args.tg_prefill, - do_sample=do_sample, - endpoint=endpoint, - ) - ) - print_table(rows) - - else: - # Start a fresh server per model, benchmark, stop - if not args.models: - args.models = ["Qwen/Qwen2.5-7B-Instruct"] - - for model_str in args.models: - spec = parse_model_spec(model_str) - # Reference texts from the first attn impl (bench mode only) - reference_texts = None - - for attn_impl in args.attn_impl: - print(f"\nStarting server for {spec['model']} (attn={attn_impl})...") - try: - server = start_server( - spec["model"], - args.port, - spec["processor"], - attn_implementation=attn_impl, - compile=args.compile, - continuous_batching=args.continuous_batching, - ) - except Exception as e: - print(f" ERROR: Failed to start server with attn={attn_impl}: {e}. Skipping.") - continue - - base_url = f"http://localhost:{args.port}" - if not wait_for_server(base_url): - print(" ERROR: Server did not become ready. Skipping.") - server.kill_server() - continue - - tokenizer = AutoTokenizer.from_pretrained(spec["tokenizer"]) - - # Warmup — use non-streaming when compile is on (first compile call takes ~30s, - # streaming warmup can hang waiting for SSE chunks during compilation) - if args.compile: - warmup_prompt = make_prompt(tokenizer, max(args.pp + [args.tg_prefill])) - gen_cfg = {"max_new_tokens": max(args.tg), "do_sample": False, "eos_token_id": -1} - payload = { - "messages": [{"role": "user", "content": warmup_prompt}], - "stream": False, - "seed": args.seed, - "generation_config": json.dumps(gen_cfg), - } - print(" compile warmup (non-streaming, may take ~30s)...") - requests.post(f"{base_url}/v1/chat/completions", json=payload, timeout=120) - else: - streaming_request( - base_url, [{"role": "user", "content": "hi"}], max_tokens=5, seed=args.seed, endpoint=endpoint - ) - - rows = [] - for pp in args.pp: - print(f" pp{pp}") - rows.append( - bench_pp( - base_url, - tokenizer, - pp, - args.warmup, - args.iterations, - args.seed, - do_sample=do_sample, - endpoint=endpoint, - ) - ) - for tg in args.tg: - print(f" tg{tg}") - rows.append( - bench_tg( - base_url, - tokenizer, - tg, - args.warmup, - args.iterations, - args.seed, - tg_prefill=args.tg_prefill, - do_sample=do_sample, - endpoint=endpoint, - ) - ) - - server.kill_server() - - # Build reference from first attn impl in greedy mode - if not do_sample and reference_texts is None: - reference_texts = {row["test"]: row["text"] for row in rows if row.get("text")} - # Pass reference_texts so the first impl shows "REF" in the ref column - print_table( - rows, - title=f"{spec['model']} | attn={attn_impl} ({mode_str})", - reference_texts=reference_texts if len(args.attn_impl) > 1 else None, - is_reference=True, - ) - else: - print_table( - rows, - title=f"{spec['model']} | attn={attn_impl} ({mode_str})", - reference_texts=reference_texts if not do_sample else None, - ) - - # Summary: check for mismatches across attn impls (bench mode only) - if not do_sample and reference_texts and len(args.attn_impl) > 1: - print_reference_summary(reference_texts, args.attn_impl[0]) - - -def print_reference_summary(reference_texts: dict[str, str], ref_impl: str) -> None: - """Print a summary noting that outputs are compared against the reference implementation.""" - print(f"Reference comparison: all decode outputs compared against '{ref_impl}'.") - print(" MATCH = identical text (greedy decoding is deterministic)") - print(" MISMATCH = text differs (FP divergence across attention kernels — check preview for correctness)") - print() - - -if __name__ == "__main__": - main() diff --git a/tests/cli/benchmark_serve_load.py b/tests/cli/benchmark_serve_load.py deleted file mode 100644 index bedb5bdc1e6c..000000000000 --- a/tests/cli/benchmark_serve_load.py +++ /dev/null @@ -1,516 +0,0 @@ -""" -Load test for `transformers serve` — measures throughput and latency under concurrent requests. - -Unlike benchmark_serve.py (single-user perf), this tests server capacity: -- How many tokens/sec can the server sustain under load? -- What's the latency distribution (p50/p90/p99) as concurrency increases? -- Does the server stay stable under pressure? - -Each --max-concurrency value sends that many requests simultaneously. - -Examples: - # Sweep concurrency levels - python tests/cli/benchmark_serve_load.py --model Qwen/Qwen2.5-7B-Instruct \\ - --max-concurrency 1 4 8 32 --continuous-batching - - # 500 concurrent non-streaming requests - python tests/cli/benchmark_serve_load.py --model Qwen/Qwen2.5-7B-Instruct \\ - --max-concurrency 500 --continuous-batching --no-stream - - # Against an existing server - python tests/cli/benchmark_serve_load.py --url http://localhost:8000 \\ - --processor Qwen/Qwen2.5-7B-Instruct --max-concurrency 1 4 8 -""" - -import argparse -import asyncio -import json -import os -import random -import statistics -import time - - -os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") - -import aiohttp - -from transformers import AutoTokenizer - - -# --------------------------------------------------------------------------- -# Prompt generation -# --------------------------------------------------------------------------- - -_FILLER = ( - "The quick brown fox jumps over the lazy dog. " - "Pack my box with five dozen liquor jugs. " - "How vexingly quick daft zebras jump. " - "Sphinx of black quartz, judge my vow. " -) * 200 - - -def make_prompt(tokenizer, num_tokens: int) -> str: - token_ids = tokenizer.encode(_FILLER, add_special_tokens=False)[:num_tokens] - return tokenizer.decode(token_ids) - - -def make_prompts(tokenizer, num_requests: int, prompt_tokens: int, variance: float = 0.2) -> list[str]: - """Generate a list of prompts with some length variance to simulate realistic traffic.""" - prompts = [] - for _ in range(num_requests): - # Vary prompt length by ±variance around the target - length = max(10, int(prompt_tokens * (1.0 + random.uniform(-variance, variance)))) - prompts.append(make_prompt(tokenizer, length)) - return prompts - - -# --------------------------------------------------------------------------- -# Request sender -# --------------------------------------------------------------------------- - - -async def send_request( - session: aiohttp.ClientSession, - base_url: str, - prompt: str, - max_new_tokens: int, - seed: int, - endpoint: str = "responses", - model: str | None = None, - stream: bool = True, -) -> dict: - """Send a single request and collect timing metrics.""" - # eos_token_id=-1 forces exact max_new_tokens generation (no early stopping) - # for consistent benchmarking - gen_cfg = {"max_new_tokens": max_new_tokens, "do_sample": False, "eos_token_id": -1} - - if endpoint == "responses": - url = f"{base_url}/v1/responses" - payload = { - "model": model, - "input": [{"role": "user", "content": prompt}], - "stream": stream, - "seed": seed, - "generation_config": json.dumps(gen_cfg), - } - else: - url = f"{base_url}/v1/chat/completions" - payload = { - "model": model, - "messages": [{"role": "user", "content": prompt}], - "stream": stream, - "seed": seed, - "generation_config": json.dumps(gen_cfg), - } - - t_start = time.perf_counter() - t_first_token = None - token_times = [] - text_chunks = [] - non_streaming_tokens = 0 - error = None - - try: - async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=600)) as resp: - if resp.status != 200: - error = f"HTTP {resp.status}: {await resp.text()}" - return _make_result(t_start, error=error) - - if not stream: - # Non-streaming: single JSON response — get token count from usage - body = await resp.json() - t_first_token = time.perf_counter() - if endpoint == "responses": - output_tokens = body.get("usage", {}).get("output_tokens", 0) - for item in body.get("output", []): - if item.get("type") == "message": - for part in item.get("content", []): - if part.get("type") == "output_text": - text_chunks.append(part.get("text", "")) - else: - output_tokens = body.get("usage", {}).get("completion_tokens", 0) - for choice in body.get("choices", []): - content = choice.get("message", {}).get("content", "") - if content: - text_chunks.append(content) - # Use server-reported token count instead of len(text_chunks) - non_streaming_tokens = output_tokens - token_times.append(t_first_token) - else: - # Streaming: parse SSE events - async for line in resp.content: - line = line.decode("utf-8").strip() - if not line or not line.startswith("data: "): - continue - data_str = line[len("data: ") :] - if data_str.strip() == "[DONE]": - break - try: - chunk = json.loads(data_str) - except json.JSONDecodeError: - continue - - # Extract token content based on endpoint format - has_content = False - if endpoint == "responses": - if chunk.get("type") == "response.output_text.delta": - delta = chunk.get("delta", "") - if delta: - text_chunks.append(delta) - has_content = True - elif chunk.get("type") == "response.completed": - break - else: - choices = chunk.get("choices", []) - if choices: - content = choices[0].get("delta", {}).get("content") - if content is not None and content != "": - text_chunks.append(content) - has_content = True - if choices[0].get("finish_reason") is not None: - break - - if has_content: - now = time.perf_counter() - token_times.append(now) - if t_first_token is None: - t_first_token = now - - except asyncio.TimeoutError: - error = "timeout" - except Exception as e: - error = str(e) - - output_token_count = non_streaming_tokens if not stream else None - return _make_result(t_start, t_first_token, token_times, text_chunks, error, output_token_count=output_token_count) - - -def _make_result(t_start, t_first_token=None, token_times=None, text_chunks=None, error=None, output_token_count=None): - t_end = time.perf_counter() - token_times = token_times or [] - text_chunks = text_chunks or [] - - # Inter-token latencies - itl = [] - for i in range(1, len(token_times)): - itl.append(token_times[i] - token_times[i - 1]) - - return { - "e2e": t_end - t_start, - "ttft": (t_first_token - t_start) if t_first_token else None, - "tpot": statistics.mean(itl) if itl else None, # time per output token - "itl": itl, - "output_tokens": output_token_count if output_token_count is not None else len(text_chunks), - "text": "".join(text_chunks), - "error": error, - } - - -# --------------------------------------------------------------------------- -# Load generators -# --------------------------------------------------------------------------- - - -async def run_concurrency_test( - base_url: str, - prompts: list[str], - max_new_tokens: int, - seed: int, - endpoint: str, - model: str | None = None, - stream: bool = True, -) -> list[dict]: - """Send all prompts concurrently and collect results.""" - connector = aiohttp.TCPConnector(limit=0) - timeout = aiohttp.ClientTimeout(total=600) - async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: - tasks = [ - send_request(session, base_url, p, max_new_tokens, seed, endpoint, model=model, stream=stream) - for p in prompts - ] - results = await asyncio.gather(*tasks) - - return list(results) - - -# --------------------------------------------------------------------------- -# Metrics -# --------------------------------------------------------------------------- - - -def compute_metrics(results: list[dict], duration: float) -> dict: - """Compute aggregate metrics from individual request results.""" - from collections import Counter - - successful = [r for r in results if r["error"] is None] - failed = [r for r in results if r["error"] is not None] - error_summary = Counter(r["error"] for r in failed) - - if not successful: - return {"error": "all requests failed", "failures": len(failed), "error_summary": error_summary} - - total_output_tokens = sum(r["output_tokens"] for r in successful) - - e2e_latencies = [r["e2e"] for r in successful] - ttfts = [r["ttft"] for r in successful if r["ttft"] is not None] - tpots = [r["tpot"] for r in successful if r["tpot"] is not None] - - # Flatten all inter-token latencies - all_itl = [] - for r in successful: - all_itl.extend(r["itl"]) - - def percentiles(values): - if not values: - return {} - values = sorted(values) - n = len(values) - return { - "mean": statistics.mean(values), - "median": statistics.median(values), - "p90": values[int(n * 0.9)], - "p99": values[min(int(n * 0.99), n - 1)], - "min": values[0], - "max": values[-1], - } - - return { - "total_requests": len(results), - "successful": len(successful), - "failed": len(failed), - "duration": duration, - "total_output_tokens": total_output_tokens, - "throughput_req_per_sec": len(successful) / duration, - "throughput_tok_per_sec": total_output_tokens / duration, - "e2e_latency": percentiles(e2e_latencies), - "ttft": percentiles(ttfts), - "tpot": percentiles(tpots), - "itl": percentiles(all_itl), - "error_summary": error_summary, - } - - -# --------------------------------------------------------------------------- -# Output -# --------------------------------------------------------------------------- - - -def format_ms(seconds): - if seconds is None: - return "N/A" - return f"{seconds * 1000:.1f}ms" - - -def print_metrics(metrics: dict, label: str): - print(f"\n{'=' * 70}") - print(f" {label}") - print(f"{'=' * 70}") - - if "error" in metrics: - print(f" ERROR: {metrics['error']}") - return - - print( - f" Requests: {metrics['successful']} ok / {metrics['failed']} failed / {metrics['total_requests']} total" - ) - if metrics.get("error_summary"): - for err, count in metrics["error_summary"].most_common(5): - print(f" - {count}x: {err}") - print(f" Duration: {metrics['duration']:.1f}s") - print( - f" Throughput: {metrics['throughput_req_per_sec']:.2f} req/s, {metrics['throughput_tok_per_sec']:.1f} tok/s" - ) - print(f" Tokens: {metrics['total_output_tokens']} total output") - print() - - headers = ["metric", "mean", "median", "p90", "p99", "min", "max"] - rows = [] - for name in ["e2e_latency", "ttft", "tpot", "itl"]: - p = metrics.get(name, {}) - if not p: - continue - rows.append( - [ - name.upper().replace("_", " "), - format_ms(p.get("mean")), - format_ms(p.get("median")), - format_ms(p.get("p90")), - format_ms(p.get("p99")), - format_ms(p.get("min")), - format_ms(p.get("max")), - ] - ) - - if rows: - widths = [max(len(h), *(len(r[i]) for r in rows)) for i, h in enumerate(headers)] - fmt = " " + " | ".join(f"{{:<{w}}}" for w in widths) - sep = " " + "-+-".join("-" * w for w in widths) - print(fmt.format(*headers)) - print(sep) - for row in rows: - print(fmt.format(*row)) - print() - - -# --------------------------------------------------------------------------- -# Server management -# --------------------------------------------------------------------------- - - -def wait_for_server(base_url: str, timeout: int = 120) -> bool: - import requests - - deadline = time.time() + timeout - while time.time() < deadline: - try: - if requests.get(f"{base_url}/health", timeout=2).status_code == 200: - return True - except Exception: - continue - time.sleep(1) - return False - - -def start_server( - model: str, - port: int, - compile: bool = False, - continuous_batching: bool = False, - attn_implementation: str | None = None, -): - from transformers.cli.serve_refactored import Serve - - kwargs = {"force_model": model, "port": port, "non_blocking": True, "log_level": "warning"} - if compile: - kwargs["compile"] = True - if continuous_batching: - kwargs["continuous_batching"] = True - if attn_implementation: - kwargs["attn_implementation"] = attn_implementation - return Serve(**kwargs) - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - - -async def async_main(args): - base_url = args.url if args.url else f"http://localhost:{args.port}" - server = None - - # Default to flash_attention_3 when using continuous batching - if args.continuous_batching and args.attn_impl is None: - args.attn_impl = "flash_attention_3" - - if not args.url: - print(f"Starting server for {args.model}...") - server = start_server( - args.model, - args.port, - compile=args.compile, - continuous_batching=args.continuous_batching, - attn_implementation=args.attn_impl, - ) - if not wait_for_server(base_url): - print("ERROR: Server did not start") - if server: - server.kill_server() - return - print("Server ready.") - - tokenizer_id = args.processor or args.model - tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) - - num_requests = max(args.max_concurrency) - prompts = make_prompts(tokenizer, num_requests, args.prompt_tokens, variance=args.prompt_variance) - print(f"Generated {num_requests} prompts (~{args.prompt_tokens} tokens each, ±{int(args.prompt_variance * 100)}%)") - print(f"Max new tokens per request: {args.max_new_tokens}") - print(f"Endpoint: /v1/{args.endpoint} ({'streaming' if args.stream else 'non-streaming'})") - - # Warmup — ramp up to full batch size so CUDA graphs are compiled for - # the batch shapes the scheduler will use under load (~100+ active requests). - warmup_size = min(200, num_requests) - warmup_prompts = prompts[:warmup_size] - print(f"Warming up ({args.warmup}x {warmup_size} requests)...") - for _ in range(args.warmup): - await run_concurrency_test( - base_url, - warmup_prompts, - args.max_new_tokens, - args.seed, - args.endpoint, - model=args.model, - stream=args.stream, - ) - print("Warmup done.") - - # Run tests — one round per concurrency level - for concurrency in args.max_concurrency: - test_prompts = prompts[:concurrency] - label = f"{concurrency} concurrent requests" - print(f"\nRunning: {label}") - t0 = time.perf_counter() - results = await run_concurrency_test( - base_url, - test_prompts, - args.max_new_tokens, - args.seed, - args.endpoint, - model=args.model, - stream=args.stream, - ) - duration = time.perf_counter() - t0 - metrics = compute_metrics(results, duration) - print_metrics(metrics, label) - - if server: - server.kill_server() - - -def main(): - parser = argparse.ArgumentParser( - description="Load test for transformers serve", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-7B-Instruct") - parser.add_argument("--processor", type=str, default=None) - parser.add_argument("--url", type=str, default=None, help="Existing server URL (skip start/stop)") - parser.add_argument("--port", type=int, default=8642) - parser.add_argument("--compile", action="store_true", help="Enable --compile on the server") - parser.add_argument("--continuous-batching", action="store_true", help="Enable continuous batching on the server") - parser.add_argument( - "--attn-impl", type=str, default=None, help="Attention implementation (e.g. flash_attention_3)" - ) - parser.add_argument("--endpoint", type=str, choices=["chat", "responses"], default="responses") - parser.add_argument("--no-stream", action="store_true", help="Use non-streaming requests") - - # Load parameters - parser.add_argument( - "--max-concurrency", - type=int, - nargs="+", - default=[1, 2, 4], - help="Number of concurrent requests to send (default: 1 2 4)", - ) - - # Prompt parameters - parser.add_argument("--prompt-tokens", type=int, default=256, help="Target prompt length in tokens (default: 256)") - parser.add_argument( - "--prompt-variance", type=float, default=0.2, help="Prompt length variance as fraction (default: 0.2 = ±20%%)" - ) - parser.add_argument( - "--max-new-tokens", type=int, default=128, help="Max tokens to generate per request (default: 128)" - ) - - parser.add_argument("--warmup", type=int, default=2, help="Warmup requests (default: 2)") - parser.add_argument("--seed", type=int, default=42) - args = parser.parse_args() - args.stream = not args.no_stream - - asyncio.run(async_main(args)) - - -if __name__ == "__main__": - main() From ef1c71019b8b5c4d05b51f17961cc5301643166d Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 30 Mar 2026 17:37:59 +0000 Subject: [PATCH 49/64] batch output --- .../continuous_batching/continuous_api.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index e2d985d60906..bf64c5e9bc2c 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -101,6 +101,27 @@ def deliver(self, output: GenerationOutput) -> None: else: self.output_queue.put(output) + def deliver_batch(self, outputs: list[GenerationOutput]) -> None: + """Route a batch of outputs, using a single ``call_soon_threadsafe`` to minimize cross-thread overhead. + + Outputs without a registered handler fall back to the shared ``output_queue``. + """ + callbacks: list[tuple[Callable, GenerationOutput]] = [] + loop = None + with self._lock: + for output in outputs: + entry = self.result_handlers.get(output.request_id) + if entry is not None: + callback, loop = entry + callbacks.append((callback, output)) + else: + self.output_queue.put(output) + if callbacks: + def _run_batch(batch=callbacks): + for cb, out in batch: + cb(out) + loop.call_soon_threadsafe(_run_batch) + # Continuous Batch Processor (Internal Logic) @attach_tracer() @@ -357,6 +378,7 @@ def update_batch(self) -> None: """Update request states based on generated tokens.""" requests_in_batch, new_tokens, logprobs = self.inputs_and_outputs.prepare_batch_update() current_logits_index = 0 + pending_outputs = [] for future_state in requests_in_batch: state = future_state.state # Early return if the request is finished @@ -387,11 +409,14 @@ def update_batch(self) -> None: self.scheduler.finish_request(state.request_id) self.scheduler.block_new_requests = False if state.streaming or state.status == RequestStatus.FINISHED: - self.output_router.deliver(state.to_generation_output()) + pending_outputs.append(state.to_generation_output()) # Otherwise, the request is still prefilling, but the prefill has been split elif state.status == RequestStatus.PREFILLING: self.cache.mark_shareable_blocks_as_complete(state, future_state.complete_blocks) + if pending_outputs: + self.output_router.deliver_batch(pending_outputs) + # If some requests need to be forked, we do it now copy_source, copy_destination = [], [] while self.scheduler._requests_to_fork: From caaab6e1432483d3c5d5f4d8d659bee29f509005 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 30 Mar 2026 17:38:55 +0000 Subject: [PATCH 50/64] style --- .../generation/continuous_batching/continuous_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index bf64c5e9bc2c..cc924cc80578 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -117,9 +117,11 @@ def deliver_batch(self, outputs: list[GenerationOutput]) -> None: else: self.output_queue.put(output) if callbacks: + def _run_batch(batch=callbacks): for cb, out in batch: cb(out) + loop.call_soon_threadsafe(_run_batch) From 4c1cd01dfc8ddf8a2c185d1966be2ebccaeb4208 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 30 Mar 2026 17:43:32 +0000 Subject: [PATCH 51/64] type --- .../generation/continuous_batching/continuous_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index cc924cc80578..e7cbdda7292e 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -116,7 +116,7 @@ def deliver_batch(self, outputs: list[GenerationOutput]) -> None: callbacks.append((callback, output)) else: self.output_queue.put(output) - if callbacks: + if callbacks and loop is not None: def _run_batch(batch=callbacks): for cb, out in batch: From 702ff749c8e40da8d19fe2b81550875bc133b430 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 30 Mar 2026 17:52:15 +0000 Subject: [PATCH 52/64] better tests --- src/transformers/cli/serve.py | 13 +- .../cli/serving/chat_completion.py | 20 +- src/transformers/cli/serving/response.py | 52 +- src/transformers/cli/serving/server.py | 12 +- src/transformers/cli/serving/transcription.py | 12 +- src/transformers/cli/transformers.py | 2 +- src/transformers/testing_utils.py | 8 + src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 + tests/cli/test_serve_refactored.py | 580 +++++------------- 10 files changed, 234 insertions(+), 471 deletions(-) diff --git a/src/transformers/cli/serve.py b/src/transformers/cli/serve.py index 20bac9b5d8f2..8f07684ac8bb 100644 --- a/src/transformers/cli/serve.py +++ b/src/transformers/cli/serve.py @@ -22,20 +22,11 @@ import typer from transformers.utils import logging -from transformers.utils.import_utils import ( - is_fastapi_available, - is_openai_available, - is_pydantic_available, - is_uvicorn_available, -) +from transformers.utils.import_utils import is_serve_available from .serving.utils import set_torch_seed -serve_dependencies_available = ( - is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available() -) - logger = logging.get_logger(__name__) @@ -70,7 +61,7 @@ def __init__( bool, typer.Option(hidden=True, help="Run server in a background thread. Used by tests.") ] = False, ) -> None: - if not serve_dependencies_available: + if not is_serve_available(): raise ImportError("Missing dependencies for serving. Install with `pip install transformers[serving]`") import uvicorn diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index c933b23980b9..10f22c230e47 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -22,15 +22,19 @@ from collections.abc import AsyncGenerator from typing import TYPE_CHECKING -from fastapi.responses import JSONResponse, StreamingResponse -from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall -from openai.types.chat.chat_completion import Choice -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall -from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk -from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming -from openai.types.completion_usage import CompletionUsage - from ...utils import logging +from ...utils.import_utils import is_serve_available + + +if is_serve_available(): + from fastapi.responses import JSONResponse, StreamingResponse + from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall + from openai.types.chat.chat_completion import Choice + from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall + from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk + from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming + from openai.types.completion_usage import CompletionUsage + from .utils import ( BaseGenerateManager, BaseHandler, diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 5a290bc082c7..ea6d6ebfb20d 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -22,31 +22,35 @@ from collections.abc import AsyncGenerator from typing import TYPE_CHECKING -from fastapi import HTTPException -from fastapi.responses import JSONResponse, StreamingResponse -from openai.types.responses import ( - Response, - ResponseCompletedEvent, - ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent, - ResponseCreatedEvent, - ResponseError, - ResponseErrorEvent, - ResponseFailedEvent, - ResponseFunctionCallArgumentsDoneEvent, - ResponseFunctionToolCall, - ResponseInProgressEvent, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponseOutputMessage, - ResponseOutputText, - ResponseTextDeltaEvent, - ResponseTextDoneEvent, -) -from openai.types.responses.response_create_params import ResponseCreateParamsStreaming -from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage - from ...utils import logging +from ...utils.import_utils import is_serve_available + + +if is_serve_available(): + from fastapi import HTTPException + from fastapi.responses import JSONResponse, StreamingResponse + from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseError, + ResponseErrorEvent, + ResponseFailedEvent, + ResponseFunctionCallArgumentsDoneEvent, + ResponseFunctionToolCall, + ResponseInProgressEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, + ResponseOutputText, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, + ) + from openai.types.responses.response_create_params import ResponseCreateParamsStreaming + from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage + from .utils import ( BaseGenerateManager, BaseHandler, diff --git a/src/transformers/cli/serving/server.py b/src/transformers/cli/serving/server.py index 587e50981b6e..ec0f287b36ee 100644 --- a/src/transformers/cli/serving/server.py +++ b/src/transformers/cli/serving/server.py @@ -18,11 +18,15 @@ import uuid from contextlib import asynccontextmanager -from fastapi import FastAPI, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, StreamingResponse - from ...utils import logging +from ...utils.import_utils import is_serve_available + + +if is_serve_available(): + from fastapi import FastAPI, Request + from fastapi.middleware.cors import CORSMiddleware + from fastapi.responses import JSONResponse, StreamingResponse + from .chat_completion import ChatCompletionHandler from .model_manager import ModelManager from .response import ResponseHandler diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py index 0d4a8adaa027..be021be414b9 100644 --- a/src/transformers/cli/serving/transcription.py +++ b/src/transformers/cli/serving/transcription.py @@ -18,11 +18,15 @@ import io from typing import TYPE_CHECKING -from fastapi import HTTPException, Request -from fastapi.responses import JSONResponse, StreamingResponse -from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase - from ...utils import logging +from ...utils.import_utils import is_serve_available + + +if is_serve_available(): + from fastapi import HTTPException, Request + from fastapi.responses import JSONResponse, StreamingResponse + from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase + from .model_manager import ModelManager from .utils import DirectStreamer, GenerateManager, GenerationState, _StreamError diff --git a/src/transformers/cli/transformers.py b/src/transformers/cli/transformers.py index 6ae79d99c74e..cefee1ca97c8 100644 --- a/src/transformers/cli/transformers.py +++ b/src/transformers/cli/transformers.py @@ -18,7 +18,7 @@ from transformers.cli.add_new_model_like import add_new_model_like from transformers.cli.chat import Chat from transformers.cli.download import download -from transformers.cli.serve_refactored import Serve +from transformers.cli.serve import Serve from transformers.cli.system import env, version diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index bdbf213412fe..4b021f2bf013 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -140,6 +140,7 @@ is_scipy_available, is_sentencepiece_available, is_seqio_available, + is_serve_available, is_spacy_available, is_speech_available, is_spqr_available, @@ -1497,6 +1498,13 @@ def require_openai(test_case): return unittest.skipUnless(is_openai_available(), "test requires openai")(test_case) +def require_serve(test_case): + """ + Decorator marking a test that requires the serving dependencies (fastapi, uvicorn, pydantic, openai). + """ + return unittest.skipUnless(is_serve_available(), "test requires serving dependencies")(test_case) + + def require_mistral_common(test_case): """ Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 7b8bfb80ec19..55297dd706d9 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -198,6 +198,7 @@ is_scipy_available, is_sentencepiece_available, is_seqio_available, + is_serve_available, is_sinq_available, is_sklearn_available, is_soundfile_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 6d9cef86e499..a6bed67b9d20 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -749,6 +749,11 @@ def is_openai_available() -> bool: return _is_package_available("openai")[0] +@lru_cache +def is_serve_available() -> bool: + return is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available() + + @lru_cache def is_pretty_midi_available() -> bool: return _is_package_available("pretty_midi")[0] diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py index cae52ea4595d..63503cbe0bf9 100644 --- a/tests/cli/test_serve_refactored.py +++ b/tests/cli/test_serve_refactored.py @@ -12,36 +12,68 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Tests for the refactored serving layer (Phase 1: chat completions). +Tests for the serving layer. -Run: pytest tests/cli/test_serve_refactored.py -x -v -Integration tests (need GPU): RUN_SLOW=1 pytest tests/cli/test_serve_refactored.py -x -v -k "Integration" """ import asyncio -import io import json import os +import socket import time import unittest from unittest.mock import MagicMock -from transformers.testing_utils import require_openai -from transformers.utils.import_utils import is_openai_available, is_vision_available +import httpx + +from transformers.cli.serve import Serve +from transformers.cli.serving.chat_completion import ChatCompletionHandler +from transformers.cli.serving.model_manager import ModelManager, TimedModel +from transformers.cli.serving.response import ResponseHandler, compute_usage +from transformers.cli.serving.server import build_server +from transformers.cli.serving.transcription import TranscriptionHandler +from transformers.cli.serving.utils import ( + BaseHandler, + GenerationState, + Modality, + ToolCallParser, + detect_tool_format, +) +from transformers.testing_utils import require_serve, require_torch_accelerator, require_vision, slow +from transformers.utils.import_utils import is_serve_available + + +if is_serve_available(): + from fastapi import HTTPException + from openai import OpenAI + from openai.types.responses import Response, ResponseCreatedEvent -if is_openai_available(): - from openai import OpenAI +def _find_free_port() -> int: + """Return a free TCP port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + return s.getsockname()[1] -run_slow = os.environ.get("RUN_SLOW", "0") == "1" +def _start_serve(**kwargs) -> tuple["Serve", int]: + """Start a non-blocking Serve instance on a free port and wait until healthy. -# --------------------------------------------------------------------------- -# 1. CLI tests — verify CLI args reach uvicorn -# --------------------------------------------------------------------------- + Returns ``(serve, port)``. + """ + port = _find_free_port() + serve = Serve(port=port, non_blocking=True, **kwargs) + for _ in range(30): + try: + if httpx.get(f"http://localhost:{port}/health", timeout=2).status_code == 200: + return serve, port + except Exception: # noqa: S110 + pass + time.sleep(1) + raise RuntimeError(f"Server on port {port} did not become healthy in time") -@require_openai +@require_serve def test_host_port_blocking(cli): """CLI args --host and --port are passed to uvicorn.Config, and server.run() is called.""" from unittest.mock import Mock, patch @@ -63,15 +95,8 @@ def test_host_port_blocking(cli): server_instance.run.assert_called_once() -# --------------------------------------------------------------------------- -# 2. Unit tests — message parsing -# --------------------------------------------------------------------------- - - class TestProcessorInputsFromMessages(unittest.TestCase): def test_llm_string_content(self): - from transformers.cli.serving.utils import BaseHandler, Modality - get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages messages = [{"role": "user", "content": "Hello"}] @@ -79,8 +104,6 @@ def test_llm_string_content(self): self.assertEqual(result, [{"role": "user", "content": "Hello"}]) def test_llm_list_content_text_only(self): - from transformers.cli.serving.utils import BaseHandler, Modality - get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages messages = [{"role": "user", "content": [{"type": "text", "text": "A"}, {"type": "text", "text": "B"}]}] @@ -88,8 +111,6 @@ def test_llm_list_content_text_only(self): self.assertEqual(result, [{"role": "user", "content": "A B"}]) def test_vlm_string_content_wrapped(self): - from transformers.cli.serving.utils import BaseHandler, Modality - get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages messages = [{"role": "user", "content": "Hello"}] @@ -97,8 +118,6 @@ def test_vlm_string_content_wrapped(self): self.assertEqual(result, [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]) def test_vlm_text_and_image_url(self): - from transformers.cli.serving.utils import BaseHandler, Modality - get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages messages = [ @@ -117,7 +136,6 @@ def test_vlm_text_and_image_url(self): def test_llm_multi_turn_conversation(self): """Multi-turn conversation with string content should pass through as-is.""" - from transformers.cli.serving.utils import BaseHandler, Modality get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages @@ -134,7 +152,6 @@ def test_llm_multi_turn_conversation(self): def test_llm_list_content_with_type(self): """LLM messages with typed content list should extract text and join.""" - from transformers.cli.serving.utils import BaseHandler, Modality get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages @@ -144,12 +161,9 @@ def test_llm_list_content_with_type(self): result = get_processor_inputs_from_messages(messages, Modality.LLM) self.assertEqual(result[0]["content"], "Hello world") - @unittest.skipUnless(is_vision_available(), "Requires PIL") + @require_vision def test_vlm_base64_image_creates_temp_file(self): """Base64 image URLs should be decoded and saved to a temp file.""" - import os - - from transformers.cli.serving.utils import BaseHandler, Modality get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages @@ -174,7 +188,6 @@ def test_vlm_base64_image_creates_temp_file(self): def test_vlm_multi_turn(self): """VLM multi-turn: string content should be wrapped in text type.""" - from transformers.cli.serving.utils import BaseHandler, Modality get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages @@ -197,8 +210,6 @@ def test_lists_only_generative_models(self): from huggingface_hub import hf_hub_download - from transformers.cli.serving.model_manager import ModelManager - with tempfile.TemporaryDirectory() as cache_dir: # Download config.json for a few models hf_hub_download("Qwen/Qwen2.5-0.5B-Instruct", "config.json", cache_dir=cache_dir) @@ -211,17 +222,9 @@ def test_lists_only_generative_models(self): self.assertNotIn("google-bert/bert-base-cased", model_ids) -# --------------------------------------------------------------------------- -# 2. Unit tests — generation config mapping -# --------------------------------------------------------------------------- - - -@require_openai +@require_serve class TestBuildGenerationConfig(unittest.TestCase): def _make_handler(self): - from transformers.cli.serving.chat_completion import ChatCompletionHandler - from transformers.cli.serving.utils import GenerationState - return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_max_tokens(self): @@ -289,17 +292,9 @@ def test_user_max_tokens_overrides_default(self): self.assertEqual(result.max_new_tokens, 50) -# --------------------------------------------------------------------------- -# 3. Unit tests — validation -# --------------------------------------------------------------------------- - - -@require_openai +@require_serve class TestValidation(unittest.TestCase): def _make_handler(self): - from transformers.cli.serving.chat_completion import ChatCompletionHandler - from transformers.cli.serving.utils import GenerationState - return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_valid_request_passes(self): @@ -308,92 +303,64 @@ def test_valid_request_passes(self): handler._validate_request({"model": "x", "messages": [{"role": "user", "content": "hi"}], "stream": True}) def test_unexpected_keys_rejected(self): - from fastapi import HTTPException - handler = self._make_handler() with self.assertRaises(HTTPException) as ctx: handler._validate_request({"model": "x", "messages": [], "bogus_field": True}) self.assertEqual(ctx.exception.status_code, 422) self.assertIn("bogus_field", ctx.exception.detail) - def test_unsupported_fields_rejected(self): - from fastapi import HTTPException - + def test_unsupported_fields_warns(self): handler = self._make_handler() - with self.assertRaises(HTTPException) as ctx: + with self.assertLogs("transformers", level="WARNING") as cm: handler._validate_request({"model": "x", "messages": [], "audio": {}}) - self.assertEqual(ctx.exception.status_code, 422) - self.assertIn("audio", ctx.exception.detail) - - -# --------------------------------------------------------------------------- -# 4. Unit tests — model manager -# --------------------------------------------------------------------------- + self.assertTrue(any("audio" in msg for msg in cm.output)) class TestModelManager(unittest.TestCase): def test_process_model_name_adds_main(self): - from transformers.cli.serving.model_manager import ModelManager - self.assertEqual(ModelManager.process_model_name("org/model"), "org/model@main") def test_process_model_name_preserves_revision(self): - from transformers.cli.serving.model_manager import ModelManager - self.assertEqual(ModelManager.process_model_name("org/model@dev"), "org/model@dev") def test_quantization_config_4bit(self): - from transformers.cli.serving.model_manager import ModelManager - mm = ModelManager(quantization="bnb-4bit") cfg = mm.get_quantization_config() self.assertTrue(cfg.load_in_4bit) def test_quantization_config_8bit(self): - from transformers.cli.serving.model_manager import ModelManager - mm = ModelManager(quantization="bnb-8bit") cfg = mm.get_quantization_config() self.assertTrue(cfg.load_in_8bit) def test_quantization_config_none(self): - from transformers.cli.serving.model_manager import ModelManager - mm = ModelManager() self.assertIsNone(mm.get_quantization_config()) class TestTimedModel(unittest.TestCase): def test_delete_model(self): - from transformers.cli.serving.model_manager import TimedModel - mock_model = MagicMock() - timed = TimedModel(mock_model, timeout_seconds=9999, processor=MagicMock()) - self.assertFalse(timed.is_deleted()) + deleted = [] + timed = TimedModel( + mock_model, timeout_seconds=9999, processor=MagicMock(), on_unload=lambda: deleted.append(True) + ) + self.assertIsNotNone(timed.model) timed.delete_model() - self.assertTrue(timed.is_deleted()) + self.assertIsNone(timed.model) + self.assertEqual(len(deleted), 1) def test_timeout_zero_no_delete(self): - from transformers.cli.serving.model_manager import TimedModel - mock_model = MagicMock() timed = TimedModel(mock_model, timeout_seconds=0, processor=MagicMock()) timed._timeout_reached() - self.assertFalse(timed.is_deleted()) + self.assertIsNotNone(timed.model) timed._timer.cancel() -# --------------------------------------------------------------------------- -# 5. Unit tests — SSE formatting -# --------------------------------------------------------------------------- - - -@require_openai +@require_serve class TestChunkSSE(unittest.TestCase): def _make_handler(self): - from transformers.cli.serving.chat_completion import ChatCompletionHandler - from transformers.cli.serving.utils import GenerationState - return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_build_chunk_sse_content(self): @@ -418,61 +385,42 @@ def test_build_chunk_sse_finish_reason(self): self.assertEqual(parsed["choices"][0]["finish_reason"], "stop") def test_chunk_to_sse_string_passthrough(self): - from transformers.cli.serving.utils import BaseHandler - result = BaseHandler.chunk_to_sse("data: already formatted\n\n") self.assertEqual(result, "data: already formatted\n\n") def test_chunk_to_sse_wraps_plain_string(self): - from transformers.cli.serving.utils import BaseHandler - result = BaseHandler.chunk_to_sse("hello") self.assertEqual(result, "data: hello\n\n") -# --------------------------------------------------------------------------- -# 6. Unit tests — tool parser -# --------------------------------------------------------------------------- - - QWEN_TOOL_FORMAT = {"start": "", "end": ""} -@require_openai +@require_serve class TestToolParser(unittest.TestCase): def test_detect_tool_format_qwen(self): - from transformers.cli.serving.utils import detect_tool_format - model = MagicMock() model.config.architectures = ["Qwen2ForCausalLM"] fmt = detect_tool_format(model) self.assertEqual(fmt, QWEN_TOOL_FORMAT) def test_detect_tool_format_unsupported(self): - from transformers.cli.serving.utils import detect_tool_format - model = MagicMock() model.config.architectures = ["LlamaForCausalLM"] self.assertIsNone(detect_tool_format(model)) def test_parser_start_token(self): - from transformers.cli.serving.utils import ToolCallParser - parser = ToolCallParser(QWEN_TOOL_FORMAT) result = parser.feed("") self.assertIs(result, ToolCallParser.CONSUMED) def test_parser_end_token(self): - from transformers.cli.serving.utils import ToolCallParser - parser = ToolCallParser(QWEN_TOOL_FORMAT) parser.feed("") result = parser.feed("") self.assertIs(result, ToolCallParser.CONSUMED) def test_parser_buffers_until_end(self): - from transformers.cli.serving.utils import ToolCallParser - parser = ToolCallParser(QWEN_TOOL_FORMAT) parser.feed("") # Intermediate tokens are buffered @@ -484,15 +432,12 @@ def test_parser_buffers_until_end(self): self.assertEqual(result["name"], "my_tool") def test_parser_normal_text_returns_none(self): - from transformers.cli.serving.utils import ToolCallParser - parser = ToolCallParser(QWEN_TOOL_FORMAT) result = parser.feed("Hello world") self.assertIsNone(result) def test_parser_full_flow(self): """Simulate a complete tool call token sequence.""" - from transformers.cli.serving.utils import ToolCallParser parser = ToolCallParser(QWEN_TOOL_FORMAT) tool_calls = [] @@ -517,7 +462,6 @@ def test_parser_full_flow(self): def test_parse_tool_calls_from_text(self): """Non-streaming tool call parsing from complete text.""" - from transformers.cli.serving.utils import ToolCallParser text = '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n' calls = ToolCallParser.parse(text, QWEN_TOOL_FORMAT) @@ -528,14 +472,12 @@ def test_parse_tool_calls_from_text(self): def test_parse_tool_calls_no_tool_call(self): """Non-streaming: normal text returns None.""" - from transformers.cli.serving.utils import ToolCallParser calls = ToolCallParser.parse("Hello, how can I help?", QWEN_TOOL_FORMAT) self.assertIsNone(calls) def test_parse_multiple_tool_calls(self): """Non-streaming: multiple tool calls in one response.""" - from transformers.cli.serving.utils import ToolCallParser text = ( '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n\n' @@ -551,7 +493,6 @@ def test_parse_multiple_tool_calls(self): def test_feed_multiple_tool_calls(self): """Streaming: multiple tool calls emitted sequentially.""" - from transformers.cli.serving.utils import ToolCallParser parser = ToolCallParser(QWEN_TOOL_FORMAT) tool_calls = [] @@ -576,21 +517,10 @@ def test_feed_multiple_tool_calls(self): self.assertIn("London", tool_calls[1]["arguments"]) -# --------------------------------------------------------------------------- -# 7. App-level tests with ASGI test client (no real model) -# --------------------------------------------------------------------------- - - -@require_openai +@require_serve class TestAppRoutes(unittest.TestCase): @classmethod def setUpClass(cls): - from transformers.cli.serving.chat_completion import ChatCompletionHandler - from transformers.cli.serving.model_manager import ModelManager - from transformers.cli.serving.response import ResponseHandler - from transformers.cli.serving.server import build_server - from transformers.cli.serving.transcription import TranscriptionHandler - cls.model_manager = MagicMock(spec=ModelManager) cls.model_manager.get_gen_models.return_value = [ {"id": "test/model", "owned_by": "test", "object": "model", "created": 0} @@ -599,86 +529,48 @@ def setUpClass(cls): cls.response_handler = MagicMock(spec=ResponseHandler) cls.transcription_handler = MagicMock(spec=TranscriptionHandler) cls.app = build_server(cls.model_manager, cls.chat_handler, cls.response_handler, cls.transcription_handler) + cls.transport = httpx.ASGITransport(app=cls.app) - def _run(self, coro): - return asyncio.get_event_loop().run_until_complete(coro) + async def _request(self, method: str, path: str, **kwargs) -> httpx.Response: + async with httpx.AsyncClient(transport=self.transport, base_url="http://test") as c: + return await c.request(method, path, **kwargs) def test_health(self): - from httpx import ASGITransport, AsyncClient - - async def _test(): - async with AsyncClient(transport=ASGITransport(app=self.app), base_url="http://test") as c: - resp = await c.get("/health") - self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.json(), {"status": "ok"}) - - self._run(_test()) + resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/health")) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json(), {"status": "ok"}) def test_models_list(self): - from httpx import ASGITransport, AsyncClient - - async def _test(): - async with AsyncClient(transport=ASGITransport(app=self.app), base_url="http://test") as c: - resp = await c.get("/v1/models") - self.assertEqual(resp.status_code, 200) - data = resp.json() - self.assertEqual(data["object"], "list") - self.assertEqual(len(data["data"]), 1) - - self._run(_test()) + resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/v1/models")) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["object"], "list") + self.assertEqual(len(data["data"]), 1) def test_request_id_generated(self): - from httpx import ASGITransport, AsyncClient - - async def _test(): - async with AsyncClient(transport=ASGITransport(app=self.app), base_url="http://test") as c: - resp = await c.get("/health") - self.assertIn("x-request-id", resp.headers) - self.assertEqual(len(resp.headers["x-request-id"]), 36) # UUID length - - self._run(_test()) + resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/health")) + self.assertIn("x-request-id", resp.headers) + self.assertEqual(len(resp.headers["x-request-id"]), 36) # UUID length def test_request_id_passthrough(self): - from httpx import ASGITransport, AsyncClient - - async def _test(): - async with AsyncClient(transport=ASGITransport(app=self.app), base_url="http://test") as c: - resp = await c.get("/health", headers={"x-request-id": "my-id"}) - self.assertEqual(resp.headers["x-request-id"], "my-id") - - self._run(_test()) - - -# --------------------------------------------------------------------------- -# 7. Integration tests (need GPU + model) -# Only test what requires a real model. Everything else is above with mocks. -# --------------------------------------------------------------------------- + resp = asyncio.get_event_loop().run_until_complete( + self._request("GET", "/health", headers={"x-request-id": "my-id"}) + ) + self.assertEqual(resp.headers["x-request-id"], "my-id") -@unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") -@require_openai +@slow +@require_serve class TestChatCompletion(unittest.TestCase): """Integration tests for /v1/chat/completions with a real model.""" MODEL = "Qwen/Qwen2.5-0.5B-Instruct" - PORT = 8877 @classmethod def setUpClass(cls): - from transformers.cli.serve_refactored import Serve - - cls.serve = Serve(port=cls.PORT, non_blocking=True) - import requests - - for _ in range(30): - try: - if requests.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: - break - except Exception: - continue - time.sleep(2) - - cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") @classmethod def tearDownClass(cls): @@ -896,7 +788,7 @@ def test_concurrent_non_streaming(self): results = [None, None] def request_in_thread(index): - client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") results[index] = client.chat.completions.create(model=self.MODEL, messages=prompts[index]) with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: @@ -921,7 +813,7 @@ def test_concurrent_streaming(self): results = [None, None] def stream_in_thread(index): - client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") text = "" for chunk in client.chat.completions.create(model=self.MODEL, messages=prompts[index], stream=True): if chunk.choices[0].delta.content: @@ -940,17 +832,16 @@ def stream_in_thread(index): def test_request_cancellation(self): """Closing a stream early doesn't crash and the server stays healthy.""" - import requests as req - with req.post( - f"http://localhost:{self.PORT}/v1/chat/completions", + with httpx.stream( + "POST", + f"{self.base_url}/v1/chat/completions", json={ "model": self.MODEL, "stream": True, "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}], "max_tokens": 500, }, - stream=True, timeout=30, ) as resp: self.assertEqual(resp.status_code, 200) @@ -969,17 +860,9 @@ def test_request_cancellation(self): self.assertIsNotNone(resp.choices[0].message.content) -# --------------------------------------------------------------------------- -# 8. Unit tests — Response handler -# --------------------------------------------------------------------------- - - -@require_openai +@require_serve class TestResponseInputConversion(unittest.TestCase): def _make_handler(self): - from transformers.cli.serving.response import ResponseHandler - from transformers.cli.serving.utils import GenerationState - return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_string_input(self): @@ -1023,21 +906,16 @@ def test_dict_input(self): self.assertEqual(msgs, [{"role": "user", "content": "Test"}]) -@require_openai +@require_serve class TestResponseValidation(unittest.TestCase): def _make_handler(self): - from transformers.cli.serving.response import ResponseHandler - from transformers.cli.serving.utils import GenerationState - return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) - def test_unsupported_fields_rejected(self): - from fastapi import HTTPException - + def test_unsupported_fields_warns(self): handler = self._make_handler() - with self.assertRaises(HTTPException) as ctx: + with self.assertLogs("transformers", level="WARNING") as cm: handler._validate_request({"model": "x", "input": "hi", "previous_response_id": "abc"}) - self.assertEqual(ctx.exception.status_code, 422) + self.assertTrue(any("previous_response_id" in msg for msg in cm.output)) def test_valid_request_passes(self): handler = self._make_handler() @@ -1045,12 +923,9 @@ def test_valid_request_passes(self): handler._validate_request({"model": "x", "input": "hi"}) -@require_openai +@require_serve class TestResponseGenerationConfig(unittest.TestCase): def _make_handler(self): - from transformers.cli.serving.response import ResponseHandler - from transformers.cli.serving.utils import GenerationState - return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) def test_max_output_tokens(self): @@ -1066,11 +941,9 @@ def test_default_bumps_short_max_new_tokens(self): self.assertEqual(result.max_new_tokens, 1024) -@require_openai +@require_serve class TestResponseUsage(unittest.TestCase): def testcompute_usage(self): - from transformers.cli.serving.response import compute_usage - usage = compute_usage(input_tokens=100, output_tokens=50) self.assertEqual(usage.input_tokens, 100) self.assertEqual(usage.output_tokens, 50) @@ -1080,9 +953,6 @@ def testcompute_usage(self): def test_usage_in_completed_response(self): """Usage should serialize correctly inside a Response.""" - from openai.types.responses import Response - - from transformers.cli.serving.response import compute_usage usage = compute_usage(10, 5) response = Response( @@ -1103,13 +973,9 @@ def test_usage_in_completed_response(self): self.assertEqual(dumped["usage"]["total_tokens"], 15) -@require_openai +@require_serve class TestResponseSSEFormat(unittest.TestCase): def test_sse_format(self): - from openai.types.responses import Response, ResponseCreatedEvent - - from transformers.cli.serving.utils import BaseHandler - event = ResponseCreatedEvent( type="response.created", sequence_number=0, @@ -1134,35 +1000,18 @@ def test_sse_format(self): self.assertEqual(parsed["response"]["status"], "queued") -# --------------------------------------------------------------------------- -# 9. Integration tests — Responses API (need GPU + model) -# --------------------------------------------------------------------------- - - -@unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") -@require_openai +@slow +@require_serve class TestResponsesIntegration(unittest.TestCase): """Integration tests for /v1/responses with a real model.""" MODEL = "Qwen/Qwen2.5-0.5B-Instruct" - PORT = 8878 @classmethod def setUpClass(cls): - from transformers.cli.serve_refactored import Serve - - cls.serve = Serve(port=cls.PORT, non_blocking=True) - import requests - - for _ in range(30): - try: - if requests.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: - break - except Exception: - continue - time.sleep(2) - - cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") @classmethod def tearDownClass(cls): @@ -1346,7 +1195,7 @@ def test_concurrent_non_streaming(self): results = [None, None] def request_in_thread(index): - client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False) with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: @@ -1368,7 +1217,7 @@ def test_concurrent_streaming(self): results = [None, None] def stream_in_thread(index): - client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") results[index] = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True)) with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: @@ -1384,44 +1233,27 @@ def stream_in_thread(index): self.assertIn("response.completed", types, f"Request {i} missing completed event") -# --------------------------------------------------------------------------- -# 10. Integration tests — /load_model endpoint (need GPU + model) -# --------------------------------------------------------------------------- - - def _parse_sse_events(response): - """Parse SSE lines from a streaming requests response into a list of dicts.""" + """Parse SSE lines from a streaming httpx response into a list of dicts.""" events = [] - for line in response.iter_lines(decode_unicode=True): + for line in response.iter_lines(): if not line or not line.startswith("data: "): continue events.append(json.loads(line[len("data: ") :])) return events -@unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") -@require_openai +@slow +@require_serve class TestLoadModel(unittest.TestCase): """Integration tests for POST /load_model SSE endpoint.""" MODEL = "Qwen/Qwen2.5-0.5B-Instruct" - PORT = 8879 @classmethod def setUpClass(cls): - import requests as req - - from transformers.cli.serve_refactored import Serve - - cls.serve = Serve(port=cls.PORT, non_blocking=True) - for _ in range(30): - try: - if req.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: - break - except Exception: - continue - time.sleep(2) - cls.base_url = f"http://localhost:{cls.PORT}" + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" @classmethod def tearDownClass(cls): @@ -1432,10 +1264,8 @@ def setUp(self): self.serve.reset_loaded_models() def _load_model(self, model: str): - import requests as req - - resp = req.post(f"{self.base_url}/load_model", json={"model": model}, stream=True, timeout=120) - events = _parse_sse_events(resp) + with httpx.stream("POST", f"{self.base_url}/load_model", json={"model": model}, timeout=120) as resp: + events = _parse_sse_events(resp) return resp, events def test_load_model_fresh(self): @@ -1479,9 +1309,8 @@ def test_load_model_error(self): def test_load_model_missing_field(self): """POST /load_model with no model field returns 422.""" - import requests as req - response = req.post(f"{self.base_url}/load_model", json={}, timeout=30) + response = httpx.post(f"{self.base_url}/load_model", json={}, timeout=30) self.assertEqual(response.status_code, 422) def test_load_model_event_schema(self): @@ -1527,11 +1356,9 @@ def test_concurrent_load_same_model(self): results = [None, None] def load_in_thread(index): - import requests as req - - resp = req.post(f"{self.base_url}/load_model", json={"model": self.MODEL}, stream=True, timeout=120) - events = _parse_sse_events(resp) - results[index] = (resp.status_code, events) + with httpx.stream("POST", f"{self.base_url}/load_model", json={"model": self.MODEL}, timeout=120) as resp: + events = _parse_sse_events(resp) + results[index] = (resp.status_code, events) with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: futures = [pool.submit(load_in_thread, i) for i in range(2)] @@ -1595,7 +1422,7 @@ def test_load_model_usable_after_load(self): """After /load_model completes, the model should be usable for inference.""" self._load_model(self.MODEL) - client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") resp = client.chat.completions.create( model=self.MODEL, messages=[{"role": "user", "content": "Say hi"}], @@ -1622,7 +1449,7 @@ def test_concurrent_non_streaming(self): results = [None, None] def request_in_thread(index): - client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False) with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: @@ -1644,7 +1471,7 @@ def test_concurrent_streaming(self): results = [None, None] def stream_in_thread(index): - client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") events = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True)) results[index] = events @@ -1661,60 +1488,23 @@ def stream_in_thread(index): self.assertIn("response.completed", types, f"Request {i} missing completed event") -# --------------------------------------------------------------------------- -# 11. Integration tests — Transcription API (need GPU + model + librosa) -# --------------------------------------------------------------------------- - - -def _make_test_wav(duration: float = 2.0, sample_rate: int = 16000) -> bytes: - """Create a simple WAV file with a sine wave. Returns raw bytes.""" - import wave - - import numpy as np - - t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) - audio = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) - buf = io.BytesIO() - with wave.open(buf, "w") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - wf.setframerate(sample_rate) - wf.writeframes(audio.tobytes()) - return buf.getvalue() - - -# --------------------------------------------------------------------------- -# 12. Integration tests — VLM support (need GPU + model) -# --------------------------------------------------------------------------- - - # Real image URL for VLM tests (person + dog on a beach) _DOG_IMAGE_URL = "https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2-VL/demo_small.jpg" -@unittest.skipUnless(run_slow and is_vision_available(), "Set RUN_SLOW=1 and install torchvision + PIL") -@require_openai +@slow +@require_vision +@require_serve class TestVLM(unittest.TestCase): """Integration tests for VLM (vision-language model) support. Requires torchvision.""" MODEL = "HuggingFaceTB/SmolVLM-256M-Instruct" - PORT = 8881 @classmethod def setUpClass(cls): - import requests as req - - from transformers.cli.serve_refactored import Serve - - cls.serve = Serve(port=cls.PORT, non_blocking=True) - for _ in range(60): - try: - if req.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: - break - except Exception: - continue - time.sleep(2) - cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") @classmethod def tearDownClass(cls): @@ -1766,29 +1556,17 @@ def test_responses_with_image(self): ) -@unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") -@require_openai +@slow +@require_serve class TestTranscription(unittest.TestCase): """Integration tests for POST /v1/audio/transcriptions with whisper-tiny.""" MODEL = "openai/whisper-tiny" - PORT = 8880 @classmethod def setUpClass(cls): - import requests as req - - from transformers.cli.serve_refactored import Serve - - cls.serve = Serve(port=cls.PORT, non_blocking=True) - for _ in range(30): - try: - if req.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: - break - except Exception: - continue - time.sleep(2) - cls.base_url = f"http://localhost:{cls.PORT}" + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" @classmethod def tearDownClass(cls): @@ -1807,10 +1585,9 @@ def _get_audio_bytes(cls): def test_transcription_returns_text(self): """POST /v1/audio/transcriptions with real speech returns meaningful transcription.""" - import requests as req audio_bytes = self._get_audio_bytes() - resp = req.post( + resp = httpx.post( f"{self.base_url}/v1/audio/transcriptions", files={"file": ("mlk.flac", audio_bytes, "audio/flac")}, data={"model": self.MODEL}, @@ -1826,7 +1603,7 @@ def test_transcription_returns_text(self): def test_transcription_openai_client(self): """Transcription should work via the OpenAI Python client.""" audio_bytes = self._get_audio_bytes() - client = OpenAI(base_url=f"http://localhost:{self.PORT}/v1", api_key="unused") + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") result = client.audio.transcriptions.create( model=self.MODEL, file=("mlk.flac", audio_bytes), @@ -1836,22 +1613,21 @@ def test_transcription_openai_client(self): def test_transcription_streaming(self): """Streaming transcription should yield text chunks via SSE.""" - import requests as req audio_bytes = self._get_audio_bytes() - resp = req.post( + with httpx.stream( + "POST", f"{self.base_url}/v1/audio/transcriptions", files={"file": ("mlk.flac", audio_bytes, "audio/flac")}, data={"model": self.MODEL, "stream": "true"}, - stream=True, timeout=120, - ) - self.assertEqual(resp.status_code, 200) + ) as resp: + self.assertEqual(resp.status_code, 200) - chunks = [] - for line in resp.iter_lines(decode_unicode=True): - if line and line.startswith("data: "): - chunks.append(line[len("data: ") :]) + chunks = [] + for line in resp.iter_lines(): + if line and line.startswith("data: "): + chunks.append(line[len("data: ") :]) self.assertGreater(len(chunks), 0, "No streaming chunks received") full_text = "".join(chunks) @@ -1859,9 +1635,8 @@ def test_transcription_streaming(self): def test_transcription_missing_file(self): """POST without a file should fail.""" - import requests as req - resp = req.post( + resp = httpx.post( f"{self.base_url}/v1/audio/transcriptions", data={"model": self.MODEL}, timeout=30, @@ -1869,43 +1644,25 @@ def test_transcription_missing_file(self): self.assertNotEqual(resp.status_code, 200) -# --------------------------------------------------------------------------- -# Continuous Batching integration tests -# --------------------------------------------------------------------------- - - -@unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") -@require_openai +@slow +@require_serve +@require_torch_accelerator class TestContinuousBatchingChatCompletion(unittest.TestCase): """Integration tests for /v1/chat/completions with continuous batching.""" MODEL = "Qwen/Qwen2.5-0.5B-Instruct" - PORT = 8891 @classmethod def setUpClass(cls): - from transformers.cli.serve_refactored import Serve - - cls.serve = Serve( + cls.serve, port = _start_serve( force_model=cls.MODEL, - port=cls.PORT, device="cuda:0", continuous_batching=True, attn_implementation="sdpa", default_seed=42, - non_blocking=True, ) - import requests - - for _ in range(30): - try: - if requests.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: - break - except Exception: - continue - time.sleep(2) - - cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") @classmethod def tearDownClass(cls): @@ -1952,20 +1709,19 @@ def test_multi_turn(self): def test_request_cancellation(self): """Opening a stream and closing it early triggers CB cancellation.""" - import requests as req request_id = "test-cb-cancel" # Open a streaming request and close after a few chunks - with req.post( - f"http://localhost:{self.PORT}/v1/chat/completions", + with httpx.stream( + "POST", + f"{self.base_url}/v1/chat/completions", headers={"X-Request-ID": request_id}, json={ "model": self.MODEL, "stream": True, "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}], }, - stream=True, timeout=30, ) as resp: self.assertEqual(resp.status_code, 200) @@ -1997,38 +1753,25 @@ def test_request_cancellation(self): self.assertIsNotNone(resp.choices[0].message.content) -@unittest.skipUnless(run_slow, "Set RUN_SLOW=1 to run integration tests") -@require_openai +@slow +@require_serve +@require_torch_accelerator class TestContinuousBatchingResponses(unittest.TestCase): """Integration tests for /v1/responses with continuous batching.""" MODEL = "Qwen/Qwen2.5-0.5B-Instruct" - PORT = 8893 @classmethod def setUpClass(cls): - from transformers.cli.serve_refactored import Serve - - cls.serve = Serve( + cls.serve, port = _start_serve( force_model=cls.MODEL, - port=cls.PORT, device="cuda:0", continuous_batching=True, attn_implementation="sdpa", default_seed=42, - non_blocking=True, ) - import requests - - for _ in range(30): - try: - if requests.get(f"http://localhost:{cls.PORT}/health", timeout=1).status_code == 200: - break - except Exception: - continue - time.sleep(2) - - cls.client = OpenAI(base_url=f"http://localhost:{cls.PORT}/v1", api_key="unused") + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") @classmethod def tearDownClass(cls): @@ -2076,12 +1819,12 @@ def test_multi_turn(self): def test_request_cancellation(self): """Opening a stream and closing it early triggers CB cancellation.""" - import requests as req request_id = "test-cb-resp-cancel" - with req.post( - f"http://localhost:{self.PORT}/v1/responses", + with httpx.stream( + "POST", + f"{self.base_url}/v1/responses", headers={"X-Request-ID": request_id}, json={ "model": self.MODEL, @@ -2089,13 +1832,12 @@ def test_request_cancellation(self): "input": "Count slowly so I can cancel you.", "max_output_tokens": 500, }, - stream=True, timeout=30, ) as resp: self.assertEqual(resp.status_code, 200) # Read enough data to ensure CB generation has started, then close. received = b"" - for chunk in resp.iter_content(chunk_size=512): + for chunk in resp.iter_bytes(chunk_size=512): received += chunk if b"output_text.delta" in received: break From 80b5c780bc81c68d59be714fa07125f98026c075 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 30 Mar 2026 17:54:31 +0000 Subject: [PATCH 53/64] update test --- tests/cli/test_serve.py | 2398 +++++++++++++++++++--------- tests/cli/test_serve_refactored.py | 1865 --------------------- 2 files changed, 1603 insertions(+), 2660 deletions(-) delete mode 100644 tests/cli/test_serve_refactored.py diff --git a/tests/cli/test_serve.py b/tests/cli/test_serve.py index 39d8cde39b50..63503cbe0bf9 100644 --- a/tests/cli/test_serve.py +++ b/tests/cli/test_serve.py @@ -11,61 +11,73 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Tests for the serving layer. + +""" + +import asyncio import json import os -import tempfile +import socket import time import unittest -from threading import Thread -from unittest.mock import Mock, patch +from unittest.mock import MagicMock import httpx -from huggingface_hub import ChatCompletionStreamOutput, InferenceClient, hf_hub_download -from parameterized import parameterized - -from transformers import GenerationConfig -from transformers.cli.serve import Modality, Serve -from transformers.testing_utils import require_openai, slow -from transformers.utils.import_utils import ( - is_fastapi_available, - is_openai_available, - is_pydantic_available, - is_uvicorn_available, + +from transformers.cli.serve import Serve +from transformers.cli.serving.chat_completion import ChatCompletionHandler +from transformers.cli.serving.model_manager import ModelManager, TimedModel +from transformers.cli.serving.response import ResponseHandler, compute_usage +from transformers.cli.serving.server import build_server +from transformers.cli.serving.transcription import TranscriptionHandler +from transformers.cli.serving.utils import ( + BaseHandler, + GenerationState, + Modality, + ToolCallParser, + detect_tool_format, ) +from transformers.testing_utils import require_serve, require_torch_accelerator, require_vision, slow +from transformers.utils.import_utils import is_serve_available -serve_dependencies_available = ( - is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available() -) +if is_serve_available(): + from fastapi import HTTPException + from openai import OpenAI + from openai.types.responses import Response, ResponseCreatedEvent + + +def _find_free_port() -> int: + """Return a free TCP port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + return s.getsockname()[1] + + +def _start_serve(**kwargs) -> tuple["Serve", int]: + """Start a non-blocking Serve instance on a free port and wait until healthy. -if serve_dependencies_available: - from openai import APIConnectionError, OpenAI - from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction - from openai.types.responses import ( - Response, - ResponseCompletedEvent, - ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent, - ResponseCreatedEvent, - ResponseInProgressEvent, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponseTextDeltaEvent, - ResponseTextDoneEvent, - ) - - -@require_openai -def test_help(cli): - """Minimal test: we can invoke the help command.""" - output = cli("serve", "--help") - assert output.exit_code == 0 - assert "serve" in output.output - - -@require_openai + Returns ``(serve, port)``. + """ + port = _find_free_port() + serve = Serve(port=port, non_blocking=True, **kwargs) + for _ in range(30): + try: + if httpx.get(f"http://localhost:{port}/health", timeout=2).status_code == 200: + return serve, port + except Exception: # noqa: S110 + pass + time.sleep(1) + raise RuntimeError(f"Server on port {port} did not become healthy in time") + + +@require_serve def test_host_port_blocking(cli): - """Minimal test: we can set arguments through the CLI - blocking""" + """CLI args --host and --port are passed to uvicorn.Config, and server.run() is called.""" + from unittest.mock import Mock, patch + with ( patch("uvicorn.Config") as ConfigMock, patch("uvicorn.Server") as ServerMock, @@ -73,880 +85,1217 @@ def test_host_port_blocking(cli): server_instance = Mock() ServerMock.return_value = server_instance - # Call the serve CLI with host/port out = cli("serve", "--host", "0.0.0.0", "--port", "9000") _, kwargs = ConfigMock.call_args assert out.exit_code == 0 assert kwargs["host"] == "0.0.0.0" assert kwargs["port"] == 9000 - ServerMock.assert_called_once_with(ConfigMock.return_value) server_instance.run.assert_called_once() -@require_openai -def test_host_port_non_blocking(cli, caplog): - """Minimal test: we can set arguments through the CLI - non-blocking""" - caplog.set_level(100000) - # ^ hack to avoid an issue happening only in CI. We don't check logs anyway so it's fine. - # Source: https://github.com/pallets/click/issues/824#issuecomment-562581313 +class TestProcessorInputsFromMessages(unittest.TestCase): + def test_llm_string_content(self): + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - with ( - patch("uvicorn.Config") as ConfigMock, - patch("uvicorn.Server") as ServerMock, - patch.object(Serve, "start_server") as start_mock, - ): - server_instance = Mock() - ServerMock.return_value = server_instance + messages = [{"role": "user", "content": "Hello"}] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(result, [{"role": "user", "content": "Hello"}]) - out = cli("serve", "--host", "0.5.0.0", "--port", "9002", "--non-blocking") - assert out.exit_code == 0 + def test_llm_list_content_text_only(self): + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - # Config got the CLI args - _, kwargs = ConfigMock.call_args - assert kwargs["host"] == "0.5.0.0" - assert kwargs["port"] == 9002 + messages = [{"role": "user", "content": [{"type": "text", "text": "A"}, {"type": "text", "text": "B"}]}] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(result, [{"role": "user", "content": "A B"}]) - # Non-blocking path uses start_server(), not server.run() - start_mock.assert_called_once() - server_instance.run.assert_not_called() + def test_vlm_string_content_wrapped(self): + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages + messages = [{"role": "user", "content": "Hello"}] + result = get_processor_inputs_from_messages(messages, Modality.VLM) + self.assertEqual(result, [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]) -@require_openai -def test_build_chat_completion_chunk(): - """ - Tests that the chunks are correctly built for the Chat Completion API. The `choices` checks implicitly - confirm that empty fields are not emitted. - """ - dummy = Serve.__new__(Serve) - - # The keys for these fields must be present in every chunk - MANDATORY_FIELDS = ["data", "id", "choices", "created", "model", "object", "system_fingerprint"] - - # Case 1: most fields are provided - chunk = dummy.build_chat_completion_chunk( - request_id="req0", content="hello", finish_reason="stop", role="user", model="dummy_model@main" - ) - chunk = dummy.chunk_to_sse_element(chunk) - for field in MANDATORY_FIELDS: - assert field in chunk - assert '"choices":[{"delta":{"content":"hello","role":"user"},"finish_reason":"stop","index":0}]' in chunk - - # Case 2: only the role is provided -- other fields in 'choices' are omitted - chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user", model="dummy_model@main") - chunk = dummy.chunk_to_sse_element(chunk) - for field in MANDATORY_FIELDS: - assert field in chunk - assert '"choices":[{"delta":{"role":"user"},"index":0}]' in chunk - - # Case 3: only the content is provided -- other fields in 'choices' are omitted - chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello", model="dummy_model@main") - chunk = dummy.chunk_to_sse_element(chunk) - for field in MANDATORY_FIELDS: - assert field in chunk - assert '"choices":[{"delta":{"content":"hello"},"index":0}]' in chunk - - # Case 4: tool calls support a list of ChoiceDeltaToolCall objects - tool_call = ChoiceDeltaToolCall( - index=0, - function=ChoiceDeltaToolCallFunction(name="foo_bar", arguments='{"foo1": "bar1", "foo2": "bar2"}'), - type="function", - ) - chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call], model="dummy_model@main") - chunk = dummy.chunk_to_sse_element(chunk) - for field in MANDATORY_FIELDS: - assert field in chunk - expected_choices_content = ( - 'choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"foo1\\": \\"bar1\\", ' - '\\"foo2\\": \\"bar2\\"}","name":"foo_bar"},"type":"function"}]},"index":0}]' - ) - assert expected_choices_content in chunk - - -def test_generative_model_list(): - with tempfile.TemporaryDirectory() as cache_dir: - # "download" a few models, including some non-generative models - hf_hub_download("Menlo/Jan-nano", "config.json", cache_dir=cache_dir) - hf_hub_download("Menlo/Jan-nano-128k", "config.json", cache_dir=cache_dir) - hf_hub_download("Qwen/Qwen2.5-0.5B-Instruct", "config.json", cache_dir=cache_dir) - hf_hub_download("HuggingFaceTB/SmolVLM-Instruct", "config.json", cache_dir=cache_dir) - hf_hub_download("google-bert/bert-base-cased", "config.json", cache_dir=cache_dir) - - expected_results = { - "HuggingFaceTB/SmolVLM-Instruct": ["HuggingFaceTB", "SmolVLM-Instruct"], - "Qwen/Qwen2.5-0.5B-Instruct": ["Qwen", "Qwen2.5-0.5B-Instruct"], - "Menlo/Jan-nano": ["Menlo", "Jan-nano"], - "Menlo/Jan-nano-128k": ["Menlo", "Jan-nano-128k"], - } + def test_vlm_text_and_image_url(self): + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - # list models - result = Serve.get_gen_models(cache_dir) - assert len(expected_results) == len(result) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/img.png"}}, + ], + } + ] + result = get_processor_inputs_from_messages(messages, Modality.VLM) + self.assertEqual(len(result[0]["content"]), 2) + self.assertEqual(result[0]["content"][0]["type"], "text") + self.assertEqual(result[0]["content"][1], {"type": "image", "url": "https://example.com/img.png"}) - local_repos = {repo["id"]: repo["owned_by"] for repo in result} + def test_llm_multi_turn_conversation(self): + """Multi-turn conversation with string content should pass through as-is.""" - for key, value in expected_results.items(): - assert key in local_repos - assert local_repos[key] == value + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages + messages = [ + {"role": "user", "content": "How are you?"}, + {"role": "assistant", "content": "I'm great!"}, + {"role": "user", "content": "Help me write tests?"}, + ] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(len(result), 3) + self.assertEqual(result[0]["content"], "How are you?") + self.assertEqual(result[1]["role"], "assistant") + self.assertEqual(result[2]["content"], "Help me write tests?") -@require_openai -def test_build_response_event(): - """ - Tests that the events are correctly built for the Response API. + def test_llm_list_content_with_type(self): + """LLM messages with typed content list should extract text and join.""" - Contrarily to the Chat Completion API, the Response API has a wide set of possible output objects. This test - only checks a few basic assumptions -- we rely on OpenAI's pydantic models to enforce the correct schema. - """ - dummy = Serve.__new__(Serve) - - response_created = ResponseCreatedEvent( - type="response.created", - sequence_number=0, - response=Response( - id="resp_0", - created_at=time.time(), - status="queued", - model="dummy_model@main", - instructions=None, # <--- is set to None = should NOT be in the output. - text={"format": {"type": "text"}}, - object="response", - tools=[], # <--- empty lists should be in the output (they are often mandatory fields) - output=[], - parallel_tool_calls=False, - tool_choice="auto", - metadata=None, - ), - ) - - event = dummy.chunk_to_sse_element(response_created) - assert event.startswith("data: ") # Sanity check: event formatting - assert '"model":"dummy_model@main"' in event # Sanity check: set field - assert '"status":"queued"' in event - assert "tools" in event # empty lists should be in the output - assert "output" in event - assert "instructions" not in event # None fields should NOT be in the output - assert "metadata" not in event - assert "error" not in event # Unset optional fields should NOT be in the output - assert "top_p" not in event - - -def retry(fn, max_attempts=5, delay=2): - """ - Retry a function up to `max_attempts` times with a `delay` between attempts. - Useful for testing functions that may fail due to server not being ready. - """ + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - def wrapper(*args, **kwargs): - nb_attempts = 0 - while True: - nb_attempts += 1 - try: - return fn(*args, **kwargs) - except (httpx.HTTPError, APIConnectionError): - if nb_attempts >= max_attempts: - raise - time.sleep(delay) + messages = [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]} + ] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(result[0]["content"], "Hello world") - return wrapper + @require_vision + def test_vlm_base64_image_creates_temp_file(self): + """Base64 image URLs should be decoded and saved to a temp file.""" + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages -class ServeCompletionsMixin: - """ - Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API - (`generate` and `continuous_batching`). - """ + # Minimal valid 1x1 PNG as base64 + base64_url = ( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4" + "2mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + ) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url", "image_url": {"url": base64_url}}, + ], + } + ] + result = get_processor_inputs_from_messages(messages, Modality.VLM) + image_item = result[0]["content"][1] + self.assertEqual(image_item["type"], "image") + self.assertTrue(os.path.exists(image_item["url"])) # temp file was created - @retry - def run_server(self, request): - with InferenceClient(f"http://localhost:{self.port}") as client: - return list(client.chat_completion(**request)) - - @parameterized.expand( - [ - ("default_request", {}), - ("one_token", {"max_tokens": 1}), - ("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}), - ( - "tool_call", - { - "tools": [ - { - "function": { - "name": "foo_bar", - "parameters": {"type": "object"}, - "description": "Foo bar", - }, - "type": "function", - } - ] - }, - ), + def test_vlm_multi_turn(self): + """VLM multi-turn: string content should be wrapped in text type.""" + + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages + + messages = [ + {"role": "user", "content": "Describe the image"}, + {"role": "assistant", "content": "It shows a cat"}, + {"role": "user", "content": "What color?"}, ] - ) - def test_requests(self, test_name: str, request_flags: dict): - """Tests that the completions app gracefully handles GOOD requests, producing the expected output payloads.""" - - request = { - "model": "Qwen/Qwen3-0.6B", - "messages": [{"role": "user", "content": "Hello, how are you?"}], - "stream": True, # We don't support "stream": False yet - "max_tokens": 5, # Small generation by default - } - request.update(request_flags) - all_payloads = self.run_server(request) - - # If a request is successful, the returned payload needs to follow the schema, which we test here. - # NOTE: the output of our server is wrapped by `InferenceClient`, which sends fields even when they - # are empty. - - # Finish reason: the last payload should have a finish reason of "length" or "stop", all others should be empty - finish_reasons = [payload.choices[0].finish_reason for payload in all_payloads] - self.assertTrue(finish_reasons[-1] in ["length", "stop"]) - self.assertTrue(all(reason is None for reason in finish_reasons[:-1])) - - # Role: the first payload should have a role of "assistant", all others should be empty - roles = [payload.choices[0].delta.role for payload in all_payloads] - self.assertEqual(roles[0], "assistant") - self.assertTrue(all(role is None for role in roles[1:])) - - # Content: the first and the last payload shouldn't have content (role and finish reason). It may be empty - # in some other payload positions, e.g. tool calls. - contents = [payload.choices[0].delta.content for payload in all_payloads] - self.assertTrue(contents[0] is None and contents[-1] is None) - self.assertTrue(any(content is not None for content in contents[1:-1])) - # TODO: add "usage" field to output and test it - - def test_generation_config_in_request(self): - """Tests that the generation config is correctly passed into the generation call.""" - generation_config = GenerationConfig(do_sample=False, temperature=0.0) - request = { - "model": "Qwen/Qwen3-0.6B", - "messages": [{"role": "user", "content": "Hello, how are you?"}], - "stream": True, - "max_tokens": 10, - "extra_body": { - "generation_config": generation_config.to_json_string(), - }, - } - all_payloads = self.run_server(request) - contents = [payload.choices[0].delta.content for payload in all_payloads] - output_text = "".join([text for text in contents if text is not None]) - # The generation config sets greedy decoding, so the output is reproducible. By default, `Qwen/Qwen3-0.6B` - # sets `do_sample=True` - self.assertEqual(output_text, '\nOkay, the user just asked, "') + result = get_processor_inputs_from_messages(messages, Modality.VLM) + self.assertEqual(len(result), 3) + for msg in result: + self.assertIsInstance(msg["content"], list) + self.assertEqual(msg["content"][0]["type"], "text") - def test_early_return_due_to_length(self): - request = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "messages": [{"role": "user", "content": "Hello, how are you?"}], - "stream": True, - "max_tokens": 3, - } - all_payloads = self.run_server(request) - last_payload = all_payloads[-1] - self.assertTrue(last_payload.choices[0]["finish_reason"] == "length") +class TestGenerativeModelList(unittest.TestCase): + def test_lists_only_generative_models(self): + """Should list LLMs and VLMs but not non-generative models like BERT.""" + import tempfile - def test_continues_until_stop(self): - request = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "messages": [{"role": "user", "content": 'Please only answer with "Hi."'}], - "stream": True, - "max_tokens": 30, - } + from huggingface_hub import hf_hub_download - all_payloads = self.run_server(request) - last_payload = all_payloads[-1] - self.assertTrue(last_payload.choices[0]["finish_reason"] == "stop") + with tempfile.TemporaryDirectory() as cache_dir: + # Download config.json for a few models + hf_hub_download("Qwen/Qwen2.5-0.5B-Instruct", "config.json", cache_dir=cache_dir) + hf_hub_download("google-bert/bert-base-cased", "config.json", cache_dir=cache_dir) + result = ModelManager.get_gen_models(cache_dir) + model_ids = {r["id"] for r in result} -class ServeCompletionsGenerateMockTests(unittest.TestCase): - def test_processor_inputs_from_inbound_messages_llm(self): - modality = Modality.LLM - messages = expected_outputs = [ - {"role": "user", "content": "How are you doing?"}, - {"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"}, - {"role": "user", "content": "Can you help me write tests?"}, - ] - outputs = Serve.get_processor_inputs_from_inbound_messages(messages, modality) - self.assertListEqual(expected_outputs, outputs) + self.assertIn("Qwen/Qwen2.5-0.5B-Instruct", model_ids) + self.assertNotIn("google-bert/bert-base-cased", model_ids) - messages_with_type = [ - {"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]}, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"} - ], - }, - {"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]}, - ] - outputs = Serve.get_processor_inputs_from_inbound_messages(messages_with_type, modality) - self.assertListEqual(expected_outputs, outputs) - messages_multiple_text = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "How are you doing?"}, - {"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"}, - ], - }, - ] - expected_outputs_multiple_text = [ - { - "role": "user", - "content": "How are you doing? I'm doing great, thank you for asking! How can I assist you today?", - }, - ] - outputs = Serve.get_processor_inputs_from_inbound_messages(messages_multiple_text, modality) - self.assertListEqual(expected_outputs_multiple_text, outputs) +@require_serve +class TestBuildGenerationConfig(unittest.TestCase): + def _make_handler(self): + return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) - def test_processor_inputs_from_inbound_messages_vlm_text_only(self): - modality = Modality.VLM - messages = [ - {"role": "user", "content": "How are you doing?"}, - {"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"}, - {"role": "user", "content": "Can you help me write tests?"}, - ] + def test_max_tokens(self): + from transformers import GenerationConfig - expected_outputs = [ - {"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]}, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"} - ], - }, - {"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]}, - ] + result = self._make_handler()._build_generation_config({"max_tokens": 7}, GenerationConfig()) + self.assertEqual(result.max_new_tokens, 7) - outputs = Serve.get_processor_inputs_from_inbound_messages(messages, modality) - self.assertListEqual(expected_outputs, outputs) + def test_temperature_zero_disables_sampling(self): + from transformers import GenerationConfig - def test_processor_inputs_from_inbound_messages_vlm_text_and_image_in_base_64(self): - modality = Modality.VLM - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "How many pixels are in the image?"}, - { - "type": "image_url", - "image_url": { - "url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAASABIAAD/4QBARXhpZgAATU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAABaADAAQAAAABAAAABQAAAAD/7QA4UGhvdG9zaG9wIDMuMAA4QklNBAQAAAAAAAA4QklNBCUAAAAAABDUHYzZjwCyBOmACZjs+EJ+/8AAEQgABQAFAwEiAAIRAQMRAf/EAB8AAAEFAQEBAQEBAAAAAAAAAAABAgMEBQYHCAkKC//EALUQAAIBAwMCBAMFBQQEAAABfQECAwAEEQUSITFBBhNRYQcicRQygZGhCCNCscEVUtHwJDNicoIJChYXGBkaJSYnKCkqNDU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6g4SFhoeIiYqSk5SVlpeYmZqio6Slpqeoqaqys7S1tre4ubrCw8TFxsfIycrS09TV1tfY2drh4uPk5ebn6Onq8fLz9PX29/j5+v/EAB8BAAMBAQEBAQEBAQEAAAAAAAABAgMEBQYHCAkKC//EALURAAIBAgQEAwQHBQQEAAECdwABAgMRBAUhMQYSQVEHYXETIjKBCBRCkaGxwQkjM1LwFWJy0QoWJDThJfEXGBkaJicoKSo1Njc4OTpDREVGR0hJSlNUVVZXWFlaY2RlZmdoaWpzdHV2d3h5eoKDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uLj5OXm5+jp6vLz9PX29/j5+v/bAEMAAQEBAQEBAgEBAgICAgICAwICAgIDBAMDAwMDBAUEBAQEBAQFBQUFBQUFBQYGBgYGBgcHBwcHCAgICAgICAgICP/bAEMBAQEBAgICAwICAwgFBAUICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICP/dAAQAAf/aAAwDAQACEQMRAD8A/v4ooooA/9k=" - }, - }, - ], - }, - { - "role": "assistant", - "content": "The number of pixels in the image cannot be determined from the provided information.", - }, - {"role": "user", "content": "Alright"}, - ] + result = self._make_handler()._build_generation_config({"temperature": 0.0}, GenerationConfig(do_sample=True)) + self.assertFalse(result.do_sample) - expected_outputs = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "How many pixels are in the image?"}, - {"type": "image", "url": "/var/folders/4v/64sxdhsd3gz3r8vhhnyc0mqw0000gn/T/tmp50oyghk6.png"}, - ], - }, - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": "The number of pixels in the image cannot be determined from the provided information.", - } - ], - }, - {"role": "user", "content": [{"type": "text", "text": "Alright"}]}, - ] + def test_frequency_penalty(self): + from transformers import GenerationConfig - outputs = Serve.get_processor_inputs_from_inbound_messages(messages, modality) + result = self._make_handler()._build_generation_config({"frequency_penalty": 0.5}, GenerationConfig()) + self.assertAlmostEqual(result.repetition_penalty, 1.5) - for expected_output, output in zip(expected_outputs, outputs): - expected_output_content = expected_output["content"] - output_content = output["content"] + def test_logit_bias_tuple_keys(self): + from transformers import GenerationConfig - self.assertEqual(type(expected_output_content), type(output_content)) + result = self._make_handler()._build_generation_config({"logit_bias": {"42": 1.0}}, GenerationConfig()) + self.assertEqual(result.sequence_bias, {(42,): 1.0}) - if isinstance(expected_output_content, list): - for expected_output_content_item, output_content_item in zip(expected_output_content, output_content): - self.assertIn("type", expected_output_content_item) - self.assertIn("type", output_content_item) - self.assertTrue(expected_output_content_item["type"] == output_content_item["type"]) + def test_stop_strings(self): + from transformers import GenerationConfig - if expected_output_content_item["type"] == "text": - self.assertEqual(expected_output_content_item["text"], output_content_item["text"]) + result = self._make_handler()._build_generation_config({"stop": [""]}, GenerationConfig()) + self.assertEqual(result.stop_strings, [""]) - if expected_output_content_item["type"] == "image": - self.assertTrue(os.path.exists(output_content_item["url"])) - else: - raise ValueError("VLMs should only receive content as lists.") + def test_generation_config_json_overrides(self): + from transformers import GenerationConfig + custom = GenerationConfig(max_new_tokens=5, do_sample=False) + result = self._make_handler()._build_generation_config( + {"generation_config": custom.to_json_string()}, GenerationConfig(max_new_tokens=100) + ) + self.assertEqual(result.max_new_tokens, 5) + self.assertFalse(result.do_sample) -@slow # server startup time is slow on our push CI -@require_openai -class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase): - """Tests the `generate` version of the Completions API.""" + def test_generation_config_json_no_defaults_applied(self): + """When generation_config JSON is passed, serving defaults should NOT be applied.""" + from transformers import GenerationConfig - @classmethod - def setUpClass(cls): - """Starts a server for tests to connect to.""" - cls.port = 8001 - cls.server = Serve(port=cls.port, non_blocking=True) + custom = GenerationConfig(max_new_tokens=10) + result = self._make_handler()._build_generation_config( + {"generation_config": custom.to_json_string()}, GenerationConfig() + ) + # Should keep 10, not bump to 1024 + self.assertEqual(result.max_new_tokens, 10) - @classmethod - def tearDownClass(cls): - cls.server.kill_server() + def test_default_bumps_short_max_new_tokens(self): + from transformers import GenerationConfig - @slow - def test_tool_call(self): - """Tests that the tool call is correctly handled and that the payloads are correctly structured.""" - # TODO: move to the mixin when CB also supports tool calls - - request = { - # This model is a small model that's very eager to call tools - # TODO: this is a 4B model. Find a smaller model that's eager to call tools - "model": "Menlo/Jan-nano", - # The request should produce a tool call - "messages": [{"role": "user", "content": "Generate an image of a cat."}], - "stream": True, - "max_tokens": 50, - # Reproducibility - "temperature": 0.0, - # This tool is a copy from the tool in the original tiny-agents demo - "tools": [ - { - "function": { - "name": "flux1_schnell_infer", - "parameters": { - "type": "object", - "properties": { - "prompt": {"type": "string"}, - "seed": {"type": "number", "description": "numeric value between 0 and 2147483647"}, - "randomize_seed": {"type": "boolean", "default": True}, - "width": { - "type": "number", - "description": "numeric value between 256 and 2048", - "default": 1024, - }, - "height": { - "type": "number", - "description": "numeric value between 256 and 2048", - "default": 1024, - }, - "num_inference_steps": { - "type": "number", - "description": "numeric value between 1 and 16", - "default": 4, - }, - }, - }, - "description": "Generate an image using the Flux 1 Schnell Image Generator.", - }, - "type": "function", - } - ], - } - all_payloads = self.run_server(request) - - # The first payload should contain the role - roles = [payload.choices[0].delta.role for payload in all_payloads] - self.assertEqual(roles[0], "assistant") - self.assertTrue(all(role is None for role in roles[1:])) - - # All other payloads (except the last one) should be tool call related, for this specific request - contents = [payload.choices[0].delta.content for payload in all_payloads] - self.assertTrue(all(content is None for content in contents)) - - # The first tool call delta should contain the tool name. The other tool call deltas should contain the tool - # arguments. - tool_calls = [payload.choices[0].delta.tool_calls[0] for payload in all_payloads[1:-1]] - first_tool_call = tool_calls[0] - self.assertEqual(first_tool_call["function"]["name"], "flux1_schnell_infer") - self.assertEqual(first_tool_call["function"]["arguments"], None) - other_tool_calls = tool_calls[1:] - self.assertTrue(all(tool_call["function"]["name"] is None for tool_call in other_tool_calls)) - self.assertTrue(all(tool_call["function"]["arguments"] is not None for tool_call in other_tool_calls)) - - # Finally, the last payload should contain a finish reason - finish_reasons = [payload.choices[0].finish_reason for payload in all_payloads] - # TODO: I think the finish reason for a tool call is different? double check this - self.assertTrue(finish_reasons[-1] in ["stop", "length"]) - self.assertTrue(all(reason is None for reason in finish_reasons[:-1])) - - -def _get_scheduler(serve_command): - # Defensive navigation in case any layer is renamed in the future - cbm = getattr(serve_command, "running_continuous_batching_manager", None) - assert cbm is not None, "ServeCommand has no running_continuous_batching_manager" - bp = getattr(cbm, "batch_processor", None) - assert bp is not None, "running_continuous_batching_manager has no batch_processor" - sched = getattr(bp, "scheduler", None) - assert sched is not None, "batch_processor has no scheduler" - return sched - - -def _call_healthcheck(base_url: str): - response = None - retries = 10 - while retries > 0: - try: - response = httpx.get(f"{base_url}/health") - break - except httpx.NetworkError: - time.sleep(0.1) - retries -= 1 - return response + result = self._make_handler()._build_generation_config({}, GenerationConfig(max_new_tokens=20)) + self.assertEqual(result.max_new_tokens, 1024) + def test_user_max_tokens_overrides_default(self): + """User's max_tokens should win over the serving default.""" + from transformers import GenerationConfig -def _open_stream_and_cancel(base_url: str, request_id: str): - with httpx.Client() as s: - with s.stream( - "POST", - f"{base_url}/v1/chat/completions", - headers={"X-Request-ID": request_id}, - json={ - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "stream": True, - "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}], - }, - timeout=30, - ) as resp: - assert resp.status_code == 200 + result = self._make_handler()._build_generation_config({"max_tokens": 50}, GenerationConfig(max_new_tokens=20)) + self.assertEqual(result.max_new_tokens, 50) - wait_for_n_chunks = 3 - for i, _ in enumerate(resp.iter_bytes(chunk_size=None)): - if i >= wait_for_n_chunks: - resp.close() - break + +@require_serve +class TestValidation(unittest.TestCase): + def _make_handler(self): + return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) + + def test_valid_request_passes(self): + handler = self._make_handler() + # Should not raise + handler._validate_request({"model": "x", "messages": [{"role": "user", "content": "hi"}], "stream": True}) + + def test_unexpected_keys_rejected(self): + handler = self._make_handler() + with self.assertRaises(HTTPException) as ctx: + handler._validate_request({"model": "x", "messages": [], "bogus_field": True}) + self.assertEqual(ctx.exception.status_code, 422) + self.assertIn("bogus_field", ctx.exception.detail) + + def test_unsupported_fields_warns(self): + handler = self._make_handler() + with self.assertLogs("transformers", level="WARNING") as cm: + handler._validate_request({"model": "x", "messages": [], "audio": {}}) + self.assertTrue(any("audio" in msg for msg in cm.output)) -@slow # server startup time is slow on our push CI -@require_openai -class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, unittest.TestCase): - """Tests the `continuous_batching` version of the Completions API.""" +class TestModelManager(unittest.TestCase): + def test_process_model_name_adds_main(self): + self.assertEqual(ModelManager.process_model_name("org/model"), "org/model@main") + def test_process_model_name_preserves_revision(self): + self.assertEqual(ModelManager.process_model_name("org/model@dev"), "org/model@dev") + + def test_quantization_config_4bit(self): + mm = ModelManager(quantization="bnb-4bit") + cfg = mm.get_quantization_config() + self.assertTrue(cfg.load_in_4bit) + + def test_quantization_config_8bit(self): + mm = ModelManager(quantization="bnb-8bit") + cfg = mm.get_quantization_config() + self.assertTrue(cfg.load_in_8bit) + + def test_quantization_config_none(self): + mm = ModelManager() + self.assertIsNone(mm.get_quantization_config()) + + +class TestTimedModel(unittest.TestCase): + def test_delete_model(self): + mock_model = MagicMock() + deleted = [] + timed = TimedModel( + mock_model, timeout_seconds=9999, processor=MagicMock(), on_unload=lambda: deleted.append(True) + ) + self.assertIsNotNone(timed.model) + timed.delete_model() + self.assertIsNone(timed.model) + self.assertEqual(len(deleted), 1) + + def test_timeout_zero_no_delete(self): + mock_model = MagicMock() + timed = TimedModel(mock_model, timeout_seconds=0, processor=MagicMock()) + timed._timeout_reached() + self.assertIsNotNone(timed.model) + timed._timer.cancel() + + +@require_serve +class TestChunkSSE(unittest.TestCase): + def _make_handler(self): + return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) + + def test_build_chunk_sse_content(self): + handler = self._make_handler() + sse = handler._build_chunk_sse(request_id="req1", content="hi", model="m") + self.assertTrue(sse.startswith("data: ")) + self.assertTrue(sse.endswith("\n\n")) + parsed = json.loads(sse[len("data: ") :].strip()) + self.assertEqual(parsed["choices"][0]["delta"]["content"], "hi") + + def test_build_chunk_sse_role(self): + handler = self._make_handler() + sse = handler._build_chunk_sse(request_id="req1", role="assistant", model="m") + parsed = json.loads(sse[len("data: ") :].strip()) + self.assertEqual(parsed["choices"][0]["delta"]["role"], "assistant") + self.assertNotIn("content", parsed["choices"][0]["delta"]) + + def test_build_chunk_sse_finish_reason(self): + handler = self._make_handler() + sse = handler._build_chunk_sse(request_id="req1", finish_reason="stop", model="m") + parsed = json.loads(sse[len("data: ") :].strip()) + self.assertEqual(parsed["choices"][0]["finish_reason"], "stop") + + def test_chunk_to_sse_string_passthrough(self): + result = BaseHandler.chunk_to_sse("data: already formatted\n\n") + self.assertEqual(result, "data: already formatted\n\n") + + def test_chunk_to_sse_wraps_plain_string(self): + result = BaseHandler.chunk_to_sse("hello") + self.assertEqual(result, "data: hello\n\n") + + +QWEN_TOOL_FORMAT = {"start": "", "end": ""} + + +@require_serve +class TestToolParser(unittest.TestCase): + def test_detect_tool_format_qwen(self): + model = MagicMock() + model.config.architectures = ["Qwen2ForCausalLM"] + fmt = detect_tool_format(model) + self.assertEqual(fmt, QWEN_TOOL_FORMAT) + + def test_detect_tool_format_unsupported(self): + model = MagicMock() + model.config.architectures = ["LlamaForCausalLM"] + self.assertIsNone(detect_tool_format(model)) + + def test_parser_start_token(self): + parser = ToolCallParser(QWEN_TOOL_FORMAT) + result = parser.feed("") + self.assertIs(result, ToolCallParser.CONSUMED) + + def test_parser_end_token(self): + parser = ToolCallParser(QWEN_TOOL_FORMAT) + parser.feed("") + result = parser.feed("") + self.assertIs(result, ToolCallParser.CONSUMED) + + def test_parser_buffers_until_end(self): + parser = ToolCallParser(QWEN_TOOL_FORMAT) + parser.feed("") + # Intermediate tokens are buffered + result = parser.feed('{"name": "my_tool", "arguments": {"x": 1}}') + self.assertIs(result, ToolCallParser.CONSUMED) + # Tool call is emitted on end token + result = parser.feed("") + self.assertIsNot(result, ToolCallParser.CONSUMED) + self.assertEqual(result["name"], "my_tool") + + def test_parser_normal_text_returns_none(self): + parser = ToolCallParser(QWEN_TOOL_FORMAT) + result = parser.feed("Hello world") + self.assertIsNone(result) + + def test_parser_full_flow(self): + """Simulate a complete tool call token sequence.""" + + parser = ToolCallParser(QWEN_TOOL_FORMAT) + tool_calls = [] + + for token in [ + "", + '{"name": "get_weather",', + ' "arguments": {', + '"city": "Paris"', + "}}", + "\n", + "", + ]: + result = parser.feed(token) + if result is not None and result is not ToolCallParser.CONSUMED: + tool_calls.append(result) + + # Single tool call emitted on with both name and arguments + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertIn("Paris", tool_calls[0]["arguments"]) + + def test_parse_tool_calls_from_text(self): + """Non-streaming tool call parsing from complete text.""" + + text = '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n' + calls = ToolCallParser.parse(text, QWEN_TOOL_FORMAT) + self.assertIsNotNone(calls) + self.assertEqual(len(calls), 1) + self.assertEqual(calls[0]["name"], "get_weather") + self.assertIn("Paris", calls[0]["arguments"]) + + def test_parse_tool_calls_no_tool_call(self): + """Non-streaming: normal text returns None.""" + + calls = ToolCallParser.parse("Hello, how can I help?", QWEN_TOOL_FORMAT) + self.assertIsNone(calls) + + def test_parse_multiple_tool_calls(self): + """Non-streaming: multiple tool calls in one response.""" + + text = ( + '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n\n' + '\n{"name": "get_weather", "arguments": {"city": "London"}}\n' + ) + calls = ToolCallParser.parse(text, QWEN_TOOL_FORMAT) + self.assertIsNotNone(calls) + self.assertEqual(len(calls), 2) + self.assertEqual(calls[0]["name"], "get_weather") + self.assertIn("Paris", calls[0]["arguments"]) + self.assertEqual(calls[1]["name"], "get_weather") + self.assertIn("London", calls[1]["arguments"]) + + def test_feed_multiple_tool_calls(self): + """Streaming: multiple tool calls emitted sequentially.""" + + parser = ToolCallParser(QWEN_TOOL_FORMAT) + tool_calls = [] + + tokens = [ + "", + '{"name": "get_weather", "arguments": {"city": "Paris"}}', + "", + "", + '{"name": "get_weather", "arguments": {"city": "London"}}', + "", + ] + for token in tokens: + result = parser.feed(token) + if result is not None and result is not ToolCallParser.CONSUMED: + tool_calls.append(result) + + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertIn("Paris", tool_calls[0]["arguments"]) + self.assertEqual(tool_calls[1]["name"], "get_weather") + self.assertIn("London", tool_calls[1]["arguments"]) + + +@require_serve +class TestAppRoutes(unittest.TestCase): @classmethod def setUpClass(cls): - """Starts a server for tests to connect to.""" - cls.port = 8002 - cls.server = Serve( - port=cls.port, continuous_batching=True, attn_implementation="sdpa", default_seed=42, non_blocking=True + cls.model_manager = MagicMock(spec=ModelManager) + cls.model_manager.get_gen_models.return_value = [ + {"id": "test/model", "owned_by": "test", "object": "model", "created": 0} + ] + cls.chat_handler = MagicMock(spec=ChatCompletionHandler) + cls.response_handler = MagicMock(spec=ResponseHandler) + cls.transcription_handler = MagicMock(spec=TranscriptionHandler) + cls.app = build_server(cls.model_manager, cls.chat_handler, cls.response_handler, cls.transcription_handler) + cls.transport = httpx.ASGITransport(app=cls.app) + + async def _request(self, method: str, path: str, **kwargs) -> httpx.Response: + async with httpx.AsyncClient(transport=self.transport, base_url="http://test") as c: + return await c.request(method, path, **kwargs) + + def test_health(self): + resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/health")) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json(), {"status": "ok"}) + + def test_models_list(self): + resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/v1/models")) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["object"], "list") + self.assertEqual(len(data["data"]), 1) + + def test_request_id_generated(self): + resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/health")) + self.assertIn("x-request-id", resp.headers) + self.assertEqual(len(resp.headers["x-request-id"]), 36) # UUID length + + def test_request_id_passthrough(self): + resp = asyncio.get_event_loop().run_until_complete( + self._request("GET", "/health", headers={"x-request-id": "my-id"}) ) + self.assertEqual(resp.headers["x-request-id"], "my-id") + + +@slow +@require_serve +class TestChatCompletion(unittest.TestCase): + """Integration tests for /v1/chat/completions with a real model.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + @classmethod + def setUpClass(cls): + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") @classmethod def tearDownClass(cls): - cls.server.kill_server() + cls.serve.kill_server() - def test_full_request(self): - """Tests that an inference using the Responses API and Continuous Batching works""" + def test_non_streaming(self): + resp = self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}] + ) + self.assertIsNotNone(resp.choices[0].message.content) + self.assertIn(resp.choices[0].finish_reason, ("stop", "length")) + + def test_streaming(self): + text = "" + for chunk in self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}], stream=True + ): + if chunk.choices[0].delta.content: + text += chunk.choices[0].delta.content + self.assertTrue(len(text) > 0) - request = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "messages": [ - {"role": "system", "content": "You are a sports assistant designed to craft sports programs."}, - {"role": "user", "content": "Tell me what you can do."}, + def test_early_return_due_to_length(self): + """When max_tokens is hit, finish_reason should be 'length'.""" + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Hello, how are you?"}], + stream=True, + max_tokens=3, + ) + ) + last = chunks[-1] + self.assertEqual(last.choices[0].finish_reason, "length") + + def test_continues_until_stop(self): + """When model stops naturally, finish_reason should be 'stop'.""" + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": 'Please only answer with "Hi."'}], + stream=True, + max_tokens=30, + ) + ) + last = chunks[-1] + self.assertEqual(last.choices[0].finish_reason, "stop") + + def test_stop_strings(self): + resp = self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Count to 10"}], stop=["5"] + ) + self.assertNotIn("6", resp.choices[0].message.content) + + def test_multi_turn(self): + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[ + {"role": "user", "content": "My name is Alice"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "What is my name?"}, ], - "stream": True, - "max_tokens": 30, - } - all_payloads = self.run_server(request) + ) + self.assertIn("Alice", resp.choices[0].message.content) - full_text = "" - for token in all_payloads: - if isinstance(token, ChatCompletionStreamOutput) and token.choices and len(token.choices) > 0: - content = token.choices[0].delta.get("content", "") - full_text += content if content is not None else "" + def test_multiple_models_on_demand(self): + """Load two different models via separate requests — both should work.""" + model_a = "Qwen/Qwen2.5-0.5B-Instruct" + model_b = "HuggingFaceTB/SmolLM2-135M-Instruct" + prompt = [{"role": "user", "content": "Say hello"}] - # Verify that the system prompt went through. - self.assertTrue( - full_text.startswith( - "I can assist you with a wide range of tasks, from answering questions to providing information on various sports topics." + resp_a = self.client.chat.completions.create(model=model_a, messages=prompt) + self.assertIn(model_a, resp_a.model) + self.assertIsNotNone(resp_a.choices[0].message.content) + + resp_b = self.client.chat.completions.create(model=model_b, messages=prompt) + self.assertIn(model_b, resp_b.model) + self.assertIsNotNone(resp_b.choices[0].message.content) + + def test_non_streaming_usage(self): + resp = self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}] + ) + self.assertIsNotNone(resp.usage) + self.assertGreater(resp.usage.prompt_tokens, 0) + self.assertGreater(resp.usage.completion_tokens, 0) + self.assertEqual(resp.usage.total_tokens, resp.usage.prompt_tokens + resp.usage.completion_tokens) + + def test_streaming_usage(self): + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hello"}], + stream=True, ) ) + # Last chunk should have usage + last = chunks[-1] + self.assertIsNotNone(last.usage) + self.assertGreater(last.usage.prompt_tokens, 0) + self.assertGreater(last.usage.completion_tokens, 0) + self.assertEqual(last.usage.total_tokens, last.usage.prompt_tokens + last.usage.completion_tokens) - def test_max_tokens_not_set_in_req(self): - request = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "messages": [ - {"role": "system", "content": "You are a sports assistant designed to craft sports programs."}, - {"role": "user", "content": "Tell me what you can do."}, - ], - "stream": True, + def test_tool_call(self): + """Tool calls should be parsed and emitted as ChoiceDeltaToolCall objects.""" + # Qwen2.5-0.5B-Instruct supports tools (Qwen family) + tool_def = { + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + "description": "Get the weather for a city.", + }, + "type": "function", } - all_payloads = self.run_server(request) + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "What is the weather in Paris?"}], + stream=True, + max_tokens=50, + temperature=0.0, + tools=[tool_def], + ) + ) - full_text = "" - for token in all_payloads: - if isinstance(token, ChatCompletionStreamOutput) and token.choices and len(token.choices) > 0: - content = token.choices[0].delta.get("content", "") - full_text += content if content is not None else "" + # First chunk should have role="assistant" + self.assertEqual(chunks[0].choices[0].delta.role, "assistant") - # Verify that the system prompt went through. - self.assertTrue( - full_text.startswith( - "I can assist you with a wide range of tasks, from answering questions to providing information on various sports topics." + # Model should make a tool call for this prompt + tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls] + self.assertGreater(len(tool_chunks), 0, "Model did not produce a tool call") + + # First tool call delta should have the function name + first_tool = tool_chunks[0].choices[0].delta.tool_calls[0] + self.assertEqual(first_tool.function.name, "get_weather") + + # finish_reason should be "tool_calls" + last = chunks[-1] + self.assertEqual(last.choices[0].finish_reason, "tool_calls") + + # Arguments should be valid JSON with no trailing brace + args_json = first_tool.function.arguments + import json as json_mod + + parsed_args = json_mod.loads(args_json) + self.assertIsInstance(parsed_args, dict) + + def test_tool_call_non_streaming(self): + """Non-streaming tool calls should return tool_calls in the message.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "What is the weather in Paris?"}], + stream=False, + max_tokens=50, + temperature=0.0, + tools=[tool_def], + ) + self.assertEqual(resp.choices[0].finish_reason, "tool_calls") + self.assertIsNotNone(resp.choices[0].message.tool_calls) + tc = resp.choices[0].message.tool_calls[0] + self.assertEqual(tc.function.name, "get_weather") + + import json as json_mod + + parsed_args = json_mod.loads(tc.function.arguments) + self.assertIsInstance(parsed_args, dict) + + def test_tool_call_multi(self): + """Model should be able to call multiple tools when asked.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + # Ask for two cities to encourage multiple tool calls + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "What is the weather in Paris and London?"}], + stream=True, + max_tokens=100, + temperature=0.0, + tools=[tool_def], ) ) + tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls] + # Should have two tool calls — one for Paris, one for London + self.assertEqual(len(tool_chunks), 2, f"Expected 2 tool calls, got {len(tool_chunks)}") + cities = {tc.choices[0].delta.tool_calls[0].function.name for tc in tool_chunks} + self.assertEqual(cities, {"get_weather"}) + last = chunks[-1] + self.assertEqual(last.choices[0].finish_reason, "tool_calls") + + def test_concurrent_non_streaming(self): + """Two concurrent non-streaming requests should both complete without interference.""" + import concurrent.futures - def test_request_cancellation(self): - """Tests that a request can be cancelled.""" + prompts = [ + [{"role": "user", "content": "Say hello"}], + [{"role": "user", "content": "Say goodbye"}], + ] + results = [None, None] - base_url = f"http://127.0.0.1:{self.port}" - request_id = "test-cancel" + def request_in_thread(index): + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + results[index] = client.chat.completions.create(model=self.MODEL, messages=prompts[index]) - # Ensure the server is up before sending a request - response = _call_healthcheck(base_url) - self.assertIsNotNone(response, "Failed to connect to the server health endpoint.") - self.assertEqual(response.status_code, 200) + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(request_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() # re-raise exceptions + + for i in range(2): + self.assertIsNotNone(results[i]) + self.assertIsNotNone(results[i].choices[0].message.content) + self.assertTrue(len(results[i].choices[0].message.content) > 0) - _open_stream_and_cancel(base_url, request_id) + def test_concurrent_streaming(self): + """Two concurrent streaming requests should both produce complete, non-empty output.""" + import concurrent.futures - scheduler = _get_scheduler(self.server) + prompts = [ + [{"role": "user", "content": "Say hello"}], + [{"role": "user", "content": "Say goodbye"}], + ] + results = [None, None] - # Because cancellation is non-blocking, poll for a short, bounded time. - deadline = time.time() + 8.0 # generous but still CI-friendly - last_seen = None - while time.time() < deadline: - is_cancelled = scheduler.request_is_cancelled(request_id) - if is_cancelled: - break - last_seen = time.time() - time.sleep(0.1) # don't spin the CPU + def stream_in_thread(index): + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + text = "" + for chunk in client.chat.completions.create(model=self.MODEL, messages=prompts[index], stream=True): + if chunk.choices[0].delta.content: + text += chunk.choices[0].delta.content + results[index] = text - is_cancelled = scheduler.request_is_cancelled(request_id) - self.assertTrue( - is_cancelled, - f"Request {request_id} still present in scheduler after cancellation " - f"(last seen at {last_seen}). Check cancellation propagation.", + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(stream_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + self.assertIsNotNone(results[i]) + self.assertTrue(len(results[i]) > 0, f"Request {i} produced empty output") + + def test_request_cancellation(self): + """Closing a stream early doesn't crash and the server stays healthy.""" + + with httpx.stream( + "POST", + f"{self.base_url}/v1/chat/completions", + json={ + "model": self.MODEL, + "stream": True, + "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}], + "max_tokens": 500, + }, + timeout=30, + ) as resp: + self.assertEqual(resp.status_code, 200) + chunks_read = 0 + for _ in resp.iter_lines(): + chunks_read += 1 + if chunks_read >= 3: + break + + # Server should still be healthy and serve subsequent requests + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hi"}], + max_tokens=10, + ) + self.assertIsNotNone(resp.choices[0].message.content) + + +@require_serve +class TestResponseInputConversion(unittest.TestCase): + def _make_handler(self): + return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) + + def test_string_input(self): + handler = self._make_handler() + msgs = handler._input_to_messages({"input": "Hello"}) + self.assertEqual(msgs, [{"role": "user", "content": "Hello"}]) + + def test_string_input_with_instructions(self): + handler = self._make_handler() + msgs = handler._input_to_messages({"input": "Hello", "instructions": "Be brief"}) + self.assertEqual(len(msgs), 2) + self.assertEqual(msgs[0], {"role": "system", "content": "Be brief"}) + self.assertEqual(msgs[1], {"role": "user", "content": "Hello"}) + + def test_list_input(self): + handler = self._make_handler() + msgs = handler._input_to_messages( + {"input": [{"role": "user", "content": "A"}, {"role": "assistant", "content": "B"}]} ) + self.assertEqual(len(msgs), 2) + self.assertEqual(msgs[0]["content"], "A") + + def test_list_input_with_instructions_prepends_system(self): + handler = self._make_handler() + msgs = handler._input_to_messages({"input": [{"role": "user", "content": "Hi"}], "instructions": "Be helpful"}) + self.assertEqual(len(msgs), 2) + self.assertEqual(msgs[0]["role"], "system") + self.assertEqual(msgs[0]["content"], "Be helpful") + + def test_list_input_with_instructions_replaces_existing_system(self): + handler = self._make_handler() + msgs = handler._input_to_messages( + {"input": [{"role": "system", "content": "Old"}, {"role": "user", "content": "Hi"}], "instructions": "New"} + ) + self.assertEqual(len(msgs), 2) + self.assertEqual(msgs[0]["content"], "New") + def test_dict_input(self): + handler = self._make_handler() + msgs = handler._input_to_messages({"input": {"role": "user", "content": "Test"}}) + self.assertEqual(msgs, [{"role": "user", "content": "Test"}]) -@require_openai -class ServeResponsesMixin: - """ - Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API - (`generate` and `continuous_batching`). - """ - @retry - def run_server(self, request): - client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="") - stream = client.responses.create(**request) +@require_serve +class TestResponseValidation(unittest.TestCase): + def _make_handler(self): + return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) - all_payloads = [] - for payload in stream: - all_payloads.append(payload) + def test_unsupported_fields_warns(self): + handler = self._make_handler() + with self.assertLogs("transformers", level="WARNING") as cm: + handler._validate_request({"model": "x", "input": "hi", "previous_response_id": "abc"}) + self.assertTrue(any("previous_response_id" in msg for msg in cm.output)) - return all_payloads + def test_valid_request_passes(self): + handler = self._make_handler() + # Should not raise + handler._validate_request({"model": "x", "input": "hi"}) - def test_request(self): - """Tests that an inference using the Responses API works""" - request = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "instructions": "You are a helpful assistant.", - "input": "Hello!", - "stream": True, - "max_output_tokens": 1, - } - all_payloads = self.run_server(request) +@require_serve +class TestResponseGenerationConfig(unittest.TestCase): + def _make_handler(self): + return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) + + def test_max_output_tokens(self): + from transformers import GenerationConfig + + result = self._make_handler()._build_generation_config({"max_output_tokens": 42}, GenerationConfig()) + self.assertEqual(result.max_new_tokens, 42) + + def test_default_bumps_short_max_new_tokens(self): + from transformers import GenerationConfig - # Allow variable number of delta events depending on tokenizer/streamer behavior - self.assertGreaterEqual(len(all_payloads), 8) + result = self._make_handler()._build_generation_config({}, GenerationConfig(max_new_tokens=20)) + self.assertEqual(result.max_new_tokens, 1024) - # Start markers - self.assertIsInstance(all_payloads[0], ResponseCreatedEvent) - self.assertIsInstance(all_payloads[1], ResponseInProgressEvent) - self.assertIsInstance(all_payloads[2], ResponseOutputItemAddedEvent) - self.assertIsInstance(all_payloads[3], ResponseContentPartAddedEvent) - # At least one delta event during streaming - self.assertTrue(any(isinstance(p, ResponseTextDeltaEvent) for p in all_payloads[4:-4])) +@require_serve +class TestResponseUsage(unittest.TestCase): + def testcompute_usage(self): + usage = compute_usage(input_tokens=100, output_tokens=50) + self.assertEqual(usage.input_tokens, 100) + self.assertEqual(usage.output_tokens, 50) + self.assertEqual(usage.total_tokens, 150) + self.assertEqual(usage.input_tokens_details.cached_tokens, 0) + self.assertEqual(usage.output_tokens_details.reasoning_tokens, 0) - # Closing markers - self.assertIsInstance(all_payloads[-4], ResponseTextDoneEvent) - self.assertIsInstance(all_payloads[-3], ResponseContentPartDoneEvent) - self.assertIsInstance(all_payloads[-2], ResponseOutputItemDoneEvent) - self.assertIsInstance(all_payloads[-1], ResponseCompletedEvent) + def test_usage_in_completed_response(self): + """Usage should serialize correctly inside a Response.""" + + usage = compute_usage(10, 5) + response = Response( + id="resp_test", + created_at=0, + status="completed", + model="test", + output=[], + object="response", + tools=[], + parallel_tool_calls=False, + tool_choice="auto", + usage=usage, + ) + dumped = response.model_dump(exclude_none=True) + self.assertEqual(dumped["usage"]["input_tokens"], 10) + self.assertEqual(dumped["usage"]["output_tokens"], 5) + self.assertEqual(dumped["usage"]["total_tokens"], 15) + + +@require_serve +class TestResponseSSEFormat(unittest.TestCase): + def test_sse_format(self): + event = ResponseCreatedEvent( + type="response.created", + sequence_number=0, + response=Response( + id="resp_test", + created_at=0, + status="queued", + model="test", + text={"format": {"type": "text"}}, + object="response", + tools=[], + output=[], + parallel_tool_calls=False, + tool_choice="auto", + ), + ) + result = BaseHandler.chunk_to_sse(event) + self.assertTrue(result.startswith("data: ")) + self.assertTrue(result.endswith("\n\n")) + parsed = json.loads(result[len("data: ") :].strip()) + self.assertEqual(parsed["type"], "response.created") + self.assertEqual(parsed["response"]["status"], "queued") - # TODO: one test for each request flag, to confirm it is working as expected - # TODO: speed-based test to confirm that KV cache is working across requests +@slow +@require_serve +class TestResponsesIntegration(unittest.TestCase): + """Integration tests for /v1/responses with a real model.""" -@slow # server startup time is slow on our push CI -@require_openai -class ServeResponsesIntegrationTest(ServeResponsesMixin, unittest.TestCase): - """Tests the Responses API.""" + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" @classmethod def setUpClass(cls): - """Starts a server for tests to connect to.""" - cls.port = 8003 - cls.server = Serve(port=cls.port, default_seed=42, non_blocking=True) + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") @classmethod def tearDownClass(cls): - cls.server.kill_server() - - @slow - def test_full_request(self): - """Tests that an inference using the Responses API works""" - - request = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "instructions": "You are a sports assistant designed to craft sports programs.", - "input": "Tell me what you can do.", - "stream": True, - "max_output_tokens": 30, - # Disable sampling for deterministic output - "temperature": 0, + cls.serve.kill_server() + + def test_streaming(self): + events = list( + self.client.responses.create( + model=self.MODEL, + input="Say hello", + stream=True, + max_output_tokens=1, + ) + ) + # At least 8 events: created, in_progress, output_item_added, content_part_added, + # delta(s), text_done, content_part_done, output_item_done, completed + self.assertGreaterEqual(len(events), 8) + + # Start markers (fixed order) + self.assertEqual(events[0].type, "response.created") + self.assertEqual(events[1].type, "response.in_progress") + self.assertEqual(events[2].type, "response.output_item.added") + self.assertEqual(events[3].type, "response.content_part.added") + + # At least one delta + self.assertTrue(any(e.type == "response.output_text.delta" for e in events[4:-4])) + + # Closing markers (fixed order from the end) + self.assertEqual(events[-4].type, "response.output_text.done") + self.assertEqual(events[-3].type, "response.content_part.done") + self.assertEqual(events[-2].type, "response.output_item.done") + self.assertEqual(events[-1].type, "response.completed") + + def test_non_streaming(self): + resp = self.client.responses.create( + model=self.MODEL, + input="Say hello", + stream=False, + ) + self.assertEqual(resp.status, "completed") + self.assertTrue(len(resp.output) > 0) + self.assertTrue(len(resp.output[0].content[0].text) > 0) + + def test_non_streaming_usage(self): + resp = self.client.responses.create( + model=self.MODEL, + input="Say hello", + stream=False, + ) + self.assertIsNotNone(resp.usage) + self.assertGreater(resp.usage.input_tokens, 0) + self.assertGreater(resp.usage.output_tokens, 0) + self.assertEqual(resp.usage.total_tokens, resp.usage.input_tokens + resp.usage.output_tokens) + + def test_streaming_usage(self): + events = list( + self.client.responses.create( + model=self.MODEL, + input="Say hello", + stream=True, + max_output_tokens=5, + ) + ) + completed = events[-1] + self.assertEqual(completed.type, "response.completed") + usage = completed.response.usage + self.assertIsNotNone(usage) + self.assertGreater(usage.input_tokens, 0) + self.assertGreater(usage.output_tokens, 0) + self.assertEqual(usage.total_tokens, usage.input_tokens + usage.output_tokens) + + def test_tool_call_streaming(self): + """Streaming responses with tools should emit function_call events.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + events = list( + self.client.responses.create( + model=self.MODEL, + input="What is the weather in Paris?", + stream=True, + max_output_tokens=50, + tools=[tool_def], + ) + ) + types = [e.type for e in events] + self.assertIn("response.created", types) + self.assertIn("response.completed", types) + + # Should have function call events + self.assertIn("response.output_item.added", types) + self.assertIn("response.function_call_arguments.done", types) + + # Check the arguments done event + args_done = [e for e in events if e.type == "response.function_call_arguments.done"] + self.assertGreater(len(args_done), 0) + self.assertEqual(args_done[0].name, "get_weather") + + import json as json_mod + + parsed = json_mod.loads(args_done[0].arguments) + self.assertIsInstance(parsed, dict) + + def test_tool_call_non_streaming(self): + """Non-streaming responses with tools should include function_call output items.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", } - all_payloads = self.run_server(request) - - full_text = "" - for token in all_payloads: - if isinstance(token, ResponseTextDeltaEvent): - full_text += token.delta - - # Verify that the system prompt went through. - # With deterministic decoding, exact wording can still vary across versions. - # Assert non-empty output and that it references sports. - self.assertTrue(len(full_text) > 0) - self.assertIn("sports", full_text.lower()) - - @slow - def test_non_streaming_request(self): - """Tests that an inference using the Responses API with stream=False returns a single Response payload.""" - from openai import OpenAI - from openai.types.responses import Response as OpenAIResponse - - client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="") - resp = client.responses.create( - model="Qwen/Qwen2.5-0.5B-Instruct", - instructions="You are a helpful assistant.", - input="Hello!", + resp = self.client.responses.create( + model=self.MODEL, + input="What is the weather in Paris?", stream=False, - max_output_tokens=5, + max_output_tokens=50, + tools=[tool_def], ) + self.assertEqual(resp.status, "completed") + + # Should have at least message + function_call in output + self.assertGreater(len(resp.output), 1) + fc_items = [o for o in resp.output if o.type == "function_call"] + self.assertGreater(len(fc_items), 0) + self.assertEqual(fc_items[0].name, "get_weather") - # Should be a single Response object with completed status and one output item containing text - self.assertIsInstance(resp, OpenAIResponse) + import json as json_mod + + parsed = json_mod.loads(fc_items[0].arguments) + self.assertIsInstance(parsed, dict) + + def test_tool_call_multi(self): + """Model should produce multiple tool calls when asked about two cities.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + events = list( + self.client.responses.create( + model=self.MODEL, + input="What is the weather in Paris and London?", + stream=True, + max_output_tokens=100, + tools=[tool_def], + ) + ) + args_done = [e for e in events if e.type == "response.function_call_arguments.done"] + self.assertEqual(len(args_done), 2, f"Expected 2 tool calls, got {len(args_done)}") + self.assertEqual(events[-1].type, "response.completed") + + def test_multi_turn(self): + """Multi-turn conversation via list input.""" + resp = self.client.responses.create( + model=self.MODEL, + input=[ + {"role": "user", "content": "My name is Alice"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "What is my name?"}, + ], + stream=False, + ) self.assertEqual(resp.status, "completed") - self.assertTrue(len(resp.output) >= 1) - first_item = resp.output[0] - self.assertEqual(first_item.type, "message") - self.assertEqual(first_item.status, "completed") - self.assertTrue(len(first_item.content) >= 1) - first_part = first_item.content[0] - self.assertEqual(first_part.type, "output_text") - self.assertIsInstance(first_part.text, str) + self.assertIn("Alice", resp.output[0].content[0].text) + def test_concurrent_non_streaming(self): + """Two concurrent non-streaming responses requests should both complete.""" + import concurrent.futures -class ServeInfrastructureTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.port = 8042 - thread = Thread(target=Serve, kwargs={"port": cls.port}) - thread.daemon = True - thread.start() - - def test_healthcheck(self): - """Tests that the healthcheck endpoint works.""" - response = _call_healthcheck(f"http://localhost:{self.port}") - self.assertIsNotNone(response, "Failed to connect to the server health endpoint.") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json(), {"status": "ok"}) + inputs = ["Say hello", "Say goodbye"] + results = [None, None] + def request_in_thread(index): + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False) -def parse_sse_events(response): + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(request_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + self.assertIsNotNone(results[i]) + self.assertEqual(results[i].status, "completed") + self.assertTrue(len(results[i].output[0].content[0].text) > 0) + + def test_concurrent_streaming(self): + """Two concurrent streaming responses requests should both produce complete event streams.""" + import concurrent.futures + + inputs = ["Say hello", "Say goodbye"] + results = [None, None] + + def stream_in_thread(index): + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + results[index] = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True)) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(stream_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + types = [e.type for e in results[i]] + self.assertIn("response.created", types, f"Request {i} missing created event") + self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events") + self.assertIn("response.completed", types, f"Request {i} missing completed event") + + +def _parse_sse_events(response): """Parse SSE lines from a streaming httpx response into a list of dicts.""" events = [] for line in response.iter_lines(): - if not line: + if not line or not line.startswith("data: "): continue - if line.startswith("data: "): - events.append(json.loads(line[6:])) + events.append(json.loads(line[len("data: ") :])) return events @slow -@require_openai -class ServeLoadModelIntegrationTest(unittest.TestCase): - """Tests the /load_model SSE endpoint.""" +@require_serve +class TestLoadModel(unittest.TestCase): + """Integration tests for POST /load_model SSE endpoint.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" @classmethod def setUpClass(cls): - cls.port = 8043 - cls.server = Serve(port=cls.port, non_blocking=True) - cls.base_url = f"http://localhost:{cls.port}" - # Wait for the server to be ready - response = _call_healthcheck(cls.base_url) - assert response is not None and response.status_code == 200 + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" @classmethod def tearDownClass(cls): - cls.server.kill_server() + cls.serve.kill_server() def setUp(self): - # Clear the in-memory model cache so each test starts fresh - self.server.reset_loaded_models() + # Clear model cache so each test starts fresh + self.serve.reset_loaded_models() def _load_model(self, model: str): - with httpx.Client(timeout=120) as client: - with client.stream("POST", f"{self.base_url}/load_model", json={"model": model}) as response: - events = parse_sse_events(response) - return response, events + with httpx.stream("POST", f"{self.base_url}/load_model", json={"model": model}, timeout=120) as resp: + events = _parse_sse_events(resp) + return resp, events def test_load_model_fresh(self): - """POST /load_model with a valid model returns SSE events ending with ready.""" - response, events = self._load_model("Qwen/Qwen2.5-0.5B-Instruct") + """POST /load_model returns SSE events ending with ready.""" + response, events = self._load_model(self.MODEL) self.assertEqual(response.status_code, 200) - self.assertIn("text/event-stream", response.headers.get("content-type", "")) - # Extract stages from loading events stages = [e["stage"] for e in events if e["status"] == "loading" and "stage" in e] self.assertIn("processor", stages) - self.assertIn("config", stages) self.assertIn("weights", stages) - # Stages must appear in the correct order - stage_indices = {stage: i for i, stage in enumerate(stages) if stage in ("processor", "config", "weights")} - self.assertLess(stage_indices["processor"], stage_indices["config"]) - self.assertLess(stage_indices["config"], stage_indices["weights"]) - - # Last event is ready with cached: false last = events[-1] self.assertEqual(last["status"], "ready") self.assertFalse(last["cached"]) - # Every event has status and model for event in events: self.assertIn("status", event) self.assertIn("model", event) def test_load_model_cached(self): - """Loading a model that is already in memory returns a single ready event with cached: true.""" - # First load to ensure the model is in memory - self._load_model("Qwen/Qwen2.5-0.5B-Instruct") + """Loading an already-loaded model returns a single ready event with cached: true.""" + self._load_model(self.MODEL) - # Second load should be cached - _, events = self._load_model("Qwen/Qwen2.5-0.5B-Instruct") + _, events = self._load_model(self.MODEL) ready_events = [e for e in events if e["status"] == "ready"] self.assertEqual(len(ready_events), 1) self.assertTrue(ready_events[0]["cached"]) - # No loading events should be present loading_events = [e for e in events if e["status"] == "loading"] self.assertEqual(len(loading_events), 0) @@ -955,18 +1304,18 @@ def test_load_model_error(self): _, events = self._load_model("nonexistent/model-that-does-not-exist") error_events = [e for e in events if e["status"] == "error"] - self.assertGreaterEqual(len(error_events), 1, "Expected at least one error event") + self.assertGreaterEqual(len(error_events), 1) self.assertIn("message", error_events[0]) def test_load_model_missing_field(self): """POST /load_model with no model field returns 422.""" - with httpx.Client(timeout=30) as client: - response = client.post(f"{self.base_url}/load_model", json={}) - self.assertEqual(response.status_code, 422) + + response = httpx.post(f"{self.base_url}/load_model", json={}, timeout=30) + self.assertEqual(response.status_code, 422) def test_load_model_event_schema(self): - """Every event in a load_model stream conforms to the expected schema.""" - _, events = self._load_model("Qwen/Qwen2.5-0.5B-Instruct") + """Every event conforms to the expected schema.""" + _, events = self._load_model(self.MODEL) for event in events: self.assertIsInstance(event["status"], str) @@ -974,7 +1323,6 @@ def test_load_model_event_schema(self): if event["status"] == "loading": self.assertIn("stage", event) - if event["stage"] in ("download", "weights") and "progress" in event: progress = event["progress"] self.assertIn("current", progress) @@ -986,12 +1334,10 @@ def test_load_model_event_schema(self): self.assertIsInstance(event["cached"], bool) def test_load_model_stage_ordering(self): - """Stages in loading events follow the expected order.""" - _, events = self._load_model("Qwen/Qwen2.5-0.5B-Instruct") + """Stages appear in the expected order.""" + _, events = self._load_model(self.MODEL) stages = [e["stage"] for e in events if e["status"] == "loading" and "stage" in e] - - # Deduplicate while preserving order (stages repeat for progress ticks) seen = set() unique_stages = [] for s in stages: @@ -1000,29 +1346,23 @@ def test_load_model_stage_ordering(self): unique_stages.append(s) expected_order = ["processor", "config", "download", "weights"] - # Filter expected_order to only stages that are actually present expected_present = [s for s in expected_order if s in unique_stages] - self.assertEqual(unique_stages, expected_present, "Stages appeared out of order") def test_concurrent_load_same_model(self): - """Two concurrent /load_model requests for the same model should both receive progress events - and a final ready event, but the model should only be loaded once.""" + """Two concurrent /load_model requests both get events and a ready event.""" import concurrent.futures - model = "Qwen/Qwen2.5-0.5B-Instruct" results = [None, None] def load_in_thread(index): - with httpx.Client(timeout=120) as client: - with client.stream("POST", f"{self.base_url}/load_model", json={"model": model}) as response: - events = parse_sse_events(response) - results[index] = (response.status_code, events) + with httpx.stream("POST", f"{self.base_url}/load_model", json={"model": self.MODEL}, timeout=120) as resp: + events = _parse_sse_events(resp) + results[index] = (resp.status_code, events) with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: futures = [pool.submit(load_in_thread, i) for i in range(2)] concurrent.futures.wait(futures) - # Re-raise any exceptions from threads for f in futures: f.result() @@ -1030,28 +1370,496 @@ def load_in_thread(index): status_code, events = results[i] self.assertEqual(status_code, 200, f"Caller {i} got non-200 status") self.assertTrue(len(events) > 0, f"Caller {i} received no events") - ready_events = [e for e in events if e["status"] == "ready"] self.assertEqual(len(ready_events), 1, f"Caller {i} should get exactly one ready event") - self.assertIn("model", ready_events[0]) - - def test_concurrent_load_second_caller_gets_cached_if_first_finishes(self): - """If the first /load_model finishes before the second arrives, - the second caller should get a cached response.""" - model = "Qwen/Qwen2.5-0.5B-Instruct" - # First load — blocks until complete - _, events1 = self._load_model(model) + def test_concurrent_load_second_caller_gets_cached(self): + """If the first /load_model finishes before the second, the second gets cached: true.""" + _, events1 = self._load_model(self.MODEL) ready1 = [e for e in events1 if e["status"] == "ready"] self.assertEqual(len(ready1), 1) self.assertFalse(ready1[0]["cached"]) - # Second load — model is now in memory - _, events2 = self._load_model(model) + _, events2 = self._load_model(self.MODEL) ready2 = [e for e in events2 if e["status"] == "ready"] self.assertEqual(len(ready2), 1) self.assertTrue(ready2[0]["cached"]) - # No loading events on the cached path loading2 = [e for e in events2 if e["status"] == "loading"] self.assertEqual(len(loading2), 0) + + def test_load_model_weights_progress_complete(self): + """Weights progress should go from 1 to total, with total matching across events.""" + _, events = self._load_model(self.MODEL) + + weights_events = [e for e in events if e.get("stage") == "weights" and "progress" in e] + self.assertGreater(len(weights_events), 0, "No weights progress events emitted") + + # All events should have the same total + totals = {e["progress"]["total"] for e in weights_events} + self.assertEqual(len(totals), 1, f"Inconsistent totals: {totals}") + total = totals.pop() + self.assertIsNotNone(total) + self.assertGreater(total, 0) + + # First should be 1, last should be total + self.assertEqual(weights_events[0]["progress"]["current"], 1) + self.assertEqual(weights_events[-1]["progress"]["current"], total) + + # Progress should be monotonically increasing + currents = [e["progress"]["current"] for e in weights_events] + self.assertEqual(currents, sorted(currents)) + + def test_load_model_exactly_one_ready(self): + """A fresh load should produce exactly one ready event as the last event.""" + _, events = self._load_model(self.MODEL) + + ready_events = [e for e in events if e["status"] == "ready"] + self.assertEqual(len(ready_events), 1) + self.assertEqual(events[-1]["status"], "ready") + + def test_load_model_usable_after_load(self): + """After /load_model completes, the model should be usable for inference.""" + self._load_model(self.MODEL) + + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + resp = client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hi"}], + max_tokens=5, + ) + self.assertIsNotNone(resp.choices[0].message.content) + self.assertTrue(len(resp.choices[0].message.content) > 0) + + def test_load_model_model_field_matches(self): + """The model field in every event should match the canonical model ID.""" + _, events = self._load_model(self.MODEL) + + for event in events: + self.assertTrue( + event["model"].startswith(self.MODEL), + f"Event model '{event['model']}' doesn't match '{self.MODEL}'", + ) + + def test_concurrent_non_streaming(self): + """Two concurrent non-streaming responses requests should both complete.""" + import concurrent.futures + + inputs = ["Say hello", "Say goodbye"] + results = [None, None] + + def request_in_thread(index): + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(request_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + self.assertIsNotNone(results[i]) + self.assertEqual(results[i].status, "completed") + self.assertTrue(len(results[i].output[0].content[0].text) > 0) + + def test_concurrent_streaming(self): + """Two concurrent streaming responses requests should both produce complete event streams.""" + import concurrent.futures + + inputs = ["Say hello", "Say goodbye"] + results = [None, None] + + def stream_in_thread(index): + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + events = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True)) + results[index] = events + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(stream_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + types = [e.type for e in results[i]] + self.assertIn("response.created", types, f"Request {i} missing created event") + self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events") + self.assertIn("response.completed", types, f"Request {i} missing completed event") + + +# Real image URL for VLM tests (person + dog on a beach) +_DOG_IMAGE_URL = "https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2-VL/demo_small.jpg" + + +@slow +@require_vision +@require_serve +class TestVLM(unittest.TestCase): + """Integration tests for VLM (vision-language model) support. Requires torchvision.""" + + MODEL = "HuggingFaceTB/SmolVLM-256M-Instruct" + + @classmethod + def setUpClass(cls): + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + def test_chat_completion_with_image(self): + """Chat completions should accept image_url content and produce a meaningful response.""" + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + {"type": "image_url", "image_url": {"url": _DOG_IMAGE_URL}}, + ], + } + ], + max_tokens=50, + ) + text = resp.choices[0].message.content + self.assertIsNotNone(text) + self.assertTrue( + any(word in text.lower() for word in ["dog", "beach", "person"]), + f"Expected dog/beach/person in response, got: {text}", + ) + + def test_responses_with_image(self): + """Responses API should accept image_url content and produce a meaningful response.""" + resp = self.client.responses.create( + model=self.MODEL, + input=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + {"type": "image_url", "image_url": {"url": _DOG_IMAGE_URL}}, + ], + } + ], + stream=False, + max_output_tokens=50, + ) + self.assertEqual(resp.status, "completed") + text = resp.output[0].content[0].text + self.assertTrue( + any(word in text.lower() for word in ["dog", "beach", "person"]), + f"Expected dog/beach/person in response, got: {text}", + ) + + +@slow +@require_serve +class TestTranscription(unittest.TestCase): + """Integration tests for POST /v1/audio/transcriptions with whisper-tiny.""" + + MODEL = "openai/whisper-tiny" + + @classmethod + def setUpClass(cls): + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + @classmethod + def _get_audio_bytes(cls): + """Download the MLK 'I have a dream' speech sample from HF Hub.""" + if not hasattr(cls, "_audio_bytes"): + from huggingface_hub import hf_hub_download + + path = hf_hub_download("Narsil/asr_dummy", "mlk.flac", repo_type="dataset") + with open(path, "rb") as f: + cls._audio_bytes = f.read() + return cls._audio_bytes + + def test_transcription_returns_text(self): + """POST /v1/audio/transcriptions with real speech returns meaningful transcription.""" + + audio_bytes = self._get_audio_bytes() + resp = httpx.post( + f"{self.base_url}/v1/audio/transcriptions", + files={"file": ("mlk.flac", audio_bytes, "audio/flac")}, + data={"model": self.MODEL}, + timeout=120, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertIn("text", data) + self.assertIsInstance(data["text"], str) + # Whisper-tiny should recognize at least "dream" from the MLK speech + self.assertIn("dream", data["text"].lower()) + + def test_transcription_openai_client(self): + """Transcription should work via the OpenAI Python client.""" + audio_bytes = self._get_audio_bytes() + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + result = client.audio.transcriptions.create( + model=self.MODEL, + file=("mlk.flac", audio_bytes), + ) + self.assertIsInstance(result.text, str) + self.assertTrue(len(result.text) > 10) + + def test_transcription_streaming(self): + """Streaming transcription should yield text chunks via SSE.""" + + audio_bytes = self._get_audio_bytes() + with httpx.stream( + "POST", + f"{self.base_url}/v1/audio/transcriptions", + files={"file": ("mlk.flac", audio_bytes, "audio/flac")}, + data={"model": self.MODEL, "stream": "true"}, + timeout=120, + ) as resp: + self.assertEqual(resp.status_code, 200) + + chunks = [] + for line in resp.iter_lines(): + if line and line.startswith("data: "): + chunks.append(line[len("data: ") :]) + + self.assertGreater(len(chunks), 0, "No streaming chunks received") + full_text = "".join(chunks) + self.assertIn("dream", full_text.lower()) + + def test_transcription_missing_file(self): + """POST without a file should fail.""" + + resp = httpx.post( + f"{self.base_url}/v1/audio/transcriptions", + data={"model": self.MODEL}, + timeout=30, + ) + self.assertNotEqual(resp.status_code, 200) + + +@slow +@require_serve +@require_torch_accelerator +class TestContinuousBatchingChatCompletion(unittest.TestCase): + """Integration tests for /v1/chat/completions with continuous batching.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + @classmethod + def setUpClass(cls): + cls.serve, port = _start_serve( + force_model=cls.MODEL, + device="cuda:0", + continuous_batching=True, + attn_implementation="sdpa", + default_seed=42, + ) + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + def test_streaming(self): + """Streaming chat completion with CB produces text.""" + text = "" + for chunk in self.client.chat.completions.create( + model=self.MODEL, + messages=[ + {"role": "system", "content": "You are a sports assistant designed to craft sports programs."}, + {"role": "user", "content": "Tell me what you can do."}, + ], + stream=True, + max_tokens=30, + ): + if chunk.choices[0].delta.content: + text += chunk.choices[0].delta.content + self.assertTrue(len(text) > 0) + + def test_non_streaming(self): + """Non-streaming chat completion with CB returns a full response.""" + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hello"}], + max_tokens=20, + ) + self.assertIsNotNone(resp.choices[0].message.content) + self.assertTrue(len(resp.choices[0].message.content) > 0) + + def test_multi_turn(self): + """Multi-turn conversation works with CB.""" + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[ + {"role": "user", "content": "My name is Alice"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "What is my name?"}, + ], + max_tokens=20, + ) + self.assertIn("Alice", resp.choices[0].message.content) + + def test_request_cancellation(self): + """Opening a stream and closing it early triggers CB cancellation.""" + + request_id = "test-cb-cancel" + + # Open a streaming request and close after a few chunks + with httpx.stream( + "POST", + f"{self.base_url}/v1/chat/completions", + headers={"X-Request-ID": request_id}, + json={ + "model": self.MODEL, + "stream": True, + "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}], + }, + timeout=30, + ) as resp: + self.assertEqual(resp.status_code, 200) + chunks_read = 0 + for _ in resp.iter_lines(): + chunks_read += 1 + if chunks_read >= 3: + break + + # Poll for cancellation in the CB scheduler + scheduler = self.serve._generation_state._cb_manager.scheduler + deadline = time.time() + 8.0 + while time.time() < deadline: + if scheduler.request_is_cancelled(request_id): + break + time.sleep(0.1) + + self.assertTrue( + scheduler.request_is_cancelled(request_id), + f"Request {request_id} not cancelled in scheduler after stream close.", + ) + + # Server should still be healthy and serve subsequent requests + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hi"}], + max_tokens=10, + ) + self.assertIsNotNone(resp.choices[0].message.content) + + +@slow +@require_serve +@require_torch_accelerator +class TestContinuousBatchingResponses(unittest.TestCase): + """Integration tests for /v1/responses with continuous batching.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + @classmethod + def setUpClass(cls): + cls.serve, port = _start_serve( + force_model=cls.MODEL, + device="cuda:0", + continuous_batching=True, + attn_implementation="sdpa", + default_seed=42, + ) + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + def test_streaming(self): + """Streaming response with CB produces text.""" + text = "" + stream = self.client.responses.create( + model=self.MODEL, + input="Say hello in one sentence.", + stream=True, + max_output_tokens=30, + ) + for event in stream: + if event.type == "response.output_text.delta": + text += event.delta + self.assertTrue(len(text) > 0) + + def test_non_streaming(self): + """Non-streaming response with CB returns text.""" + resp = self.client.responses.create( + model=self.MODEL, + input="Say hello in one sentence.", + stream=False, + max_output_tokens=30, + ) + content = resp.output[0].content[0].text + self.assertTrue(len(content) > 0) + + def test_multi_turn(self): + """Multi-turn conversation works with CB via Responses API.""" + resp = self.client.responses.create( + model=self.MODEL, + input=[ + {"role": "user", "content": "My name is Alice"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "What is my name?"}, + ], + stream=False, + max_output_tokens=20, + ) + content = resp.output[0].content[0].text + self.assertIn("Alice", content) + + def test_request_cancellation(self): + """Opening a stream and closing it early triggers CB cancellation.""" + + request_id = "test-cb-resp-cancel" + + with httpx.stream( + "POST", + f"{self.base_url}/v1/responses", + headers={"X-Request-ID": request_id}, + json={ + "model": self.MODEL, + "stream": True, + "input": "Count slowly so I can cancel you.", + "max_output_tokens": 500, + }, + timeout=30, + ) as resp: + self.assertEqual(resp.status_code, 200) + # Read enough data to ensure CB generation has started, then close. + received = b"" + for chunk in resp.iter_bytes(chunk_size=512): + received += chunk + if b"output_text.delta" in received: + break + + # Poll for cancellation in the CB scheduler + scheduler = self.serve._generation_state._cb_manager.scheduler + deadline = time.time() + 8.0 + while time.time() < deadline: + if scheduler.request_is_cancelled(request_id): + break + time.sleep(0.1) + + self.assertTrue( + scheduler.request_is_cancelled(request_id), + f"Request {request_id} not cancelled in scheduler after stream close.", + ) + + # Server should still serve subsequent requests + resp = self.client.responses.create( + model=self.MODEL, + input="Say hi", + stream=False, + max_output_tokens=10, + ) + self.assertTrue(len(resp.output[0].content[0].text) > 0) diff --git a/tests/cli/test_serve_refactored.py b/tests/cli/test_serve_refactored.py deleted file mode 100644 index 63503cbe0bf9..000000000000 --- a/tests/cli/test_serve_refactored.py +++ /dev/null @@ -1,1865 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Tests for the serving layer. - -""" - -import asyncio -import json -import os -import socket -import time -import unittest -from unittest.mock import MagicMock - -import httpx - -from transformers.cli.serve import Serve -from transformers.cli.serving.chat_completion import ChatCompletionHandler -from transformers.cli.serving.model_manager import ModelManager, TimedModel -from transformers.cli.serving.response import ResponseHandler, compute_usage -from transformers.cli.serving.server import build_server -from transformers.cli.serving.transcription import TranscriptionHandler -from transformers.cli.serving.utils import ( - BaseHandler, - GenerationState, - Modality, - ToolCallParser, - detect_tool_format, -) -from transformers.testing_utils import require_serve, require_torch_accelerator, require_vision, slow -from transformers.utils.import_utils import is_serve_available - - -if is_serve_available(): - from fastapi import HTTPException - from openai import OpenAI - from openai.types.responses import Response, ResponseCreatedEvent - - -def _find_free_port() -> int: - """Return a free TCP port on localhost.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("localhost", 0)) - return s.getsockname()[1] - - -def _start_serve(**kwargs) -> tuple["Serve", int]: - """Start a non-blocking Serve instance on a free port and wait until healthy. - - Returns ``(serve, port)``. - """ - port = _find_free_port() - serve = Serve(port=port, non_blocking=True, **kwargs) - for _ in range(30): - try: - if httpx.get(f"http://localhost:{port}/health", timeout=2).status_code == 200: - return serve, port - except Exception: # noqa: S110 - pass - time.sleep(1) - raise RuntimeError(f"Server on port {port} did not become healthy in time") - - -@require_serve -def test_host_port_blocking(cli): - """CLI args --host and --port are passed to uvicorn.Config, and server.run() is called.""" - from unittest.mock import Mock, patch - - with ( - patch("uvicorn.Config") as ConfigMock, - patch("uvicorn.Server") as ServerMock, - ): - server_instance = Mock() - ServerMock.return_value = server_instance - - out = cli("serve", "--host", "0.0.0.0", "--port", "9000") - _, kwargs = ConfigMock.call_args - - assert out.exit_code == 0 - assert kwargs["host"] == "0.0.0.0" - assert kwargs["port"] == 9000 - ServerMock.assert_called_once_with(ConfigMock.return_value) - server_instance.run.assert_called_once() - - -class TestProcessorInputsFromMessages(unittest.TestCase): - def test_llm_string_content(self): - get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - - messages = [{"role": "user", "content": "Hello"}] - result = get_processor_inputs_from_messages(messages, Modality.LLM) - self.assertEqual(result, [{"role": "user", "content": "Hello"}]) - - def test_llm_list_content_text_only(self): - get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - - messages = [{"role": "user", "content": [{"type": "text", "text": "A"}, {"type": "text", "text": "B"}]}] - result = get_processor_inputs_from_messages(messages, Modality.LLM) - self.assertEqual(result, [{"role": "user", "content": "A B"}]) - - def test_vlm_string_content_wrapped(self): - get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - - messages = [{"role": "user", "content": "Hello"}] - result = get_processor_inputs_from_messages(messages, Modality.VLM) - self.assertEqual(result, [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]) - - def test_vlm_text_and_image_url(self): - get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is this?"}, - {"type": "image_url", "image_url": {"url": "https://example.com/img.png"}}, - ], - } - ] - result = get_processor_inputs_from_messages(messages, Modality.VLM) - self.assertEqual(len(result[0]["content"]), 2) - self.assertEqual(result[0]["content"][0]["type"], "text") - self.assertEqual(result[0]["content"][1], {"type": "image", "url": "https://example.com/img.png"}) - - def test_llm_multi_turn_conversation(self): - """Multi-turn conversation with string content should pass through as-is.""" - - get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - - messages = [ - {"role": "user", "content": "How are you?"}, - {"role": "assistant", "content": "I'm great!"}, - {"role": "user", "content": "Help me write tests?"}, - ] - result = get_processor_inputs_from_messages(messages, Modality.LLM) - self.assertEqual(len(result), 3) - self.assertEqual(result[0]["content"], "How are you?") - self.assertEqual(result[1]["role"], "assistant") - self.assertEqual(result[2]["content"], "Help me write tests?") - - def test_llm_list_content_with_type(self): - """LLM messages with typed content list should extract text and join.""" - - get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - - messages = [ - {"role": "user", "content": [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]} - ] - result = get_processor_inputs_from_messages(messages, Modality.LLM) - self.assertEqual(result[0]["content"], "Hello world") - - @require_vision - def test_vlm_base64_image_creates_temp_file(self): - """Base64 image URLs should be decoded and saved to a temp file.""" - - get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - - # Minimal valid 1x1 PNG as base64 - base64_url = ( - "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4" - "2mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" - ) - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is this?"}, - {"type": "image_url", "image_url": {"url": base64_url}}, - ], - } - ] - result = get_processor_inputs_from_messages(messages, Modality.VLM) - image_item = result[0]["content"][1] - self.assertEqual(image_item["type"], "image") - self.assertTrue(os.path.exists(image_item["url"])) # temp file was created - - def test_vlm_multi_turn(self): - """VLM multi-turn: string content should be wrapped in text type.""" - - get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - - messages = [ - {"role": "user", "content": "Describe the image"}, - {"role": "assistant", "content": "It shows a cat"}, - {"role": "user", "content": "What color?"}, - ] - result = get_processor_inputs_from_messages(messages, Modality.VLM) - self.assertEqual(len(result), 3) - for msg in result: - self.assertIsInstance(msg["content"], list) - self.assertEqual(msg["content"][0]["type"], "text") - - -class TestGenerativeModelList(unittest.TestCase): - def test_lists_only_generative_models(self): - """Should list LLMs and VLMs but not non-generative models like BERT.""" - import tempfile - - from huggingface_hub import hf_hub_download - - with tempfile.TemporaryDirectory() as cache_dir: - # Download config.json for a few models - hf_hub_download("Qwen/Qwen2.5-0.5B-Instruct", "config.json", cache_dir=cache_dir) - hf_hub_download("google-bert/bert-base-cased", "config.json", cache_dir=cache_dir) - - result = ModelManager.get_gen_models(cache_dir) - model_ids = {r["id"] for r in result} - - self.assertIn("Qwen/Qwen2.5-0.5B-Instruct", model_ids) - self.assertNotIn("google-bert/bert-base-cased", model_ids) - - -@require_serve -class TestBuildGenerationConfig(unittest.TestCase): - def _make_handler(self): - return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) - - def test_max_tokens(self): - from transformers import GenerationConfig - - result = self._make_handler()._build_generation_config({"max_tokens": 7}, GenerationConfig()) - self.assertEqual(result.max_new_tokens, 7) - - def test_temperature_zero_disables_sampling(self): - from transformers import GenerationConfig - - result = self._make_handler()._build_generation_config({"temperature": 0.0}, GenerationConfig(do_sample=True)) - self.assertFalse(result.do_sample) - - def test_frequency_penalty(self): - from transformers import GenerationConfig - - result = self._make_handler()._build_generation_config({"frequency_penalty": 0.5}, GenerationConfig()) - self.assertAlmostEqual(result.repetition_penalty, 1.5) - - def test_logit_bias_tuple_keys(self): - from transformers import GenerationConfig - - result = self._make_handler()._build_generation_config({"logit_bias": {"42": 1.0}}, GenerationConfig()) - self.assertEqual(result.sequence_bias, {(42,): 1.0}) - - def test_stop_strings(self): - from transformers import GenerationConfig - - result = self._make_handler()._build_generation_config({"stop": [""]}, GenerationConfig()) - self.assertEqual(result.stop_strings, [""]) - - def test_generation_config_json_overrides(self): - from transformers import GenerationConfig - - custom = GenerationConfig(max_new_tokens=5, do_sample=False) - result = self._make_handler()._build_generation_config( - {"generation_config": custom.to_json_string()}, GenerationConfig(max_new_tokens=100) - ) - self.assertEqual(result.max_new_tokens, 5) - self.assertFalse(result.do_sample) - - def test_generation_config_json_no_defaults_applied(self): - """When generation_config JSON is passed, serving defaults should NOT be applied.""" - from transformers import GenerationConfig - - custom = GenerationConfig(max_new_tokens=10) - result = self._make_handler()._build_generation_config( - {"generation_config": custom.to_json_string()}, GenerationConfig() - ) - # Should keep 10, not bump to 1024 - self.assertEqual(result.max_new_tokens, 10) - - def test_default_bumps_short_max_new_tokens(self): - from transformers import GenerationConfig - - result = self._make_handler()._build_generation_config({}, GenerationConfig(max_new_tokens=20)) - self.assertEqual(result.max_new_tokens, 1024) - - def test_user_max_tokens_overrides_default(self): - """User's max_tokens should win over the serving default.""" - from transformers import GenerationConfig - - result = self._make_handler()._build_generation_config({"max_tokens": 50}, GenerationConfig(max_new_tokens=20)) - self.assertEqual(result.max_new_tokens, 50) - - -@require_serve -class TestValidation(unittest.TestCase): - def _make_handler(self): - return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) - - def test_valid_request_passes(self): - handler = self._make_handler() - # Should not raise - handler._validate_request({"model": "x", "messages": [{"role": "user", "content": "hi"}], "stream": True}) - - def test_unexpected_keys_rejected(self): - handler = self._make_handler() - with self.assertRaises(HTTPException) as ctx: - handler._validate_request({"model": "x", "messages": [], "bogus_field": True}) - self.assertEqual(ctx.exception.status_code, 422) - self.assertIn("bogus_field", ctx.exception.detail) - - def test_unsupported_fields_warns(self): - handler = self._make_handler() - with self.assertLogs("transformers", level="WARNING") as cm: - handler._validate_request({"model": "x", "messages": [], "audio": {}}) - self.assertTrue(any("audio" in msg for msg in cm.output)) - - -class TestModelManager(unittest.TestCase): - def test_process_model_name_adds_main(self): - self.assertEqual(ModelManager.process_model_name("org/model"), "org/model@main") - - def test_process_model_name_preserves_revision(self): - self.assertEqual(ModelManager.process_model_name("org/model@dev"), "org/model@dev") - - def test_quantization_config_4bit(self): - mm = ModelManager(quantization="bnb-4bit") - cfg = mm.get_quantization_config() - self.assertTrue(cfg.load_in_4bit) - - def test_quantization_config_8bit(self): - mm = ModelManager(quantization="bnb-8bit") - cfg = mm.get_quantization_config() - self.assertTrue(cfg.load_in_8bit) - - def test_quantization_config_none(self): - mm = ModelManager() - self.assertIsNone(mm.get_quantization_config()) - - -class TestTimedModel(unittest.TestCase): - def test_delete_model(self): - mock_model = MagicMock() - deleted = [] - timed = TimedModel( - mock_model, timeout_seconds=9999, processor=MagicMock(), on_unload=lambda: deleted.append(True) - ) - self.assertIsNotNone(timed.model) - timed.delete_model() - self.assertIsNone(timed.model) - self.assertEqual(len(deleted), 1) - - def test_timeout_zero_no_delete(self): - mock_model = MagicMock() - timed = TimedModel(mock_model, timeout_seconds=0, processor=MagicMock()) - timed._timeout_reached() - self.assertIsNotNone(timed.model) - timed._timer.cancel() - - -@require_serve -class TestChunkSSE(unittest.TestCase): - def _make_handler(self): - return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) - - def test_build_chunk_sse_content(self): - handler = self._make_handler() - sse = handler._build_chunk_sse(request_id="req1", content="hi", model="m") - self.assertTrue(sse.startswith("data: ")) - self.assertTrue(sse.endswith("\n\n")) - parsed = json.loads(sse[len("data: ") :].strip()) - self.assertEqual(parsed["choices"][0]["delta"]["content"], "hi") - - def test_build_chunk_sse_role(self): - handler = self._make_handler() - sse = handler._build_chunk_sse(request_id="req1", role="assistant", model="m") - parsed = json.loads(sse[len("data: ") :].strip()) - self.assertEqual(parsed["choices"][0]["delta"]["role"], "assistant") - self.assertNotIn("content", parsed["choices"][0]["delta"]) - - def test_build_chunk_sse_finish_reason(self): - handler = self._make_handler() - sse = handler._build_chunk_sse(request_id="req1", finish_reason="stop", model="m") - parsed = json.loads(sse[len("data: ") :].strip()) - self.assertEqual(parsed["choices"][0]["finish_reason"], "stop") - - def test_chunk_to_sse_string_passthrough(self): - result = BaseHandler.chunk_to_sse("data: already formatted\n\n") - self.assertEqual(result, "data: already formatted\n\n") - - def test_chunk_to_sse_wraps_plain_string(self): - result = BaseHandler.chunk_to_sse("hello") - self.assertEqual(result, "data: hello\n\n") - - -QWEN_TOOL_FORMAT = {"start": "", "end": ""} - - -@require_serve -class TestToolParser(unittest.TestCase): - def test_detect_tool_format_qwen(self): - model = MagicMock() - model.config.architectures = ["Qwen2ForCausalLM"] - fmt = detect_tool_format(model) - self.assertEqual(fmt, QWEN_TOOL_FORMAT) - - def test_detect_tool_format_unsupported(self): - model = MagicMock() - model.config.architectures = ["LlamaForCausalLM"] - self.assertIsNone(detect_tool_format(model)) - - def test_parser_start_token(self): - parser = ToolCallParser(QWEN_TOOL_FORMAT) - result = parser.feed("") - self.assertIs(result, ToolCallParser.CONSUMED) - - def test_parser_end_token(self): - parser = ToolCallParser(QWEN_TOOL_FORMAT) - parser.feed("") - result = parser.feed("") - self.assertIs(result, ToolCallParser.CONSUMED) - - def test_parser_buffers_until_end(self): - parser = ToolCallParser(QWEN_TOOL_FORMAT) - parser.feed("") - # Intermediate tokens are buffered - result = parser.feed('{"name": "my_tool", "arguments": {"x": 1}}') - self.assertIs(result, ToolCallParser.CONSUMED) - # Tool call is emitted on end token - result = parser.feed("") - self.assertIsNot(result, ToolCallParser.CONSUMED) - self.assertEqual(result["name"], "my_tool") - - def test_parser_normal_text_returns_none(self): - parser = ToolCallParser(QWEN_TOOL_FORMAT) - result = parser.feed("Hello world") - self.assertIsNone(result) - - def test_parser_full_flow(self): - """Simulate a complete tool call token sequence.""" - - parser = ToolCallParser(QWEN_TOOL_FORMAT) - tool_calls = [] - - for token in [ - "", - '{"name": "get_weather",', - ' "arguments": {', - '"city": "Paris"', - "}}", - "\n", - "", - ]: - result = parser.feed(token) - if result is not None and result is not ToolCallParser.CONSUMED: - tool_calls.append(result) - - # Single tool call emitted on with both name and arguments - self.assertEqual(len(tool_calls), 1) - self.assertEqual(tool_calls[0]["name"], "get_weather") - self.assertIn("Paris", tool_calls[0]["arguments"]) - - def test_parse_tool_calls_from_text(self): - """Non-streaming tool call parsing from complete text.""" - - text = '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n' - calls = ToolCallParser.parse(text, QWEN_TOOL_FORMAT) - self.assertIsNotNone(calls) - self.assertEqual(len(calls), 1) - self.assertEqual(calls[0]["name"], "get_weather") - self.assertIn("Paris", calls[0]["arguments"]) - - def test_parse_tool_calls_no_tool_call(self): - """Non-streaming: normal text returns None.""" - - calls = ToolCallParser.parse("Hello, how can I help?", QWEN_TOOL_FORMAT) - self.assertIsNone(calls) - - def test_parse_multiple_tool_calls(self): - """Non-streaming: multiple tool calls in one response.""" - - text = ( - '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n\n' - '\n{"name": "get_weather", "arguments": {"city": "London"}}\n' - ) - calls = ToolCallParser.parse(text, QWEN_TOOL_FORMAT) - self.assertIsNotNone(calls) - self.assertEqual(len(calls), 2) - self.assertEqual(calls[0]["name"], "get_weather") - self.assertIn("Paris", calls[0]["arguments"]) - self.assertEqual(calls[1]["name"], "get_weather") - self.assertIn("London", calls[1]["arguments"]) - - def test_feed_multiple_tool_calls(self): - """Streaming: multiple tool calls emitted sequentially.""" - - parser = ToolCallParser(QWEN_TOOL_FORMAT) - tool_calls = [] - - tokens = [ - "", - '{"name": "get_weather", "arguments": {"city": "Paris"}}', - "", - "", - '{"name": "get_weather", "arguments": {"city": "London"}}', - "", - ] - for token in tokens: - result = parser.feed(token) - if result is not None and result is not ToolCallParser.CONSUMED: - tool_calls.append(result) - - self.assertEqual(len(tool_calls), 2) - self.assertEqual(tool_calls[0]["name"], "get_weather") - self.assertIn("Paris", tool_calls[0]["arguments"]) - self.assertEqual(tool_calls[1]["name"], "get_weather") - self.assertIn("London", tool_calls[1]["arguments"]) - - -@require_serve -class TestAppRoutes(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model_manager = MagicMock(spec=ModelManager) - cls.model_manager.get_gen_models.return_value = [ - {"id": "test/model", "owned_by": "test", "object": "model", "created": 0} - ] - cls.chat_handler = MagicMock(spec=ChatCompletionHandler) - cls.response_handler = MagicMock(spec=ResponseHandler) - cls.transcription_handler = MagicMock(spec=TranscriptionHandler) - cls.app = build_server(cls.model_manager, cls.chat_handler, cls.response_handler, cls.transcription_handler) - cls.transport = httpx.ASGITransport(app=cls.app) - - async def _request(self, method: str, path: str, **kwargs) -> httpx.Response: - async with httpx.AsyncClient(transport=self.transport, base_url="http://test") as c: - return await c.request(method, path, **kwargs) - - def test_health(self): - resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/health")) - self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.json(), {"status": "ok"}) - - def test_models_list(self): - resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/v1/models")) - self.assertEqual(resp.status_code, 200) - data = resp.json() - self.assertEqual(data["object"], "list") - self.assertEqual(len(data["data"]), 1) - - def test_request_id_generated(self): - resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/health")) - self.assertIn("x-request-id", resp.headers) - self.assertEqual(len(resp.headers["x-request-id"]), 36) # UUID length - - def test_request_id_passthrough(self): - resp = asyncio.get_event_loop().run_until_complete( - self._request("GET", "/health", headers={"x-request-id": "my-id"}) - ) - self.assertEqual(resp.headers["x-request-id"], "my-id") - - -@slow -@require_serve -class TestChatCompletion(unittest.TestCase): - """Integration tests for /v1/chat/completions with a real model.""" - - MODEL = "Qwen/Qwen2.5-0.5B-Instruct" - - @classmethod - def setUpClass(cls): - cls.serve, port = _start_serve() - cls.base_url = f"http://localhost:{port}" - cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") - - @classmethod - def tearDownClass(cls): - cls.serve.kill_server() - - def test_non_streaming(self): - resp = self.client.chat.completions.create( - model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}] - ) - self.assertIsNotNone(resp.choices[0].message.content) - self.assertIn(resp.choices[0].finish_reason, ("stop", "length")) - - def test_streaming(self): - text = "" - for chunk in self.client.chat.completions.create( - model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}], stream=True - ): - if chunk.choices[0].delta.content: - text += chunk.choices[0].delta.content - self.assertTrue(len(text) > 0) - - def test_early_return_due_to_length(self): - """When max_tokens is hit, finish_reason should be 'length'.""" - chunks = list( - self.client.chat.completions.create( - model=self.MODEL, - messages=[{"role": "user", "content": "Hello, how are you?"}], - stream=True, - max_tokens=3, - ) - ) - last = chunks[-1] - self.assertEqual(last.choices[0].finish_reason, "length") - - def test_continues_until_stop(self): - """When model stops naturally, finish_reason should be 'stop'.""" - chunks = list( - self.client.chat.completions.create( - model=self.MODEL, - messages=[{"role": "user", "content": 'Please only answer with "Hi."'}], - stream=True, - max_tokens=30, - ) - ) - last = chunks[-1] - self.assertEqual(last.choices[0].finish_reason, "stop") - - def test_stop_strings(self): - resp = self.client.chat.completions.create( - model=self.MODEL, messages=[{"role": "user", "content": "Count to 10"}], stop=["5"] - ) - self.assertNotIn("6", resp.choices[0].message.content) - - def test_multi_turn(self): - resp = self.client.chat.completions.create( - model=self.MODEL, - messages=[ - {"role": "user", "content": "My name is Alice"}, - {"role": "assistant", "content": "Nice to meet you!"}, - {"role": "user", "content": "What is my name?"}, - ], - ) - self.assertIn("Alice", resp.choices[0].message.content) - - def test_multiple_models_on_demand(self): - """Load two different models via separate requests — both should work.""" - model_a = "Qwen/Qwen2.5-0.5B-Instruct" - model_b = "HuggingFaceTB/SmolLM2-135M-Instruct" - prompt = [{"role": "user", "content": "Say hello"}] - - resp_a = self.client.chat.completions.create(model=model_a, messages=prompt) - self.assertIn(model_a, resp_a.model) - self.assertIsNotNone(resp_a.choices[0].message.content) - - resp_b = self.client.chat.completions.create(model=model_b, messages=prompt) - self.assertIn(model_b, resp_b.model) - self.assertIsNotNone(resp_b.choices[0].message.content) - - def test_non_streaming_usage(self): - resp = self.client.chat.completions.create( - model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}] - ) - self.assertIsNotNone(resp.usage) - self.assertGreater(resp.usage.prompt_tokens, 0) - self.assertGreater(resp.usage.completion_tokens, 0) - self.assertEqual(resp.usage.total_tokens, resp.usage.prompt_tokens + resp.usage.completion_tokens) - - def test_streaming_usage(self): - chunks = list( - self.client.chat.completions.create( - model=self.MODEL, - messages=[{"role": "user", "content": "Say hello"}], - stream=True, - ) - ) - # Last chunk should have usage - last = chunks[-1] - self.assertIsNotNone(last.usage) - self.assertGreater(last.usage.prompt_tokens, 0) - self.assertGreater(last.usage.completion_tokens, 0) - self.assertEqual(last.usage.total_tokens, last.usage.prompt_tokens + last.usage.completion_tokens) - - def test_tool_call(self): - """Tool calls should be parsed and emitted as ChoiceDeltaToolCall objects.""" - # Qwen2.5-0.5B-Instruct supports tools (Qwen family) - tool_def = { - "function": { - "name": "get_weather", - "parameters": { - "type": "object", - "properties": {"city": {"type": "string"}}, - }, - "description": "Get the weather for a city.", - }, - "type": "function", - } - chunks = list( - self.client.chat.completions.create( - model=self.MODEL, - messages=[{"role": "user", "content": "What is the weather in Paris?"}], - stream=True, - max_tokens=50, - temperature=0.0, - tools=[tool_def], - ) - ) - - # First chunk should have role="assistant" - self.assertEqual(chunks[0].choices[0].delta.role, "assistant") - - # Model should make a tool call for this prompt - tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls] - self.assertGreater(len(tool_chunks), 0, "Model did not produce a tool call") - - # First tool call delta should have the function name - first_tool = tool_chunks[0].choices[0].delta.tool_calls[0] - self.assertEqual(first_tool.function.name, "get_weather") - - # finish_reason should be "tool_calls" - last = chunks[-1] - self.assertEqual(last.choices[0].finish_reason, "tool_calls") - - # Arguments should be valid JSON with no trailing brace - args_json = first_tool.function.arguments - import json as json_mod - - parsed_args = json_mod.loads(args_json) - self.assertIsInstance(parsed_args, dict) - - def test_tool_call_non_streaming(self): - """Non-streaming tool calls should return tool_calls in the message.""" - tool_def = { - "function": { - "name": "get_weather", - "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, - "description": "Get the weather for a city.", - }, - "type": "function", - } - resp = self.client.chat.completions.create( - model=self.MODEL, - messages=[{"role": "user", "content": "What is the weather in Paris?"}], - stream=False, - max_tokens=50, - temperature=0.0, - tools=[tool_def], - ) - self.assertEqual(resp.choices[0].finish_reason, "tool_calls") - self.assertIsNotNone(resp.choices[0].message.tool_calls) - tc = resp.choices[0].message.tool_calls[0] - self.assertEqual(tc.function.name, "get_weather") - - import json as json_mod - - parsed_args = json_mod.loads(tc.function.arguments) - self.assertIsInstance(parsed_args, dict) - - def test_tool_call_multi(self): - """Model should be able to call multiple tools when asked.""" - tool_def = { - "function": { - "name": "get_weather", - "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, - "description": "Get the weather for a city.", - }, - "type": "function", - } - # Ask for two cities to encourage multiple tool calls - chunks = list( - self.client.chat.completions.create( - model=self.MODEL, - messages=[{"role": "user", "content": "What is the weather in Paris and London?"}], - stream=True, - max_tokens=100, - temperature=0.0, - tools=[tool_def], - ) - ) - tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls] - # Should have two tool calls — one for Paris, one for London - self.assertEqual(len(tool_chunks), 2, f"Expected 2 tool calls, got {len(tool_chunks)}") - cities = {tc.choices[0].delta.tool_calls[0].function.name for tc in tool_chunks} - self.assertEqual(cities, {"get_weather"}) - last = chunks[-1] - self.assertEqual(last.choices[0].finish_reason, "tool_calls") - - def test_concurrent_non_streaming(self): - """Two concurrent non-streaming requests should both complete without interference.""" - import concurrent.futures - - prompts = [ - [{"role": "user", "content": "Say hello"}], - [{"role": "user", "content": "Say goodbye"}], - ] - results = [None, None] - - def request_in_thread(index): - client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") - results[index] = client.chat.completions.create(model=self.MODEL, messages=prompts[index]) - - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: - futures = [pool.submit(request_in_thread, i) for i in range(2)] - concurrent.futures.wait(futures) - for f in futures: - f.result() # re-raise exceptions - - for i in range(2): - self.assertIsNotNone(results[i]) - self.assertIsNotNone(results[i].choices[0].message.content) - self.assertTrue(len(results[i].choices[0].message.content) > 0) - - def test_concurrent_streaming(self): - """Two concurrent streaming requests should both produce complete, non-empty output.""" - import concurrent.futures - - prompts = [ - [{"role": "user", "content": "Say hello"}], - [{"role": "user", "content": "Say goodbye"}], - ] - results = [None, None] - - def stream_in_thread(index): - client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") - text = "" - for chunk in client.chat.completions.create(model=self.MODEL, messages=prompts[index], stream=True): - if chunk.choices[0].delta.content: - text += chunk.choices[0].delta.content - results[index] = text - - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: - futures = [pool.submit(stream_in_thread, i) for i in range(2)] - concurrent.futures.wait(futures) - for f in futures: - f.result() - - for i in range(2): - self.assertIsNotNone(results[i]) - self.assertTrue(len(results[i]) > 0, f"Request {i} produced empty output") - - def test_request_cancellation(self): - """Closing a stream early doesn't crash and the server stays healthy.""" - - with httpx.stream( - "POST", - f"{self.base_url}/v1/chat/completions", - json={ - "model": self.MODEL, - "stream": True, - "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}], - "max_tokens": 500, - }, - timeout=30, - ) as resp: - self.assertEqual(resp.status_code, 200) - chunks_read = 0 - for _ in resp.iter_lines(): - chunks_read += 1 - if chunks_read >= 3: - break - - # Server should still be healthy and serve subsequent requests - resp = self.client.chat.completions.create( - model=self.MODEL, - messages=[{"role": "user", "content": "Say hi"}], - max_tokens=10, - ) - self.assertIsNotNone(resp.choices[0].message.content) - - -@require_serve -class TestResponseInputConversion(unittest.TestCase): - def _make_handler(self): - return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) - - def test_string_input(self): - handler = self._make_handler() - msgs = handler._input_to_messages({"input": "Hello"}) - self.assertEqual(msgs, [{"role": "user", "content": "Hello"}]) - - def test_string_input_with_instructions(self): - handler = self._make_handler() - msgs = handler._input_to_messages({"input": "Hello", "instructions": "Be brief"}) - self.assertEqual(len(msgs), 2) - self.assertEqual(msgs[0], {"role": "system", "content": "Be brief"}) - self.assertEqual(msgs[1], {"role": "user", "content": "Hello"}) - - def test_list_input(self): - handler = self._make_handler() - msgs = handler._input_to_messages( - {"input": [{"role": "user", "content": "A"}, {"role": "assistant", "content": "B"}]} - ) - self.assertEqual(len(msgs), 2) - self.assertEqual(msgs[0]["content"], "A") - - def test_list_input_with_instructions_prepends_system(self): - handler = self._make_handler() - msgs = handler._input_to_messages({"input": [{"role": "user", "content": "Hi"}], "instructions": "Be helpful"}) - self.assertEqual(len(msgs), 2) - self.assertEqual(msgs[0]["role"], "system") - self.assertEqual(msgs[0]["content"], "Be helpful") - - def test_list_input_with_instructions_replaces_existing_system(self): - handler = self._make_handler() - msgs = handler._input_to_messages( - {"input": [{"role": "system", "content": "Old"}, {"role": "user", "content": "Hi"}], "instructions": "New"} - ) - self.assertEqual(len(msgs), 2) - self.assertEqual(msgs[0]["content"], "New") - - def test_dict_input(self): - handler = self._make_handler() - msgs = handler._input_to_messages({"input": {"role": "user", "content": "Test"}}) - self.assertEqual(msgs, [{"role": "user", "content": "Test"}]) - - -@require_serve -class TestResponseValidation(unittest.TestCase): - def _make_handler(self): - return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) - - def test_unsupported_fields_warns(self): - handler = self._make_handler() - with self.assertLogs("transformers", level="WARNING") as cm: - handler._validate_request({"model": "x", "input": "hi", "previous_response_id": "abc"}) - self.assertTrue(any("previous_response_id" in msg for msg in cm.output)) - - def test_valid_request_passes(self): - handler = self._make_handler() - # Should not raise - handler._validate_request({"model": "x", "input": "hi"}) - - -@require_serve -class TestResponseGenerationConfig(unittest.TestCase): - def _make_handler(self): - return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) - - def test_max_output_tokens(self): - from transformers import GenerationConfig - - result = self._make_handler()._build_generation_config({"max_output_tokens": 42}, GenerationConfig()) - self.assertEqual(result.max_new_tokens, 42) - - def test_default_bumps_short_max_new_tokens(self): - from transformers import GenerationConfig - - result = self._make_handler()._build_generation_config({}, GenerationConfig(max_new_tokens=20)) - self.assertEqual(result.max_new_tokens, 1024) - - -@require_serve -class TestResponseUsage(unittest.TestCase): - def testcompute_usage(self): - usage = compute_usage(input_tokens=100, output_tokens=50) - self.assertEqual(usage.input_tokens, 100) - self.assertEqual(usage.output_tokens, 50) - self.assertEqual(usage.total_tokens, 150) - self.assertEqual(usage.input_tokens_details.cached_tokens, 0) - self.assertEqual(usage.output_tokens_details.reasoning_tokens, 0) - - def test_usage_in_completed_response(self): - """Usage should serialize correctly inside a Response.""" - - usage = compute_usage(10, 5) - response = Response( - id="resp_test", - created_at=0, - status="completed", - model="test", - output=[], - object="response", - tools=[], - parallel_tool_calls=False, - tool_choice="auto", - usage=usage, - ) - dumped = response.model_dump(exclude_none=True) - self.assertEqual(dumped["usage"]["input_tokens"], 10) - self.assertEqual(dumped["usage"]["output_tokens"], 5) - self.assertEqual(dumped["usage"]["total_tokens"], 15) - - -@require_serve -class TestResponseSSEFormat(unittest.TestCase): - def test_sse_format(self): - event = ResponseCreatedEvent( - type="response.created", - sequence_number=0, - response=Response( - id="resp_test", - created_at=0, - status="queued", - model="test", - text={"format": {"type": "text"}}, - object="response", - tools=[], - output=[], - parallel_tool_calls=False, - tool_choice="auto", - ), - ) - result = BaseHandler.chunk_to_sse(event) - self.assertTrue(result.startswith("data: ")) - self.assertTrue(result.endswith("\n\n")) - parsed = json.loads(result[len("data: ") :].strip()) - self.assertEqual(parsed["type"], "response.created") - self.assertEqual(parsed["response"]["status"], "queued") - - -@slow -@require_serve -class TestResponsesIntegration(unittest.TestCase): - """Integration tests for /v1/responses with a real model.""" - - MODEL = "Qwen/Qwen2.5-0.5B-Instruct" - - @classmethod - def setUpClass(cls): - cls.serve, port = _start_serve() - cls.base_url = f"http://localhost:{port}" - cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") - - @classmethod - def tearDownClass(cls): - cls.serve.kill_server() - - def test_streaming(self): - events = list( - self.client.responses.create( - model=self.MODEL, - input="Say hello", - stream=True, - max_output_tokens=1, - ) - ) - # At least 8 events: created, in_progress, output_item_added, content_part_added, - # delta(s), text_done, content_part_done, output_item_done, completed - self.assertGreaterEqual(len(events), 8) - - # Start markers (fixed order) - self.assertEqual(events[0].type, "response.created") - self.assertEqual(events[1].type, "response.in_progress") - self.assertEqual(events[2].type, "response.output_item.added") - self.assertEqual(events[3].type, "response.content_part.added") - - # At least one delta - self.assertTrue(any(e.type == "response.output_text.delta" for e in events[4:-4])) - - # Closing markers (fixed order from the end) - self.assertEqual(events[-4].type, "response.output_text.done") - self.assertEqual(events[-3].type, "response.content_part.done") - self.assertEqual(events[-2].type, "response.output_item.done") - self.assertEqual(events[-1].type, "response.completed") - - def test_non_streaming(self): - resp = self.client.responses.create( - model=self.MODEL, - input="Say hello", - stream=False, - ) - self.assertEqual(resp.status, "completed") - self.assertTrue(len(resp.output) > 0) - self.assertTrue(len(resp.output[0].content[0].text) > 0) - - def test_non_streaming_usage(self): - resp = self.client.responses.create( - model=self.MODEL, - input="Say hello", - stream=False, - ) - self.assertIsNotNone(resp.usage) - self.assertGreater(resp.usage.input_tokens, 0) - self.assertGreater(resp.usage.output_tokens, 0) - self.assertEqual(resp.usage.total_tokens, resp.usage.input_tokens + resp.usage.output_tokens) - - def test_streaming_usage(self): - events = list( - self.client.responses.create( - model=self.MODEL, - input="Say hello", - stream=True, - max_output_tokens=5, - ) - ) - completed = events[-1] - self.assertEqual(completed.type, "response.completed") - usage = completed.response.usage - self.assertIsNotNone(usage) - self.assertGreater(usage.input_tokens, 0) - self.assertGreater(usage.output_tokens, 0) - self.assertEqual(usage.total_tokens, usage.input_tokens + usage.output_tokens) - - def test_tool_call_streaming(self): - """Streaming responses with tools should emit function_call events.""" - tool_def = { - "function": { - "name": "get_weather", - "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, - "description": "Get the weather for a city.", - }, - "type": "function", - } - events = list( - self.client.responses.create( - model=self.MODEL, - input="What is the weather in Paris?", - stream=True, - max_output_tokens=50, - tools=[tool_def], - ) - ) - types = [e.type for e in events] - self.assertIn("response.created", types) - self.assertIn("response.completed", types) - - # Should have function call events - self.assertIn("response.output_item.added", types) - self.assertIn("response.function_call_arguments.done", types) - - # Check the arguments done event - args_done = [e for e in events if e.type == "response.function_call_arguments.done"] - self.assertGreater(len(args_done), 0) - self.assertEqual(args_done[0].name, "get_weather") - - import json as json_mod - - parsed = json_mod.loads(args_done[0].arguments) - self.assertIsInstance(parsed, dict) - - def test_tool_call_non_streaming(self): - """Non-streaming responses with tools should include function_call output items.""" - tool_def = { - "function": { - "name": "get_weather", - "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, - "description": "Get the weather for a city.", - }, - "type": "function", - } - resp = self.client.responses.create( - model=self.MODEL, - input="What is the weather in Paris?", - stream=False, - max_output_tokens=50, - tools=[tool_def], - ) - self.assertEqual(resp.status, "completed") - - # Should have at least message + function_call in output - self.assertGreater(len(resp.output), 1) - fc_items = [o for o in resp.output if o.type == "function_call"] - self.assertGreater(len(fc_items), 0) - self.assertEqual(fc_items[0].name, "get_weather") - - import json as json_mod - - parsed = json_mod.loads(fc_items[0].arguments) - self.assertIsInstance(parsed, dict) - - def test_tool_call_multi(self): - """Model should produce multiple tool calls when asked about two cities.""" - tool_def = { - "function": { - "name": "get_weather", - "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, - "description": "Get the weather for a city.", - }, - "type": "function", - } - events = list( - self.client.responses.create( - model=self.MODEL, - input="What is the weather in Paris and London?", - stream=True, - max_output_tokens=100, - tools=[tool_def], - ) - ) - args_done = [e for e in events if e.type == "response.function_call_arguments.done"] - self.assertEqual(len(args_done), 2, f"Expected 2 tool calls, got {len(args_done)}") - self.assertEqual(events[-1].type, "response.completed") - - def test_multi_turn(self): - """Multi-turn conversation via list input.""" - resp = self.client.responses.create( - model=self.MODEL, - input=[ - {"role": "user", "content": "My name is Alice"}, - {"role": "assistant", "content": "Nice to meet you!"}, - {"role": "user", "content": "What is my name?"}, - ], - stream=False, - ) - self.assertEqual(resp.status, "completed") - self.assertIn("Alice", resp.output[0].content[0].text) - - def test_concurrent_non_streaming(self): - """Two concurrent non-streaming responses requests should both complete.""" - import concurrent.futures - - inputs = ["Say hello", "Say goodbye"] - results = [None, None] - - def request_in_thread(index): - client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") - results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False) - - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: - futures = [pool.submit(request_in_thread, i) for i in range(2)] - concurrent.futures.wait(futures) - for f in futures: - f.result() - - for i in range(2): - self.assertIsNotNone(results[i]) - self.assertEqual(results[i].status, "completed") - self.assertTrue(len(results[i].output[0].content[0].text) > 0) - - def test_concurrent_streaming(self): - """Two concurrent streaming responses requests should both produce complete event streams.""" - import concurrent.futures - - inputs = ["Say hello", "Say goodbye"] - results = [None, None] - - def stream_in_thread(index): - client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") - results[index] = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True)) - - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: - futures = [pool.submit(stream_in_thread, i) for i in range(2)] - concurrent.futures.wait(futures) - for f in futures: - f.result() - - for i in range(2): - types = [e.type for e in results[i]] - self.assertIn("response.created", types, f"Request {i} missing created event") - self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events") - self.assertIn("response.completed", types, f"Request {i} missing completed event") - - -def _parse_sse_events(response): - """Parse SSE lines from a streaming httpx response into a list of dicts.""" - events = [] - for line in response.iter_lines(): - if not line or not line.startswith("data: "): - continue - events.append(json.loads(line[len("data: ") :])) - return events - - -@slow -@require_serve -class TestLoadModel(unittest.TestCase): - """Integration tests for POST /load_model SSE endpoint.""" - - MODEL = "Qwen/Qwen2.5-0.5B-Instruct" - - @classmethod - def setUpClass(cls): - cls.serve, port = _start_serve() - cls.base_url = f"http://localhost:{port}" - - @classmethod - def tearDownClass(cls): - cls.serve.kill_server() - - def setUp(self): - # Clear model cache so each test starts fresh - self.serve.reset_loaded_models() - - def _load_model(self, model: str): - with httpx.stream("POST", f"{self.base_url}/load_model", json={"model": model}, timeout=120) as resp: - events = _parse_sse_events(resp) - return resp, events - - def test_load_model_fresh(self): - """POST /load_model returns SSE events ending with ready.""" - response, events = self._load_model(self.MODEL) - - self.assertEqual(response.status_code, 200) - - stages = [e["stage"] for e in events if e["status"] == "loading" and "stage" in e] - self.assertIn("processor", stages) - self.assertIn("weights", stages) - - last = events[-1] - self.assertEqual(last["status"], "ready") - self.assertFalse(last["cached"]) - - for event in events: - self.assertIn("status", event) - self.assertIn("model", event) - - def test_load_model_cached(self): - """Loading an already-loaded model returns a single ready event with cached: true.""" - self._load_model(self.MODEL) - - _, events = self._load_model(self.MODEL) - - ready_events = [e for e in events if e["status"] == "ready"] - self.assertEqual(len(ready_events), 1) - self.assertTrue(ready_events[0]["cached"]) - - loading_events = [e for e in events if e["status"] == "loading"] - self.assertEqual(len(loading_events), 0) - - def test_load_model_error(self): - """Loading a nonexistent model produces an error event.""" - _, events = self._load_model("nonexistent/model-that-does-not-exist") - - error_events = [e for e in events if e["status"] == "error"] - self.assertGreaterEqual(len(error_events), 1) - self.assertIn("message", error_events[0]) - - def test_load_model_missing_field(self): - """POST /load_model with no model field returns 422.""" - - response = httpx.post(f"{self.base_url}/load_model", json={}, timeout=30) - self.assertEqual(response.status_code, 422) - - def test_load_model_event_schema(self): - """Every event conforms to the expected schema.""" - _, events = self._load_model(self.MODEL) - - for event in events: - self.assertIsInstance(event["status"], str) - self.assertIsInstance(event["model"], str) - - if event["status"] == "loading": - self.assertIn("stage", event) - if event["stage"] in ("download", "weights") and "progress" in event: - progress = event["progress"] - self.assertIn("current", progress) - self.assertIn("total", progress) - self.assertIsInstance(progress["current"], int) - - if event["status"] == "ready": - self.assertIn("cached", event) - self.assertIsInstance(event["cached"], bool) - - def test_load_model_stage_ordering(self): - """Stages appear in the expected order.""" - _, events = self._load_model(self.MODEL) - - stages = [e["stage"] for e in events if e["status"] == "loading" and "stage" in e] - seen = set() - unique_stages = [] - for s in stages: - if s not in seen: - seen.add(s) - unique_stages.append(s) - - expected_order = ["processor", "config", "download", "weights"] - expected_present = [s for s in expected_order if s in unique_stages] - self.assertEqual(unique_stages, expected_present, "Stages appeared out of order") - - def test_concurrent_load_same_model(self): - """Two concurrent /load_model requests both get events and a ready event.""" - import concurrent.futures - - results = [None, None] - - def load_in_thread(index): - with httpx.stream("POST", f"{self.base_url}/load_model", json={"model": self.MODEL}, timeout=120) as resp: - events = _parse_sse_events(resp) - results[index] = (resp.status_code, events) - - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: - futures = [pool.submit(load_in_thread, i) for i in range(2)] - concurrent.futures.wait(futures) - for f in futures: - f.result() - - for i in range(2): - status_code, events = results[i] - self.assertEqual(status_code, 200, f"Caller {i} got non-200 status") - self.assertTrue(len(events) > 0, f"Caller {i} received no events") - ready_events = [e for e in events if e["status"] == "ready"] - self.assertEqual(len(ready_events), 1, f"Caller {i} should get exactly one ready event") - - def test_concurrent_load_second_caller_gets_cached(self): - """If the first /load_model finishes before the second, the second gets cached: true.""" - _, events1 = self._load_model(self.MODEL) - ready1 = [e for e in events1 if e["status"] == "ready"] - self.assertEqual(len(ready1), 1) - self.assertFalse(ready1[0]["cached"]) - - _, events2 = self._load_model(self.MODEL) - ready2 = [e for e in events2 if e["status"] == "ready"] - self.assertEqual(len(ready2), 1) - self.assertTrue(ready2[0]["cached"]) - - loading2 = [e for e in events2 if e["status"] == "loading"] - self.assertEqual(len(loading2), 0) - - def test_load_model_weights_progress_complete(self): - """Weights progress should go from 1 to total, with total matching across events.""" - _, events = self._load_model(self.MODEL) - - weights_events = [e for e in events if e.get("stage") == "weights" and "progress" in e] - self.assertGreater(len(weights_events), 0, "No weights progress events emitted") - - # All events should have the same total - totals = {e["progress"]["total"] for e in weights_events} - self.assertEqual(len(totals), 1, f"Inconsistent totals: {totals}") - total = totals.pop() - self.assertIsNotNone(total) - self.assertGreater(total, 0) - - # First should be 1, last should be total - self.assertEqual(weights_events[0]["progress"]["current"], 1) - self.assertEqual(weights_events[-1]["progress"]["current"], total) - - # Progress should be monotonically increasing - currents = [e["progress"]["current"] for e in weights_events] - self.assertEqual(currents, sorted(currents)) - - def test_load_model_exactly_one_ready(self): - """A fresh load should produce exactly one ready event as the last event.""" - _, events = self._load_model(self.MODEL) - - ready_events = [e for e in events if e["status"] == "ready"] - self.assertEqual(len(ready_events), 1) - self.assertEqual(events[-1]["status"], "ready") - - def test_load_model_usable_after_load(self): - """After /load_model completes, the model should be usable for inference.""" - self._load_model(self.MODEL) - - client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") - resp = client.chat.completions.create( - model=self.MODEL, - messages=[{"role": "user", "content": "Say hi"}], - max_tokens=5, - ) - self.assertIsNotNone(resp.choices[0].message.content) - self.assertTrue(len(resp.choices[0].message.content) > 0) - - def test_load_model_model_field_matches(self): - """The model field in every event should match the canonical model ID.""" - _, events = self._load_model(self.MODEL) - - for event in events: - self.assertTrue( - event["model"].startswith(self.MODEL), - f"Event model '{event['model']}' doesn't match '{self.MODEL}'", - ) - - def test_concurrent_non_streaming(self): - """Two concurrent non-streaming responses requests should both complete.""" - import concurrent.futures - - inputs = ["Say hello", "Say goodbye"] - results = [None, None] - - def request_in_thread(index): - client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") - results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False) - - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: - futures = [pool.submit(request_in_thread, i) for i in range(2)] - concurrent.futures.wait(futures) - for f in futures: - f.result() - - for i in range(2): - self.assertIsNotNone(results[i]) - self.assertEqual(results[i].status, "completed") - self.assertTrue(len(results[i].output[0].content[0].text) > 0) - - def test_concurrent_streaming(self): - """Two concurrent streaming responses requests should both produce complete event streams.""" - import concurrent.futures - - inputs = ["Say hello", "Say goodbye"] - results = [None, None] - - def stream_in_thread(index): - client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") - events = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True)) - results[index] = events - - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: - futures = [pool.submit(stream_in_thread, i) for i in range(2)] - concurrent.futures.wait(futures) - for f in futures: - f.result() - - for i in range(2): - types = [e.type for e in results[i]] - self.assertIn("response.created", types, f"Request {i} missing created event") - self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events") - self.assertIn("response.completed", types, f"Request {i} missing completed event") - - -# Real image URL for VLM tests (person + dog on a beach) -_DOG_IMAGE_URL = "https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2-VL/demo_small.jpg" - - -@slow -@require_vision -@require_serve -class TestVLM(unittest.TestCase): - """Integration tests for VLM (vision-language model) support. Requires torchvision.""" - - MODEL = "HuggingFaceTB/SmolVLM-256M-Instruct" - - @classmethod - def setUpClass(cls): - cls.serve, port = _start_serve() - cls.base_url = f"http://localhost:{port}" - cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") - - @classmethod - def tearDownClass(cls): - cls.serve.kill_server() - - def test_chat_completion_with_image(self): - """Chat completions should accept image_url content and produce a meaningful response.""" - resp = self.client.chat.completions.create( - model=self.MODEL, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "What do you see in this image?"}, - {"type": "image_url", "image_url": {"url": _DOG_IMAGE_URL}}, - ], - } - ], - max_tokens=50, - ) - text = resp.choices[0].message.content - self.assertIsNotNone(text) - self.assertTrue( - any(word in text.lower() for word in ["dog", "beach", "person"]), - f"Expected dog/beach/person in response, got: {text}", - ) - - def test_responses_with_image(self): - """Responses API should accept image_url content and produce a meaningful response.""" - resp = self.client.responses.create( - model=self.MODEL, - input=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "What do you see in this image?"}, - {"type": "image_url", "image_url": {"url": _DOG_IMAGE_URL}}, - ], - } - ], - stream=False, - max_output_tokens=50, - ) - self.assertEqual(resp.status, "completed") - text = resp.output[0].content[0].text - self.assertTrue( - any(word in text.lower() for word in ["dog", "beach", "person"]), - f"Expected dog/beach/person in response, got: {text}", - ) - - -@slow -@require_serve -class TestTranscription(unittest.TestCase): - """Integration tests for POST /v1/audio/transcriptions with whisper-tiny.""" - - MODEL = "openai/whisper-tiny" - - @classmethod - def setUpClass(cls): - cls.serve, port = _start_serve() - cls.base_url = f"http://localhost:{port}" - - @classmethod - def tearDownClass(cls): - cls.serve.kill_server() - - @classmethod - def _get_audio_bytes(cls): - """Download the MLK 'I have a dream' speech sample from HF Hub.""" - if not hasattr(cls, "_audio_bytes"): - from huggingface_hub import hf_hub_download - - path = hf_hub_download("Narsil/asr_dummy", "mlk.flac", repo_type="dataset") - with open(path, "rb") as f: - cls._audio_bytes = f.read() - return cls._audio_bytes - - def test_transcription_returns_text(self): - """POST /v1/audio/transcriptions with real speech returns meaningful transcription.""" - - audio_bytes = self._get_audio_bytes() - resp = httpx.post( - f"{self.base_url}/v1/audio/transcriptions", - files={"file": ("mlk.flac", audio_bytes, "audio/flac")}, - data={"model": self.MODEL}, - timeout=120, - ) - self.assertEqual(resp.status_code, 200) - data = resp.json() - self.assertIn("text", data) - self.assertIsInstance(data["text"], str) - # Whisper-tiny should recognize at least "dream" from the MLK speech - self.assertIn("dream", data["text"].lower()) - - def test_transcription_openai_client(self): - """Transcription should work via the OpenAI Python client.""" - audio_bytes = self._get_audio_bytes() - client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") - result = client.audio.transcriptions.create( - model=self.MODEL, - file=("mlk.flac", audio_bytes), - ) - self.assertIsInstance(result.text, str) - self.assertTrue(len(result.text) > 10) - - def test_transcription_streaming(self): - """Streaming transcription should yield text chunks via SSE.""" - - audio_bytes = self._get_audio_bytes() - with httpx.stream( - "POST", - f"{self.base_url}/v1/audio/transcriptions", - files={"file": ("mlk.flac", audio_bytes, "audio/flac")}, - data={"model": self.MODEL, "stream": "true"}, - timeout=120, - ) as resp: - self.assertEqual(resp.status_code, 200) - - chunks = [] - for line in resp.iter_lines(): - if line and line.startswith("data: "): - chunks.append(line[len("data: ") :]) - - self.assertGreater(len(chunks), 0, "No streaming chunks received") - full_text = "".join(chunks) - self.assertIn("dream", full_text.lower()) - - def test_transcription_missing_file(self): - """POST without a file should fail.""" - - resp = httpx.post( - f"{self.base_url}/v1/audio/transcriptions", - data={"model": self.MODEL}, - timeout=30, - ) - self.assertNotEqual(resp.status_code, 200) - - -@slow -@require_serve -@require_torch_accelerator -class TestContinuousBatchingChatCompletion(unittest.TestCase): - """Integration tests for /v1/chat/completions with continuous batching.""" - - MODEL = "Qwen/Qwen2.5-0.5B-Instruct" - - @classmethod - def setUpClass(cls): - cls.serve, port = _start_serve( - force_model=cls.MODEL, - device="cuda:0", - continuous_batching=True, - attn_implementation="sdpa", - default_seed=42, - ) - cls.base_url = f"http://localhost:{port}" - cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") - - @classmethod - def tearDownClass(cls): - cls.serve.kill_server() - - def test_streaming(self): - """Streaming chat completion with CB produces text.""" - text = "" - for chunk in self.client.chat.completions.create( - model=self.MODEL, - messages=[ - {"role": "system", "content": "You are a sports assistant designed to craft sports programs."}, - {"role": "user", "content": "Tell me what you can do."}, - ], - stream=True, - max_tokens=30, - ): - if chunk.choices[0].delta.content: - text += chunk.choices[0].delta.content - self.assertTrue(len(text) > 0) - - def test_non_streaming(self): - """Non-streaming chat completion with CB returns a full response.""" - resp = self.client.chat.completions.create( - model=self.MODEL, - messages=[{"role": "user", "content": "Say hello"}], - max_tokens=20, - ) - self.assertIsNotNone(resp.choices[0].message.content) - self.assertTrue(len(resp.choices[0].message.content) > 0) - - def test_multi_turn(self): - """Multi-turn conversation works with CB.""" - resp = self.client.chat.completions.create( - model=self.MODEL, - messages=[ - {"role": "user", "content": "My name is Alice"}, - {"role": "assistant", "content": "Nice to meet you!"}, - {"role": "user", "content": "What is my name?"}, - ], - max_tokens=20, - ) - self.assertIn("Alice", resp.choices[0].message.content) - - def test_request_cancellation(self): - """Opening a stream and closing it early triggers CB cancellation.""" - - request_id = "test-cb-cancel" - - # Open a streaming request and close after a few chunks - with httpx.stream( - "POST", - f"{self.base_url}/v1/chat/completions", - headers={"X-Request-ID": request_id}, - json={ - "model": self.MODEL, - "stream": True, - "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}], - }, - timeout=30, - ) as resp: - self.assertEqual(resp.status_code, 200) - chunks_read = 0 - for _ in resp.iter_lines(): - chunks_read += 1 - if chunks_read >= 3: - break - - # Poll for cancellation in the CB scheduler - scheduler = self.serve._generation_state._cb_manager.scheduler - deadline = time.time() + 8.0 - while time.time() < deadline: - if scheduler.request_is_cancelled(request_id): - break - time.sleep(0.1) - - self.assertTrue( - scheduler.request_is_cancelled(request_id), - f"Request {request_id} not cancelled in scheduler after stream close.", - ) - - # Server should still be healthy and serve subsequent requests - resp = self.client.chat.completions.create( - model=self.MODEL, - messages=[{"role": "user", "content": "Say hi"}], - max_tokens=10, - ) - self.assertIsNotNone(resp.choices[0].message.content) - - -@slow -@require_serve -@require_torch_accelerator -class TestContinuousBatchingResponses(unittest.TestCase): - """Integration tests for /v1/responses with continuous batching.""" - - MODEL = "Qwen/Qwen2.5-0.5B-Instruct" - - @classmethod - def setUpClass(cls): - cls.serve, port = _start_serve( - force_model=cls.MODEL, - device="cuda:0", - continuous_batching=True, - attn_implementation="sdpa", - default_seed=42, - ) - cls.base_url = f"http://localhost:{port}" - cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") - - @classmethod - def tearDownClass(cls): - cls.serve.kill_server() - - def test_streaming(self): - """Streaming response with CB produces text.""" - text = "" - stream = self.client.responses.create( - model=self.MODEL, - input="Say hello in one sentence.", - stream=True, - max_output_tokens=30, - ) - for event in stream: - if event.type == "response.output_text.delta": - text += event.delta - self.assertTrue(len(text) > 0) - - def test_non_streaming(self): - """Non-streaming response with CB returns text.""" - resp = self.client.responses.create( - model=self.MODEL, - input="Say hello in one sentence.", - stream=False, - max_output_tokens=30, - ) - content = resp.output[0].content[0].text - self.assertTrue(len(content) > 0) - - def test_multi_turn(self): - """Multi-turn conversation works with CB via Responses API.""" - resp = self.client.responses.create( - model=self.MODEL, - input=[ - {"role": "user", "content": "My name is Alice"}, - {"role": "assistant", "content": "Nice to meet you!"}, - {"role": "user", "content": "What is my name?"}, - ], - stream=False, - max_output_tokens=20, - ) - content = resp.output[0].content[0].text - self.assertIn("Alice", content) - - def test_request_cancellation(self): - """Opening a stream and closing it early triggers CB cancellation.""" - - request_id = "test-cb-resp-cancel" - - with httpx.stream( - "POST", - f"{self.base_url}/v1/responses", - headers={"X-Request-ID": request_id}, - json={ - "model": self.MODEL, - "stream": True, - "input": "Count slowly so I can cancel you.", - "max_output_tokens": 500, - }, - timeout=30, - ) as resp: - self.assertEqual(resp.status_code, 200) - # Read enough data to ensure CB generation has started, then close. - received = b"" - for chunk in resp.iter_bytes(chunk_size=512): - received += chunk - if b"output_text.delta" in received: - break - - # Poll for cancellation in the CB scheduler - scheduler = self.serve._generation_state._cb_manager.scheduler - deadline = time.time() + 8.0 - while time.time() < deadline: - if scheduler.request_is_cancelled(request_id): - break - time.sleep(0.1) - - self.assertTrue( - scheduler.request_is_cancelled(request_id), - f"Request {request_id} not cancelled in scheduler after stream close.", - ) - - # Server should still serve subsequent requests - resp = self.client.responses.create( - model=self.MODEL, - input="Say hi", - stream=False, - max_output_tokens=10, - ) - self.assertTrue(len(resp.output[0].content[0].text) > 0) From a8461fc8af1f50ba56aaea70b97a61739fb4bb27 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 30 Mar 2026 18:07:37 +0000 Subject: [PATCH 54/64] queue draining --- .../cli/serving/chat_completion.py | 65 ++++--- src/transformers/cli/serving/response.py | 180 ++++++++++-------- 2 files changed, 143 insertions(+), 102 deletions(-) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 10f22c230e47..07835167bd6e 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -182,32 +182,47 @@ async def sse_gen() -> AsyncGenerator[str, None]: try: yield self._build_chunk_sse(request_id, role="assistant", model=model_id) - while True: + done = False + while not done: text = await queue.get() - if text is None: - break - elif isinstance(text, _StreamError): - yield f'data: {{"error": "{text.msg}"}}\n\n' - return - - # Tool call parsing: None = normal text, CONSUMED = buffering, else = tool call dict - chunk_kwargs = {"content": text} - if parser is not None and (result := parser.feed(text)) is not None: - if result is ToolCallParser.CONSUMED: - continue - has_tool_calls = True - chunk_kwargs = { - "tool_calls": [ - ChoiceDeltaToolCall( - index=0, - type="function", - id=f"{request_id}_tool_call", - function={"name": result["name"], "arguments": result["arguments"]}, - ) - ] - } - - yield self._build_chunk_sse(request_id, model=model_id, **chunk_kwargs) + batch = [text] + try: + while True: + batch.append(queue.get_nowait()) + except asyncio.QueueEmpty: + pass + + sse_parts: list[str] = [] + for text in batch: + if text is None: + done = True + break + if isinstance(text, _StreamError): + sse_parts.append(f'data: {{"error": "{text.msg}"}}\n\n') + yield "".join(sse_parts) + return + + # Tool call parsing: None = normal text, CONSUMED = buffering, else = tool call dict + chunk_kwargs = {"content": text} + if parser is not None and (result := parser.feed(text)) is not None: + if result is ToolCallParser.CONSUMED: + continue + has_tool_calls = True + chunk_kwargs = { + "tool_calls": [ + ChoiceDeltaToolCall( + index=0, + type="function", + id=f"{request_id}_tool_call", + function={"name": result["name"], "arguments": result["arguments"]}, + ) + ] + } + + sse_parts.append(self._build_chunk_sse(request_id, model=model_id, **chunk_kwargs)) + + if sse_parts: + yield "".join(sse_parts) hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens if has_tool_calls: diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index ea6d6ebfb20d..9bdcdd6912ac 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -290,96 +290,122 @@ async def event_stream() -> AsyncGenerator[str, None]: ) seq += 1 - # 4. Stream tokens + # 4. Stream tokens — drain queue to batch HTTP writes full_text = "" tool_calls = [] + done = False - while True: + while not done: text = await queue.get() - if text is None: - break - if isinstance(text, _StreamError): - logger.error(f"Exception in response generation: {text.msg}") - yield self.chunk_to_sse( - ResponseErrorEvent(type="error", sequence_number=seq, message=text.msg) - ) - seq += 1 - yield self.chunk_to_sse( - ResponseFailedEvent( - type="response.failed", - sequence_number=seq, - response=Response( - **response_base, - status="failed", - output=[], - error=ResponseError(code="server_error", message=text.msg), - ), - ) - ) - return - - # Tool call parsing - if parser is not None and (result := parser.feed(text)) is not None: - if result is not ToolCallParser.CONSUMED: - # Emit tool call as a function_call output item - tc_id = f"{request_id}_tool_call" - name = result["name"] - arguments = result["arguments"] - tc_item = ResponseFunctionToolCall( - id=tc_id, - call_id=tc_id, - type="function_call", - name=name, - arguments=arguments, - status="completed", - ) - tool_calls.append(tc_item) - # output[0] = message, output[1..N] = tool calls (required by OpenAI SSE spec) - output_index += 1 - yield self.chunk_to_sse( - ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=seq, - output_index=output_index, - item=tc_item, + # Drain all available tokens for one batched HTTP write + batch = [text] + try: + while True: + batch.append(queue.get_nowait()) + except asyncio.QueueEmpty: + pass + + sse_parts: list[str] = [] + for text in batch: + if text is None: + done = True + break + if isinstance(text, _StreamError): + logger.error(f"Exception in response generation: {text.msg}") + sse_parts.append( + self.chunk_to_sse( + ResponseErrorEvent(type="error", sequence_number=seq, message=text.msg) ) ) seq += 1 - yield self.chunk_to_sse( - ResponseFunctionCallArgumentsDoneEvent( - type="response.function_call_arguments.done", - sequence_number=seq, - item_id=tc_id, - output_index=output_index, - arguments=arguments, - name=name, + sse_parts.append( + self.chunk_to_sse( + ResponseFailedEvent( + type="response.failed", + sequence_number=seq, + response=Response( + **response_base, + status="failed", + output=[], + error=ResponseError(code="server_error", message=text.msg), + ), + ) ) ) - seq += 1 - yield self.chunk_to_sse( - ResponseOutputItemDoneEvent( - type="response.output_item.done", + yield "".join(sse_parts) + return + + # Tool call parsing + if parser is not None and (result := parser.feed(text)) is not None: + if result is not ToolCallParser.CONSUMED: + tc_id = f"{request_id}_tool_call" + name = result["name"] + arguments = result["arguments"] + tc_item = ResponseFunctionToolCall( + id=tc_id, + call_id=tc_id, + type="function_call", + name=name, + arguments=arguments, + status="completed", + ) + tool_calls.append(tc_item) + output_index += 1 + sse_parts.append( + self.chunk_to_sse( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=seq, + output_index=output_index, + item=tc_item, + ) + ) + ) + seq += 1 + sse_parts.append( + self.chunk_to_sse( + ResponseFunctionCallArgumentsDoneEvent( + type="response.function_call_arguments.done", + sequence_number=seq, + item_id=tc_id, + output_index=output_index, + arguments=arguments, + name=name, + ) + ) + ) + seq += 1 + sse_parts.append( + self.chunk_to_sse( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=seq, + output_index=output_index, + item=tc_item, + ) + ) + ) + seq += 1 + continue + + full_text += text + sse_parts.append( + self.chunk_to_sse( + ResponseTextDeltaEvent( + type="response.output_text.delta", + item_id=msg_id, sequence_number=seq, - output_index=output_index, - item=tc_item, + output_index=0, + content_index=0, + delta=text, + logprobs=[], ) ) - seq += 1 - continue - - full_text += text - yield self.chunk_to_sse( - ResponseTextDeltaEvent( - type="response.output_text.delta", - item_id=msg_id, - sequence_number=seq, - output_index=0, - content_index=0, - delta=text, - logprobs=[], ) - ) - seq += 1 + seq += 1 + + if sse_parts: + yield "".join(sse_parts) # 5. Close text output output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[]) From 40417ee30b119d89a5c1dc520d11ca6c501e930a Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 31 Mar 2026 14:52:52 +0000 Subject: [PATCH 55/64] some logs --- src/transformers/cli/chat.py | 1 + src/transformers/cli/serve.py | 2 +- src/transformers/cli/serving/chat_completion.py | 3 +-- src/transformers/cli/serving/model_manager.py | 1 + src/transformers/cli/serving/response.py | 3 +-- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/cli/chat.py b/src/transformers/cli/chat.py index f38305550de7..968fd290c75b 100644 --- a/src/transformers/cli/chat.py +++ b/src/transformers/cli/chat.py @@ -163,6 +163,7 @@ async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) elapsed = time.time() - start_time if elapsed > 0 and completion_tokens > 0: tok_per_sec = completion_tokens / elapsed + self._console.print() self._console.print(f"[dim]{completion_tokens} tokens in {elapsed:.1f}s ({tok_per_sec:.1f} tok/s)[/dim]") self._console.print() diff --git a/src/transformers/cli/serve.py b/src/transformers/cli/serve.py index 8f07684ac8bb..f4a69e49a76a 100644 --- a/src/transformers/cli/serve.py +++ b/src/transformers/cli/serve.py @@ -112,7 +112,7 @@ def __init__( enable_cors=enable_cors, ) - config = uvicorn.Config(app, host=host, port=port, log_level=log_level) + config = uvicorn.Config(app, host=host, port=port, log_level="info") self.server = uvicorn.Server(config) if non_blocking: diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 30272eb4697f..520eab7abbf4 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -92,8 +92,6 @@ class ChatCompletionHandler(BaseHandler): _valid_params_class = TransformersCompletionCreateParamsStreaming _unused_fields = UNUSED_CHAT_COMPLETION_FIELDS - # ----- entry point ----- - async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: """Validate the request, load the model, and dispatch to streaming or non-streaming. @@ -109,6 +107,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse model_id, model, processor = self._resolve_model(body) modality = self.model_manager.get_model_modality(model, processor=processor) use_cb = self.generation_state.use_continuous_batching(model, modality) + logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}") gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb) processor_inputs = self.get_processor_inputs_from_messages(body["messages"], modality) diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index fc97dd992a4f..10706dae1d64 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -254,6 +254,7 @@ def load_model_and_processor( with lock: if model_id_and_revision not in self.loaded_models: + logger.warning(f"Loading {model_id_and_revision}") if progress_callback is not None: progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "processor"}) processor = self._load_processor(model_id_and_revision) diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 3b488558158f..23d49a480d33 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -95,8 +95,6 @@ class ResponseHandler(BaseHandler): _valid_params_class = TransformersResponseCreateParamsStreaming _unused_fields = UNUSED_RESPONSE_FIELDS - # ----- entry point ----- - async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: """Validate, load model, dispatch to streaming or non-streaming. @@ -112,6 +110,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse model_id, model, processor = self._resolve_model(body) modality = self.model_manager.get_model_modality(model, processor=processor) use_cb = self.generation_state.use_continuous_batching(model, modality) + logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}") gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb) # Two-step input conversion (chat completions skips step 1 since messages are already standard): From ced96c2bb66898f7c032a6cd81ba79a82810ec6e Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 31 Mar 2026 16:01:58 +0000 Subject: [PATCH 56/64] readd nathan feature + some minor fixes --- .../pytorch/transformers_serve_cb_eval_job.py | 16 ++++++--- src/transformers/cli/serve.py | 33 +++++++++++++++++-- src/transformers/cli/serving/model_manager.py | 11 +++++-- src/transformers/cli/serving/utils.py | 25 +++++++++++--- tests/cli/test_serve.py | 22 ++++++++++++- 5 files changed, 93 insertions(+), 14 deletions(-) diff --git a/examples/pytorch/transformers_serve_cb_eval_job.py b/examples/pytorch/transformers_serve_cb_eval_job.py index c6355427b161..77ccf7ffb2f3 100644 --- a/examples/pytorch/transformers_serve_cb_eval_job.py +++ b/examples/pytorch/transformers_serve_cb_eval_job.py @@ -19,7 +19,7 @@ from inspect_evals.gpqa import gpqa_diamond -def wait_for_server_up(server_process, timeout=600): +def wait_for_server_up(server_process, port=8000, timeout=600): start_time = time.time() import urllib.error @@ -27,7 +27,7 @@ def wait_for_server_up(server_process, timeout=600): while time.time() - start_time < timeout: try: - req = urllib.request.urlopen("http://127.0.0.1:8000/health", timeout=2) + req = urllib.request.urlopen(f"http://127.0.0.1:{port}/health", timeout=2) if req.status == 200: elapsed = time.time() - start_time print("\n" + "=" * 70) @@ -69,6 +69,12 @@ def main(): action="store_true", help="Disable continuous batching (enabled by default)", ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port for the transformers serve server (default: 8000)", + ) parser.add_argument( "--limit", type=int, @@ -133,6 +139,7 @@ def main(): serve_cmd = [ "transformers", "serve", + args.model, ] # Add continuous batching if not disabled @@ -154,6 +161,7 @@ def main(): # Always use sdpa attention implementation serve_cmd.extend(["--attn-implementation", "kernels-community/flash-attn2"]) + serve_cmd.extend(["--port", str(args.port)]) print("Starting transformers serve with continuous batching...") print(f"Model: {args.model}") @@ -171,13 +179,13 @@ def main(): # Start server with output going directly to stdout/stderr server_process = subprocess.Popen(serve_cmd, stdout=None, stderr=None) - wait_for_server_up(server_process, timeout=600) + wait_for_server_up(server_process, port=args.port, timeout=600) eval( gpqa_diamond, model=f"openai-api/transformers-serve/{args.model}", log_dir=args.log_dir, - model_base_url="http://localhost:8000/v1", + model_base_url=f"http://localhost:{args.port}/v1", display="plain", limit=args.limit, model_args={"stream": False}, diff --git a/src/transformers/cli/serve.py b/src/transformers/cli/serve.py index f4a69e49a76a..edb1bb8f5c32 100644 --- a/src/transformers/cli/serve.py +++ b/src/transformers/cli/serve.py @@ -36,8 +36,24 @@ def __init__( force_model: Annotated[str | None, typer.Argument(help="Model to preload and use for all requests.")] = None, # Model options continuous_batching: Annotated[ - bool, typer.Option(help="Enable continuous batching with paged attention for higher throughput.") + bool, + typer.Option(help="Enable continuous batching with paged attention. Configure with --cb-* flags."), ] = False, + cb_block_size: Annotated[ + int | None, typer.Option(help="KV cache block size in tokens for continuous batching.") + ] = None, + cb_num_blocks: Annotated[ + int | None, typer.Option(help="Number of KV cache blocks for continuous batching.") + ] = None, + cb_max_batch_tokens: Annotated[ + int | None, typer.Option(help="Maximum tokens per batch for continuous batching.") + ] = None, + cb_max_memory_percent: Annotated[ + float | None, typer.Option(help="Max GPU memory fraction for KV cache (0.0-1.0).") + ] = None, + cb_use_cuda_graph: Annotated[ + bool | None, typer.Option(help="Enable CUDA graphs for continuous batching.") + ] = None, attn_implementation: Annotated[ str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).") ] = None, @@ -90,7 +106,20 @@ def __init__( model_timeout=model_timeout, force_model=force_model, ) - self._generation_state = GenerationState(continuous_batching=continuous_batching, compile=compile) + from transformers import ContinuousBatchingConfig + + cb_config = ContinuousBatchingConfig( + block_size=cb_block_size, + num_blocks=cb_num_blocks, + max_batch_tokens=cb_max_batch_tokens, + max_memory_percent=cb_max_memory_percent, + use_cuda_graph=cb_use_cuda_graph, + ) + self._generation_state = GenerationState( + continuous_batching=continuous_batching, + compile=compile, + cb_config=cb_config, + ) self._chat_handler = ChatCompletionHandler( model_manager=self._model_manager, diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 10706dae1d64..39b477302c58 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -164,10 +164,17 @@ def _validate_args(self): f"Unsupported quantization method: '{self.quantization}'. Must be 'bnb-4bit' or 'bnb-8bit'." ) VALID_ATTN_IMPLEMENTATIONS = {"eager", "sdpa", "flash_attention_2", "flash_attention_3", "flex_attention"} - if self.attn_implementation is not None and self.attn_implementation not in VALID_ATTN_IMPLEMENTATIONS: + is_kernels_community = self.attn_implementation is not None and self.attn_implementation.startswith( + "kernels-community/" + ) + if ( + self.attn_implementation is not None + and not is_kernels_community + and self.attn_implementation not in VALID_ATTN_IMPLEMENTATIONS + ): raise ValueError( f"Unsupported attention implementation: '{self.attn_implementation}'. " - f"Must be one of {VALID_ATTN_IMPLEMENTATIONS}." + f"Must be one of {VALID_ATTN_IMPLEMENTATIONS} or a kernels-community kernel (e.g. 'kernels-community/flash-attn2')." ) @staticmethod diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index d46ca261cce3..d9828d123b12 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -38,7 +38,13 @@ import tokenizers import torch - from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin + from transformers import ( + ContinuousBatchingConfig, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerFast, + ProcessorMixin, + ) from transformers.generation.continuous_batching.continuous_api import ContinuousBatchingManager from transformers.generation.continuous_batching.requests import GenerationOutput from transformers.generation.continuous_batching.scheduler import Scheduler @@ -614,8 +620,9 @@ class CBGenerateManager(BaseGenerateManager): to ``add_request`` and the CB manager no longer needs a shared config. """ - def __init__(self): + def __init__(self, cb_config: "ContinuousBatchingConfig | None" = None): self._cb = None + self._cb_config = cb_config def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> None: """Initialize the CB manager on first call with the request's generation config. @@ -629,7 +636,9 @@ def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> N if self._cb is not None: return - self._cb = model.init_continuous_batching(generation_config=gen_config) + self._cb = model.init_continuous_batching( + generation_config=gen_config, continuous_batching_config=self._cb_config + ) self._cb.start() def generate_streaming( @@ -727,9 +736,15 @@ class GenerationState: sequential ``model.generate()`` calls. """ - def __init__(self, continuous_batching: bool = False, compile: bool = False): + def __init__( + self, + continuous_batching: bool = False, + compile: bool = False, + cb_config: "ContinuousBatchingConfig | None" = None, + ): self._continuous_batching = continuous_batching self._compile = compile + self._cb_config = cb_config self._generate_managers: dict[str, GenerateManager] = {} self._cb_manager: CBGenerateManager | None = None self._cb_model_id: str | None = None @@ -770,7 +785,7 @@ def get_manager(self, model_id: str, use_cb: bool = False) -> BaseGenerateManage self._cb_manager.stop() self._cb_manager = None if self._cb_manager is None: - self._cb_manager = CBGenerateManager() + self._cb_manager = CBGenerateManager(cb_config=self._cb_config) self._cb_model_id = model_id return self._cb_manager if model_id not in self._generate_managers: diff --git a/tests/cli/test_serve.py b/tests/cli/test_serve.py index 63503cbe0bf9..d6c6fcbcfed2 100644 --- a/tests/cli/test_serve.py +++ b/tests/cli/test_serve.py @@ -39,7 +39,7 @@ ToolCallParser, detect_tool_format, ) -from transformers.testing_utils import require_serve, require_torch_accelerator, require_vision, slow +from transformers.testing_utils import require_librosa, require_serve, require_torch_accelerator, require_vision, slow from transformers.utils.import_utils import is_serve_available @@ -1557,6 +1557,7 @@ def test_responses_with_image(self): @slow +@require_librosa @require_serve class TestTranscription(unittest.TestCase): """Integration tests for POST /v1/audio/transcriptions with whisper-tiny.""" @@ -1694,6 +1695,25 @@ def test_non_streaming(self): self.assertIsNotNone(resp.choices[0].message.content) self.assertTrue(len(resp.choices[0].message.content) > 0) + def test_non_streaming_response_json_format(self): + """Non-streaming CB responses return proper JSON objects, not double-encoded strings.""" + response = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hello"}], + stream=False, + max_tokens=5, + ) + self.assertIsNotNone(response) + self.assertIsNotNone(response.id) + self.assertIsNotNone(response.choices) + self.assertEqual(len(response.choices), 1) + + choice = response.choices[0] + self.assertIsNotNone(choice.message) + self.assertIsNotNone(choice.message.content) + self.assertEqual(choice.message.role, "assistant") + self.assertIsInstance(choice.message.content, str) + def test_multi_turn(self): """Multi-turn conversation works with CB.""" resp = self.client.chat.completions.create( From ff02cd796d5587008fcbfdbf4b6ee9edcc86cc7e Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 31 Mar 2026 16:40:11 +0000 Subject: [PATCH 57/64] fix --- src/transformers/cli/serve.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/transformers/cli/serve.py b/src/transformers/cli/serve.py index edb1bb8f5c32..4d9e6c712e7c 100644 --- a/src/transformers/cli/serve.py +++ b/src/transformers/cli/serve.py @@ -108,13 +108,18 @@ def __init__( ) from transformers import ContinuousBatchingConfig - cb_config = ContinuousBatchingConfig( - block_size=cb_block_size, - num_blocks=cb_num_blocks, - max_batch_tokens=cb_max_batch_tokens, - max_memory_percent=cb_max_memory_percent, - use_cuda_graph=cb_use_cuda_graph, - ) + cb_kwargs = { + k: v + for k, v in { + "block_size": cb_block_size, + "num_blocks": cb_num_blocks, + "max_batch_tokens": cb_max_batch_tokens, + "max_memory_percent": cb_max_memory_percent, + "use_cuda_graph": cb_use_cuda_graph, + }.items() + if v is not None + } + cb_config = ContinuousBatchingConfig(**cb_kwargs) if cb_kwargs else None self._generation_state = GenerationState( continuous_batching=continuous_batching, compile=compile, From 307498eea8c40bb7e18162a680263fdb1f1d366d Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 31 Mar 2026 16:58:34 +0000 Subject: [PATCH 58/64] guard transcription --- docker/transformers-all-latest-gpu/Dockerfile | 3 +++ src/transformers/cli/serving/transcription.py | 6 +++++- src/transformers/testing_utils.py | 8 ++++++++ src/transformers/utils/import_utils.py | 5 +++++ tests/cli/test_serve.py | 10 +++++++++- 5 files changed, 30 insertions(+), 2 deletions(-) diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index ab4bb8b12114..1380dc8a405c 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -105,6 +105,9 @@ RUN python3 -m pip install --no-cache-dir python-Levenshtein # For `FastSpeech2ConformerTokenizer` tokenizer RUN python3 -m pip install --no-cache-dir g2p-en +# For serving tests (audio pipelines) +RUN python3 -m pip install --no-cache-dir librosa python-multipart + # For Some bitsandbytes tests RUN python3 -m pip install --no-cache-dir einops diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py index be021be414b9..b63add7e5ed6 100644 --- a/src/transformers/cli/serving/transcription.py +++ b/src/transformers/cli/serving/transcription.py @@ -92,10 +92,14 @@ async def handle_request(self, request: Request) -> JSONResponse | StreamingResp Returns: `JSONResponse | StreamingResponse`: Transcription result or SSE stream. """ - from transformers.utils.import_utils import is_librosa_available + from transformers.utils.import_utils import is_librosa_available, is_multipart_available if not is_librosa_available(): raise ImportError("Missing librosa dependency for audio transcription. Install with `pip install librosa`") + if not is_multipart_available(): + raise ImportError( + "Missing python-multipart dependency for file uploads. Install with `pip install python-multipart`" + ) async with request.form() as form: self._validate_request(set(form.keys())) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 4b021f2bf013..fdc730df55c8 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -115,6 +115,7 @@ is_liger_kernel_available, is_lomo_available, is_mistral_common_available, + is_multipart_available, is_natten_available, is_nltk_available, is_numba_available, @@ -1415,6 +1416,13 @@ def require_librosa(test_case): return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case) +def require_multipart(test_case): + """ + Decorator marking a test that requires python-multipart + """ + return unittest.skipUnless(is_multipart_available(), "test requires python-multipart")(test_case) + + def require_liger_kernel(test_case): """ Decorator marking a test that requires liger_kernel diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 569e825d928b..70ab6074427f 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -725,6 +725,11 @@ def is_librosa_available() -> bool: return _is_package_available("librosa")[0] +@lru_cache +def is_multipart_available() -> bool: + return _is_package_available("multipart", package_name="python-multipart")[0] + + @lru_cache def is_essentia_available() -> bool: return _is_package_available("essentia")[0] diff --git a/tests/cli/test_serve.py b/tests/cli/test_serve.py index d6c6fcbcfed2..9be3dbeb99ff 100644 --- a/tests/cli/test_serve.py +++ b/tests/cli/test_serve.py @@ -39,7 +39,14 @@ ToolCallParser, detect_tool_format, ) -from transformers.testing_utils import require_librosa, require_serve, require_torch_accelerator, require_vision, slow +from transformers.testing_utils import ( + require_librosa, + require_multipart, + require_serve, + require_torch_accelerator, + require_vision, + slow, +) from transformers.utils.import_utils import is_serve_available @@ -1558,6 +1565,7 @@ def test_responses_with_image(self): @slow @require_librosa +@require_multipart @require_serve class TestTranscription(unittest.TestCase): """Integration tests for POST /v1/audio/transcriptions with whisper-tiny.""" From ffe4c64e07088ed3664a77cfed1b2d2b79fb4895 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 31 Mar 2026 17:01:05 +0000 Subject: [PATCH 59/64] better now --- src/transformers/utils/import_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 70ab6074427f..18278e8c319e 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -727,7 +727,7 @@ def is_librosa_available() -> bool: @lru_cache def is_multipart_available() -> bool: - return _is_package_available("multipart", package_name="python-multipart")[0] + return _is_package_available("multipart")[0] @lru_cache From 06a7881570b3bd41e0932f03ca4edfc5a6f2a8b7 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 31 Mar 2026 17:04:14 +0000 Subject: [PATCH 60/64] fix --- src/transformers/utils/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 55297dd706d9..3f5c7cac386b 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -164,6 +164,7 @@ is_matplotlib_available, is_mistral_common_available, is_mlx_available, + is_multipart_available, is_natten_available, is_ninja_available, is_nltk_available, From 052cbc782b7196ee94085d4018ab09d065460bbe Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 1 Apr 2026 10:16:51 +0000 Subject: [PATCH 61/64] adding lock to see if this helps --- src/transformers/cli/serving/utils.py | 31 ++++++++++++++++----------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index d9828d123b12..2c11f8d8b607 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -623,6 +623,7 @@ class CBGenerateManager(BaseGenerateManager): def __init__(self, cb_config: "ContinuousBatchingConfig | None" = None): self._cb = None self._cb_config = cb_config + self._init_lock = threading.Lock() def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> None: """Initialize the CB manager on first call with the request's generation config. @@ -636,10 +637,14 @@ def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> N if self._cb is not None: return - self._cb = model.init_continuous_batching( - generation_config=gen_config, continuous_batching_config=self._cb_config - ) - self._cb.start() + with self._init_lock: + if self._cb is not None: + return + + self._cb = model.init_continuous_batching( + generation_config=gen_config, continuous_batching_config=self._cb_config + ) + self._cb.start() def generate_streaming( self, @@ -748,6 +753,7 @@ def __init__( self._generate_managers: dict[str, GenerateManager] = {} self._cb_manager: CBGenerateManager | None = None self._cb_model_id: str | None = None + self._cb_manager_lock = threading.Lock() def use_continuous_batching(self, model: "PreTrainedModel", modality: Modality) -> bool: """Check if continuous batching can be used for this model and modality. @@ -780,14 +786,15 @@ def get_manager(self, model_id: str, use_cb: bool = False) -> BaseGenerateManage `BaseGenerateManager`: Either a `GenerateManager` or `CBGenerateManager`. """ if use_cb: - if self._cb_model_id != model_id: - if self._cb_manager is not None: - self._cb_manager.stop() - self._cb_manager = None - if self._cb_manager is None: - self._cb_manager = CBGenerateManager(cb_config=self._cb_config) - self._cb_model_id = model_id - return self._cb_manager + with self._cb_manager_lock: + if self._cb_model_id != model_id: + if self._cb_manager is not None: + self._cb_manager.stop() + self._cb_manager = None + if self._cb_manager is None: + self._cb_manager = CBGenerateManager(cb_config=self._cb_config) + self._cb_model_id = model_id + return self._cb_manager if model_id not in self._generate_managers: self._generate_managers[model_id] = GenerateManager() return self._generate_managers[model_id] From 67997271fd059bc82945976e4b71ee1fbbbfdbda Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 1 Apr 2026 11:42:17 +0000 Subject: [PATCH 62/64] remove locks --- src/transformers/cli/serving/utils.py | 31 +++++++++++---------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index 2c11f8d8b607..d9828d123b12 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -623,7 +623,6 @@ class CBGenerateManager(BaseGenerateManager): def __init__(self, cb_config: "ContinuousBatchingConfig | None" = None): self._cb = None self._cb_config = cb_config - self._init_lock = threading.Lock() def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> None: """Initialize the CB manager on first call with the request's generation config. @@ -637,14 +636,10 @@ def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> N if self._cb is not None: return - with self._init_lock: - if self._cb is not None: - return - - self._cb = model.init_continuous_batching( - generation_config=gen_config, continuous_batching_config=self._cb_config - ) - self._cb.start() + self._cb = model.init_continuous_batching( + generation_config=gen_config, continuous_batching_config=self._cb_config + ) + self._cb.start() def generate_streaming( self, @@ -753,7 +748,6 @@ def __init__( self._generate_managers: dict[str, GenerateManager] = {} self._cb_manager: CBGenerateManager | None = None self._cb_model_id: str | None = None - self._cb_manager_lock = threading.Lock() def use_continuous_batching(self, model: "PreTrainedModel", modality: Modality) -> bool: """Check if continuous batching can be used for this model and modality. @@ -786,15 +780,14 @@ def get_manager(self, model_id: str, use_cb: bool = False) -> BaseGenerateManage `BaseGenerateManager`: Either a `GenerateManager` or `CBGenerateManager`. """ if use_cb: - with self._cb_manager_lock: - if self._cb_model_id != model_id: - if self._cb_manager is not None: - self._cb_manager.stop() - self._cb_manager = None - if self._cb_manager is None: - self._cb_manager = CBGenerateManager(cb_config=self._cb_config) - self._cb_model_id = model_id - return self._cb_manager + if self._cb_model_id != model_id: + if self._cb_manager is not None: + self._cb_manager.stop() + self._cb_manager = None + if self._cb_manager is None: + self._cb_manager = CBGenerateManager(cb_config=self._cb_config) + self._cb_model_id = model_id + return self._cb_manager if model_id not in self._generate_managers: self._generate_managers[model_id] = GenerateManager() return self._generate_managers[model_id] From 3a07c867eb13de24d733cfd6fd35ab9baa1cf80c Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 1 Apr 2026 12:04:36 +0000 Subject: [PATCH 63/64] lock again --- src/transformers/cli/serving/utils.py | 31 ++++++++++++++++----------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index d9828d123b12..2c11f8d8b607 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -623,6 +623,7 @@ class CBGenerateManager(BaseGenerateManager): def __init__(self, cb_config: "ContinuousBatchingConfig | None" = None): self._cb = None self._cb_config = cb_config + self._init_lock = threading.Lock() def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> None: """Initialize the CB manager on first call with the request's generation config. @@ -636,10 +637,14 @@ def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> N if self._cb is not None: return - self._cb = model.init_continuous_batching( - generation_config=gen_config, continuous_batching_config=self._cb_config - ) - self._cb.start() + with self._init_lock: + if self._cb is not None: + return + + self._cb = model.init_continuous_batching( + generation_config=gen_config, continuous_batching_config=self._cb_config + ) + self._cb.start() def generate_streaming( self, @@ -748,6 +753,7 @@ def __init__( self._generate_managers: dict[str, GenerateManager] = {} self._cb_manager: CBGenerateManager | None = None self._cb_model_id: str | None = None + self._cb_manager_lock = threading.Lock() def use_continuous_batching(self, model: "PreTrainedModel", modality: Modality) -> bool: """Check if continuous batching can be used for this model and modality. @@ -780,14 +786,15 @@ def get_manager(self, model_id: str, use_cb: bool = False) -> BaseGenerateManage `BaseGenerateManager`: Either a `GenerateManager` or `CBGenerateManager`. """ if use_cb: - if self._cb_model_id != model_id: - if self._cb_manager is not None: - self._cb_manager.stop() - self._cb_manager = None - if self._cb_manager is None: - self._cb_manager = CBGenerateManager(cb_config=self._cb_config) - self._cb_model_id = model_id - return self._cb_manager + with self._cb_manager_lock: + if self._cb_model_id != model_id: + if self._cb_manager is not None: + self._cb_manager.stop() + self._cb_manager = None + if self._cb_manager is None: + self._cb_manager = CBGenerateManager(cb_config=self._cb_config) + self._cb_model_id = model_id + return self._cb_manager if model_id not in self._generate_managers: self._generate_managers[model_id] = GenerateManager() return self._generate_managers[model_id] From 7a7abf204e81f7d3fe2f03bdedff9aaee60a84d2 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 1 Apr 2026 13:52:53 +0000 Subject: [PATCH 64/64] update bench and remove lock for now --- .../pytorch/transformers_serve_cb_eval_job.py | 22 +++++++++---- src/transformers/cli/serving/utils.py | 31 +++++++------------ 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/examples/pytorch/transformers_serve_cb_eval_job.py b/examples/pytorch/transformers_serve_cb_eval_job.py index 77ccf7ffb2f3..b30f71e4aa2e 100644 --- a/examples/pytorch/transformers_serve_cb_eval_job.py +++ b/examples/pytorch/transformers_serve_cb_eval_job.py @@ -16,7 +16,6 @@ from inspect_ai import eval from inspect_ai.log import bundle_log_dir -from inspect_evals.gpqa import gpqa_diamond def wait_for_server_up(server_process, port=8000, timeout=600): @@ -79,13 +78,19 @@ def main(): "--limit", type=int, default=10, - help="Number of evaluation samples to run (default: 5)", + help="Number of evaluation samples to run (default: 10)", ) parser.add_argument( "--max-connections", type=int, default=10, - help="Maximum concurrent connections for evaluation (default: 2)", + help="Maximum concurrent connections for evaluation (default: 10)", + ) + parser.add_argument( + "--temperature", + type=float, + default=0, + help="Temperature for generation (default: 0)", ) parser.add_argument( "--log-dir", @@ -127,7 +132,7 @@ def main(): ) parser.add_argument( "--cb-use-cuda-graph", - action="store_true", + action=argparse.BooleanOptionalAction, help="Enable CUDA graphs for continuous batching performance", ) @@ -157,7 +162,10 @@ def main(): serve_cmd.extend(["--cb-max-memory-percent", str(args.cb_max_memory_percent)]) - serve_cmd.append("--cb-use-cuda-graph") + if args.cb_use_cuda_graph is True: + serve_cmd.append("--cb-use-cuda-graph") + elif args.cb_use_cuda_graph is False: + serve_cmd.append("--no-cb-use-cuda-graph") # Always use sdpa attention implementation serve_cmd.extend(["--attn-implementation", "kernels-community/flash-attn2"]) @@ -171,6 +179,7 @@ def main(): print(f"CB Max Batch Tokens: {args.cb_max_batch_tokens if args.cb_max_batch_tokens else 'auto'}") print(f"CB Max Memory: {args.cb_max_memory_percent * 100}%") print(f"CB CUDA Graph: {args.cb_use_cuda_graph}") + print(f"Temperature: {args.temperature}") print(f"Command: {' '.join(serve_cmd)}") print("=" * 70) print("SERVER OUTPUT:") @@ -182,13 +191,14 @@ def main(): wait_for_server_up(server_process, port=args.port, timeout=600) eval( - gpqa_diamond, + "hf/Idavidrein/gpqa/diamond", model=f"openai-api/transformers-serve/{args.model}", log_dir=args.log_dir, model_base_url=f"http://localhost:{args.port}/v1", display="plain", limit=args.limit, model_args={"stream": False}, + temperature=args.temperature, max_connections=args.max_connections, max_tokens=2048, ) diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index 2c11f8d8b607..d9828d123b12 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -623,7 +623,6 @@ class CBGenerateManager(BaseGenerateManager): def __init__(self, cb_config: "ContinuousBatchingConfig | None" = None): self._cb = None self._cb_config = cb_config - self._init_lock = threading.Lock() def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> None: """Initialize the CB manager on first call with the request's generation config. @@ -637,14 +636,10 @@ def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> N if self._cb is not None: return - with self._init_lock: - if self._cb is not None: - return - - self._cb = model.init_continuous_batching( - generation_config=gen_config, continuous_batching_config=self._cb_config - ) - self._cb.start() + self._cb = model.init_continuous_batching( + generation_config=gen_config, continuous_batching_config=self._cb_config + ) + self._cb.start() def generate_streaming( self, @@ -753,7 +748,6 @@ def __init__( self._generate_managers: dict[str, GenerateManager] = {} self._cb_manager: CBGenerateManager | None = None self._cb_model_id: str | None = None - self._cb_manager_lock = threading.Lock() def use_continuous_batching(self, model: "PreTrainedModel", modality: Modality) -> bool: """Check if continuous batching can be used for this model and modality. @@ -786,15 +780,14 @@ def get_manager(self, model_id: str, use_cb: bool = False) -> BaseGenerateManage `BaseGenerateManager`: Either a `GenerateManager` or `CBGenerateManager`. """ if use_cb: - with self._cb_manager_lock: - if self._cb_model_id != model_id: - if self._cb_manager is not None: - self._cb_manager.stop() - self._cb_manager = None - if self._cb_manager is None: - self._cb_manager = CBGenerateManager(cb_config=self._cb_config) - self._cb_model_id = model_id - return self._cb_manager + if self._cb_model_id != model_id: + if self._cb_manager is not None: + self._cb_manager.stop() + self._cb_manager = None + if self._cb_manager is None: + self._cb_manager = CBGenerateManager(cb_config=self._cb_config) + self._cb_model_id = model_id + return self._cb_manager if model_id not in self._generate_managers: self._generate_managers[model_id] = GenerateManager() return self._generate_managers[model_id]