diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index ab4bb8b12114..1380dc8a405c 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -105,6 +105,9 @@ RUN python3 -m pip install --no-cache-dir python-Levenshtein # For `FastSpeech2ConformerTokenizer` tokenizer RUN python3 -m pip install --no-cache-dir g2p-en +# For serving tests (audio pipelines) +RUN python3 -m pip install --no-cache-dir librosa python-multipart + # For Some bitsandbytes tests RUN python3 -m pip install --no-cache-dir einops diff --git a/examples/pytorch/transformers_serve_cb_eval_job.py b/examples/pytorch/transformers_serve_cb_eval_job.py index c6355427b161..b30f71e4aa2e 100644 --- a/examples/pytorch/transformers_serve_cb_eval_job.py +++ b/examples/pytorch/transformers_serve_cb_eval_job.py @@ -16,10 +16,9 @@ from inspect_ai import eval from inspect_ai.log import bundle_log_dir -from inspect_evals.gpqa import gpqa_diamond -def wait_for_server_up(server_process, timeout=600): +def wait_for_server_up(server_process, port=8000, timeout=600): start_time = time.time() import urllib.error @@ -27,7 +26,7 @@ def wait_for_server_up(server_process, timeout=600): while time.time() - start_time < timeout: try: - req = urllib.request.urlopen("http://127.0.0.1:8000/health", timeout=2) + req = urllib.request.urlopen(f"http://127.0.0.1:{port}/health", timeout=2) if req.status == 200: elapsed = time.time() - start_time print("\n" + "=" * 70) @@ -69,17 +68,29 @@ def main(): action="store_true", help="Disable continuous batching (enabled by default)", ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port for the transformers serve server (default: 8000)", + ) parser.add_argument( "--limit", type=int, default=10, - help="Number of evaluation samples to run (default: 5)", + help="Number of evaluation samples to run (default: 10)", ) parser.add_argument( "--max-connections", type=int, default=10, - help="Maximum concurrent connections for evaluation (default: 2)", + help="Maximum concurrent connections for evaluation (default: 10)", + ) + parser.add_argument( + "--temperature", + type=float, + default=0, + help="Temperature for generation (default: 0)", ) parser.add_argument( "--log-dir", @@ -121,7 +132,7 @@ def main(): ) parser.add_argument( "--cb-use-cuda-graph", - action="store_true", + action=argparse.BooleanOptionalAction, help="Enable CUDA graphs for continuous batching performance", ) @@ -133,6 +144,7 @@ def main(): serve_cmd = [ "transformers", "serve", + args.model, ] # Add continuous batching if not disabled @@ -150,10 +162,14 @@ def main(): serve_cmd.extend(["--cb-max-memory-percent", str(args.cb_max_memory_percent)]) - serve_cmd.append("--cb-use-cuda-graph") + if args.cb_use_cuda_graph is True: + serve_cmd.append("--cb-use-cuda-graph") + elif args.cb_use_cuda_graph is False: + serve_cmd.append("--no-cb-use-cuda-graph") # Always use sdpa attention implementation serve_cmd.extend(["--attn-implementation", "kernels-community/flash-attn2"]) + serve_cmd.extend(["--port", str(args.port)]) print("Starting transformers serve with continuous batching...") print(f"Model: {args.model}") @@ -163,6 +179,7 @@ def main(): print(f"CB Max Batch Tokens: {args.cb_max_batch_tokens if args.cb_max_batch_tokens else 'auto'}") print(f"CB Max Memory: {args.cb_max_memory_percent * 100}%") print(f"CB CUDA Graph: {args.cb_use_cuda_graph}") + print(f"Temperature: {args.temperature}") print(f"Command: {' '.join(serve_cmd)}") print("=" * 70) print("SERVER OUTPUT:") @@ -171,16 +188,17 @@ def main(): # Start server with output going directly to stdout/stderr server_process = subprocess.Popen(serve_cmd, stdout=None, stderr=None) - wait_for_server_up(server_process, timeout=600) + wait_for_server_up(server_process, port=args.port, timeout=600) eval( - gpqa_diamond, + "hf/Idavidrein/gpqa/diamond", model=f"openai-api/transformers-serve/{args.model}", log_dir=args.log_dir, - model_base_url="http://localhost:8000/v1", + model_base_url=f"http://localhost:{args.port}/v1", display="plain", limit=args.limit, model_args={"stream": False}, + temperature=args.temperature, max_connections=args.max_connections, max_tokens=2048, ) diff --git a/src/transformers/cli/chat.py b/src/transformers/cli/chat.py index c6e434b07481..968fd290c75b 100644 --- a/src/transformers/cli/chat.py +++ b/src/transformers/cli/chat.py @@ -114,11 +114,17 @@ async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) self._console.print(f"[bold blue]<{self.model_id}>:") with Live(console=self._console, refresh_per_second=4) as live: text = "" + completion_tokens = 0 + start_time = time.time() finish_reason: str | None = None async for token in await stream: outputs = token.choices[0].delta.content finish_reason = getattr(token.choices[0], "finish_reason", finish_reason) + usage = getattr(token, "usage", None) + if usage is not None: + completion_tokens = getattr(usage, "completion_tokens", completion_tokens) + if not outputs: continue @@ -154,6 +160,11 @@ async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) # Update the Live console output live.update(markdown, refresh=True) + elapsed = time.time() - start_time + if elapsed > 0 and completion_tokens > 0: + tok_per_sec = completion_tokens / elapsed + self._console.print() + self._console.print(f"[dim]{completion_tokens} tokens in {elapsed:.1f}s ({tok_per_sec:.1f} tok/s)[/dim]") self._console.print() return text, finish_reason @@ -544,14 +555,16 @@ async def _inner_run(self): else: chat.append({"role": "user", "content": user_input}) + extra_body = { + "generation_config": config.to_json_string(), + "model": self.model_id, + } + stream = client.chat_completion( chat, stream=True, model=self.model_id, - extra_body={ - "generation_config": config.to_json_string(), - "model": self.model_id, - }, + extra_body=extra_body, ) model_output, finish_reason = await interface.stream_output(stream) diff --git a/src/transformers/cli/serve.py b/src/transformers/cli/serve.py index 7337eb305b61..4d9e6c712e7c 100644 --- a/src/transformers/cli/serve.py +++ b/src/transformers/cli/serve.py @@ -11,2340 +11,180 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +CLI entry point for `transformers serve`. +""" + import asyncio -import base64 -import copy -import enum -import gc -import io -import json -import re -import tempfile import threading -import time -import uuid -from collections.abc import Callable, Generator, Iterable -from contextlib import asynccontextmanager -from functools import lru_cache -from io import BytesIO -from threading import Thread -from typing import TYPE_CHECKING, Annotated, Optional, TypedDict, Union +from typing import Annotated import typer -from huggingface_hub import scan_cache_dir -from tokenizers.decoders import DecodeStream -from tqdm import tqdm -from tqdm.auto import tqdm as base_tqdm -import transformers -from transformers import AutoTokenizer, BitsAndBytesConfig, GenerationConfig, PreTrainedTokenizerBase -from transformers.generation import ( - LogitsProcessorList, - TextIteratorStreamer, -) from transformers.utils import logging -from transformers.utils.import_utils import ( - is_fastapi_available, - is_librosa_available, - is_openai_available, - is_pydantic_available, - is_uvicorn_available, - is_vision_available, -) - - -if TYPE_CHECKING: - from transformers import ( - PreTrainedModel, - PreTrainedTokenizerFast, - ProcessorMixin, - ) - - from ..generation.continuous_batching import ContinuousBatchingManager - - -if is_librosa_available(): - import librosa - -if is_vision_available(): - from PIL import Image - -serve_dependencies_available = ( - is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available() -) -if serve_dependencies_available: - import uvicorn - from fastapi import FastAPI, HTTPException - from fastapi.middleware.cors import CORSMiddleware - from fastapi.responses import JSONResponse, StreamingResponse - from openai.types.audio.transcription import Transcription - from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase - from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam - from openai.types.chat.chat_completion import Choice - from openai.types.chat.chat_completion_chunk import ( - ChatCompletionChunk, - ChoiceDelta, - ChoiceDeltaToolCall, - ChoiceDeltaToolCallFunction, - ) - from openai.types.chat.chat_completion_chunk import ( - Choice as ChoiceChunk, - ) - from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming - from openai.types.responses import ( - Response, - ResponseCompletedEvent, - ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent, - ResponseCreatedEvent, - ResponseError, - ResponseErrorEvent, - ResponseFailedEvent, - ResponseInProgressEvent, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponseOutputMessage, - ResponseOutputText, - ResponseTextDeltaEvent, - ResponseTextDoneEvent, - ) - from openai.types.responses.response_create_params import ResponseCreateParamsStreaming - from pydantic import BaseModel, TypeAdapter, ValidationError - - # Expand OpenAI's request input types with an optional `generation_config` field - class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): - """ - OpenAI's ResponseCreateParamsStreaming with an additional field for the generation config (as a json string). - """ - - generation_config: str - - class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): - """ - OpenAI's CompletionCreateParamsStreaming with additional fields for the generation config (as a json string) and passing the request_id - """ +from transformers.utils.import_utils import is_serve_available - generation_config: str - - class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False): - """ - OpenAI's TranscriptionCreateParamsBase with an additional field for the generation config (as a json string). - """ - - file: bytes # Overwritten -- pydantic isn't happy with `typing.IO[bytes]`, present in the original type - generation_config: str - stream: bool = False - - # Contrarily to OpenAI's output types, input types are `TypedDict`, which don't have built-in validation. - response_validator = TypeAdapter(TransformersResponseCreateParamsStreaming) - completion_validator = TypeAdapter(TransformersCompletionCreateParamsStreaming) - transcription_validator = TypeAdapter(TransformersTranscriptionCreateParams) - - # Define request fields that are not yet used in `transformers serve`. Receiving these fields will raise an - # HTTPException. - UNUSED_RESPONSE_FIELDS = { - "background", - "include", - "max_tool_calls", - "previous_response_id", - "prompt", - "reasoning", - "service_tier", - "store", - "text", - "tool_choice", - "top_logprobs", - "truncation", - "user", - } - - UNUSED_CHAT_COMPLETION_FIELDS = { - "audio", - "function_call", - "functions", - "logprobs", - "max_completion_tokens", - "metadata", - "modalities", - "n", - "parallel_tool_calls", - "prediction", - "presence_penalty", - "reasoning_effort", - "response_format", - "service_tier", - "stop", - "store", - "stream_options", - "tool_choice", - "top_logprobs", - "user", - "web_search_options", - } - UNUSED_TRANSCRIPTION_FIELDS = { - "chunking_strategy", - "include", - "language", - "prompt", - "response_format", - "timestamp_granularities", - } +from .serving.utils import set_torch_seed logger = logging.get_logger(__name__) -# Possible tokens that indicate the start/end of a tool call -# TODO (joao, matt): streamline tool token detection logic -_TOOL_CALL_TOKENS = { - "qwen": { - "start": "", - "end": "", - }, -} -_MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys()) - -X_REQUEST_ID = "x-request-id" - - -def set_torch_seed(_seed): - import torch - - torch.manual_seed(_seed) - - -def reset_torch_cache(): - import torch - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def torch_ones_like(_input_tensor): - import torch - - return torch.ones_like(_input_tensor) - - -class Modality(enum.Enum): - LLM = "LLM" - VLM = "VLM" - STT = "STT" - TTS = "TTS" - - -def create_generation_config_from_req( - req: dict, - model_generation_config: GenerationConfig, - **kwargs, -) -> GenerationConfig: - """ - Creates a generation config from the parameters of the request. If a generation config is passed in the request, - it will be used as a baseline for parameterization. Otherwise, we will use the model's default generation config. - Other parameters in the request will be applied on top of the baseline. - - Args: - req (`dict`): - The request which may optionally contain generation parameters. - model_generation_config (`GenerationConfig`): - The model's default generation config. - kwargs (`dict`): - Additional parameters to set in the generation config. - - Returns: - The prepared `GenerationConfig` object. - """ - # If there is a generation config in the request, it is a json string serialization from a `GenerationConfig` - # object. For simplicity, flags set here take precedence over all other flags. - if req.get("generation_config") is not None: - generation_config = GenerationConfig(**json.loads(req["generation_config"])) - else: - generation_config = copy.deepcopy(model_generation_config) - - non_standard_kwargs = generation_config.update(**kwargs) - # Set extra kwargs that are not in the `GenerationConfig` class (e.g. continuous batching flags) - for k, v in non_standard_kwargs.items(): - if v is not None: - setattr(generation_config, k, v) - - # Response-specific parameters - if req.get("max_output_tokens") is not None: - generation_config.max_new_tokens = int(req["max_output_tokens"]) - - # Completion-specific parameters - if req.get("max_tokens") is not None: - generation_config.max_new_tokens = int(req["max_tokens"]) - if req.get("frequency_penalty") is not None: - generation_config.repetition_penalty = float(req["frequency_penalty"]) - if req.get("logit_bias") is not None: - generation_config.sequence_bias = req["logit_bias"] - if req.get("stop") is not None: - generation_config.stop_strings = req["stop"] - if req.get("temperature") is not None: - generation_config.temperature = float(req["temperature"]) - if float(req["temperature"]) == 0.0: - generation_config.do_sample = False - if req.get("top_p") is not None: - generation_config.top_p = float(req["top_p"]) - if req.get("seed") is not None: - set_torch_seed(req["seed"]) - - return generation_config - - -class DownloadAggregator: - """Aggregates byte-progress across multiple concurrent download tqdm bars into a single SSE stream. - - huggingface_hub opens one tqdm bar per file shard. This class tracks them all and emits - a single aggregate ``{"stage": "download", "progress": {"current": ..., "total": ...}}`` - event whenever any bar updates. - """ - - def __init__(self, enqueue: Callable[[dict], None], model_id_and_revision: str): - self.enqueue = enqueue - self.model = model_id_and_revision - self.bars: dict[int, tuple[int, int | None]] = {} # id -> (current, total) - self.last_emitted_current: int | None = None - - def register(self, bar_id: int, total: int | None): - self.bars[bar_id] = (0, total) - self._emit() - - def update(self, bar_id: int, current: int, total: int | None): - self.bars[bar_id] = (current, total) - self._emit() - - def close(self, bar_id: int): - pass # keep the bar in _bars so totals remain correct - - def _emit(self): - agg_current = sum(c for c, _ in self.bars.values()) - if agg_current == self.last_emitted_current: - return - self.last_emitted_current = agg_current - totals = [t for _, t in self.bars.values() if t is not None] - agg_total = sum(totals) if totals else None - self.enqueue( - { - "status": "loading", - "model": self.model, - "stage": "download", - "progress": {"current": agg_current, "total": agg_total}, - } - ) - - -class DownloadProxy: - """ - Leverages the DownloadAggregator in order to have a coherent tqdm wrapper. - """ - - def __init__(self, wrapped_bar, download_aggregator): - self.wrapped_bar = wrapped_bar - self.bar_id = id(wrapped_bar) - self.download_aggregator = download_aggregator - - self.n = 0 - self.total = wrapped_bar.total - - def __getattr__(self, name): - return getattr(self.wrapped_bar, name) - - def update(self, n=1): - if n is None: - n = 1 - - self.n += n - self.download_aggregator.update(self.bar_id, self.n, getattr(self.wrapped_bar, "total", self.total)) - - return self.wrapped_bar.update(n) - - def close(self): - self.download_aggregator.close(self.bar_id) - return self.wrapped_bar.close() - - def __enter__(self): - self.wrapped_bar.__enter__() - return self - - def __exit__(self, *a): - return self.wrapped_bar.__exit__(*a) - - def __iter__(self): - count = 0 - for item in self.wrapped_bar: - count += 1 - self.download_aggregator.update(self.bar_id, count, getattr(self.wrapped_bar, "total", self.total)) - - yield item - - -class WeightsProxy: - """ - Wraps the weight-loading tqdm bar to have finer control over how we emit them to the clinet. - """ - - def __init__(self, wrapped_bar, _callable, model_id_and_revision): - self.wrapped_bar = wrapped_bar - self.last_emitted = -1 - self.callable = _callable - self.model_id_and_revision = model_id_and_revision - - self.n = 0 - self.total = wrapped_bar.total - - def __getattr__(self, name): - return getattr(self.wrapped_bar, name) - - def _emit(self): - if self.n == self.last_emitted: - return - self.last_emitted = self.n - - self.callable( - { - "status": "loading", - "model": self.model_id_and_revision, - "stage": "weights", - "progress": { - "current": self.n, - "total": getattr(self.wrapped_bar, "total", self.total), - }, - } - ) - - def update(self, n=1): - if n is None: - n = 1 - - self.n += n - self._emit() - - return self.wrapped_bar.update(n) - - def close(self): - return self.wrapped_bar.close() - - def __enter__(self): - self.wrapped_bar.__enter__() - return self - - def __exit__(self, *a): - return self.wrapped_bar.__exit__(*a) - - def __iter__(self): - for item in self.wrapped_bar: - self.n += 1 - self._emit() - - yield item - - -def set_tqdm_class(callback, mid): - download_aggregator = DownloadAggregator(callback, mid) - - class ProgressTqdm(base_tqdm): - """tqdm subclass that routes progress to the correct SSE stage. - - Bars with ``unit="B"`` are download bars (one per file shard) — they are - aggregated into a single ``download`` stage stream via ``_DownloadAggregator``. - All other bars are weight-loading bars emitted as ``weights`` stage events. - """ - - def __init__(self, *args, **kwargs): - self.sse_unit = kwargs.get("unit") or "it" - kwargs["disable"] = True # suppress server-side display - super().__init__(*args, **kwargs) - self.n = 0 - self.last_emitted = -1 - if self.sse_unit == "B": - self._bar_id = id(self) - download_aggregator.register(self._bar_id, self.total) - - def update(self, n=1): - if n is None: - n = 1 - - self.n += n - - if self.sse_unit == "B": - download_aggregator.update(self._bar_id, self.n, self.total) - elif self.n != self.last_emitted: - self.last_emitted = self.n - - callback( - { - "status": "loading", - "model": mid, - "stage": "weights", - "progress": {"current": self.n, "total": self.total}, - } - ) - - def close(self): - if self.sse_unit == "B": - download_aggregator.close(self._bar_id) - super().close() - - return ProgressTqdm - - -class ToolState: - """Lightweight class to keep track of the tool call state.""" - - def __init__(self): - self.reset() - - def reset(self): - """Reset the tool call state (assumes we're outside a tool call).""" - self.inside_tool_call = False - self.has_tool_name_defined = False - self.arg_nesting_level = 0 - self.buffer = "" - - -class TimedModel: - """ - A class that holds a PreTrainedModel instance and its associated processor. - Automatically deletes the instances after a specified timeout. - """ - - def __init__( - self, - model: "PreTrainedModel", - timeout_seconds: int, - processor: Union["ProcessorMixin", "PreTrainedTokenizerFast"] | None = None, - ): - self.model = model - self._name_or_path = str(model.name_or_path) - self.processor = processor - self.timeout_seconds = timeout_seconds - self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached) - self._timer.start() - - def reset_timer(self): - """Reset the timer for the deletion of the instances.""" - self._timer.cancel() - self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached) - self._timer.start() - - def delete_model(self): - """Delete the wrapped model and processor and clean up resources.""" - if hasattr(self, "model") and self.model is not None: - del self.model - del self.processor - self.model = None - self.processor = None - gc.collect() - - # Clear CUDA cache if available - reset_torch_cache() - - # XXX: in case we manually delete the model, like on server shutdown - self._timer.cancel() - - def timeout_reached(self): - if self.timeout_seconds > 0: - self.delete_model() - logger.warning( - f"{self._name_or_path} was removed from memory after {self.timeout_seconds} seconds of inactivity" - ) - - def is_deleted(self): - """Check if the instances have been deleted.""" - return not hasattr(self, "model") or self.model is None - class Serve: - # Defining a class to help with internal state but in practice it's just a method to call - # TODO: refactor into a proper module with helpers + 1 main method def __init__( self, + force_model: Annotated[str | None, typer.Argument(help="Model to preload and use for all requests.")] = None, + # Model options continuous_batching: Annotated[ - bool | None, - typer.Option(help="Whether to use continuous batching for chat completions. Configure with --cb-* flags."), - ] = None, + bool, + typer.Option(help="Enable continuous batching with paged attention. Configure with --cb-* flags."), + ] = False, cb_block_size: Annotated[ - int | None, - typer.Option(help="Size of each KV cache block in tokens for continuous batching. Default: 256."), + int | None, typer.Option(help="KV cache block size in tokens for continuous batching.") ] = None, cb_num_blocks: Annotated[ - int | None, - typer.Option( - help="Number of blocks in KV cache for continuous batching. Default: auto-inferred from GPU memory." - ), + int | None, typer.Option(help="Number of KV cache blocks for continuous batching.") ] = None, cb_max_batch_tokens: Annotated[ - int | None, - typer.Option( - help="Maximum number of tokens in a batch for continuous batching. Default: auto-inferred from GPU memory." - ), + int | None, typer.Option(help="Maximum tokens per batch for continuous batching.") ] = None, cb_max_memory_percent: Annotated[ - float | None, - typer.Option( - help="Maximum percentage of free GPU memory to use for KV cache in continuous batching (0.0-1.0). Default: 0.8." - ), + float | None, typer.Option(help="Max GPU memory fraction for KV cache (0.0-1.0).") ] = None, cb_use_cuda_graph: Annotated[ - bool | None, - typer.Option( - help="Enable CUDA graphs for continuous batching performance. Default: auto-inferred based on attention implementation." - ), + bool | None, typer.Option(help="Enable CUDA graphs for continuous batching.") ] = None, - device: Annotated[ - str, - typer.Option( - help="Device to use for inference; will default to `auto` and place the model on an accelerator if available." - ), - ] = "auto", - dtype: Annotated[ - str | None, - typer.Option( - help="Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype will be automatically derived from the model's weights." - ), - ] = "auto", - trust_remote_code: Annotated[ - bool, typer.Option(help="Whether to trust remote code when loading a model.") - ] = False, attn_implementation: Annotated[ - str | None, - typer.Option(help="Which attention implementation to use."), + str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).") ] = None, + compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False, quantization: Annotated[ - str | None, - typer.Option(help="Which quantization method to use. choices: 'bnb-4bit', 'bnb-8bit'"), + str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.") ] = None, - host: Annotated[str, typer.Option(help="Interface the server will listen to.")] = "localhost", - port: Annotated[int, typer.Option(help="Port the server will listen to.")] = 8000, + device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto", + dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto", + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False, model_timeout: Annotated[ - int, typer.Option(help="Time in seconds after which a model will be removed from memory.") + int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.") ] = 300, - log_level: Annotated[ - str, typer.Option(help="Logging level as a string. Example: 'info' or 'warning'.") - ] = "warning", - default_seed: Annotated[ - int | None, typer.Option(help="The default seed for torch, should be an integer.") - ] = None, - enable_cors: Annotated[ - bool, - typer.Option( - help="Whether to enable CORS. Some apps that make requests from external domains (e.g. Cursor) require CORS to be enabled." - ), - ] = False, - input_validation: Annotated[bool, typer.Option(help="Whether to turn on strict input validation.")] = False, - force_model: Annotated[ - str | None, - typer.Option( - help="Name of the model to be forced on all requests. This is useful for testing Apps that don't allow changing models in the request." - ), - ] = None, + # Server options + host: Annotated[str, typer.Option(help="Server listen address.")] = "localhost", + port: Annotated[int, typer.Option(help="Server listen port.")] = 8000, + enable_cors: Annotated[bool, typer.Option(help="Enable permissive CORS.")] = False, + log_level: Annotated[str, typer.Option(help="Logging level (e.g. 'info', 'warning').")] = "warning", + default_seed: Annotated[int | None, typer.Option(help="Default torch seed.")] = None, non_blocking: Annotated[ - bool, typer.Option(hidden=True, help="Whether to run the server in a separate thread.") + bool, typer.Option(hidden=True, help="Run server in a background thread. Used by tests.") ] = False, ) -> None: - if not serve_dependencies_available: - raise ImportError( - "Missing dependencies for the serving CLI. Please install with `pip install transformers[serving]`" - ) - - # Save input arguments - self.attn_implementation = attn_implementation - self.continuous_batching = continuous_batching - self.device = device - self.quantization = quantization + if not is_serve_available(): + raise ImportError("Missing dependencies for serving. Install with `pip install transformers[serving]`") - self.dtype = dtype - self.trust_remote_code = trust_remote_code - self.host = host - self.port = port - self.model_timeout = model_timeout - self.log_level = log_level - self.default_seed = default_seed - self.enable_cors = enable_cors - self.input_validation = input_validation - self.force_model = force_model - self.non_blocking = non_blocking + import uvicorn - # Continuous batching configuration arguments - self.cb_block_size = cb_block_size - self.cb_num_blocks = cb_num_blocks - self.cb_max_batch_tokens = cb_max_batch_tokens - self.cb_max_memory_percent = cb_max_memory_percent - self.cb_use_cuda_graph = cb_use_cuda_graph + from .serving.chat_completion import ChatCompletionHandler + from .serving.model_manager import ModelManager + from .serving.response import ResponseHandler + from .serving.server import build_server + from .serving.transcription import TranscriptionHandler + from .serving.utils import GenerationState # Seed if default_seed is not None: set_torch_seed(default_seed) - # Set up logging + # Logging transformers_logger = logging.get_logger("transformers") transformers_logger.setLevel(logging.log_levels[log_level.lower()]) - cb_logger = logging.get_logger("transformers.generation.continuous_batching") - cb_logger.setLevel(logging.log_levels[log_level.lower()]) - - # Internal state: - # 1. Tracks models in memory, to prevent reloading the model unnecessarily - self.loaded_models: dict[str, TimedModel] = {} - self.running_continuous_batching_manager: ContinuousBatchingManager | None = None - - # Tracks in-flight model loads for fan-out to multiple SSE subscribers - self.loading_subscribers: dict[str, list[asyncio.Queue[str | None]]] = {} - self.loading_tasks: dict[str, asyncio.Task] = {} - - # Thread-safety for load_model_and_processor / load_audio_model_and_processor - self.model_locks: dict[str, threading.Lock] = {} - self.model_locks_guard = threading.Lock() - # Thread-safety for continuous batching manager init/teardown - self._cb_manager_lock = threading.Lock() - - # 2. preserves information about the last call and last KV cache, to determine whether we can reuse the KV - # cache and avoid re-running prefill - self.last_messages = None - self.last_kv_cache = None - self.last_model = None - - if self.model_timeout is None: - self.model_timeout = -1 if self.force_model else 300 - - if self.force_model: - model_id_and_revision = self.process_model_name(self.force_model) - self.last_model = model_id_and_revision - self.load_model_and_processor(model_id_and_revision) - - @asynccontextmanager - async def lifespan(app: FastAPI): - yield - self.reset_loaded_models() - - app = FastAPI(lifespan=lifespan) - - # Some apps that make requests from external domains (e.g. Cursor) require CORS to be enabled. However, for - # security purposes, it's disabled by default - if self.enable_cors: - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - logger.warning_once( - "CORS allow origin is set to `*`. This is not recommended for production environments." - ) - - from fastapi import Request - - @app.post("/v1/chat/completions") - def chat_completion(request: Request, body: dict): - self.validate_chat_completion_request(request=body) - - logger.warning(f"[Request received] Model: {body['model']}, CB: {self.continuous_batching}") - - if self.continuous_batching: - return self.continuous_batching_chat_completion(body, request.state.request_id) - else: - return self.generate_chat_completion(body) - - @app.post("/v1/responses") - def responses(request: dict): - self.validate_response_request(request=request) - - logger.warning(f"[Request received] Model: {request['model']}, CB: {self.continuous_batching}") - - # Support non-streaming mode when `stream=false` is provided - stream = request.get("stream", True) - if not stream: - response_obj = self.generate_response_non_streaming(request) - return JSONResponse(response_obj) - - output = self.generate_response(request) - return StreamingResponse(output, media_type="text/event-stream") - - @app.post("/v1/audio/transcriptions") - async def audio_transcriptions(request: Request): - # Parses the multipart/form-data request into the request format used by other endpoints - async with request.form() as form: - parsed_request = TransformersTranscriptionCreateParams( - file=await form["file"].read(), - model=form["model"], - # TODO: add other fields - ) - logger.debug( - f"Received file: {form['file'].filename}; MIME type: {form['file'].content_type}; " - f"size: {form['file'].size / 1024:.2f} KiB" - ) - self.validate_transcription_request(request=parsed_request) - - output = self.generate_transcription(parsed_request) - return StreamingResponse(output, media_type="text/event-stream") - - @app.options("/v1/models") - @app.get("/v1/models") - def get_all_models(): - return JSONResponse({"object": "list", "data": self.get_gen_models()}) - - @app.get("/health") - def healthcheck(): - return JSONResponse({"status": "ok"}) - - @app.post("/load_model") - async def load_model(body: dict): - model = body.get("model") - if model is None: - raise HTTPException(status_code=422, detail="Missing `model` field in the request body.") - - model_id_and_revision = self.process_model_name(model) - - async def event_publisher(): - queue: asyncio.Queue[str | None] = asyncio.Queue() - model_loaded = model_id_and_revision in self.loaded_models - model_not_deleted = model_loaded and not self.loaded_models[model_id_and_revision].is_deleted() - - # Case 1: Model already cached - if model_loaded and model_not_deleted: - self.loaded_models[model_id_and_revision].reset_timer() - yield f"data: {json.dumps({'status': 'ready', 'model': model_id_and_revision, 'cached': True})}\n\n" - return - - # Case 2: Load already happening, join existing subscribers - if model_id_and_revision in self.loading_tasks: - self.loading_subscribers[model_id_and_revision].append(queue) - - while True: - item = await queue.get() - if item is None: - break - yield item - return - - # Case 3: First request — start the load task - self.loading_subscribers[model_id_and_revision] = [queue] - loop = asyncio.get_running_loop() - - def enqueue(payload: dict): - msg = f"data: {json.dumps(payload)}\n\n" - - def broadcast(): - for q in self.loading_subscribers.get(model_id_and_revision, []): - q.put_nowait(msg) - - loop.call_soon_threadsafe(broadcast) - - download_aggregator = DownloadAggregator(enqueue, model_id_and_revision) - - def streaming_tqdm_hook(factory, args, kwargs): - bar = factory(*args, **kwargs) - desc = kwargs.get("desc") or getattr(bar, "desc", None) or "" - unit = kwargs.get("unit") or getattr(bar, "unit", "it") - total = getattr(bar, "total", kwargs.get("total")) - - # Only forward byte-progress bars (file downloads) — skip noise like "Fetching N files" - if unit == "B": - download_aggregator.register(id(bar), total) - return DownloadProxy(bar, download_aggregator=download_aggregator) - - # "Loading weights" bar — emit as stage: "weights" with item-count progress - if desc == "Loading weights": - return WeightsProxy(bar, enqueue, model_id_and_revision) - - # Other non-byte, non-weights bars (noise) — return unmodified - return bar - - async def run_load(): - previous_hook = logging.set_tqdm_hook(streaming_tqdm_hook) - try: - await asyncio.to_thread(self.load_model_and_processor, model_id_and_revision, enqueue) - except Exception as e: - logger.error(f"Failed to load {model_id_and_revision}: {e}", exc_info=True) - enqueue({"status": "error", "model": model_id_and_revision, "message": str(e)}) - finally: - logging.set_tqdm_hook(previous_hook) - - def _send_sentinel(): - for q in self.loading_subscribers.pop(model_id_and_revision, []): - q.put_nowait(None) - self.loading_tasks.pop(model_id_and_revision, None) - - loop.call_soon_threadsafe(_send_sentinel) + self._model_manager = ModelManager( + device=device, + dtype=dtype, + trust_remote_code=trust_remote_code, + attn_implementation=attn_implementation, + quantization=quantization, + model_timeout=model_timeout, + force_model=force_model, + ) + from transformers import ContinuousBatchingConfig + + cb_kwargs = { + k: v + for k, v in { + "block_size": cb_block_size, + "num_blocks": cb_num_blocks, + "max_batch_tokens": cb_max_batch_tokens, + "max_memory_percent": cb_max_memory_percent, + "use_cuda_graph": cb_use_cuda_graph, + }.items() + if v is not None + } + cb_config = ContinuousBatchingConfig(**cb_kwargs) if cb_kwargs else None + self._generation_state = GenerationState( + continuous_batching=continuous_batching, + compile=compile, + cb_config=cb_config, + ) - self.loading_tasks[model_id_and_revision] = asyncio.create_task(run_load()) + self._chat_handler = ChatCompletionHandler( + model_manager=self._model_manager, + generation_state=self._generation_state, + ) - while True: - item = await queue.get() - if item is None: - break - yield item + self._response_handler = ResponseHandler( + model_manager=self._model_manager, + generation_state=self._generation_state, + ) - return StreamingResponse(event_publisher(), media_type="text/event-stream") + self._transcription_handler = TranscriptionHandler(self._model_manager, self._generation_state) - @app.middleware("http") - async def get_or_set_request_id(request: Request, call_next): - request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) - request.state.request_id = request_id - response = await call_next(request) - response.headers[X_REQUEST_ID] = request_id - return response + app = build_server( + self._model_manager, + self._chat_handler, + response_handler=self._response_handler, + transcription_handler=self._transcription_handler, + enable_cors=enable_cors, + ) - config = uvicorn.Config(app, host=self.host, port=self.port, log_level="info") + config = uvicorn.Config(app, host=host, port=port, log_level="info") self.server = uvicorn.Server(config) - if self.non_blocking: + if non_blocking: self.start_server() else: self.server.run() def start_server(self): def _run(): - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - # serve() is a coroutine; it exits when server.should_exit becomes True - self._loop.run_until_complete(self.server.serve()) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.server.serve()) self._thread = threading.Thread(target=_run, name="uvicorn-thread", daemon=False) self._thread.start() - def kill_server(self): - if not self._thread: - raise ValueError("The server cannot be killed as it was not launched in a separate thread.") - - if not self._thread.is_alive(): - raise ValueError("The server is already killed.") - - self.server.should_exit = True - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=2) - def reset_loaded_models(self): - """ - Resets all loaded models. - """ - if self.running_continuous_batching_manager is not None: - logger.warning("Resetting the continuous batching manager.") - self.running_continuous_batching_manager.stop(block=True, timeout=2) - self.running_continuous_batching_manager = None - for model in list(self.loaded_models.values()): - model.delete_model() - self.last_model = None - - def _validate_request( - self, - request: dict, - schema: TypedDict, - validator: "TypeAdapter", - unused_fields: set, - ): - """ - Validates the request against the schema, and checks for unexpected keys. - - Args: - request (`dict`): - The request to validate. - schema (`TypedDict`): - The schema of the request to validate. It is a `TypedDict` definition. - validator (`TypeAdapter`): - The validator to use to validate the request. Built from `schema`. - unused_fields (`set`): - Fields accepted by `schema`, but not used in `transformers serve`. - - Raises: - HTTPException: If the request is invalid or contains unexpected or unused fields. - """ - logger.debug(f"Validating request: {request}") - - # Validate unexpected keys -- Pydantic doesn't validate extra keys in the request. - input_keys = set(request.keys()) - possible_keys = schema.__mutable_keys__ - unexpected_keys = input_keys - possible_keys - if unexpected_keys: - logger.error(f"Unexpected keys in the request: {unexpected_keys}") - raise HTTPException(status_code=422, detail=f"Unexpected keys in the request: {unexpected_keys}") - - if self.input_validation: - # Validate expected keys - try: - validator.validate_python(request) - except ValidationError as e: - logger.error(f"Validation error: {e.errors()}") - raise HTTPException(status_code=422, detail=e.errors()) - - # Validate unused fields - unused_fields_in_request = input_keys & unused_fields - if unused_fields_in_request: - logger.error(f"Unused fields in the request: {unused_fields_in_request}") - raise HTTPException( - status_code=422, detail=f"Unused fields in the request: {unused_fields_in_request}" - ) - - def validate_response_request(self, request: dict): - self._validate_request( - request=request, - schema=TransformersResponseCreateParamsStreaming, - validator=response_validator, - unused_fields=UNUSED_RESPONSE_FIELDS, - ) - - def validate_chat_completion_request(self, request: dict): - self._validate_request( - request=request, - schema=TransformersCompletionCreateParamsStreaming, - validator=completion_validator, - unused_fields=UNUSED_CHAT_COMPLETION_FIELDS, - ) - - def validate_transcription_request(self, request: dict): - self._validate_request( - request=request, - schema=TransformersTranscriptionCreateParams, - validator=transcription_validator, - unused_fields=UNUSED_TRANSCRIPTION_FIELDS, - ) - - def build_chat_completion_chunk( - self, - request_id: str = "", - content: int | None = None, - model: str | None = None, - role: str | None = None, - finish_reason: str | None = None, - tool_calls: list["ChoiceDeltaToolCall"] | None = None, - decode_stream: DecodeStream | None = None, - tokenizer: Optional["PreTrainedTokenizerFast"] = None, - ) -> "ChatCompletionChunk": - """ - Builds a chunk of a streaming OpenAI Chat Completion response. - - IMPORTANT: The serialized chunk won't contain empty fields (fields with `None`). Some downstream apps, - like Cursor, assume that when the field exists, it has data. - - Args: - request_id (`str`): - The request ID. - content (`str`, *optional*): - Content of the response from the model. - model (`str`, *optional*): - The model that generated the content. - role (`str`, *optional*): - The role of the next content, until a new role is defined. - finish_reason (`str`, *optional*): - The reason the generation by the model has finished. - tool_calls (`list[ChoiceDeltaToolCall]`, *optional*): - Data about the tool calls, when they are triggered. - - Returns: - `str`: The built chunk, a string containing a JSON string with the payload. - """ - if decode_stream is not None and content is not None and tokenizer is not None: - content = decode_stream.step(tokenizer._tokenizer, content) - - chunk = ChatCompletionChunk( - id=request_id, - created=int(time.time()), - model=model, - choices=[ - ChoiceChunk( - delta=ChoiceDelta( - content=content, - role=role, - tool_calls=tool_calls, - ), - index=0, - finish_reason=finish_reason, - ) - ], - system_fingerprint="", - object="chat.completion.chunk", - ) - - return chunk - - @staticmethod - def chunk_to_sse_element(chunk: "ChatCompletionChunk | BaseModel") -> str: - """ - Builds an event of a streaming OpenAI Response model or a ChatCompletion chunk. - - IMPORTANT: The serialized chunk won't contain empty fields (fields with `None`). Some downstream apps, - like Cursor, assume that when the field exists, it has data. - - Args: - chunk (`BaseModel` or `ChatCompletionChunk`): - The response to build an event from. One of the multiple OpenAI Response output types - - Returns: - `str`: The built chunk, a string containing a JSON string with the payload. - """ - if isinstance(chunk, str): - # Error paths may yield pre-formatted strings — pass them through as-is. - return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n" - return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" - - @staticmethod - @lru_cache - def get_gen_models(cache_dir: str | None = None) -> list[dict[str, any]]: - """ - List LLMs and VLMs in the cache. - """ - from transformers.models.auto.modeling_auto import ( - MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, - MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, - ) - - generative_models = [] - - logger.warning("Scanning the cache directory for LLMs and VLMs.") - for repo in tqdm(scan_cache_dir(cache_dir).repos): - if repo.repo_type != "model": - continue - - refs = repo.refs - for ref, revision_info in refs.items(): - files = revision_info.files - config_path = next((f.file_path for f in files if f.file_name == "config.json"), None) - - if not config_path: - continue - - config = json.loads(config_path.open().read()) - - if not (isinstance(config, dict) and "architectures" in config): - continue - - architectures = config["architectures"] - llms = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values() - vlms = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() - - if any(arch for arch in architectures if arch in [*llms, *vlms]): - author = repo.repo_id.split("/") if "/" in repo.repo_id else "" - repo_handle = repo.repo_id + (f"@{ref}" if ref != "main" else "") - generative_models.append( - { - "owned_by": author, - "id": repo_handle, - "object": "model", - "created": repo.last_modified, - } - ) - - return generative_models - - def continuous_batching_chat_completion(self, req: dict, request_id: str) -> "StreamingResponse | JSONResponse": - """ - Generates an OpenAI Chat Completion using continuous batching. - - Args: - req (`dict`): The request to generate an OpenAI Chat Completion for. - - Returns: - `Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks. - """ - - model_id_and_revision = self.process_model_name(req["model"]) + """Clear all loaded models from memory.""" + self._model_manager.shutdown() - with self._cb_manager_lock: - must_discard_cache = model_id_and_revision != self.last_model - self.last_model = model_id_and_revision - - # When switching models, terminate a continuous batching manager if it is running. - if must_discard_cache: - if self.running_continuous_batching_manager is not None: - self.running_continuous_batching_manager.stop(block=True, timeout=2) - self.running_continuous_batching_manager = None - - model, processor = self.load_model_and_processor(model_id_and_revision) - - # Continuous batching only supports text-only models - if self.get_model_modality(model, processor=processor) != Modality.LLM: - logger.warning_once( - "Continuous batching is not supported for non-text-only models. Falling back to regular generate." - ) - return self.generate_chat_completion(req) - - tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor - - generation_config = create_generation_config_from_req( - req, - model_generation_config=model.generation_config, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - use_cache=False, - do_sample=False, - scheduler="fifo", - ) - - if self.running_continuous_batching_manager is None: - from transformers import ContinuousBatchingConfig - - # Build continuous batching config from CLI arguments - cb_config_kwargs = {} - if self.cb_block_size is not None: - cb_config_kwargs["block_size"] = self.cb_block_size - if self.cb_num_blocks is not None: - cb_config_kwargs["num_blocks"] = self.cb_num_blocks - if self.cb_max_batch_tokens is not None: - cb_config_kwargs["max_batch_tokens"] = self.cb_max_batch_tokens - if self.cb_max_memory_percent is not None: - cb_config_kwargs["max_memory_percent"] = self.cb_max_memory_percent - if self.cb_use_cuda_graph is not None: - cb_config_kwargs["use_cuda_graph"] = self.cb_use_cuda_graph - - cb_config = ContinuousBatchingConfig(**cb_config_kwargs) if cb_config_kwargs else None - - self.running_continuous_batching_manager = model.init_continuous_batching( - generation_config=generation_config, - continuous_batching_config=cb_config, - ) - - # TODO (Joao, Lysandre): the logits processors should be fixed in continuous batching and correctly applied in non-cb - self.running_continuous_batching_manager.logit_processor = LogitsProcessorList() - self.running_continuous_batching_manager.start() - - # TODO (Joao, Lysandre): this should also work with tool support - modality = self.get_model_modality(model, processor=processor) - processor_inputs = self.get_processor_inputs_from_inbound_messages(req["messages"], modality) - - inputs = processor.apply_chat_template( - processor_inputs, - add_generation_prompt=True, - tools=req.get("tools"), - return_tensors="pt", - return_dict=True, - tokenize=True, - ) - inputs = inputs["input_ids"][0].to(model.device) - - def stream_chat_completion(request_id, decode_stream): - from ..generation.continuous_batching import RequestStatus - - try: - # Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit - # they come from the assistant. - yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision) - - n_tokens_generated = 0 - for result in self.running_continuous_batching_manager.request_id_iter(request_id): - n_tokens_generated += 1 - - # Always yield the token content (even for the final FINISHED token) - if result.generated_tokens: - token_id = result.generated_tokens[-1] - yield self.build_chat_completion_chunk( - request_id=request_id, - content=token_id, - model=model_id_and_revision, - decode_stream=decode_stream, - tokenizer=tokenizer, - ) - - if result.status == RequestStatus.FINISHED: - generated_all_tokens = ( - generation_config.max_new_tokens is not None - and n_tokens_generated >= generation_config.max_new_tokens - ) - - # If the tokenizer has an eos_token, we can have a more robust check. - if hasattr(tokenizer, "eos_token"): - final_token_is_eos = result == tokenizer.eos_token - generated_all_tokens = generated_all_tokens and not final_token_is_eos - - reason = "length" if generated_all_tokens else "stop" - - yield self.build_chat_completion_chunk( - request_id, - finish_reason=reason, - model=model_id_and_revision, - ) - break - - except Exception as e: - logger.error(str(e)) - self.running_continuous_batching_manager.cancel_request(request_id) - yield f'data: {{"error": "{str(e)}"}}' - - def buffer_chat_completion(_request_id): - result = None - while self.running_continuous_batching_manager.is_running() and result is None: - result = self.running_continuous_batching_manager.get_result(request_id=_request_id, timeout=1) - - if result is None: - raise RuntimeError(f"Request {_request_id} failed: generation loop stopped before producing a result.") - - content = tokenizer.decode(result.generated_tokens) - - chat_completion_result = ChatCompletion( - id=_request_id, - created=int(time.time()), - object="chat.completion", - model=model_id_and_revision, - choices=[ - Choice( - # TODO check the index - index=0, - message=ChatCompletionMessage(content=content, role="assistant"), - finish_reason="stop", - ) - ], - # TODO implement function calling - # TODO implement usage - ) - - return chat_completion_result - - async def cancellation_wrapper_stream(_request_id): - # Enables cancellation in an async context - try: - decode_stream = DecodeStream(inputs.tolist(), False) - for _chunk in stream_chat_completion(_request_id, decode_stream): - yield self.chunk_to_sse_element(_chunk) - await asyncio.sleep(0) - except asyncio.CancelledError: - self.running_continuous_batching_manager.cancel_request(_request_id) - logger.warning(f"Request {_request_id} was cancelled.") - - def cancellation_wrapper_buffer(_request_id): - # Enables cancellation in an async context - try: - return buffer_chat_completion(_request_id) - except asyncio.CancelledError: - self.running_continuous_batching_manager.cancel_request(_request_id) - logger.warning(f"Request {_request_id} was cancelled.") - - request_id = self.running_continuous_batching_manager.add_request( - inputs, request_id=request_id, max_new_tokens=generation_config.max_new_tokens, streaming=req.get("stream") - ) - - if req.get("stream"): - return StreamingResponse(cancellation_wrapper_stream(request_id), media_type="text/event-stream") - else: - chunk = cancellation_wrapper_buffer(request_id) - json_chunk = chunk.model_dump(exclude_none=True) - return JSONResponse(json_chunk, media_type="application/json") - - @staticmethod - def get_model_modality(model: "PreTrainedModel", processor=None) -> Modality: - if processor is not None: - if isinstance(processor, PreTrainedTokenizerBase): - return Modality.LLM - - from transformers.models.auto.modeling_auto import ( - MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, - MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, - ) - - model_classname = model.__class__.__name__ - if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values(): - modality = Modality.VLM - elif model_classname in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): - modality = Modality.LLM - else: - raise ValueError(f"Unknown modality: {model_classname}") - - return modality - - @staticmethod - def get_processor_inputs_from_inbound_messages(messages, modality: Modality): - processor_inputs = [] - - for message in messages: - parsed_message = {"role": message["role"], "content": []} - - if modality == Modality.LLM: - # Input: `content` is a string or a list of dictionaries with a "text" key. - # Output: `content` is a string. - if isinstance(message["content"], str): - parsed_content = message["content"] - elif isinstance(message["content"], list): - parsed_content = [] - for content in message["content"]: - if content["type"] == "text": - parsed_content.append(content["text"]) - parsed_content = " ".join(parsed_content) - parsed_message["content"] = parsed_content - - elif modality == Modality.VLM: - # Input: `content` is a string or a list of dictionaries with a "type" key (possible types: "text", - # "image_url"). - # Output: `content` is a list of dictionaries with a "type" key - if isinstance(message["content"], str): - parsed_message["content"].append({"type": "text", "text": message["content"]}) - else: - for content in message["content"]: - if content["type"] == "text": - parsed_message["content"].append(content) - elif content["type"] == "image_url": - if "base64" in content["image_url"]["url"]: - image_data = re.sub("^data:image/.+;base64,", "", content["image_url"]["url"]) - image = Image.open(BytesIO(base64.b64decode(image_data))) - - file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) - url = file.name - - image.save(file.name) - else: - url = content["image_url"]["url"] - - parsed_message["content"].append({"type": "image", "url": url}) - processor_inputs.append(parsed_message) - return processor_inputs - - def generate_chat_completion(self, req: dict) -> "StreamingResponse | JSONResponse": - """ - Generates an OpenAI Chat Completion using `generate`. - - Args: - req (`dict`): The request to generate an OpenAI Chat Completion for. - - Returns: - `Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks. - """ - - # TODO: This should throw an error in case the specified model in the request is different to the forced model. - if self.force_model is not None: - req["model"] = self.force_model - - messages: Iterable[ChatCompletionMessageParam] = req["messages"] - - # HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a - # request whose last message is from the assistant. - if messages[-1]["role"] == "assistant": + def kill_server(self): + self._generation_state.shutdown() + self._model_manager.shutdown() + if not self._thread or not self._thread.is_alive(): return - - model_id_and_revision = self.process_model_name(req["model"]) - must_discard_cache = model_id_and_revision != self.last_model - - self.last_model = model_id_and_revision - model, processor = self.load_model_and_processor(model_id_and_revision) - - modality = self.get_model_modality(model, processor=processor) - processor_inputs = self.get_processor_inputs_from_inbound_messages(messages, modality) - - # ====== TOOL PREPROCESSING LOGIC ====== - tool_model_family = None - for supported_model_families in _MODELS_WITH_TOOL_SUPPORT: - if supported_model_families in model.config.architectures[0].lower(): - tool_model_family = supported_model_families - break - # TODO: trigger 2 constrained generations after the tool call start token is emitted: - # 1. force generation to pick from the tool names - # 2. force generation to pick from that tool's arguments - # ====== END OF TOOL PREPROCESSING LOGIC ====== - - inputs = processor.apply_chat_template( - processor_inputs, - add_generation_prompt=True, - tools=req.get("tools"), - return_tensors="pt", - return_dict=True, - tokenize=True, - ) - inputs = inputs.to(model.device) - request_id = req.get("request_id", "req_0") - - # Temporary hack for GPTOSS 1: don't filter special tokens - skip_special_tokens = True - if "gptoss" in model.config.architectures[0].lower(): - skip_special_tokens = False - - generation_streamer = TextIteratorStreamer( - processor, - skip_special_tokens=skip_special_tokens, - skip_prompt=True, - ) - - if self.is_continuation(req) and not must_discard_cache: - seq_len = self.last_kv_cache.get_seq_length() - if inputs["input_ids"].shape[-1] > seq_len: - last_kv_cache = self.last_kv_cache - else: - last_kv_cache = None - else: - seq_len = inputs["input_ids"].shape[-1] - last_kv_cache = None - - generation_config = create_generation_config_from_req( - req, - model_generation_config=model.generation_config, - ) - - generation_kwargs = { - **inputs, - "streamer": generation_streamer, - "generation_config": generation_config, - "return_dict_in_generate": True, - "past_key_values": last_kv_cache, - } - - def stream_chat_completion(streamer, _request_id): - # Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output - # classes and piping the reasoning trace into a new field - filter_cot = False - cot_trace_end = None - if "gptoss" in model.config.architectures[0].lower(): - filter_cot = True - cot_trace_end = "<|channel|>final<|message|>" - - # Thin wrapper to save the KV cache after generation - def generate_with_cache(**kwargs): - generate_output = model.generate(**kwargs) - self.last_kv_cache = generate_output.past_key_values - - thread = Thread(target=generate_with_cache, kwargs=generation_kwargs) - results = "" - - try: - thread.start() - tool_state = ToolState() - - # Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit - # they come from the assistant. - yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision) - - result = "" - n_tokens_generated = 0 - - for result in streamer: - n_tokens_generated += 1 - - # Temporary hack for GPT-OSS 3: don't emit the final "<|return|>" - if "gptoss" in model.config.architectures[0].lower(): - result = result.removesuffix("<|return|>") - results += result - - # (related to temporary hack 2) - if filter_cot: - if cot_trace_end in results: # end of reasoning trace observed -> stop filtering - filter_cot = False - continue - else: - continue - - # ====== TOOL CALL LOGIC ====== - if tool_model_family is not None: - # Start of a tool call: reset state variables, set `inside_tool_call` - if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]: - tool_state.inside_tool_call = True - continue - - # End of tool call: reset `inside_tool_call`, emit a `finish_reason` - if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]: - tool_state.reset() - yield self.build_chat_completion_chunk( - request_id=_request_id, - role=None, - finish_reason="tool_calls", - model=model_id_and_revision, - ) - - continue - # Inside a tool call - if tool_state.inside_tool_call: - tool_state.buffer += result - - # First step: extract the tool name (may need several tokens, and we can't emit a delta - # until we have the full name) - if not tool_state.has_tool_name_defined: - tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer) - if tool_name is None: - continue - else: - tool_name = tool_name.group(1) - tool_state.has_tool_name_defined = True - tool = ChoiceDeltaToolCall( - function=ChoiceDeltaToolCallFunction(name=tool_name), - index=0, - type="function", - id=_request_id + "_tool_call", # Only the first tool call delta has an id - ) - - # Second step: extract tool arguments. The tool arguments can be seen as a json string - # within the tool json string. We emit a delta for the arguments. - else: - # Empty text: skip - if result == "": - continue - # Until we see the `"arguments": {` in the buffer, we skip - # TODO: other models will likely need more elaborate processing here - if '"arguments": {' not in tool_state.buffer: - continue - - # Handle nesting. We want to exclude the last } from the emitted arguments (it's - # closing the outermost nesting level, outside the arguments block) - tool_state.arg_nesting_level += result.count("{") - tool_state.arg_nesting_level -= result.count("}") - if tool_state.arg_nesting_level < 0: - result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}" - - tool = ChoiceDeltaToolCall( - function=ChoiceDeltaToolCallFunction(arguments=result), - index=0, - type="function", - ) - - yield self.build_chat_completion_chunk( - request_id=_request_id, - role=None, - tool_calls=[tool], - model=model_id_and_revision, - ) - continue - # ====== END OF TOOL CALL LOGIC ====== - - # All non-tool related tokens are emitted as assistant messages. Empty text is skipped. - if result != "": - yield self.build_chat_completion_chunk( - _request_id, content=result, model=model_id_and_revision - ) - - generated_all_tokens = ( - generation_config.max_new_tokens is not None - and n_tokens_generated >= generation_config.max_new_tokens - ) - - # If the tokenizer has an eos_token, we can have a more robust check. - if hasattr(streamer.tokenizer, "eos_token"): - final_token_is_eos = result == streamer.tokenizer.eos_token - generated_all_tokens = generated_all_tokens and not final_token_is_eos - - reason = "length" if generated_all_tokens else "stop" - - yield self.build_chat_completion_chunk(_request_id, finish_reason=reason, model=model_id_and_revision) - - thread.join() - except Exception as e: - logger.error(str(e)) - yield f'data: {{"error": "{str(e)}"}}' - - finally: - thread.join() - - if req.get("stream"): - return StreamingResponse( - map(self.chunk_to_sse_element, stream_chat_completion(generation_streamer, request_id)), - media_type="text/event-stream", - ) - else: - content = [] - finish_reason = "stop" - - generator = stream_chat_completion(generation_streamer, request_id) - usage = None - - for chunk in generator: - choice = chunk.choices[0] - if getattr(choice.delta, "content", None): - content.append(choice.delta.content) - if choice.finish_reason: - finish_reason = choice.finish_reason - if getattr(chunk, "usage", None): - usage = chunk.usage - - chat_completion_result = ChatCompletion( - id=request_id, - created=int(time.time()), - object="chat.completion", - model=model_id_and_revision, - choices=[ - Choice( - # TODO check the index - index=0, - message=ChatCompletionMessage(content="".join(content), role="assistant"), - finish_reason=finish_reason, - ) - ], - # TODO implement function calling - usage=usage, - ) - - result = chat_completion_result.model_dump(exclude_none=True) - - return JSONResponse(result, media_type="application/json") - - def generate_response(self, req: dict) -> Generator[str, None, None]: - """ - Generates an OpenAI Response using `generate`. - - Args: - req (`dict`): The request to generate an OpenAI Response for. - - Returns: - `Generator[str, None, None]`: A generator that yields the OpenAI Response events. - """ - # TODO -- Implement non-streaming mode - model_id_and_revision = self.process_model_name(req["model"]) - must_discard_cache = model_id_and_revision != self.last_model - self.last_model = model_id_and_revision - model, processor = self.load_model_and_processor(model_id_and_revision) - - if isinstance(req["input"], str): - inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] - inputs.append({"role": "user", "content": req["input"]}) - elif isinstance(req["input"], list): - if "instructions" in req: - if req["input"][0]["role"] != "system": - inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]] - else: - inputs = req["input"] - inputs[0]["content"] = req["instructions"] - else: - inputs = req["input"] - elif isinstance(req["input"], dict): - inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] - inputs.append(req["input"]) - else: - raise TypeError("inputs should be a list, dict, or str") - - inputs = processor.apply_chat_template(inputs, tokenize=True, add_generation_prompt=True, return_tensors="pt") - inputs = inputs.to(model.device) - request_id = req.get("previous_response_id", "req_0") - - # Temporary hack for GPT-OSS 1: don't filter special tokens - skip_special_tokens = True - if "gptoss" in model.config.architectures[0].lower(): - skip_special_tokens = False - - generation_streamer = TextIteratorStreamer( - processor, - skip_special_tokens=skip_special_tokens, - skip_prompt=True, - ) - generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) - - last_kv_cache = None - if self.is_continuation(req) and not must_discard_cache: - seq_len = self.last_kv_cache.get_seq_length() - if inputs.shape[-1] > seq_len: - last_kv_cache = self.last_kv_cache - - generation_kwargs = { - "inputs": inputs, - "attention_mask": torch_ones_like(inputs), - "streamer": generation_streamer, - "generation_config": generation_config, - "return_dict_in_generate": True, - "past_key_values": last_kv_cache, - } - - def stream_response(streamer, _request_id): - # Temporary hack for GPT-OSS 2: filter out the CoT tokens. Full solution here implies defining new output - # classes and piping the reasoning trace into a new field - filter_cot = False - cot_trace_end = None - if "gptoss" in model.config.architectures[0].lower(): - filter_cot = True - cot_trace_end = "<|channel|>final<|message|>" - - # Thin wrapper to save the KV cache after generation - def generate_with_cache(**kwargs): - generate_output = model.generate(**kwargs) - self.last_kv_cache = generate_output.past_key_values - - thread = Thread(target=generate_with_cache, kwargs=generation_kwargs) - sequence_number = 0 - output_index = 0 - content_index = 0 - - try: - thread.start() - created_at = time.time() # the spec expects a unix timestamp in seconds - - # We start by acknowledging the request (the request has `status="queued"`), and then by moving it to - # in progress (`status="in_progress"`) - response_created = ResponseCreatedEvent( - type="response.created", - sequence_number=sequence_number, - response=Response( - id=f"resp_{request_id}", - created_at=created_at, - status="queued", - model=model_id_and_revision, - instructions=req.get("instructions"), - text={"format": {"type": "text"}}, - object="response", - tools=[], - output=[], - parallel_tool_calls=req.get("parallel_tool_calls", False), - tool_choice="auto", - metadata=req.get("metadata"), - ), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_created) - - response_in_progress = ResponseInProgressEvent( - type="response.in_progress", - sequence_number=sequence_number, - response=Response( - id=f"resp_{request_id}", - created_at=created_at, - status="in_progress", - model=model_id_and_revision, - instructions=req.get("instructions"), - text={"format": {"type": "text"}}, - object="response", - tools=[], - output=[], - parallel_tool_calls=req.get("parallel_tool_calls", False), - tool_choice="auto", - metadata=req.get("metadata"), - ), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_in_progress) - - # Start the output item. Emit the assistant role to start the stream. Other chunks won't have a role, - # as it is implicit - response_output_item_added = ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=sequence_number, - output_index=output_index, - item=ResponseOutputMessage( - id=f"msg_{request_id}", type="message", status="in_progress", role="assistant", content=[] - ), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_output_item_added) - - # Start the content part of the event - response_content_part_added = ResponseContentPartAddedEvent( - type="response.content_part.added", - item_id=f"msg_{request_id}", - sequence_number=sequence_number, - output_index=output_index, - content_index=content_index, - part=ResponseOutputText(type="output_text", text="", annotations=[]), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_content_part_added) - - # Stream the actual generated text - results = "" - for result in streamer: - # Temporary hack for GPTOS 3: don't emit the final "<|return|>" - if "gptoss" in model.config.architectures[0].lower(): - result = result.removesuffix("<|return|>") - results += result - - # (related to temporary hack 2) - if filter_cot: - if cot_trace_end in results: # end of reasoning trace observed -> stop filtering - filter_cot = False - results = "" # reset the results -> results will now track the final response - continue - else: - response_output_text_delta = ResponseTextDeltaEvent( - type="response.output_text.delta", - item_id=f"msg_{request_id}", - sequence_number=sequence_number, - output_index=output_index, - content_index=content_index, - delta=result, - logprobs=[], - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_output_text_delta) - else: - # Normal path: emit token deltas when not filtering CoT - if result: - response_output_text_delta = ResponseTextDeltaEvent( - type="response.output_text.delta", - item_id=f"msg_{request_id}", - sequence_number=sequence_number, - output_index=output_index, - content_index=content_index, - delta=result, - logprobs=[], - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_output_text_delta) - - # Signal the end of the text generation - response_output_text_done = ResponseTextDoneEvent( - type="response.output_text.done", - item_id=f"msg_{request_id}", - sequence_number=sequence_number, - output_index=output_index, - content_index=0, - text=results, - logprobs=[], - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_output_text_done) - - # Complete the content part - response_content_part_done = ResponseContentPartDoneEvent( - type="response.content_part.done", - item_id=f"msg_{request_id}", - sequence_number=sequence_number, - output_index=output_index, - content_index=content_index, - part=ResponseOutputText(type="output_text", text=response_output_text_done.text, annotations=[]), - ) - sequence_number += 1 - content_index += 1 - yield self.chunk_to_sse_element(response_content_part_done) - - # Complete the output item - response_output_item_done = ResponseOutputItemDoneEvent( - type="response.output_item.done", - sequence_number=sequence_number, - output_index=output_index, - item=ResponseOutputMessage( - id=f"msg_{request_id}", - type="message", - status="completed", - role="assistant", - content=[response_content_part_done.part], - annotations=[], - ), - ) - sequence_number += 1 - output_index += 1 - yield self.chunk_to_sse_element(response_output_item_done) - - # Finally, Complete the event - response_completed = ResponseCompletedEvent( - type="response.completed", - sequence_number=sequence_number, - response=Response( - id=f"resp_{request_id}", - created_at=created_at, - status="completed", - model=model_id_and_revision, - instructions=req.get("instructions"), - text={"format": {"type": "text"}}, - output=[response_output_item_done.item], - object="response", - tools=[], - parallel_tool_calls=req.get("parallel_tool_calls", False), - tool_choice="auto", - metadata=req.get("metadata"), - ), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_completed) - - thread.join() - except Exception as e: - logger.error(f"Exception in response generation: {str(e)}") - error_event = ResponseErrorEvent( - type="error", - sequence_number=sequence_number, - message=str(e), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(error_event) - - response_failed = ResponseFailedEvent( - type="response.failed", - sequence_number=sequence_number, - response=Response( - id=f"resp_{request_id}", - created_at=created_at, - status="failed", - model=model_id_and_revision, - instructions=req.get("instructions"), - text={"format": {"type": "text"}}, - output=[], - object="response", - tools=[], - parallel_tool_calls=False, - tool_choice="auto", - metadata=req.get("metadata"), - error=ResponseError( - code="server_error", - message=str(e), - ), - ), - ) - sequence_number += 1 - yield self.chunk_to_sse_element(response_failed) - - finally: - thread.join() - - return stream_response(generation_streamer, request_id) - - def generate_response_non_streaming(self, req: dict) -> dict: - """ - Generates an OpenAI Response in non-streaming mode (single JSON payload). - - Args: - req (`dict`): The request to generate an OpenAI Response for. - - Returns: - `dict`: The OpenAI `Response` serialized as a dict. - """ - model_id_and_revision = self.process_model_name(req["model"]) - must_discard_cache = model_id_and_revision != self.last_model - self.last_model = model_id_and_revision - model, processor = self.load_model_and_processor(model_id_and_revision) - - if isinstance(req["input"], str): - inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] - inputs.append({"role": "user", "content": req["input"]}) - elif isinstance(req["input"], list): - if "instructions" in req: - if req["input"][0]["role"] != "system": - inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]] - else: - inputs = req["input"] - inputs[0]["content"] = req["instructions"] - else: - inputs = req["input"] - elif isinstance(req["input"], dict): - inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] - inputs.append(req["input"]) - else: - raise ValueError("inputs should be a list, dict, or str") - - inputs = processor.apply_chat_template( - inputs, add_generation_prompt=True, return_tensors="pt", return_dict=True - )["input_ids"] - inputs = inputs.to(model.device) - request_id = req.get("previous_response_id", "req_0") - - # Temporary hack for GPTOSS 1: don't filter special tokens - skip_special_tokens = True - if "gptoss" in model.config.architectures[0].lower(): - skip_special_tokens = False - - generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) - - last_kv_cache = None - if self.is_continuation(req) and not must_discard_cache: - seq_len = self.last_kv_cache.get_seq_length() - if inputs.shape[-1] > seq_len: - last_kv_cache = self.last_kv_cache - - generate_output = model.generate( - inputs=inputs, - attention_mask=torch_ones_like(inputs), - generation_config=generation_config, - return_dict_in_generate=True, - past_key_values=last_kv_cache, - ) - # save KV cache - self.last_kv_cache = generate_output.past_key_values - - # Decode full text - full_text = processor.batch_decode(generate_output.sequences, skip_special_tokens=skip_special_tokens)[0] - - created_at = time.time() - response_output_item = ResponseOutputMessage( - id=f"msg_{request_id}", - type="message", - status="completed", - role="assistant", - content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])], - annotations=[], - ) - response_completed = Response( - id=f"resp_{request_id}", - created_at=created_at, - status="completed", - model=model_id_and_revision, - instructions=req.get("instructions"), - text={"format": {"type": "text"}}, - output=[response_output_item], - object="response", - tools=[], - parallel_tool_calls=req.get("parallel_tool_calls", False), - tool_choice="auto", - metadata=req.get("metadata"), - ) - return response_completed.model_dump(exclude_none=True) - - def generate_transcription(self, req: dict) -> Generator[str, None, None]: - """ - Generates an OpenAI Transcription using the audio file. - - Args: - req (`dict`): The request containing the audio file and model information. - - Returns: - `Generator[str, None, None]`: A generator that yields the transcription result. - """ - # TODO: implement streaming transcription (currently, it's not streaming) - if not is_librosa_available(): - raise ImportError( - "Missing librosa dependency for audio transcription. Please install with `pip install librosa`" - ) - model_id_and_revision = self.process_model_name(req["model"]) - audio_model, audio_processor = self.load_audio_model_and_processor(model_id_and_revision) - - generation_streamer = TextIteratorStreamer( - audio_processor.tokenizer, skip_special_tokens=True, skip_prompt=True - ) - generation_config = create_generation_config_from_req( - req, model_generation_config=audio_model.generation_config - ) - - # Read the binary audio file using librosa - model_sampling_rate = audio_processor.feature_extractor.sampling_rate - audio_bytes = io.BytesIO(req["file"]) - audio_array, _ = librosa.load(audio_bytes, sr=model_sampling_rate, mono=True) - audio_inputs = audio_processor(audio_array, sampling_rate=model_sampling_rate, return_tensors="pt").to( - audio_model.device - ) - audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype) - - generation_kwargs = { - "streamer": generation_streamer, - "generation_config": generation_config, - "return_dict_in_generate": True, - } - - def _generate_transcription(): - generated_ids = audio_model.generate(**audio_inputs, **generation_kwargs) - transcription_text = audio_processor.batch_decode(generated_ids.sequences, skip_special_tokens=True)[0] - transcription = Transcription(text=transcription_text) - yield f"{transcription.model_dump_json(exclude_none=True)}" - - return _generate_transcription() - - def is_continuation(self, req: dict) -> bool: - """ - Determines whether the current request is a continuation of the last request. In other words, if it is the - same chat session. - - Args: - req (`dict`): The request to check. - - Returns: - `True` if the request is a continuation of the last request, `False` otherwise. - """ - messages = req.get("messages") or req.get("input") # ChatCompletion and Response have different fields - req_continues_last_messages = True - - # No cached messages: this is a new request - if self.last_messages is None: - req_continues_last_messages = False - # The new request has no new rounds of conversation: this is a new request - elif len(self.last_messages) >= len(messages): - req_continues_last_messages = False - # Otherwise, check that the last messages are a subset of the new request - else: - for i in range(len(self.last_messages)): - if self.last_messages[i] != messages[i]: - req_continues_last_messages = False - break - - self.last_messages = messages - return req_continues_last_messages - - def get_quantization_config(self) -> BitsAndBytesConfig | None: - """ - Returns the quantization config for the given CLI arguments. - - Returns: - `Optional[BitsAndBytesConfig]`: The quantization config. - """ - if self.quantization == "bnb-4bit": - quantization_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_use_double_quant=True, - ) - elif self.quantization == "bnb-8bit": - quantization_config = BitsAndBytesConfig(load_in_8bit=True) - else: - quantization_config = None - - if quantization_config is not None: - logger.warning(f"Quantization applied with the following config: {quantization_config}") - - return quantization_config - - def process_model_name(self, model_id: str) -> str: - """ - Applies the `force_model` CLI argument and canonicalizes the model name to the format "model_id@revision". - If the model_id DOESN'T contain an @, it defaults to "model_id@main". - - Args: - model_id (`str`): The model ID. - - Returns: - `str`: The canonicalized model name to be used - """ - if self.force_model is not None: - model_id = self.force_model - if "@" in model_id: - return model_id - return f"{model_id}@main" - - def _load_model_and_data_processor( - self, model_id_and_revision: str, progress_callback: Callable[[dict], None] | None = None - ): - """ - Generic method to load a model and a data processor from a model ID and revision, making use of the serve CLI - arguments. - - Args: - model_id_and_revision (`str`): - The model ID and revision to load. - model_cls (`type[PreTrainedModel]`): - The model class to load. - - Returns: - `tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]`: The loaded model and - data processor (tokenizer, audio processor, etc.). - """ - import torch - - from transformers import AutoConfig, AutoProcessor - - callback = progress_callback - mid = model_id_and_revision - - tqdm_class = set_tqdm_class(callback, mid) if callback is not None else None - - def emit_progress(stage: str): - if progress_callback is None: - return - progress_callback({"status": "loading", "model": model_id_and_revision, "stage": stage}) - - logger.warning(f"Loading {model_id_and_revision}") - emit_progress("processor") - - if "@" in model_id_and_revision: - model_id, revision = model_id_and_revision.split("@", 1) - else: - model_id, revision = model_id_and_revision, "main" - - try: - data_processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - trust_remote_code=self.trust_remote_code, - ) - except OSError: - try: - data_processor = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - trust_remote_code=self.trust_remote_code, - ) - except OSError: - raise OSError("Failed to load processor with `AutoProcessor` and `AutoTokenizer`.") - - # processor done — move on to config - dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype) - quantization_config = self.get_quantization_config() - - model_kwargs = { - "revision": revision, - "attn_implementation": self.attn_implementation, - "dtype": dtype, - "device_map": self.device, - "trust_remote_code": self.trust_remote_code, - "quantization_config": quantization_config, - } - - emit_progress("config") - config = AutoConfig.from_pretrained(model_id, **model_kwargs) - architecture = getattr(transformers, config.architectures[0]) - # weights stage events are emitted by _WeightsTqdm (and download by _DownloadAggregator) - model = architecture.from_pretrained(model_id, tqdm_class=tqdm_class, **model_kwargs) - - has_default_max_length = ( - model.generation_config.max_new_tokens is None and model.generation_config.max_length == 20 - ) - has_short_max_new_tokens = ( - model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 1024 - ) - if has_default_max_length or has_short_max_new_tokens: - model.generation_config.max_new_tokens = 1024 - - return model, data_processor - - def load_model_and_processor( - self, model_id_and_revision: str, progress_callback: Callable[[dict], None] | None = None - ) -> tuple["PreTrainedModel", "PreTrainedTokenizerFast"]: - """ - Loads the text model and processor from the given model ID and revision into the ServeCommand instance. - - Args: - model_id_and_revision (`str`): - The model ID and revision to load. - - Returns: - `tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and processor. - """ - with self.model_locks_guard: - lock = self.model_locks.setdefault(model_id_and_revision, threading.Lock()) - - with lock: - if ( - model_id_and_revision not in self.loaded_models - or self.loaded_models[model_id_and_revision].is_deleted() - ): - model, processor = self._load_model_and_data_processor( - model_id_and_revision, progress_callback=progress_callback - ) - self.loaded_models[model_id_and_revision] = TimedModel( - model, - timeout_seconds=self.model_timeout, - processor=processor, - ) - if progress_callback is not None: - progress_callback({"status": "ready", "model": model_id_and_revision, "cached": False}) - else: - self.loaded_models[model_id_and_revision].reset_timer() - model = self.loaded_models[model_id_and_revision].model - processor = self.loaded_models[model_id_and_revision].processor - if progress_callback is not None: - progress_callback({"status": "ready", "model": model_id_and_revision, "cached": True}) - - return model, processor - - def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple["PreTrainedModel", "ProcessorMixin"]: - """ - Loads the audio model and processor from the given model ID and revision into the ServeCommand instance. - - Args: - model_id_and_revision (`str`): - The model ID and revision to load. - - Returns: - `tuple[PreTrainedModel, ProcessorMixin]`: The loaded audio model and processor. - """ - with self.model_locks_guard: - lock = self.model_locks.setdefault(model_id_and_revision, threading.Lock()) - - with lock: - if ( - model_id_and_revision not in self.loaded_models - or self.loaded_models[model_id_and_revision].is_deleted() - ): - logger.warning(f"Loading model into cache: {model_id_and_revision}") - audio_model, audio_processor = self._load_model_and_data_processor(model_id_and_revision) - self.loaded_models[model_id_and_revision] = TimedModel( - audio_model, - timeout_seconds=self.model_timeout, - processor=audio_processor, - ) - else: - self.loaded_models[model_id_and_revision].reset_timer() - audio_model = self.loaded_models[model_id_and_revision].model - audio_processor = self.loaded_models[model_id_and_revision].processor - - return audio_model, audio_processor + self.server.should_exit = True + self._thread.join(timeout=2) -# set docstring separately to make it look nice (Typer doesn't play well with the class command) Serve.__doc__ = """ Run a FastAPI server to serve models on-demand with an OpenAI compatible API. - Models will be loaded and unloaded automatically based on usage and a timeout. \b -The server will expose the following endpoints: - - POST /v1/chat/completions: Generates chat completions. - - POST /v1/responses: Generates responses. - - POST /v1/audio/transcriptions: Generates transcriptions from audio. - - GET /v1/models: Lists available models for 3rd party tools. +Endpoints: + POST /v1/chat/completions — Chat completions (streaming + non-streaming). + GET /v1/models — Lists available models. + GET /health — Health check. -Requires FastAPI and Uvicorn to be installed. +Requires FastAPI and Uvicorn: pip install transformers[serving] """ - -if __name__ == "__main__": - serve = Serve() diff --git a/src/transformers/cli/serving/__init__.py b/src/transformers/cli/serving/__init__.py new file mode 100644 index 000000000000..118d3a9c2012 --- /dev/null +++ b/src/transformers/cli/serving/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .model_manager import ModelManager +from .server import build_server +from .utils import Modality diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py new file mode 100644 index 000000000000..520eab7abbf4 --- /dev/null +++ b/src/transformers/cli/serving/chat_completion.py @@ -0,0 +1,410 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Handler for the /v1/chat/completions endpoint. + +Supports streaming (SSE via DirectStreamer) and non-streaming (JSON) responses. +""" + +import asyncio +import time +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING + +from ...utils import logging +from ...utils.import_utils import is_serve_available + + +if is_serve_available(): + from fastapi.responses import JSONResponse, StreamingResponse + from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall + from openai.types.chat.chat_completion import Choice + from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall + from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk + from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming + from openai.types.completion_usage import CompletionUsage + +from .utils import ( + BaseGenerateManager, + BaseHandler, + ToolCallParser, + _StreamError, + detect_tool_format, +) + + +if TYPE_CHECKING: + from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin + + +class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): + generation_config: str + seed: int + + +# Fields accepted by the OpenAI schema but not yet supported. +# Receiving these raises an error to avoid silent misbehaviour. +# NOTE: "stop" is NOT in this set — we map it to stop_strings. +UNUSED_CHAT_COMPLETION_FIELDS = { + "audio", + "function_call", + "functions", + "logprobs", + "max_completion_tokens", + "metadata", + "modalities", + "n", + "parallel_tool_calls", + "prediction", + "presence_penalty", + "reasoning_effort", + "response_format", + "service_tier", + "store", + "stream_options", + "tool_choice", + "top_logprobs", + "user", + "web_search_options", +} + + +logger = logging.get_logger(__name__) + + +class ChatCompletionHandler(BaseHandler): + """Handler for the `/v1/chat/completions` endpoint. + + Supports both streaming (SSE) and non-streaming (JSON) responses. + """ + + _valid_params_class = TransformersCompletionCreateParamsStreaming + _unused_fields = UNUSED_CHAT_COMPLETION_FIELDS + + async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: + """Validate the request, load the model, and dispatch to streaming or non-streaming. + + Args: + body (`dict`): The raw JSON request body (OpenAI chat completion format). + request_id (`str`): Unique request identifier (from header or auto-generated). + + Returns: + `StreamingResponse | JSONResponse`: SSE stream or JSON depending on ``body["stream"]``. + """ + self._validate_request(body) + + model_id, model, processor = self._resolve_model(body) + modality = self.model_manager.get_model_modality(model, processor=processor) + use_cb = self.generation_state.use_continuous_batching(model, modality) + logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}") + gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb) + processor_inputs = self.get_processor_inputs_from_messages(body["messages"], modality) + + inputs = processor.apply_chat_template( + processor_inputs, + add_generation_prompt=True, + tools=body.get("tools"), + return_tensors=None if use_cb else "pt", + return_dict=True, + tokenize=True, + ) + if not use_cb: + inputs = inputs.to(model.device) + + gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) + # TODO: remove when CB supports per-request generation config + if use_cb: + gen_manager.init_cb(model, gen_config) + + # Detect tool support for the loaded model + # TODO: after tool_call start token, use constrained generation to: + # 1. force generation to pick from the available tool names + # 2. force generation to produce valid JSON matching the tool's parameter schema + tool_format = detect_tool_format(model) if body.get("tools") else None + + streaming = body.get("stream") + if streaming: + return self._streaming( + request_id, + model, + processor, + model_id, + inputs, + gen_config, + gen_manager=gen_manager, + tool_format=tool_format, + ) + else: + return await self._non_streaming( + request_id, + model, + processor, + model_id, + inputs, + gen_config, + gen_manager=gen_manager, + tool_format=tool_format, + ) + + # ----- streaming ----- + + def _streaming( + self, + request_id: str, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + model_id: str, + inputs: dict, + gen_config: "GenerationConfig", + gen_manager: BaseGenerateManager, + tool_format: dict | None = None, + ) -> StreamingResponse: + """Stream tokens as SSE via DirectStreamer.""" + queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id) + input_ids = inputs["input_ids"] + # CB returns plain lists, regular path returns tensors + input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] + parser = ToolCallParser(tool_format) if tool_format else None + + async def sse_gen() -> AsyncGenerator[str, None]: + has_tool_calls = False + try: + yield self._build_chunk_sse(request_id, role="assistant", model=model_id) + + done = False + while not done: + text = await queue.get() + batch = [text] + try: + while True: + batch.append(queue.get_nowait()) + except asyncio.QueueEmpty: + pass + + sse_parts: list[str] = [] + for text in batch: + if text is None: + done = True + break + if isinstance(text, _StreamError): + sse_parts.append(f'data: {{"error": "{text.msg}"}}\n\n') + yield "".join(sse_parts) + return + + # Tool call parsing: None = normal text, CONSUMED = buffering, else = tool call dict + chunk_kwargs = {"content": text} + if parser is not None and (result := parser.feed(text)) is not None: + if result is ToolCallParser.CONSUMED: + continue + has_tool_calls = True + chunk_kwargs = { + "tool_calls": [ + ChoiceDeltaToolCall( + index=0, + type="function", + id=f"{request_id}_tool_call", + function={"name": result["name"], "arguments": result["arguments"]}, + ) + ] + } + + sse_parts.append(self._build_chunk_sse(request_id, model=model_id, **chunk_kwargs)) + + if sse_parts: + yield "".join(sse_parts) + + hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens + if has_tool_calls: + finish_reason = "tool_calls" + elif hit_max: + finish_reason = "length" + else: + finish_reason = "stop" + usage = CompletionUsage( + prompt_tokens=input_len, + completion_tokens=streamer.total_tokens, + total_tokens=input_len + streamer.total_tokens, + ) + yield self._build_chunk_sse( + request_id, + finish_reason=finish_reason, + model=model_id, + usage=usage, + ) + except (GeneratorExit, asyncio.CancelledError): + # Client disconnected — abort generation to free GPU. + # Re-raise is mandatory: Python raises RuntimeError if GeneratorExit is swallowed. + streamer.cancel() + raise + + return StreamingResponse(sse_gen(), media_type="text/event-stream") + + # ----- non-streaming ----- + + async def _non_streaming( + self, + request_id: str, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + model_id: str, + inputs: dict, + gen_config: "GenerationConfig", + gen_manager: BaseGenerateManager, + tool_format: dict | None = None, + ) -> JSONResponse: + """Run generation and return a JSONResponse.""" + content, input_len, generated_ids = await gen_manager.generate_non_streaming( + model, processor, inputs, gen_config, request_id=request_id + ) + + hit_max = gen_config.max_new_tokens is not None and len(generated_ids) >= gen_config.max_new_tokens + completion_tokens = len(generated_ids) + usage = CompletionUsage( + prompt_tokens=input_len, + completion_tokens=completion_tokens, + total_tokens=input_len + completion_tokens, + ) + + # Parse tool calls from the generated text + tool_calls = None + if tool_format is not None: + parsed = ToolCallParser.parse(content, tool_format) + if parsed is not None: + tool_calls = [ + ChatCompletionMessageToolCall( + id=f"{request_id}_tool_call", + type="function", + function={"name": tc["name"], "arguments": tc["arguments"]}, + ) + for tc in parsed + ] + + if tool_calls is not None: + finish_reason = "tool_calls" + elif hit_max: + finish_reason = "length" + else: + finish_reason = "stop" + + return JSONResponse( + self._build_completion( + request_id, + content, + model_id, + finish_reason=finish_reason, + usage=usage, + tool_calls=tool_calls, + ), + media_type="application/json", + ) + + # ----- helpers ----- + + def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): + """Apply Chat Completions params (``max_tokens``, ``frequency_penalty``, ``logit_bias``, + ``stop``) on top of the base generation config.""" + generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) + + if body.get("max_tokens") is not None: + generation_config.max_new_tokens = int(body["max_tokens"]) + if body.get("frequency_penalty") is not None: + generation_config.repetition_penalty = 1.0 + float(body["frequency_penalty"]) + if body.get("logit_bias") is not None: + generation_config.sequence_bias = {(int(k),): v for k, v in body["logit_bias"].items()} + if body.get("stop") is not None: + generation_config.stop_strings = body["stop"] + + return generation_config + + # ----- response builders ----- + + def _build_completion( + self, + request_id: str, + content: str, + model_id: str, + finish_reason: str, + usage: CompletionUsage | None = None, + tool_calls: list[dict] | None = None, + ) -> dict: + """Build a non-streaming ChatCompletion response dict. + + Args: + request_id (`str`): Unique request identifier. + content (`str`): The generated text. + model_id (`str`): Model ID to include in the response. + finish_reason (`str`): Why generation stopped (``"stop"``, ``"length"``, ``"tool_calls"``). + usage (`CompletionUsage`, *optional*): Token usage statistics. + tool_calls (`list[dict]`, *optional*): Parsed tool calls, if any. + + Returns: + `dict`: Serialized ``ChatCompletion`` ready for JSON response. + """ + message = ChatCompletionMessage(content=content, role="assistant", tool_calls=tool_calls) + result = ChatCompletion( + id=request_id, + created=int(time.time()), + object="chat.completion", + model=model_id, + choices=[ + Choice( + index=0, + message=message, + finish_reason=finish_reason, + ) + ], + usage=usage, + ) + return result.model_dump(exclude_none=True) + + def _build_chunk_sse( + self, + request_id: str = "", + content: str | None = None, + model: str | None = None, + role: str | None = None, + finish_reason: str | None = None, + tool_calls: list | None = None, + usage: CompletionUsage | None = None, + ) -> str: + """Build a streaming ``ChatCompletionChunk`` and format it as an SSE ``data:`` line. + + Args: + request_id (`str`): Unique request identifier. + content (`str`, *optional*): Text content delta. + model (`str`, *optional*): Model ID. + role (`str`, *optional*): Role (only sent in the first chunk). + finish_reason (`str`, *optional*): Set on the final chunk. + tool_calls (`list`, *optional*): Tool call deltas. + usage (`CompletionUsage`, *optional*): Token usage (sent with the final chunk). + + Returns: + `str`: A formatted SSE event string. + """ + chunk = ChatCompletionChunk( + id=request_id, + created=int(time.time()), + model=model, + choices=[ + ChoiceChunk( + delta=ChoiceDelta(content=content, role=role, tool_calls=tool_calls), + index=0, + finish_reason=finish_reason, + ) + ], + usage=usage, + system_fingerprint="", + object="chat.completion.chunk", + ) + return self.chunk_to_sse(chunk) diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py new file mode 100644 index 000000000000..39b477302c58 --- /dev/null +++ b/src/transformers/cli/serving/model_manager.py @@ -0,0 +1,457 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Model loading, caching, and lifecycle management. +""" + +import asyncio +import gc +import json +import threading +from collections.abc import Callable +from functools import lru_cache +from typing import TYPE_CHECKING + +from huggingface_hub import scan_cache_dir +from tqdm import tqdm + +import transformers +from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase + +from ...utils import logging +from .utils import Modality, make_progress_tqdm_class, reset_torch_cache + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin + + +logger = logging.get_logger(__name__) + + +class TimedModel: + """Wraps a model + processor and auto-unloads them after a period of inactivity. + + Args: + model: The loaded model. + timeout_seconds: Seconds of inactivity before auto-unload. Use -1 to disable. + processor: The associated processor or tokenizer. + on_unload: Optional callback invoked after the model is unloaded from memory. + """ + + def __init__( + self, + model: "PreTrainedModel", + timeout_seconds: int, + processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None, + on_unload: "Callable | None" = None, + ): + self.model = model + self._name_or_path = str(model.name_or_path) + self.processor = processor + self.timeout_seconds = timeout_seconds + self._on_unload = on_unload + self._timer = threading.Timer(self.timeout_seconds, self._timeout_reached) + self._timer.start() + + def reset_timer(self) -> None: + """Reset the inactivity timer (called on each request).""" + self._timer.cancel() + self._timer = threading.Timer(self.timeout_seconds, self._timeout_reached) + self._timer.start() + + def delete_model(self) -> None: + """Delete the model and processor, free GPU memory.""" + if hasattr(self, "model") and self.model is not None: + del self.model + del self.processor + self.model = None + self.processor = None + gc.collect() + reset_torch_cache() + self._timer.cancel() + if self._on_unload is not None: + self._on_unload() + + def _timeout_reached(self) -> None: + if self.timeout_seconds > 0: + self.delete_model() + logger.info(f"{self._name_or_path} was removed from memory after {self.timeout_seconds}s of inactivity") + + +class ModelManager: + """Loads, caches, and manages the lifecycle of models. + + Handlers receive a reference to this and call `load_model_and_processor()` + to get a model ready for inference. + + Args: + device: Device to place models on (e.g. "auto", "cuda", "cpu"). + dtype: Torch dtype override. "auto" derives from model weights. + trust_remote_code: Whether to trust remote code when loading models. + attn_implementation: Attention implementation override (e.g. "flash_attention_2"). + quantization: Quantization method ("bnb-4bit" or "bnb-8bit"). + model_timeout: Seconds before an idle model is unloaded. -1 disables. + force_model: If set, preload this model at init time. + """ + + def __init__( + self, + device: str = "auto", + dtype: str | None = "auto", + trust_remote_code: bool = False, + attn_implementation: str | None = None, + quantization: str | None = None, + model_timeout: int = 300, + force_model: str | None = None, + ): + self.loaded_models: dict[str, TimedModel] = {} + + # Thread-safety for concurrent load_model_and_processor calls + self._model_locks: dict[str, threading.Lock] = {} + self._model_locks_guard = threading.Lock() + + # Tracks in-flight loads for fan-out to multiple SSE subscribers (used by load_model_streaming) + self._loading_subscribers: dict[str, list[asyncio.Queue[str | None]]] = {} + self._loading_tasks: dict[str, asyncio.Task] = {} + + # Convert numeric device strings (e.g. "0") to int so device_map works correctly + self.device = int(device) if device.isdigit() else device + self.dtype = self._resolve_dtype(dtype) + self.trust_remote_code = trust_remote_code + self.attn_implementation = attn_implementation + self.quantization = quantization + self.model_timeout = model_timeout + self.force_model = force_model + + self._validate_args() + + # Preloaded models should never be auto-unloaded + if force_model is not None: + self.model_timeout = -1 + + # Preload the forced model after all state is initialized + if force_model is not None: + self.load_model_and_processor(self.process_model_name(force_model)) + + @staticmethod + def _resolve_dtype(dtype: str | None): + import torch + + if dtype in ("auto", None): + return dtype + resolved = getattr(torch, dtype, None) + if not isinstance(resolved, torch.dtype): + raise ValueError( + f"Unsupported dtype: '{dtype}'. Must be 'auto' or a valid torch dtype (e.g. 'float16', 'bfloat16')." + ) + return resolved + + def _validate_args(self): + if self.quantization is not None and self.quantization not in ("bnb-4bit", "bnb-8bit"): + raise ValueError( + f"Unsupported quantization method: '{self.quantization}'. Must be 'bnb-4bit' or 'bnb-8bit'." + ) + VALID_ATTN_IMPLEMENTATIONS = {"eager", "sdpa", "flash_attention_2", "flash_attention_3", "flex_attention"} + is_kernels_community = self.attn_implementation is not None and self.attn_implementation.startswith( + "kernels-community/" + ) + if ( + self.attn_implementation is not None + and not is_kernels_community + and self.attn_implementation not in VALID_ATTN_IMPLEMENTATIONS + ): + raise ValueError( + f"Unsupported attention implementation: '{self.attn_implementation}'. " + f"Must be one of {VALID_ATTN_IMPLEMENTATIONS} or a kernels-community kernel (e.g. 'kernels-community/flash-attn2')." + ) + + @staticmethod + def process_model_name(model_id: str) -> str: + """Canonicalize to `'model_id@revision'` format. Defaults to `@main`.""" + if "@" in model_id: + return model_id + return f"{model_id}@main" + + def get_quantization_config(self) -> BitsAndBytesConfig | None: + """Return a BitsAndBytesConfig based on the `quantization` setting, or None.""" + if self.quantization == "bnb-4bit": + return BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + ) + elif self.quantization == "bnb-8bit": + return BitsAndBytesConfig(load_in_8bit=True) + return None + + def _load_processor(self, model_id_and_revision: str) -> "ProcessorMixin | PreTrainedTokenizerFast": + """Load a processor for the given model. + + Args: + model_id_and_revision: Model ID in ``'model_id@revision'`` format. + """ + from transformers import AutoProcessor + + model_id, revision = model_id_and_revision.split("@", 1) + return AutoProcessor.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code) + + def _load_model( + self, model_id_and_revision: str, tqdm_class: type | None = None, progress_callback: Callable | None = None + ) -> "PreTrainedModel": + """Load a model. + + Args: + model_id_and_revision (`str`): Model ID in ``'model_id@revision'`` format. + tqdm_class (*optional*): tqdm subclass for progress bars during ``from_pretrained``. + progress_callback (`Callable`, *optional*): Called with progress dicts during loading. + + Returns: + `PreTrainedModel`: The loaded model. + """ + from transformers import AutoConfig + + model_id, revision = model_id_and_revision.split("@", 1) + + model_kwargs = { + "revision": revision, + "attn_implementation": self.attn_implementation, + "dtype": self.dtype, + "device_map": self.device, + "trust_remote_code": self.trust_remote_code, + "quantization_config": self.get_quantization_config(), + "tqdm_class": tqdm_class, + } + + if progress_callback is not None: + progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "config"}) + config = AutoConfig.from_pretrained(model_id, **model_kwargs) + architecture = getattr(transformers, config.architectures[0]) + + return architecture.from_pretrained(model_id, **model_kwargs) + + def load_model_and_processor( + self, + model_id_and_revision: str, + progress_callback: Callable | None = None, + tqdm_class: type | None = None, + ) -> "tuple[PreTrainedModel, ProcessorMixin | PreTrainedTokenizerFast]": + """Load a model (or return it from cache), resetting its inactivity timer. + + Args: + model_id_and_revision: Model ID in ``'model_id@revision'`` format. + progress_callback: If provided, called with dicts like + ``{"status": "loading", "model": ..., "stage": ...}`` during loading. + tqdm_class: Optional tqdm subclass for progress bars during ``from_pretrained``. + """ + # Per-model lock prevents duplicate loads when concurrent requests arrive + with self._model_locks_guard: + lock = self._model_locks.setdefault(model_id_and_revision, threading.Lock()) + + with lock: + if model_id_and_revision not in self.loaded_models: + logger.warning(f"Loading {model_id_and_revision}") + if progress_callback is not None: + progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "processor"}) + processor = self._load_processor(model_id_and_revision) + model = self._load_model( + model_id_and_revision, tqdm_class=tqdm_class, progress_callback=progress_callback + ) + self.loaded_models[model_id_and_revision] = TimedModel( + model, + timeout_seconds=self.model_timeout, + processor=processor, + on_unload=lambda key=model_id_and_revision: self.loaded_models.pop(key, None), + ) + if progress_callback is not None: + progress_callback({"status": "ready", "model": model_id_and_revision, "cached": False}) + else: + self.loaded_models[model_id_and_revision].reset_timer() + model = self.loaded_models[model_id_and_revision].model + processor = self.loaded_models[model_id_and_revision].processor + if progress_callback is not None: + progress_callback({"status": "ready", "model": model_id_and_revision, "cached": True}) + return model, processor + + async def load_model_streaming(self, model_id_and_revision: str): + """Load a model and stream progress as SSE events. + + Handles three cases: + 1. Model already cached → single ``ready`` event + 2. Load already in progress → join existing subscriber stream + 3. First request → start loading, broadcast to all subscribers + + Args: + model_id_and_revision (`str`): Model ID in ``'model_id@revision'`` format. + + Yields: + `str`: SSE ``data: ...`` lines with progress updates. + """ + mid = model_id_and_revision + queue: asyncio.Queue[str | None] = asyncio.Queue() + + # Case 1: already cached + if mid in self.loaded_models: + self.loaded_models[mid].reset_timer() + yield f"data: {json.dumps({'status': 'ready', 'model': mid, 'cached': True})}\n\n" + return + + # Case 2: load in progress — join existing subscribers + if mid in self._loading_tasks: + self._loading_subscribers[mid].append(queue) + while True: + item = await queue.get() + if item is None: + break + yield item + return + + # Case 3: first request — start the load + self._loading_subscribers[mid] = [queue] + loop = asyncio.get_running_loop() + + def enqueue(payload: dict): + msg = f"data: {json.dumps(payload)}\n\n" + + def broadcast(): + for q in self._loading_subscribers.get(mid, []): + q.put_nowait(msg) + + loop.call_soon_threadsafe(broadcast) + + tqdm_class = make_progress_tqdm_class(enqueue, mid) + + def _tqdm_hook(factory, args, kwargs): + return tqdm_class(*args, **kwargs) + + async def run_load(): + try: + # Install a global tqdm hook so the "Loading weights" bar in + # core_model_loading.py (which uses logging.tqdm) routes through + # our ProgressTqdm. The tqdm_class kwarg only covers download bars. + previous_hook = logging.set_tqdm_hook(_tqdm_hook) + try: + await asyncio.to_thread( + self.load_model_and_processor, + mid, + progress_callback=enqueue, + tqdm_class=tqdm_class, + ) + finally: + logging.set_tqdm_hook(previous_hook) + except Exception as e: + logger.error(f"Failed to load {mid}: {e}", exc_info=True) + enqueue({"status": "error", "model": mid, "message": str(e)}) + finally: + + def _send_sentinel(): + for q in self._loading_subscribers.pop(mid, []): + q.put_nowait(None) + self._loading_tasks.pop(mid, None) + + loop.call_soon_threadsafe(_send_sentinel) + + self._loading_tasks[mid] = asyncio.create_task(run_load()) + + while True: + item = await queue.get() + if item is None: + break + yield item + + def shutdown(self) -> None: + """Delete all loaded models and free resources.""" + for timed in list(self.loaded_models.values()): + timed.delete_model() + + @staticmethod + def get_model_modality( + model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None + ) -> Modality: + """Detect whether a model is an LLM or VLM based on its architecture. + + Args: + model (`PreTrainedModel`): The loaded model. + processor (`ProcessorMixin | PreTrainedTokenizerFast`, *optional*): + If a plain tokenizer (not a multi-modal processor), short-circuits to LLM. + + Returns: + `Modality`: The detected modality (``Modality.LLM`` or ``Modality.VLM``). + """ + if processor is not None and isinstance(processor, PreTrainedTokenizerBase): + return Modality.LLM + + from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, + ) + + model_classname = model.__class__.__name__ + if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values(): + return Modality.VLM + elif model_classname in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + return Modality.LLM + else: + raise ValueError(f"Unknown modality for: {model_classname}") + + @staticmethod + @lru_cache + def get_gen_models(cache_dir: str | None = None) -> list[dict]: + """List generative models (LLMs and VLMs) available in the HuggingFace cache. + + Args: + cache_dir (`str`, *optional*): Path to the HuggingFace cache directory. + Defaults to the standard cache location. + + Returns: + `list[dict]`: OpenAI-compatible model list entries with ``id``, ``object``, etc. + """ + from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, + ) + + generative_models = [] + logger.warning("Scanning the cache directory for LLMs and VLMs.") + + for repo in tqdm(scan_cache_dir(cache_dir).repos): + if repo.repo_type != "model": + continue + + for ref, revision_info in repo.refs.items(): + config_path = next((f.file_path for f in revision_info.files if f.file_name == "config.json"), None) + if not config_path: + continue + + config = json.loads(config_path.open().read()) + if not (isinstance(config, dict) and "architectures" in config): + continue + + architectures = config["architectures"] + llms = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values() + vlms = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() + + if any(arch for arch in architectures if arch in [*llms, *vlms]): + author = repo.repo_id.split("/") if "/" in repo.repo_id else "" + repo_handle = repo.repo_id + (f"@{ref}" if ref != "main" else "") + generative_models.append( + { + "owned_by": author, + "id": repo_handle, + "object": "model", + "created": repo.last_modified, + } + ) + + return generative_models diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py new file mode 100644 index 000000000000..23d49a480d33 --- /dev/null +++ b/src/transformers/cli/serving/response.py @@ -0,0 +1,564 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Handler for the /v1/responses endpoint (OpenAI Responses API). + +Supports streaming (SSE) and non-streaming (JSON) responses. +""" + +import asyncio +import time +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING + +from ...utils import logging +from ...utils.import_utils import is_serve_available + + +if is_serve_available(): + from fastapi import HTTPException + from fastapi.responses import JSONResponse, StreamingResponse + from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseError, + ResponseErrorEvent, + ResponseFailedEvent, + ResponseFunctionCallArgumentsDoneEvent, + ResponseFunctionToolCall, + ResponseInProgressEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, + ResponseOutputText, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, + ) + from openai.types.responses.response_create_params import ResponseCreateParamsStreaming + from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage + +from .utils import ( + BaseGenerateManager, + BaseHandler, + ToolCallParser, + _StreamError, + detect_tool_format, +) + + +if TYPE_CHECKING: + from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin + + +logger = logging.get_logger(__name__) + + +class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): + generation_config: str + seed: int + + +UNUSED_RESPONSE_FIELDS = { + "background", + "include", + "max_tool_calls", + "previous_response_id", + "prompt", + "reasoning", + "service_tier", + "store", + "text", + "tool_choice", + "top_logprobs", + "truncation", + "user", +} + + +class ResponseHandler(BaseHandler): + """Handler for the ``/v1/responses`` endpoint.""" + + _valid_params_class = TransformersResponseCreateParamsStreaming + _unused_fields = UNUSED_RESPONSE_FIELDS + + async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: + """Validate, load model, dispatch to streaming or non-streaming. + + Args: + body (`dict`): The raw JSON request body (OpenAI Responses API format). + request_id (`str`): Unique request identifier (from header or auto-generated). + + Returns: + `StreamingResponse | JSONResponse`: SSE stream or JSON depending on ``body["stream"]``. + """ + self._validate_request(body) + + model_id, model, processor = self._resolve_model(body) + modality = self.model_manager.get_model_modality(model, processor=processor) + use_cb = self.generation_state.use_continuous_batching(model, modality) + logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}") + gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb) + + # Two-step input conversion (chat completions skips step 1 since messages are already standard): + # 1. Normalize Responses API input (string/list/dict + instructions) → standard messages list + # 2. Transform message content for the HF processor (VLM image handling, text joining, etc.) + messages = self._input_to_messages(body) + processor_inputs = self.get_processor_inputs_from_messages(messages, modality) + + inputs = processor.apply_chat_template( + processor_inputs, + add_generation_prompt=True, + tools=body.get("tools"), + return_tensors=None if use_cb else "pt", + return_dict=True, + tokenize=True, + ) + if not use_cb: + inputs = inputs.to(model.device) + + gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) + # TODO: remove when CB supports per-request generation config + if use_cb: + gen_manager.init_cb(model, gen_config) + tool_format = detect_tool_format(model) if body.get("tools") else None + + streaming = body.get("stream", True) + if streaming: + return self._streaming( + request_id, + model, + processor, + model_id, + body, + inputs, + gen_config, + gen_manager=gen_manager, + tool_format=tool_format, + ) + else: + return await self._non_streaming( + request_id, + model, + processor, + model_id, + body, + inputs, + gen_config, + gen_manager=gen_manager, + tool_format=tool_format, + ) + + # ----- input conversion ----- + + @staticmethod + def _input_to_messages(body: dict) -> list[dict]: + """Convert the Responses API ``input`` field to a list of chat messages. + + Handles string, list, and dict inputs. If ``instructions`` is provided, it is + prepended as a system message (or replaces an existing one). + + Args: + body (`dict`): The raw request body containing ``input`` and optionally ``instructions``. + + Returns: + `list[dict]`: Standard OpenAI-format chat messages. + """ + inp = body["input"] + instructions = body.get("instructions") + + if isinstance(inp, str): + messages = [{"role": "system", "content": instructions}] if instructions else [] + messages.append({"role": "user", "content": inp}) + elif isinstance(inp, list): + if instructions: + if inp[0]["role"] != "system": + messages = [{"role": "system", "content": instructions}, *inp] + else: + messages = list(inp) + messages[0]["content"] = instructions + else: + messages = inp + elif isinstance(inp, dict): + messages = [{"role": "system", "content": instructions}] if instructions else [] + messages.append(inp) + else: + raise HTTPException(status_code=422, detail="'input' must be a string, list, or dict") + + return messages + + # ----- streaming ----- + + def _streaming( + self, + request_id: str, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + model_id: str, + body: dict, + inputs: dict, + gen_config: "GenerationConfig", + gen_manager: BaseGenerateManager, + tool_format: dict | None = None, + ) -> StreamingResponse: + """Generate a streaming Responses API reply (SSE) using DirectStreamer.""" + queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id) + input_ids = inputs["input_ids"] + # CB returns plain lists, regular path returns tensors + input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] + parser = ToolCallParser(tool_format) if tool_format else None + + seq = 0 + output_index = 0 + created_at = time.time() + resp_id = f"resp_{request_id}" + msg_id = f"msg_{request_id}" + + response_base = { + "id": resp_id, + "created_at": created_at, + "model": model_id, + "object": "response", + # Required by pydantic but not used — echo request config back + "tools": [], + "parallel_tool_calls": body.get("parallel_tool_calls", False), + "tool_choice": "auto", + } + + async def event_stream() -> AsyncGenerator[str, None]: + nonlocal seq, output_index + + try: + # 1. Created + In progress + yield self.chunk_to_sse( + ResponseCreatedEvent( + type="response.created", + sequence_number=seq, + response=Response(**response_base, status="queued", output=[]), + ) + ) + seq += 1 + yield self.chunk_to_sse( + ResponseInProgressEvent( + type="response.in_progress", + sequence_number=seq, + response=Response(**response_base, status="in_progress", output=[]), + ) + ) + seq += 1 + + # 2. Output item added (message) + yield self.chunk_to_sse( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=seq, + output_index=output_index, + item=ResponseOutputMessage( + id=msg_id, + type="message", + status="in_progress", + role="assistant", + content=[], + ), + ) + ) + seq += 1 + + # 3. Content part added + yield self.chunk_to_sse( + ResponseContentPartAddedEvent( + type="response.content_part.added", + item_id=msg_id, + sequence_number=seq, + output_index=output_index, + content_index=0, + part=ResponseOutputText(type="output_text", text="", annotations=[]), + ) + ) + seq += 1 + + # 4. Stream tokens — drain queue to batch HTTP writes + full_text = "" + tool_calls = [] + done = False + + while not done: + text = await queue.get() + # Drain all available tokens for one batched HTTP write + batch = [text] + try: + while True: + batch.append(queue.get_nowait()) + except asyncio.QueueEmpty: + pass + + sse_parts: list[str] = [] + for text in batch: + if text is None: + done = True + break + if isinstance(text, _StreamError): + logger.error(f"Exception in response generation: {text.msg}") + sse_parts.append( + self.chunk_to_sse( + ResponseErrorEvent(type="error", sequence_number=seq, message=text.msg) + ) + ) + seq += 1 + sse_parts.append( + self.chunk_to_sse( + ResponseFailedEvent( + type="response.failed", + sequence_number=seq, + response=Response( + **response_base, + status="failed", + output=[], + error=ResponseError(code="server_error", message=text.msg), + ), + ) + ) + ) + yield "".join(sse_parts) + return + + # Tool call parsing + if parser is not None and (result := parser.feed(text)) is not None: + if result is not ToolCallParser.CONSUMED: + tc_id = f"{request_id}_tool_call" + name = result["name"] + arguments = result["arguments"] + tc_item = ResponseFunctionToolCall( + id=tc_id, + call_id=tc_id, + type="function_call", + name=name, + arguments=arguments, + status="completed", + ) + tool_calls.append(tc_item) + output_index += 1 + sse_parts.append( + self.chunk_to_sse( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=seq, + output_index=output_index, + item=tc_item, + ) + ) + ) + seq += 1 + sse_parts.append( + self.chunk_to_sse( + ResponseFunctionCallArgumentsDoneEvent( + type="response.function_call_arguments.done", + sequence_number=seq, + item_id=tc_id, + output_index=output_index, + arguments=arguments, + name=name, + ) + ) + ) + seq += 1 + sse_parts.append( + self.chunk_to_sse( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=seq, + output_index=output_index, + item=tc_item, + ) + ) + ) + seq += 1 + continue + + full_text += text + sse_parts.append( + self.chunk_to_sse( + ResponseTextDeltaEvent( + type="response.output_text.delta", + item_id=msg_id, + sequence_number=seq, + output_index=0, + content_index=0, + delta=text, + logprobs=[], + ) + ) + ) + seq += 1 + + if sse_parts: + yield "".join(sse_parts) + + # 5. Close text output + output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[]) + yield self.chunk_to_sse( + ResponseTextDoneEvent( + type="response.output_text.done", + item_id=msg_id, + sequence_number=seq, + output_index=0, + content_index=0, + text=full_text, + logprobs=[], + ) + ) + seq += 1 + yield self.chunk_to_sse( + ResponseContentPartDoneEvent( + type="response.content_part.done", + item_id=msg_id, + sequence_number=seq, + output_index=0, + content_index=0, + part=output_text_part, + ) + ) + seq += 1 + + msg_item = ResponseOutputMessage( + id=msg_id, + type="message", + status="completed", + role="assistant", + content=[output_text_part], + annotations=[], + ) + yield self.chunk_to_sse( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=seq, + output_index=0, + item=msg_item, + ) + ) + seq += 1 + + # 6. Completed + all_output = [msg_item] + list(tool_calls) + usage = compute_usage(input_len, streamer.total_tokens) + yield self.chunk_to_sse( + ResponseCompletedEvent( + type="response.completed", + sequence_number=seq, + response=Response(**response_base, status="completed", output=all_output, usage=usage), + ) + ) + seq += 1 + except (GeneratorExit, asyncio.CancelledError): + # Client disconnected — abort generation to free GPU. + # Re-raise is mandatory: Python raises RuntimeError if GeneratorExit is swallowed. + streamer.cancel() + raise + + return StreamingResponse(event_stream(), media_type="text/event-stream") + + # ----- non-streaming ----- + + async def _non_streaming( + self, + request_id: str, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + model_id: str, + body: dict, + inputs: dict, + gen_config: "GenerationConfig", + gen_manager: BaseGenerateManager, + tool_format: dict | None = None, + ) -> JSONResponse: + """Generate a non-streaming Responses API reply (single JSON).""" + full_text, input_len, generated_ids = await gen_manager.generate_non_streaming( + model, processor, inputs, gen_config, request_id=request_id + ) + + output_items = [ + ResponseOutputMessage( + id=f"msg_{request_id}", + type="message", + status="completed", + role="assistant", + content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])], + annotations=[], + ) + ] + + # Parse tool calls from the generated text + if tool_format is not None: + parsed_calls = ToolCallParser.parse(full_text, tool_format) + if parsed_calls is not None: + for i, tc in enumerate(parsed_calls): + tc_id = f"{request_id}_tool_call" + output_items.append( + ResponseFunctionToolCall( + id=tc_id, + call_id=tc_id, + type="function_call", + name=tc["name"], + arguments=tc["arguments"], + status="completed", + ) + ) + + usage = compute_usage(input_len, len(generated_ids)) + response = Response( + id=f"resp_{request_id}", + created_at=time.time(), + status="completed", + model=model_id, + output=output_items, + object="response", + usage=usage, + # Required by pydantic but not used — echo request config back + tools=[], + parallel_tool_calls=body.get("parallel_tool_calls", False), + tool_choice="auto", + ) + return JSONResponse(response.model_dump(exclude_none=True)) + + # ----- helpers ----- + + def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): + """Apply Responses API params (``max_output_tokens``) on top of the base generation config.""" + generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) + + if body.get("max_output_tokens") is not None: + generation_config.max_new_tokens = int(body["max_output_tokens"]) + + return generation_config + + +def compute_usage(input_tokens: int, output_tokens: int) -> ResponseUsage: + """Build a ``ResponseUsage`` object for a Responses API reply. + + Args: + input_tokens (`int`): Number of prompt tokens. + output_tokens (`int`): Number of generated tokens. + + Returns: + `ResponseUsage`: Usage statistics with zero-filled detail fields. + """ + return ResponseUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ) diff --git a/src/transformers/cli/serving/server.py b/src/transformers/cli/serving/server.py new file mode 100644 index 000000000000..ec0f287b36ee --- /dev/null +++ b/src/transformers/cli/serving/server.py @@ -0,0 +1,127 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FastAPI app factory. +""" + +import uuid +from contextlib import asynccontextmanager + +from ...utils import logging +from ...utils.import_utils import is_serve_available + + +if is_serve_available(): + from fastapi import FastAPI, Request + from fastapi.middleware.cors import CORSMiddleware + from fastapi.responses import JSONResponse, StreamingResponse + +from .chat_completion import ChatCompletionHandler +from .model_manager import ModelManager +from .response import ResponseHandler +from .transcription import TranscriptionHandler +from .utils import X_REQUEST_ID + + +logger = logging.get_logger(__name__) + + +def build_server( + model_manager: ModelManager, + chat_handler: ChatCompletionHandler, + response_handler: ResponseHandler, + transcription_handler: TranscriptionHandler, + enable_cors: bool = False, +) -> FastAPI: + """Build and return a configured FastAPI application. + + Args: + model_manager: Handles model loading, caching, and cleanup. + chat_handler: Handles `/v1/chat/completions` requests. + response_handler: Handles `/v1/responses` requests. + enable_cors: If `True`, adds permissive CORS middleware (allow all origins). + + Returns: + A FastAPI app ready to be passed to uvicorn. + """ + + @asynccontextmanager + async def lifespan(app: FastAPI): + yield + model_manager.shutdown() + + app = FastAPI(lifespan=lifespan) + + if enable_cors: + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + logger.warning_once("CORS allow origin is set to `*`. Not recommended for production.") + + # ---- Middleware ---- + + @app.middleware("http") + async def request_id_middleware(request: Request, call_next): + """Get or set the request ID in the header.""" + request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) + request.state.request_id = request_id + response = await call_next(request) + response.headers[X_REQUEST_ID] = request_id + return response + + # ---- Routes ---- + + @app.post("/v1/chat/completions") + async def chat_completions(request: Request, body: dict): + return await chat_handler.handle_request(body, request.state.request_id) + + @app.post("/v1/responses") + async def responses(request: Request, body: dict): + return await response_handler.handle_request(body, request.state.request_id) + + @app.post("/v1/audio/transcriptions") + async def audio_transcriptions(request: Request): + return await transcription_handler.handle_request(request) + + @app.post("/load_model") + async def load_model(body: dict): + from fastapi import HTTPException + + model = body.get("model") + if model is None: + raise HTTPException(status_code=422, detail="Missing `model` field in the request body.") + model_id_and_revision = model_manager.process_model_name(model) + return StreamingResponse( + model_manager.load_model_streaming(model_id_and_revision), media_type="text/event-stream" + ) + + @app.post("/reset") + def reset(): + model_manager.shutdown() + return JSONResponse({"status": "ok"}) + + @app.get("/v1/models") + @app.options("/v1/models") + def list_models(): + return JSONResponse({"object": "list", "data": model_manager.get_gen_models()}) + + @app.get("/health") + def health(): + return JSONResponse({"status": "ok"}) + + return app diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py new file mode 100644 index 000000000000..b63add7e5ed6 --- /dev/null +++ b/src/transformers/cli/serving/transcription.py @@ -0,0 +1,185 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Handler for the /v1/audio/transcriptions endpoint. +""" + +import io +from typing import TYPE_CHECKING + +from ...utils import logging +from ...utils.import_utils import is_serve_available + + +if is_serve_available(): + from fastapi import HTTPException, Request + from fastapi.responses import JSONResponse, StreamingResponse + from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase + +from .model_manager import ModelManager +from .utils import DirectStreamer, GenerateManager, GenerationState, _StreamError + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, ProcessorMixin + + +logger = logging.get_logger(__name__) + + +class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False): + stream: bool + + +UNUSED_TRANSCRIPTION_FIELDS = { + "chunking_strategy", + "include", + "language", + "prompt", + "response_format", + "temperature", + "timestamp_granularities", +} + + +class TranscriptionHandler: + """Handler for ``POST /v1/audio/transcriptions``. + + Accepts a multipart/form-data request with an audio file and model name, + runs speech-to-text, and returns an OpenAI-compatible Transcription response. + + Standalone (does not extend :class:`BaseHandler`) because audio requests use + multipart form data, not JSON bodies, and don't need generation config or + validation. Shares the :class:`GenerationState` for thread safety. + """ + + def __init__(self, model_manager: ModelManager, generation_state: GenerationState): + """ + Args: + model_manager (`ModelManager`): Handles model loading, caching, and lifecycle. + generation_state (`GenerationState`): Shared generation state for thread safety. + """ + self.model_manager = model_manager + self.generation_state = generation_state + + def _validate_request(self, form_keys: set[str]) -> None: + """Validate transcription request fields.""" + unexpected = form_keys - TransformersTranscriptionCreateParams.__mutable_keys__ + if unexpected: + raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}") + unused = form_keys & UNUSED_TRANSCRIPTION_FIELDS + if unused: + logger.warning_once(f"Ignoring unsupported fields in the request: {unused}") + + async def handle_request(self, request: Request) -> JSONResponse | StreamingResponse: + """Parse multipart form, run transcription, return result. + + Args: + request (`Request`): FastAPI request containing multipart form data with + ``file`` (audio bytes), ``model`` (model ID), and optional ``stream`` flag. + + Returns: + `JSONResponse | StreamingResponse`: Transcription result or SSE stream. + """ + from transformers.utils.import_utils import is_librosa_available, is_multipart_available + + if not is_librosa_available(): + raise ImportError("Missing librosa dependency for audio transcription. Install with `pip install librosa`") + if not is_multipart_available(): + raise ImportError( + "Missing python-multipart dependency for file uploads. Install with `pip install python-multipart`" + ) + + async with request.form() as form: + self._validate_request(set(form.keys())) + file_bytes = await form["file"].read() + model = form["model"] + stream = str(form.get("stream", "false")).lower() == "true" + + model_id_and_revision = self.model_manager.process_model_name(model) + audio_model, audio_processor = self.model_manager.load_model_and_processor(model_id_and_revision) + gen_manager = self.generation_state.get_manager(model_id_and_revision) + audio_inputs = self._prepare_audio_inputs(file_bytes, audio_processor, audio_model) + + if stream: + return self._streaming(gen_manager, audio_model, audio_processor, audio_inputs) + return await self._non_streaming(gen_manager, audio_model, audio_processor, audio_inputs) + + @staticmethod + def _prepare_audio_inputs( + file_bytes: bytes, audio_processor: "ProcessorMixin", audio_model: "PreTrainedModel" + ) -> dict: + """Load audio bytes and convert to model inputs.""" + import librosa + + sampling_rate = audio_processor.feature_extractor.sampling_rate + audio_array, _ = librosa.load(io.BytesIO(file_bytes), sr=sampling_rate, mono=True) + audio_inputs = audio_processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").to( + audio_model.device + ) + audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype) + return audio_inputs + + async def _non_streaming( + self, + gen_manager: GenerateManager, + audio_model: "PreTrainedModel", + audio_processor: "ProcessorMixin", + audio_inputs: dict, + ) -> JSONResponse: + # Audio models have different inputs (input_features) and decode (batch_decode) + # than text models, so we use async_submit() directly instead of + # generate_non_streaming() + from openai.types.audio import Transcription + + generated_ids = await gen_manager.async_submit(audio_model.generate, **audio_inputs) + text = audio_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + return JSONResponse(Transcription(text=text).model_dump(exclude_none=True)) + + def _streaming( + self, + gen_manager: GenerateManager, + audio_model: "PreTrainedModel", + audio_processor: "ProcessorMixin", + audio_inputs: dict, + ) -> StreamingResponse: + # Same as _non_streaming — uses submit() directly because audio inputs + # differ from text. + import asyncio + + tokenizer = audio_processor.tokenizer if hasattr(audio_processor, "tokenizer") else audio_processor + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + streamer = DirectStreamer(tokenizer._tokenizer, loop, queue, skip_special_tokens=True) + gen_kwargs = {**audio_inputs, "streamer": streamer} + + def _run(): + try: + audio_model.generate(**gen_kwargs) + except Exception as e: + loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e))) + + gen_manager.submit(_run) + + async def sse_gen(): + while True: + text = await queue.get() + if text is None: + break + if isinstance(text, _StreamError): + yield f'data: {{"error": "{text.msg}"}}\n\n' + return + yield f"data: {text}\n\n" + + return StreamingResponse(sse_gen(), media_type="text/event-stream") diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py new file mode 100644 index 000000000000..d9828d123b12 --- /dev/null +++ b/src/transformers/cli/serving/utils.py @@ -0,0 +1,956 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Shared types, constants, and utilities for the serving layer. +""" + +import asyncio +import base64 +import copy +import enum +import json +import re +import tempfile +import threading +from abc import ABC, abstractmethod +from collections.abc import Callable +from concurrent.futures import Future +from io import BytesIO +from queue import Queue +from typing import TYPE_CHECKING + +from transformers.utils import logging + + +if TYPE_CHECKING: + import pydantic + import tokenizers + import torch + + from transformers import ( + ContinuousBatchingConfig, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerFast, + ProcessorMixin, + ) + from transformers.generation.continuous_batching.continuous_api import ContinuousBatchingManager + from transformers.generation.continuous_batching.requests import GenerationOutput + from transformers.generation.continuous_batching.scheduler import Scheduler + + from .model_manager import ModelManager + + +logger = logging.get_logger(__name__) + + +X_REQUEST_ID = "x-request-id" + + +class Modality(enum.Enum): + LLM = "LLM" + VLM = "VLM" + STT = "STT" + TTS = "TTS" + + +class _StreamError: + """Sentinel to signal an error from the generate thread.""" + + def __init__(self, msg: str): + self.msg = msg + + +class _GenerationCancelled(Exception): + """Raised inside ``DirectStreamer.put()`` to abort ``model.generate()``.""" + + +# Model-specific tokens that mark the start/end of a tool call block. +# TODO: extract these from the chat template at runtime instead of hardcoding. +# Qwen/Hermes use /, Mistral uses [TOOL_CALLS], etc. +# The markers are defined in each model's Jinja chat template. +_TOOL_CALL_TOKENS = { + "qwen": { + "start": "", + "end": "", + }, +} + + +def detect_tool_format(model: "PreTrainedModel") -> dict | None: + """Return the tool call token format for a model, if supported. + + Args: + model (`PreTrainedModel`): The loaded model. + + Returns: + `dict | None`: A dict ``{"start": str, "end": str}`` with the model's tool call + delimiters, or ``None`` if the model family is not recognized. + """ + architecture = model.config.architectures[0].lower() + for family in _TOOL_CALL_TOKENS: + if family in architecture: + return _TOOL_CALL_TOKENS[family] + return None + + +class ToolCallParser: + """Parses tool calls from model output. + + The model emits tool calls as structured text between start/end tokens + (e.g. ``{"name": "fn", "arguments": {...}}``). + + **Streaming** (``feed``): buffers tokens between start/end markers, parses + the complete block when the end marker is seen, returns a ``ChoiceDeltaToolCall``. + + **Non-streaming** (``parse``): extracts all tool call blocks from complete text. + + Usage:: + + parser = ToolCallParser(tool_format={"start": ..., "end": ...}) + for text_chunk in streamer: + result = parser.feed(text_chunk) + if result is None: + # Normal text — emit as content + elif result is ToolCallParser.CONSUMED: + # Buffering — skip + else: + # result is a ChoiceDeltaToolCall — emit it + """ + + def __init__(self, tool_format: dict): + self._tokens = tool_format + self._inside = False + self._buffer = "" + + # Sentinel: token was consumed by the parser but produced no output. + CONSUMED = object() + + def feed(self, text: str) -> object | dict | None: + """Feed a text chunk (streaming). + + Returns: + - ``None`` — normal text, not a tool token. Emit as content. + - ``CONSUMED`` — token consumed internally (buffering/markers). Skip. + - A ``ChoiceDeltaToolCall`` — emit as a tool call delta. + """ + if text.strip() == self._tokens["start"]: + self._inside = True + self._buffer = "" + return self.CONSUMED + + if text.strip() == self._tokens["end"]: + self._inside = False + block = self._buffer.strip() + self._buffer = "" + return self._parse_block(block) or self.CONSUMED + + if self._inside: + self._buffer += text + return self.CONSUMED + + return None + + @staticmethod + def _extract_name_and_args(block: str) -> tuple[str, str] | None: + """Extract (name, arguments_json) from a tool call block, or None if invalid.""" + if not block: + return None + parsed = json.loads(block) + name = parsed.get("name") + if name is None: + return None + arguments = parsed.get("arguments", {}) + return name, json.dumps(arguments) + + @staticmethod + def parse(text: str, tool_format: dict) -> list[dict] | None: + """Parse tool calls from complete text. + + Returns a list of ``{"name": str, "arguments": str}`` dicts, or ``None`` if none found. + """ + start, end = tool_format["start"], tool_format["end"] + tool_calls = [] + pos = 0 + while True: + s = text.find(start, pos) + if s < 0: + break + e = text.find(end, s + len(start)) + if e < 0: + break + result = ToolCallParser._extract_name_and_args(text[s + len(start) : e].strip()) + if result is not None: + tool_calls.append({"name": result[0], "arguments": result[1]}) + pos = e + len(end) + return tool_calls if tool_calls else None + + def _parse_block(self, block: str) -> dict | None: + """Parse a buffered tool call block. Returns ``{"name": str, "arguments": str}`` or None.""" + result = self._extract_name_and_args(block) + if result is None: + return None + return {"name": result[0], "arguments": result[1]} + + +class DownloadAggregator: + """Aggregates byte-progress across multiple concurrent download tqdm bars. + + huggingface_hub opens one tqdm bar per file shard. This class tracks them all and emits + a single aggregate ``{"stage": "download", "progress": {...}}`` event whenever any updates. + """ + + def __init__(self, enqueue: Callable, model_id: str): + self.enqueue = enqueue + self.model = model_id + self.bars: dict[int, tuple[int, int | None]] = {} + self.last_emitted_current: int | None = None + + def register(self, bar_id: int, total: int | None) -> None: + """Register a new download bar with its total byte count.""" + self.bars[bar_id] = (0, total) + self._emit() + + def update(self, bar_id: int, current: int, total: int | None) -> None: + """Update a bar's current byte count and emit aggregate progress.""" + self.bars[bar_id] = (current, total) + self._emit() + + def close(self, bar_id: int) -> None: + pass # keep the bar so totals remain correct + + def _emit(self) -> None: + agg_current = sum(c for c, _ in self.bars.values()) + if agg_current == self.last_emitted_current: + return + self.last_emitted_current = agg_current + totals = [t for _, t in self.bars.values() if t is not None] + agg_total = sum(totals) if totals else None + self.enqueue( + { + "status": "loading", + "model": self.model, + "stage": "download", + "progress": {"current": agg_current, "total": agg_total}, + } + ) + + +def make_progress_tqdm_class(callback: Callable, model_id: str) -> type: + """Create a tqdm subclass that routes progress to a callback. + + Bars with ``unit="B"`` are download bars — aggregated via ``DownloadAggregator``. + Other bars (e.g. "Loading weights") emit ``weights`` stage events. + + Args: + callback (`callable`): Called with a dict payload + ``{"status": "loading", "model": ..., "stage": ..., "progress": ...}``. + model_id (`str`): The model ID (included in progress payloads). + + Returns: + A tqdm subclass that forwards progress to *callback*. + """ + from tqdm.auto import tqdm as base_tqdm + + download_aggregator = DownloadAggregator(callback, model_id) + + class ProgressTqdm(base_tqdm): + def __init__(self, *args, **kwargs): + self.sse_unit = kwargs.get("unit") or "it" + kwargs["disable"] = True + super().__init__(*args, **kwargs) + self.n = 0 + self.last_emitted = -1 + if self.sse_unit == "B": + self._bar_id = id(self) + download_aggregator.register(self._bar_id, self.total) + + def update(self, n=1): + if n is None: + n = 1 + self.n += n + if self.sse_unit == "B": + download_aggregator.update(self._bar_id, self.n, self.total) + elif self.n != self.last_emitted: + self.last_emitted = self.n + callback( + { + "status": "loading", + "model": model_id, + "stage": "weights", + "progress": {"current": self.n, "total": self.total}, + } + ) + + def __iter__(self): + for item in self.iterable: + self.n += 1 + if self.sse_unit == "B": + download_aggregator.update(self._bar_id, self.n, self.total) + elif self.n != self.last_emitted: + self.last_emitted = self.n + callback( + { + "status": "loading", + "model": model_id, + "stage": "weights", + "progress": {"current": self.n, "total": self.total}, + } + ) + yield item + + def close(self): + if self.sse_unit == "B": + download_aggregator.close(self._bar_id) + super().close() + + return ProgressTqdm + + +class DirectStreamer: + """Streamer for ``model.generate()`` (used by :class:`GenerateManager`). + + Implements the ``put``/``end`` protocol that ``model.generate()`` expects: + generate calls ``put(token_tensor)`` after each decode step, and ``end()`` + when generation is complete. Tokens are decoded incrementally via the Rust + ``DecodeStream`` (O(1) per token) and pushed as text to an asyncio.Queue. + """ + + def __init__( + self, + tokenizer: "tokenizers.Tokenizer", + loop: asyncio.AbstractEventLoop, + queue: asyncio.Queue, + skip_special_tokens: bool = True, + ): + """ + Args: + tokenizer: The Rust tokenizer (``tokenizer._tokenizer``). + loop (`asyncio.AbstractEventLoop`): The event loop to push decoded text to. + queue (`asyncio.Queue`): The queue that receives decoded text chunks. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether to strip special tokens during decoding. + """ + from tokenizers.decoders import DecodeStream + + self._tokenizer = tokenizer + self._loop = loop + self._queue = queue + self._decode_stream = DecodeStream([], skip_special_tokens) + self._first = True + self._cancelled = threading.Event() + self.total_tokens = 0 + + def put(self, value: "torch.Tensor") -> None: + """Called by ``model.generate()`` after each decode step with new token(s).""" + if self._cancelled.is_set(): + raise _GenerationCancelled() + # The first put() contains the prompt tokens — skip since we only stream generated tokens. + if self._first: + self._first = False + return + for token_id in value.tolist(): + self.total_tokens += 1 + text = self._decode_stream.step(self._tokenizer, token_id) + if text is not None: + self._loop.call_soon_threadsafe(self._queue.put_nowait, text) + + def end(self) -> None: + """Called by ``model.generate()`` when generation is complete.""" + self._loop.call_soon_threadsafe(self._queue.put_nowait, None) + + def cancel(self) -> None: + """Signal cancellation. The next ``put()`` call will raise and abort ``model.generate()``.""" + self._cancelled.set() + + +class CBStreamer: + """Streamer for continuous batching (used by :class:`CBGenerateManager`). + + Same ``put``/``end`` protocol as :class:`DirectStreamer`, but called manually + by :class:`CBGenerateManager` instead of by ``model.generate()``: + ``put(output)`` receives a CB ``GenerationOutput``, decodes new tokens, and + pushes text to the asyncio.Queue. ``end()`` signals the stream is complete. + """ + + def __init__( + self, + cb_manager: "ContinuousBatchingManager", + request_id: str, + tokenizer: "tokenizers.Tokenizer", + loop: asyncio.AbstractEventLoop, + queue: asyncio.Queue, + ): + """ + Args: + cb_manager (`ContinuousBatchingManager`): The CB manager instance. + request_id (`str`): The request ID to track in the CB scheduler. + tokenizer: The Rust tokenizer (``tokenizer._tokenizer``). + loop (`asyncio.AbstractEventLoop`): The event loop to push decoded text to. + queue (`asyncio.Queue`): The queue that receives decoded text chunks. + """ + from tokenizers.decoders import DecodeStream + + self._cb = cb_manager + self._request_id = request_id + self._loop = loop + self._queue = queue + self._tokenizer = tokenizer + self._decode_stream = DecodeStream([], True) + self._prev_len = 0 + self.total_tokens = 0 + + def put(self, output: "GenerationOutput") -> None: + """Decode new tokens from a CB ``GenerationOutput`` and push text to the queue.""" + new_tokens = output.generated_tokens[self._prev_len :] + self._prev_len = len(output.generated_tokens) + for token_id in new_tokens: + self.total_tokens += 1 + text = self._decode_stream.step(self._tokenizer, token_id) + if text is not None: + self._queue.put_nowait(text) + + def end(self) -> None: + """Signal end of stream.""" + self._queue.put_nowait(None) + + def cancel(self) -> None: + """Cancel the CB request.""" + self._cb.cancel_request(self._request_id) + + +def set_torch_seed(seed: int) -> None: + """Set the PyTorch random seed for reproducible generation.""" + import torch + + torch.manual_seed(seed) + + +def reset_torch_cache() -> None: + """Empty the CUDA cache if a GPU is available.""" + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +class InferenceThread: + """Persistent thread for ``model.generate()`` calls. + + ``torch.compile`` with CUDA graphs stores state in thread-local storage. + All inference must run on the same thread to avoid corrupted graph state. + """ + + def __init__(self): + self._queue: Queue = Queue() + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def _run(self) -> None: + while True: + fn, args, kwargs, future, loop = self._queue.get() + try: + result = fn(*args, **kwargs) + if loop is not None: + loop.call_soon_threadsafe(future.set_result, result) + else: + future.set_result(result) + except Exception as e: + if loop is not None: + loop.call_soon_threadsafe(future.set_exception, e) + else: + future.set_exception(e) + + def submit(self, fn, *args, **kwargs) -> Future: + """Submit a callable to the inference thread. Returns a blocking Future.""" + future: Future = Future() + self._queue.put((fn, args, kwargs, future, None)) + return future + + def async_submit(self, fn, *args, **kwargs) -> asyncio.Future: + """Submit a callable to the inference thread. Returns an awaitable asyncio.Future.""" + loop = asyncio.get_running_loop() + future = loop.create_future() + self._queue.put((fn, args, kwargs, future, loop)) + return future + + +class BaseGenerateManager(ABC): + """Base class for generation managers. + + Subclasses: + - :class:`GenerateManager` — sequential ``model.generate()`` on a persistent thread. + - :class:`CBGenerateManager` — continuous batching with paged attention. + """ + + @abstractmethod + def generate_streaming( + self, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + inputs: dict, + gen_config: "GenerationConfig", + request_id: str, + ) -> tuple[asyncio.Queue, "DirectStreamer | CBStreamer"]: + """Start streaming generation. + + Args: + model (`PreTrainedModel`): The loaded model. + processor: The processor or tokenizer for decoding. + inputs (`dict`): Tokenized inputs (tensors for sequential, lists for CB). + gen_config (`GenerationConfig`): Generation parameters. + request_id (`str`): Unique request identifier. + + Returns: + `tuple[asyncio.Queue, DirectStreamer | CBStreamer]`: A ``(queue, streamer)`` pair + where *queue* yields ``str | _StreamError | None`` and *streamer* exposes + ``.total_tokens`` and ``.cancel()``. + """ + + @abstractmethod + def generate_non_streaming( + self, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + inputs: dict, + gen_config: "GenerationConfig", + request_id: str, + ) -> tuple[str, int, list[int]]: + """Run generation to completion. + + Args: + model (`PreTrainedModel`): The loaded model. + processor: The processor or tokenizer for decoding. + inputs (`dict`): Tokenized inputs (tensors for sequential, lists for CB). + gen_config (`GenerationConfig`): Generation parameters. + request_id (`str`): Unique request identifier. + + Returns: + `tuple[str, int, list[int]]`: ``(text, input_len, generated_ids)``. + """ + + @abstractmethod + def stop(self) -> None: + """Stop the generation manager and free resources.""" + + +class GenerateManager(BaseGenerateManager): + """Sequential generation via ``model.generate()`` on a persistent thread.""" + + def __init__(self): + self._thread = InferenceThread() + + def generate_streaming( + self, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + inputs: dict, + gen_config: "GenerationConfig", + request_id: str, + ) -> tuple[asyncio.Queue, DirectStreamer]: + """Start streaming generation via ``model.generate()`` on the inference thread.""" + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + streamer = DirectStreamer(processor._tokenizer, loop, queue, skip_special_tokens=True) + gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor} + + def _run() -> None: + try: + model.generate(**gen_kwargs) + except _GenerationCancelled: + loop.call_soon_threadsafe(queue.put_nowait, None) + except Exception as e: + loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e))) + + self.submit(_run) + return queue, streamer + + async def generate_non_streaming( + self, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + inputs: dict, + gen_config: "GenerationConfig", + request_id: str, + ) -> tuple[str, int, "torch.Tensor"]: + """Run generation to completion via ``model.generate()`` on the inference thread.""" + sequences = await self.async_submit( + model.generate, **inputs, generation_config=gen_config, tokenizer=processor + ) + input_len = inputs["input_ids"].shape[-1] + generated_ids = sequences[0, input_len:] + text = processor.decode(generated_ids, skip_special_tokens=True) + return text, input_len, generated_ids + + def submit(self, fn: Callable, *args, **kwargs) -> Future: + """Submit a callable to the inference thread. Returns a blocking Future.""" + return self._thread.submit(fn, *args, **kwargs) + + def async_submit(self, fn: Callable, *args, **kwargs) -> asyncio.Future: + """Submit a callable to the inference thread. Returns an awaitable asyncio.Future.""" + return self._thread.async_submit(fn, *args, **kwargs) + + def stop(self) -> None: + pass # inference thread is a daemon + + +class CBGenerateManager(BaseGenerateManager): + """Continuous batching generation via paged attention. + + Translates between the handler's text-level asyncio.Queue and CB's + token-level interface. Per-request: ``max_new_tokens``, ``eos_token_id``. + + The CB manager is initialized lazily on the first request via + :meth:`ensure_initialized`, using that request's ``gen_config`` for shared + sampling params (temperature, top_p, do_sample). + + .. todo:: Remove :meth:`init_cb` when CB supports per-request + generation config. At that point, ``gen_config`` can be passed directly + to ``add_request`` and the CB manager no longer needs a shared config. + """ + + def __init__(self, cb_config: "ContinuousBatchingConfig | None" = None): + self._cb = None + self._cb_config = cb_config + + def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> None: + """Initialize the CB manager on first call with the request's generation config. + + .. todo:: Remove when CB supports per-request generation config. + + Args: + model (`PreTrainedModel`): The loaded model (must support ``init_continuous_batching``). + gen_config (`GenerationConfig`): Generation config used for shared sampling params. + """ + if self._cb is not None: + return + + self._cb = model.init_continuous_batching( + generation_config=gen_config, continuous_batching_config=self._cb_config + ) + self._cb.start() + + def generate_streaming( + self, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + inputs: dict, + gen_config: "GenerationConfig", + request_id: str, + ) -> tuple[asyncio.Queue, CBStreamer]: + """Start streaming CB generation. Registers a per-request output handler.""" + loop = asyncio.get_running_loop() + text_queue: asyncio.Queue = asyncio.Queue() + + input_ids = inputs["input_ids"] + request_id = self._cb.add_request( + input_ids, + request_id=request_id, + streaming=True, + max_new_tokens=gen_config.max_new_tokens, + eos_token_id=gen_config.eos_token_id, + ) + streamer = CBStreamer(self._cb, request_id, processor._tokenizer, loop, text_queue) + + # Register a direct callback: the dispatcher calls this on the event loop with each GenerationOutput. + # This decodes tokens and pushes text straight to the SSE text_queue + def _on_output(output): + try: + streamer.put(output) + if output.is_finished(): + streamer.end() + except Exception as e: + text_queue.put_nowait(_StreamError(str(e))) + + self._cb.register_result_handler(request_id, _on_output) + return text_queue, streamer + + async def generate_non_streaming( + self, + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast", + inputs: dict, + gen_config: "GenerationConfig", + request_id: str, + ) -> tuple[str, int, list[int]]: + """Run non-streaming CB generation. Registers a handler that resolves an asyncio.Future on completion.""" + input_ids = inputs["input_ids"] + input_len = len(input_ids) + + # Register future BEFORE add_request to avoid race with fast completion + loop = asyncio.get_running_loop() + future = loop.create_future() + + def _on_result(result): + if not future.done(): + future.set_result(result) + + self._cb.register_result_handler(request_id, _on_result) + + self._cb.add_request( + input_ids, + request_id=request_id, + max_new_tokens=gen_config.max_new_tokens, + streaming=False, + eos_token_id=gen_config.eos_token_id, + ) + result = await future + if result is None: + raise RuntimeError(f"CB manager stopped before producing a result for {request_id}") + generated_ids = result.generated_tokens + text = processor.decode(generated_ids, skip_special_tokens=True) + return text, input_len, generated_ids + + @property + def scheduler(self) -> "Scheduler": + """The CB scheduler (for testing/monitoring).""" + return self._cb.batch_processor.scheduler + + def stop(self) -> None: + if self._cb is not None: + self._cb.stop(block=True, timeout=2) + + +class GenerationState: + """Shared generation state across all handlers. + + Manages per-model :class:`GenerateManager` instances (each with its own + :class:`InferenceThread` so different models can run concurrently while + ``torch.compile`` / CUDA graphs require same-model-same-thread) and a + single :class:`CBGenerateManager` for continuous batching. + + Args: + continuous_batching (`bool`, *optional*, defaults to `False`): + Whether to use continuous batching with paged attention instead of + sequential ``model.generate()`` calls. + """ + + def __init__( + self, + continuous_batching: bool = False, + compile: bool = False, + cb_config: "ContinuousBatchingConfig | None" = None, + ): + self._continuous_batching = continuous_batching + self._compile = compile + self._cb_config = cb_config + self._generate_managers: dict[str, GenerateManager] = {} + self._cb_manager: CBGenerateManager | None = None + self._cb_model_id: str | None = None + + def use_continuous_batching(self, model: "PreTrainedModel", modality: Modality) -> bool: + """Check if continuous batching can be used for this model and modality. + + Args: + model (`PreTrainedModel`): The loaded model. + modality (`Modality`): The detected model modality (LLM, VLM, etc.). + + Returns: + `bool`: ``True`` if CB is enabled and the model supports it, ``False`` otherwise. + """ + if not self._continuous_batching: + return False + can = hasattr(model, "init_continuous_batching") and modality == Modality.LLM + if not can: + logger.warning_once( + f"{model.__class__.__name__} does not support continuous batching. " + "Falling back to sequential generation." + ) + return can + + def get_manager(self, model_id: str, use_cb: bool = False) -> BaseGenerateManager: + """Return a per-model generation manager, lazily created on first request. + + Args: + model_id (`str`): The model ID in ``'model_id@revision'`` format. + use_cb (`bool`): Whether to return a CB manager or a sequential one. + + Returns: + `BaseGenerateManager`: Either a `GenerateManager` or `CBGenerateManager`. + """ + if use_cb: + if self._cb_model_id != model_id: + if self._cb_manager is not None: + self._cb_manager.stop() + self._cb_manager = None + if self._cb_manager is None: + self._cb_manager = CBGenerateManager(cb_config=self._cb_config) + self._cb_model_id = model_id + return self._cb_manager + if model_id not in self._generate_managers: + self._generate_managers[model_id] = GenerateManager() + return self._generate_managers[model_id] + + def shutdown(self) -> None: + """Stop any active generation managers.""" + if self._cb_manager is not None: + self._cb_manager.stop() + self._cb_manager = None + + +class BaseHandler: + """Shared logic for chat completion and responses handlers. + + Provides model resolution, generation config building, and SSE formatting. + Generation is delegated to the shared :class:`GenerationState`. + + Args: + model_manager (`ModelManager`): + Handles model loading, caching, and lifecycle. + generation_state (`GenerationState`): + Shared state managing per-model generation managers. + """ + + _valid_params_class: type | None = None + _unused_fields: set[str] = set() + + def __init__( + self, + model_manager: "ModelManager", + generation_state: GenerationState, + ): + self.model_manager = model_manager + self.generation_state = generation_state + + def _validate_request(self, body: dict) -> None: + """Validate request fields against the handler's params class and unused fields.""" + from fastapi import HTTPException + + input_keys = set(body.keys()) + if self._valid_params_class is not None: + unexpected = input_keys - self._valid_params_class.__mutable_keys__ + if unexpected: + raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}") + unused = input_keys & self._unused_fields + if unused: + logger.warning_once(f"Ignoring unsupported fields in the request: {unused}") + + @staticmethod + def chunk_to_sse(chunk: "str | pydantic.BaseModel") -> str: + """Format a pydantic model or string as an SSE ``data:`` line.""" + if isinstance(chunk, str): + return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n" + return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + + def _resolve_model(self, body: dict) -> tuple[str, "PreTrainedModel", "ProcessorMixin | PreTrainedTokenizerFast"]: + """Apply force_model, load model + processor. + + Returns ``(model_id, model, processor)``. + """ + if self.model_manager.force_model is not None: + body["model"] = self.model_manager.force_model + + model_id = self.model_manager.process_model_name(body["model"]) + model, processor = self.model_manager.load_model_and_processor(model_id) + + return model_id, model, processor + + def _build_generation_config( + self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False + ) -> "GenerationConfig": + """Build a GenerationConfig from shared params (temperature, top_p, seed, generation_config JSON). + + Subclasses should call ``super()._build_generation_config(...)`` then apply + endpoint-specific params (``max_tokens``, ``max_output_tokens``, etc.). + + Args: + body (`dict`): + The raw request body. + model_generation_config (`GenerationConfig`): + The model's default generation config (will be deep-copied). + use_cb (`bool`, *optional*, defaults to `False`): + Whether continuous batching is active. If ``True``, disables the model's + internal KV cache (CB manages its own paged cache). + + Returns: + `GenerationConfig`: A new config with request-specific overrides applied. + """ + from transformers import GenerationConfig + + if body.get("generation_config") is not None: + generation_config = GenerationConfig(**json.loads(body["generation_config"])) + else: + generation_config = copy.deepcopy(model_generation_config) + if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 1024: + generation_config.max_new_tokens = 1024 + + if body.get("temperature") is not None: + generation_config.temperature = float(body["temperature"]) + if float(body["temperature"]) == 0.0: + generation_config.do_sample = False + if body.get("top_p") is not None: + generation_config.top_p = float(body["top_p"]) + if body.get("seed") is not None: + set_torch_seed(body["seed"]) + + # --compile flag: use static cache + torch.compile for faster decode + if self.generation_state._compile and generation_config.cache_implementation is None: + generation_config.cache_implementation = "static" + + # CB manages its own paged KV cache + if use_cb: + generation_config.use_cache = False + + # TODO: add prefix caching for the non-CB path (reuse KV cache across multi-turn conversations) + + return generation_config + + @staticmethod + def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) -> list[dict]: + """Convert OpenAI-format messages to the format expected by HF processors. + + For LLMs, collapses list content blocks into plain text. For VLMs, converts + ``image_url`` content parts (including base64) into ``{"type": "image", "url": ...}`` + entries that HF processors understand. + + Args: + messages (`list[dict]`): OpenAI-format chat messages. + modality (`Modality`): Whether the model is an LLM or VLM. + + Returns: + `list[dict]`: Processor-compatible messages. + """ + processor_inputs = [] + + for message in messages: + parsed = {"role": message["role"], "content": []} + + if modality == Modality.LLM: + if isinstance(message["content"], str): + parsed["content"] = message["content"] + elif isinstance(message["content"], list): + texts = [c["text"] for c in message["content"] if c["type"] == "text"] + parsed["content"] = " ".join(texts) + + elif modality == Modality.VLM: + if isinstance(message["content"], str): + parsed["content"].append({"type": "text", "text": message["content"]}) + else: + for content in message["content"]: + if content["type"] == "text": + parsed["content"].append(content) + elif content["type"] == "image_url": + from PIL import Image + + url = content["image_url"]["url"] + if "base64" in url: + image_data = re.sub("^data:image/.+;base64,", "", url) + image = Image.open(BytesIO(base64.b64decode(image_data))) + file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + image.save(file.name) + url = file.name + parsed["content"].append({"type": "image", "url": url}) + + processor_inputs.append(parsed) + return processor_inputs diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index bdbf213412fe..fdc730df55c8 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -115,6 +115,7 @@ is_liger_kernel_available, is_lomo_available, is_mistral_common_available, + is_multipart_available, is_natten_available, is_nltk_available, is_numba_available, @@ -140,6 +141,7 @@ is_scipy_available, is_sentencepiece_available, is_seqio_available, + is_serve_available, is_spacy_available, is_speech_available, is_spqr_available, @@ -1414,6 +1416,13 @@ def require_librosa(test_case): return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case) +def require_multipart(test_case): + """ + Decorator marking a test that requires python-multipart + """ + return unittest.skipUnless(is_multipart_available(), "test requires python-multipart")(test_case) + + def require_liger_kernel(test_case): """ Decorator marking a test that requires liger_kernel @@ -1497,6 +1506,13 @@ def require_openai(test_case): return unittest.skipUnless(is_openai_available(), "test requires openai")(test_case) +def require_serve(test_case): + """ + Decorator marking a test that requires the serving dependencies (fastapi, uvicorn, pydantic, openai). + """ + return unittest.skipUnless(is_serve_available(), "test requires serving dependencies")(test_case) + + def require_mistral_common(test_case): """ Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 7b8bfb80ec19..3f5c7cac386b 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -164,6 +164,7 @@ is_matplotlib_available, is_mistral_common_available, is_mlx_available, + is_multipart_available, is_natten_available, is_ninja_available, is_nltk_available, @@ -198,6 +199,7 @@ is_scipy_available, is_sentencepiece_available, is_seqio_available, + is_serve_available, is_sinq_available, is_sklearn_available, is_soundfile_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 5d17f9a01771..1e1ac2545f05 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -741,6 +741,11 @@ def is_librosa_available() -> bool: return _is_package_available("librosa")[0] +@lru_cache +def is_multipart_available() -> bool: + return _is_package_available("multipart")[0] + + @lru_cache def is_essentia_available() -> bool: return _is_package_available("essentia")[0] @@ -766,6 +771,11 @@ def is_openai_available() -> bool: return _is_package_available("openai")[0] +@lru_cache +def is_serve_available() -> bool: + return is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available() + + @lru_cache def is_pretty_midi_available() -> bool: return _is_package_available("pretty_midi")[0] diff --git a/tests/cli/test_serve.py b/tests/cli/test_serve.py index 33f33cce7aab..9be3dbeb99ff 100644 --- a/tests/cli/test_serve.py +++ b/tests/cli/test_serve.py @@ -11,61 +11,80 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Tests for the serving layer. + +""" + +import asyncio import json import os -import tempfile +import socket import time import unittest -from threading import Thread -from unittest.mock import Mock, patch +from unittest.mock import MagicMock import httpx -from huggingface_hub import ChatCompletionStreamOutput, InferenceClient, hf_hub_download -from parameterized import parameterized - -from transformers import GenerationConfig -from transformers.cli.serve import Modality, Serve -from transformers.testing_utils import require_openai, slow -from transformers.utils.import_utils import ( - is_fastapi_available, - is_openai_available, - is_pydantic_available, - is_uvicorn_available, + +from transformers.cli.serve import Serve +from transformers.cli.serving.chat_completion import ChatCompletionHandler +from transformers.cli.serving.model_manager import ModelManager, TimedModel +from transformers.cli.serving.response import ResponseHandler, compute_usage +from transformers.cli.serving.server import build_server +from transformers.cli.serving.transcription import TranscriptionHandler +from transformers.cli.serving.utils import ( + BaseHandler, + GenerationState, + Modality, + ToolCallParser, + detect_tool_format, +) +from transformers.testing_utils import ( + require_librosa, + require_multipart, + require_serve, + require_torch_accelerator, + require_vision, + slow, ) +from transformers.utils.import_utils import is_serve_available -serve_dependencies_available = ( - is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available() -) +if is_serve_available(): + from fastapi import HTTPException + from openai import OpenAI + from openai.types.responses import Response, ResponseCreatedEvent + + +def _find_free_port() -> int: + """Return a free TCP port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + return s.getsockname()[1] + -if serve_dependencies_available: - from openai import APIConnectionError, OpenAI - from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction - from openai.types.responses import ( - Response, - ResponseCompletedEvent, - ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent, - ResponseCreatedEvent, - ResponseInProgressEvent, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponseTextDeltaEvent, - ResponseTextDoneEvent, - ) - - -@require_openai -def test_help(cli): - """Minimal test: we can invoke the help command.""" - output = cli("serve", "--help") - assert output.exit_code == 0 - assert "serve" in output.output - - -@require_openai +def _start_serve(**kwargs) -> tuple["Serve", int]: + """Start a non-blocking Serve instance on a free port and wait until healthy. + + Returns ``(serve, port)``. + """ + port = _find_free_port() + serve = Serve(port=port, non_blocking=True, **kwargs) + for _ in range(30): + try: + if httpx.get(f"http://localhost:{port}/health", timeout=2).status_code == 200: + return serve, port + except Exception: # noqa: S110 + pass + time.sleep(1) + raise RuntimeError(f"Server on port {port} did not become healthy in time") + + +@require_serve def test_host_port_blocking(cli): - """Minimal test: we can set arguments through the CLI - blocking""" + """CLI args --host and --port are passed to uvicorn.Config, and server.run() is called.""" + from unittest.mock import Mock, patch + with ( patch("uvicorn.Config") as ConfigMock, patch("uvicorn.Server") as ServerMock, @@ -73,912 +92,1217 @@ def test_host_port_blocking(cli): server_instance = Mock() ServerMock.return_value = server_instance - # Call the serve CLI with host/port out = cli("serve", "--host", "0.0.0.0", "--port", "9000") _, kwargs = ConfigMock.call_args assert out.exit_code == 0 assert kwargs["host"] == "0.0.0.0" assert kwargs["port"] == 9000 - ServerMock.assert_called_once_with(ConfigMock.return_value) server_instance.run.assert_called_once() -@require_openai -def test_host_port_non_blocking(cli, caplog): - """Minimal test: we can set arguments through the CLI - non-blocking""" - caplog.set_level(100000) - # ^ hack to avoid an issue happening only in CI. We don't check logs anyway so it's fine. - # Source: https://github.com/pallets/click/issues/824#issuecomment-562581313 +class TestProcessorInputsFromMessages(unittest.TestCase): + def test_llm_string_content(self): + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - with ( - patch("uvicorn.Config") as ConfigMock, - patch("uvicorn.Server") as ServerMock, - patch.object(Serve, "start_server") as start_mock, - ): - server_instance = Mock() - ServerMock.return_value = server_instance + messages = [{"role": "user", "content": "Hello"}] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(result, [{"role": "user", "content": "Hello"}]) - out = cli("serve", "--host", "0.5.0.0", "--port", "9002", "--non-blocking") - assert out.exit_code == 0 + def test_llm_list_content_text_only(self): + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - # Config got the CLI args - _, kwargs = ConfigMock.call_args - assert kwargs["host"] == "0.5.0.0" - assert kwargs["port"] == 9002 + messages = [{"role": "user", "content": [{"type": "text", "text": "A"}, {"type": "text", "text": "B"}]}] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(result, [{"role": "user", "content": "A B"}]) - # Non-blocking path uses start_server(), not server.run() - start_mock.assert_called_once() - server_instance.run.assert_not_called() + def test_vlm_string_content_wrapped(self): + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages + messages = [{"role": "user", "content": "Hello"}] + result = get_processor_inputs_from_messages(messages, Modality.VLM) + self.assertEqual(result, [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]) -@require_openai -def test_build_chat_completion_chunk(): - """ - Tests that the chunks are correctly built for the Chat Completion API. The `choices` checks implicitly - confirm that empty fields are not emitted. - """ - dummy = Serve.__new__(Serve) - - # The keys for these fields must be present in every chunk - MANDATORY_FIELDS = ["data", "id", "choices", "created", "model", "object", "system_fingerprint"] - - # Case 1: most fields are provided - chunk = dummy.build_chat_completion_chunk( - request_id="req0", content="hello", finish_reason="stop", role="user", model="dummy_model@main" - ) - chunk = dummy.chunk_to_sse_element(chunk) - for field in MANDATORY_FIELDS: - assert field in chunk - assert '"choices":[{"delta":{"content":"hello","role":"user"},"finish_reason":"stop","index":0}]' in chunk - - # Case 2: only the role is provided -- other fields in 'choices' are omitted - chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user", model="dummy_model@main") - chunk = dummy.chunk_to_sse_element(chunk) - for field in MANDATORY_FIELDS: - assert field in chunk - assert '"choices":[{"delta":{"role":"user"},"index":0}]' in chunk - - # Case 3: only the content is provided -- other fields in 'choices' are omitted - chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello", model="dummy_model@main") - chunk = dummy.chunk_to_sse_element(chunk) - for field in MANDATORY_FIELDS: - assert field in chunk - assert '"choices":[{"delta":{"content":"hello"},"index":0}]' in chunk - - # Case 4: tool calls support a list of ChoiceDeltaToolCall objects - tool_call = ChoiceDeltaToolCall( - index=0, - function=ChoiceDeltaToolCallFunction(name="foo_bar", arguments='{"foo1": "bar1", "foo2": "bar2"}'), - type="function", - ) - chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call], model="dummy_model@main") - chunk = dummy.chunk_to_sse_element(chunk) - for field in MANDATORY_FIELDS: - assert field in chunk - expected_choices_content = ( - 'choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"foo1\\": \\"bar1\\", ' - '\\"foo2\\": \\"bar2\\"}","name":"foo_bar"},"type":"function"}]},"index":0}]' - ) - assert expected_choices_content in chunk - - -def test_generative_model_list(): - with tempfile.TemporaryDirectory() as cache_dir: - # "download" a few models, including some non-generative models - hf_hub_download("Menlo/Jan-nano", "config.json", cache_dir=cache_dir) - hf_hub_download("Menlo/Jan-nano-128k", "config.json", cache_dir=cache_dir) - hf_hub_download("Qwen/Qwen2.5-0.5B-Instruct", "config.json", cache_dir=cache_dir) - hf_hub_download("HuggingFaceTB/SmolVLM-Instruct", "config.json", cache_dir=cache_dir) - hf_hub_download("google-bert/bert-base-cased", "config.json", cache_dir=cache_dir) - - expected_results = { - "HuggingFaceTB/SmolVLM-Instruct": ["HuggingFaceTB", "SmolVLM-Instruct"], - "Qwen/Qwen2.5-0.5B-Instruct": ["Qwen", "Qwen2.5-0.5B-Instruct"], - "Menlo/Jan-nano": ["Menlo", "Jan-nano"], - "Menlo/Jan-nano-128k": ["Menlo", "Jan-nano-128k"], - } + def test_vlm_text_and_image_url(self): + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - # list models - result = Serve.get_gen_models(cache_dir) - assert len(expected_results) == len(result) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/img.png"}}, + ], + } + ] + result = get_processor_inputs_from_messages(messages, Modality.VLM) + self.assertEqual(len(result[0]["content"]), 2) + self.assertEqual(result[0]["content"][0]["type"], "text") + self.assertEqual(result[0]["content"][1], {"type": "image", "url": "https://example.com/img.png"}) - local_repos = {repo["id"]: repo["owned_by"] for repo in result} + def test_llm_multi_turn_conversation(self): + """Multi-turn conversation with string content should pass through as-is.""" - for key, value in expected_results.items(): - assert key in local_repos - assert local_repos[key] == value + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages + messages = [ + {"role": "user", "content": "How are you?"}, + {"role": "assistant", "content": "I'm great!"}, + {"role": "user", "content": "Help me write tests?"}, + ] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(len(result), 3) + self.assertEqual(result[0]["content"], "How are you?") + self.assertEqual(result[1]["role"], "assistant") + self.assertEqual(result[2]["content"], "Help me write tests?") -@require_openai -def test_build_response_event(): - """ - Tests that the events are correctly built for the Response API. + def test_llm_list_content_with_type(self): + """LLM messages with typed content list should extract text and join.""" - Contrarily to the Chat Completion API, the Response API has a wide set of possible output objects. This test - only checks a few basic assumptions -- we rely on OpenAI's pydantic models to enforce the correct schema. - """ - dummy = Serve.__new__(Serve) - - response_created = ResponseCreatedEvent( - type="response.created", - sequence_number=0, - response=Response( - id="resp_0", - created_at=time.time(), - status="queued", - model="dummy_model@main", - instructions=None, # <--- is set to None = should NOT be in the output. - text={"format": {"type": "text"}}, - object="response", - tools=[], # <--- empty lists should be in the output (they are often mandatory fields) - output=[], - parallel_tool_calls=False, - tool_choice="auto", - metadata=None, - ), - ) - - event = dummy.chunk_to_sse_element(response_created) - assert event.startswith("data: ") # Sanity check: event formatting - assert '"model":"dummy_model@main"' in event # Sanity check: set field - assert '"status":"queued"' in event - assert "tools" in event # empty lists should be in the output - assert "output" in event - assert "instructions" not in event # None fields should NOT be in the output - assert "metadata" not in event - assert "error" not in event # Unset optional fields should NOT be in the output - assert "top_p" not in event - - -def retry(fn, max_attempts=5, delay=2): - """ - Retry a function up to `max_attempts` times with a `delay` between attempts. - Useful for testing functions that may fail due to server not being ready. - """ + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages - def wrapper(*args, **kwargs): - nb_attempts = 0 - while True: - nb_attempts += 1 - try: - return fn(*args, **kwargs) - except (httpx.HTTPError, APIConnectionError): - if nb_attempts >= max_attempts: - raise - time.sleep(delay) + messages = [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]} + ] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(result[0]["content"], "Hello world") - return wrapper + @require_vision + def test_vlm_base64_image_creates_temp_file(self): + """Base64 image URLs should be decoded and saved to a temp file.""" + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages -class ServeCompletionsMixin: - """ - Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API - (`generate` and `continuous_batching`). - """ + # Minimal valid 1x1 PNG as base64 + base64_url = ( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4" + "2mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + ) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url", "image_url": {"url": base64_url}}, + ], + } + ] + result = get_processor_inputs_from_messages(messages, Modality.VLM) + image_item = result[0]["content"][1] + self.assertEqual(image_item["type"], "image") + self.assertTrue(os.path.exists(image_item["url"])) # temp file was created - @retry - def run_server(self, request): - with InferenceClient(f"http://localhost:{self.port}") as client: - return list(client.chat_completion(**request)) - - @parameterized.expand( - [ - ("default_request", {}), - ("one_token", {"max_tokens": 1}), - ("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}), - ( - "tool_call", - { - "tools": [ - { - "function": { - "name": "foo_bar", - "parameters": {"type": "object"}, - "description": "Foo bar", - }, - "type": "function", - } - ] - }, - ), + def test_vlm_multi_turn(self): + """VLM multi-turn: string content should be wrapped in text type.""" + + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages + + messages = [ + {"role": "user", "content": "Describe the image"}, + {"role": "assistant", "content": "It shows a cat"}, + {"role": "user", "content": "What color?"}, ] - ) - def test_requests(self, test_name: str, request_flags: dict): - """Tests that the completions app gracefully handles GOOD requests, producing the expected output payloads.""" - - request = { - "model": "Qwen/Qwen3-0.6B", - "messages": [{"role": "user", "content": "Hello, how are you?"}], - "stream": True, # We don't support "stream": False yet - "max_tokens": 5, # Small generation by default - } - request.update(request_flags) - all_payloads = self.run_server(request) - - # If a request is successful, the returned payload needs to follow the schema, which we test here. - # NOTE: the output of our server is wrapped by `InferenceClient`, which sends fields even when they - # are empty. - - # Finish reason: the last payload should have a finish reason of "length" or "stop", all others should be empty - finish_reasons = [payload.choices[0].finish_reason for payload in all_payloads] - self.assertTrue(finish_reasons[-1] in ["length", "stop"]) - self.assertTrue(all(reason is None for reason in finish_reasons[:-1])) - - # Role: the first payload should have a role of "assistant", all others should be empty - roles = [payload.choices[0].delta.role for payload in all_payloads] - self.assertEqual(roles[0], "assistant") - self.assertTrue(all(role is None for role in roles[1:])) - - # Content: the first and the last payload shouldn't have content (role and finish reason). It may be empty - # in some other payload positions, e.g. tool calls. - contents = [payload.choices[0].delta.content for payload in all_payloads] - self.assertTrue(contents[0] is None and contents[-1] is None) - self.assertTrue(any(content is not None for content in contents[1:-1])) - # TODO: add "usage" field to output and test it - - def test_generation_config_in_request(self): - """Tests that the generation config is correctly passed into the generation call.""" - generation_config = GenerationConfig(do_sample=False, temperature=0.0) - request = { - "model": "Qwen/Qwen3-0.6B", - "messages": [{"role": "user", "content": "Hello, how are you?"}], - "stream": True, - "max_tokens": 10, - "extra_body": { - "generation_config": generation_config.to_json_string(), - }, - } - all_payloads = self.run_server(request) - contents = [payload.choices[0].delta.content for payload in all_payloads] - output_text = "".join([text for text in contents if text is not None]) - # The generation config sets greedy decoding, so the output is reproducible. By default, `Qwen/Qwen3-0.6B` - # sets `do_sample=True` - self.assertEqual(output_text, '\nOkay, the user just asked, "') + result = get_processor_inputs_from_messages(messages, Modality.VLM) + self.assertEqual(len(result), 3) + for msg in result: + self.assertIsInstance(msg["content"], list) + self.assertEqual(msg["content"][0]["type"], "text") - def test_early_return_due_to_length(self): - request = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "messages": [{"role": "user", "content": "Hello, how are you?"}], - "stream": True, - "max_tokens": 3, - } - all_payloads = self.run_server(request) - last_payload = all_payloads[-1] - self.assertTrue(last_payload.choices[0]["finish_reason"] == "length") +class TestGenerativeModelList(unittest.TestCase): + def test_lists_only_generative_models(self): + """Should list LLMs and VLMs but not non-generative models like BERT.""" + import tempfile - def test_continues_until_stop(self): - request = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "messages": [{"role": "user", "content": 'Please only answer with "Hi."'}], - "stream": True, - "max_tokens": 30, - } + from huggingface_hub import hf_hub_download - all_payloads = self.run_server(request) - last_payload = all_payloads[-1] - self.assertTrue(last_payload.choices[0]["finish_reason"] == "stop") + with tempfile.TemporaryDirectory() as cache_dir: + # Download config.json for a few models + hf_hub_download("Qwen/Qwen2.5-0.5B-Instruct", "config.json", cache_dir=cache_dir) + hf_hub_download("google-bert/bert-base-cased", "config.json", cache_dir=cache_dir) + result = ModelManager.get_gen_models(cache_dir) + model_ids = {r["id"] for r in result} -class ServeCompletionsGenerateMockTests(unittest.TestCase): - def test_processor_inputs_from_inbound_messages_llm(self): - modality = Modality.LLM - messages = expected_outputs = [ - {"role": "user", "content": "How are you doing?"}, - {"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"}, - {"role": "user", "content": "Can you help me write tests?"}, - ] - outputs = Serve.get_processor_inputs_from_inbound_messages(messages, modality) - self.assertListEqual(expected_outputs, outputs) + self.assertIn("Qwen/Qwen2.5-0.5B-Instruct", model_ids) + self.assertNotIn("google-bert/bert-base-cased", model_ids) - messages_with_type = [ - {"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]}, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"} - ], - }, - {"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]}, - ] - outputs = Serve.get_processor_inputs_from_inbound_messages(messages_with_type, modality) - self.assertListEqual(expected_outputs, outputs) - messages_multiple_text = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "How are you doing?"}, - {"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"}, - ], - }, - ] - expected_outputs_multiple_text = [ - { - "role": "user", - "content": "How are you doing? I'm doing great, thank you for asking! How can I assist you today?", - }, - ] - outputs = Serve.get_processor_inputs_from_inbound_messages(messages_multiple_text, modality) - self.assertListEqual(expected_outputs_multiple_text, outputs) +@require_serve +class TestBuildGenerationConfig(unittest.TestCase): + def _make_handler(self): + return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) - def test_processor_inputs_from_inbound_messages_vlm_text_only(self): - modality = Modality.VLM - messages = [ - {"role": "user", "content": "How are you doing?"}, - {"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"}, - {"role": "user", "content": "Can you help me write tests?"}, - ] + def test_max_tokens(self): + from transformers import GenerationConfig - expected_outputs = [ - {"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]}, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"} - ], - }, - {"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]}, - ] + result = self._make_handler()._build_generation_config({"max_tokens": 7}, GenerationConfig()) + self.assertEqual(result.max_new_tokens, 7) - outputs = Serve.get_processor_inputs_from_inbound_messages(messages, modality) - self.assertListEqual(expected_outputs, outputs) + def test_temperature_zero_disables_sampling(self): + from transformers import GenerationConfig - def test_processor_inputs_from_inbound_messages_vlm_text_and_image_in_base_64(self): - modality = Modality.VLM - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "How many pixels are in the image?"}, - { - "type": "image_url", - "image_url": { - "url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAASABIAAD/4QBARXhpZgAATU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAABaADAAQAAAABAAAABQAAAAD/7QA4UGhvdG9zaG9wIDMuMAA4QklNBAQAAAAAAAA4QklNBCUAAAAAABDUHYzZjwCyBOmACZjs+EJ+/8AAEQgABQAFAwEiAAIRAQMRAf/EAB8AAAEFAQEBAQEBAAAAAAAAAAABAgMEBQYHCAkKC//EALUQAAIBAwMCBAMFBQQEAAABfQECAwAEEQUSITFBBhNRYQcicRQygZGhCCNCscEVUtHwJDNicoIJChYXGBkaJSYnKCkqNDU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6g4SFhoeIiYqSk5SVlpeYmZqio6Slpqeoqaqys7S1tre4ubrCw8TFxsfIycrS09TV1tfY2drh4uPk5ebn6Onq8fLz9PX29/j5+v/EAB8BAAMBAQEBAQEBAQEAAAAAAAABAgMEBQYHCAkKC//EALURAAIBAgQEAwQHBQQEAAECdwABAgMRBAUhMQYSQVEHYXETIjKBCBRCkaGxwQkjM1LwFWJy0QoWJDThJfEXGBkaJicoKSo1Njc4OTpDREVGR0hJSlNUVVZXWFlaY2RlZmdoaWpzdHV2d3h5eoKDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uLj5OXm5+jp6vLz9PX29/j5+v/bAEMAAQEBAQEBAgEBAgICAgICAwICAgIDBAMDAwMDBAUEBAQEBAQFBQUFBQUFBQYGBgYGBgcHBwcHCAgICAgICAgICP/bAEMBAQEBAgICAwICAwgFBAUICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICP/dAAQAAf/aAAwDAQACEQMRAD8A/v4ooooA/9k=" - }, - }, - ], - }, - { - "role": "assistant", - "content": "The number of pixels in the image cannot be determined from the provided information.", - }, - {"role": "user", "content": "Alright"}, - ] + result = self._make_handler()._build_generation_config({"temperature": 0.0}, GenerationConfig(do_sample=True)) + self.assertFalse(result.do_sample) - expected_outputs = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "How many pixels are in the image?"}, - {"type": "image", "url": "/var/folders/4v/64sxdhsd3gz3r8vhhnyc0mqw0000gn/T/tmp50oyghk6.png"}, - ], - }, - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": "The number of pixels in the image cannot be determined from the provided information.", - } - ], - }, - {"role": "user", "content": [{"type": "text", "text": "Alright"}]}, - ] + def test_frequency_penalty(self): + from transformers import GenerationConfig - outputs = Serve.get_processor_inputs_from_inbound_messages(messages, modality) + result = self._make_handler()._build_generation_config({"frequency_penalty": 0.5}, GenerationConfig()) + self.assertAlmostEqual(result.repetition_penalty, 1.5) - for expected_output, output in zip(expected_outputs, outputs): - expected_output_content = expected_output["content"] - output_content = output["content"] + def test_logit_bias_tuple_keys(self): + from transformers import GenerationConfig - self.assertEqual(type(expected_output_content), type(output_content)) + result = self._make_handler()._build_generation_config({"logit_bias": {"42": 1.0}}, GenerationConfig()) + self.assertEqual(result.sequence_bias, {(42,): 1.0}) - if isinstance(expected_output_content, list): - for expected_output_content_item, output_content_item in zip(expected_output_content, output_content): - self.assertIn("type", expected_output_content_item) - self.assertIn("type", output_content_item) - self.assertTrue(expected_output_content_item["type"] == output_content_item["type"]) + def test_stop_strings(self): + from transformers import GenerationConfig - if expected_output_content_item["type"] == "text": - self.assertEqual(expected_output_content_item["text"], output_content_item["text"]) + result = self._make_handler()._build_generation_config({"stop": [""]}, GenerationConfig()) + self.assertEqual(result.stop_strings, [""]) - if expected_output_content_item["type"] == "image": - self.assertTrue(os.path.exists(output_content_item["url"])) - else: - raise ValueError("VLMs should only receive content as lists.") + def test_generation_config_json_overrides(self): + from transformers import GenerationConfig + custom = GenerationConfig(max_new_tokens=5, do_sample=False) + result = self._make_handler()._build_generation_config( + {"generation_config": custom.to_json_string()}, GenerationConfig(max_new_tokens=100) + ) + self.assertEqual(result.max_new_tokens, 5) + self.assertFalse(result.do_sample) -@slow # server startup time is slow on our push CI -@require_openai -class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase): - """Tests the `generate` version of the Completions API.""" + def test_generation_config_json_no_defaults_applied(self): + """When generation_config JSON is passed, serving defaults should NOT be applied.""" + from transformers import GenerationConfig - @classmethod - def setUpClass(cls): - """Starts a server for tests to connect to.""" - cls.port = 8001 - cls.server = Serve(port=cls.port, non_blocking=True) + custom = GenerationConfig(max_new_tokens=10) + result = self._make_handler()._build_generation_config( + {"generation_config": custom.to_json_string()}, GenerationConfig() + ) + # Should keep 10, not bump to 1024 + self.assertEqual(result.max_new_tokens, 10) - @classmethod - def tearDownClass(cls): - cls.server.kill_server() + def test_default_bumps_short_max_new_tokens(self): + from transformers import GenerationConfig - @slow - def test_tool_call(self): - """Tests that the tool call is correctly handled and that the payloads are correctly structured.""" - # TODO: move to the mixin when CB also supports tool calls - - request = { - # This model is a small model that's very eager to call tools - # TODO: this is a 4B model. Find a smaller model that's eager to call tools - "model": "Menlo/Jan-nano", - # The request should produce a tool call - "messages": [{"role": "user", "content": "Generate an image of a cat."}], - "stream": True, - "max_tokens": 50, - # Reproducibility - "temperature": 0.0, - # This tool is a copy from the tool in the original tiny-agents demo - "tools": [ - { - "function": { - "name": "flux1_schnell_infer", - "parameters": { - "type": "object", - "properties": { - "prompt": {"type": "string"}, - "seed": {"type": "number", "description": "numeric value between 0 and 2147483647"}, - "randomize_seed": {"type": "boolean", "default": True}, - "width": { - "type": "number", - "description": "numeric value between 256 and 2048", - "default": 1024, - }, - "height": { - "type": "number", - "description": "numeric value between 256 and 2048", - "default": 1024, - }, - "num_inference_steps": { - "type": "number", - "description": "numeric value between 1 and 16", - "default": 4, - }, - }, - }, - "description": "Generate an image using the Flux 1 Schnell Image Generator.", - }, - "type": "function", - } - ], - } - all_payloads = self.run_server(request) - - # The first payload should contain the role - roles = [payload.choices[0].delta.role for payload in all_payloads] - self.assertEqual(roles[0], "assistant") - self.assertTrue(all(role is None for role in roles[1:])) - - # All other payloads (except the last one) should be tool call related, for this specific request - contents = [payload.choices[0].delta.content for payload in all_payloads] - self.assertTrue(all(content is None for content in contents)) - - # The first tool call delta should contain the tool name. The other tool call deltas should contain the tool - # arguments. - tool_calls = [payload.choices[0].delta.tool_calls[0] for payload in all_payloads[1:-1]] - first_tool_call = tool_calls[0] - self.assertEqual(first_tool_call["function"]["name"], "flux1_schnell_infer") - self.assertEqual(first_tool_call["function"]["arguments"], None) - other_tool_calls = tool_calls[1:] - self.assertTrue(all(tool_call["function"]["name"] is None for tool_call in other_tool_calls)) - self.assertTrue(all(tool_call["function"]["arguments"] is not None for tool_call in other_tool_calls)) - - # Finally, the last payload should contain a finish reason - finish_reasons = [payload.choices[0].finish_reason for payload in all_payloads] - # TODO: I think the finish reason for a tool call is different? double check this - self.assertTrue(finish_reasons[-1] in ["stop", "length"]) - self.assertTrue(all(reason is None for reason in finish_reasons[:-1])) - - -def _get_scheduler(serve_command): - # Defensive navigation in case any layer is renamed in the future - cbm = getattr(serve_command, "running_continuous_batching_manager", None) - assert cbm is not None, "ServeCommand has no running_continuous_batching_manager" - bp = getattr(cbm, "batch_processor", None) - assert bp is not None, "running_continuous_batching_manager has no batch_processor" - sched = getattr(bp, "scheduler", None) - assert sched is not None, "batch_processor has no scheduler" - return sched - - -def _call_healthcheck(base_url: str): - response = None - retries = 10 - while retries > 0: - try: - response = httpx.get(f"{base_url}/health") - break - except httpx.NetworkError: - time.sleep(0.1) - retries -= 1 - return response + result = self._make_handler()._build_generation_config({}, GenerationConfig(max_new_tokens=20)) + self.assertEqual(result.max_new_tokens, 1024) + def test_user_max_tokens_overrides_default(self): + """User's max_tokens should win over the serving default.""" + from transformers import GenerationConfig -def _open_stream_and_cancel(base_url: str, request_id: str): - with httpx.Client() as s: - with s.stream( - "POST", - f"{base_url}/v1/chat/completions", - headers={"X-Request-ID": request_id}, - json={ - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "stream": True, - "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}], - }, - timeout=30, - ) as resp: - assert resp.status_code == 200 + result = self._make_handler()._build_generation_config({"max_tokens": 50}, GenerationConfig(max_new_tokens=20)) + self.assertEqual(result.max_new_tokens, 50) + + +@require_serve +class TestValidation(unittest.TestCase): + def _make_handler(self): + return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) + + def test_valid_request_passes(self): + handler = self._make_handler() + # Should not raise + handler._validate_request({"model": "x", "messages": [{"role": "user", "content": "hi"}], "stream": True}) + + def test_unexpected_keys_rejected(self): + handler = self._make_handler() + with self.assertRaises(HTTPException) as ctx: + handler._validate_request({"model": "x", "messages": [], "bogus_field": True}) + self.assertEqual(ctx.exception.status_code, 422) + self.assertIn("bogus_field", ctx.exception.detail) + + def test_unsupported_fields_warns(self): + handler = self._make_handler() + with self.assertLogs("transformers", level="WARNING") as cm: + handler._validate_request({"model": "x", "messages": [], "audio": {}}) + self.assertTrue(any("audio" in msg for msg in cm.output)) + + +class TestModelManager(unittest.TestCase): + def test_process_model_name_adds_main(self): + self.assertEqual(ModelManager.process_model_name("org/model"), "org/model@main") + + def test_process_model_name_preserves_revision(self): + self.assertEqual(ModelManager.process_model_name("org/model@dev"), "org/model@dev") + + def test_quantization_config_4bit(self): + mm = ModelManager(quantization="bnb-4bit") + cfg = mm.get_quantization_config() + self.assertTrue(cfg.load_in_4bit) + + def test_quantization_config_8bit(self): + mm = ModelManager(quantization="bnb-8bit") + cfg = mm.get_quantization_config() + self.assertTrue(cfg.load_in_8bit) + + def test_quantization_config_none(self): + mm = ModelManager() + self.assertIsNone(mm.get_quantization_config()) - wait_for_n_chunks = 3 - for i, _ in enumerate(resp.iter_bytes(chunk_size=None)): - if i >= wait_for_n_chunks: - resp.close() - break +class TestTimedModel(unittest.TestCase): + def test_delete_model(self): + mock_model = MagicMock() + deleted = [] + timed = TimedModel( + mock_model, timeout_seconds=9999, processor=MagicMock(), on_unload=lambda: deleted.append(True) + ) + self.assertIsNotNone(timed.model) + timed.delete_model() + self.assertIsNone(timed.model) + self.assertEqual(len(deleted), 1) + + def test_timeout_zero_no_delete(self): + mock_model = MagicMock() + timed = TimedModel(mock_model, timeout_seconds=0, processor=MagicMock()) + timed._timeout_reached() + self.assertIsNotNone(timed.model) + timed._timer.cancel() + + +@require_serve +class TestChunkSSE(unittest.TestCase): + def _make_handler(self): + return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState()) + + def test_build_chunk_sse_content(self): + handler = self._make_handler() + sse = handler._build_chunk_sse(request_id="req1", content="hi", model="m") + self.assertTrue(sse.startswith("data: ")) + self.assertTrue(sse.endswith("\n\n")) + parsed = json.loads(sse[len("data: ") :].strip()) + self.assertEqual(parsed["choices"][0]["delta"]["content"], "hi") + + def test_build_chunk_sse_role(self): + handler = self._make_handler() + sse = handler._build_chunk_sse(request_id="req1", role="assistant", model="m") + parsed = json.loads(sse[len("data: ") :].strip()) + self.assertEqual(parsed["choices"][0]["delta"]["role"], "assistant") + self.assertNotIn("content", parsed["choices"][0]["delta"]) + + def test_build_chunk_sse_finish_reason(self): + handler = self._make_handler() + sse = handler._build_chunk_sse(request_id="req1", finish_reason="stop", model="m") + parsed = json.loads(sse[len("data: ") :].strip()) + self.assertEqual(parsed["choices"][0]["finish_reason"], "stop") + + def test_chunk_to_sse_string_passthrough(self): + result = BaseHandler.chunk_to_sse("data: already formatted\n\n") + self.assertEqual(result, "data: already formatted\n\n") + + def test_chunk_to_sse_wraps_plain_string(self): + result = BaseHandler.chunk_to_sse("hello") + self.assertEqual(result, "data: hello\n\n") + + +QWEN_TOOL_FORMAT = {"start": "", "end": ""} + + +@require_serve +class TestToolParser(unittest.TestCase): + def test_detect_tool_format_qwen(self): + model = MagicMock() + model.config.architectures = ["Qwen2ForCausalLM"] + fmt = detect_tool_format(model) + self.assertEqual(fmt, QWEN_TOOL_FORMAT) + + def test_detect_tool_format_unsupported(self): + model = MagicMock() + model.config.architectures = ["LlamaForCausalLM"] + self.assertIsNone(detect_tool_format(model)) + + def test_parser_start_token(self): + parser = ToolCallParser(QWEN_TOOL_FORMAT) + result = parser.feed("") + self.assertIs(result, ToolCallParser.CONSUMED) + + def test_parser_end_token(self): + parser = ToolCallParser(QWEN_TOOL_FORMAT) + parser.feed("") + result = parser.feed("") + self.assertIs(result, ToolCallParser.CONSUMED) + + def test_parser_buffers_until_end(self): + parser = ToolCallParser(QWEN_TOOL_FORMAT) + parser.feed("") + # Intermediate tokens are buffered + result = parser.feed('{"name": "my_tool", "arguments": {"x": 1}}') + self.assertIs(result, ToolCallParser.CONSUMED) + # Tool call is emitted on end token + result = parser.feed("") + self.assertIsNot(result, ToolCallParser.CONSUMED) + self.assertEqual(result["name"], "my_tool") + + def test_parser_normal_text_returns_none(self): + parser = ToolCallParser(QWEN_TOOL_FORMAT) + result = parser.feed("Hello world") + self.assertIsNone(result) + + def test_parser_full_flow(self): + """Simulate a complete tool call token sequence.""" + + parser = ToolCallParser(QWEN_TOOL_FORMAT) + tool_calls = [] + + for token in [ + "", + '{"name": "get_weather",', + ' "arguments": {', + '"city": "Paris"', + "}}", + "\n", + "", + ]: + result = parser.feed(token) + if result is not None and result is not ToolCallParser.CONSUMED: + tool_calls.append(result) + + # Single tool call emitted on with both name and arguments + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertIn("Paris", tool_calls[0]["arguments"]) + + def test_parse_tool_calls_from_text(self): + """Non-streaming tool call parsing from complete text.""" + + text = '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n' + calls = ToolCallParser.parse(text, QWEN_TOOL_FORMAT) + self.assertIsNotNone(calls) + self.assertEqual(len(calls), 1) + self.assertEqual(calls[0]["name"], "get_weather") + self.assertIn("Paris", calls[0]["arguments"]) + + def test_parse_tool_calls_no_tool_call(self): + """Non-streaming: normal text returns None.""" + + calls = ToolCallParser.parse("Hello, how can I help?", QWEN_TOOL_FORMAT) + self.assertIsNone(calls) + + def test_parse_multiple_tool_calls(self): + """Non-streaming: multiple tool calls in one response.""" + + text = ( + '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n\n' + '\n{"name": "get_weather", "arguments": {"city": "London"}}\n' + ) + calls = ToolCallParser.parse(text, QWEN_TOOL_FORMAT) + self.assertIsNotNone(calls) + self.assertEqual(len(calls), 2) + self.assertEqual(calls[0]["name"], "get_weather") + self.assertIn("Paris", calls[0]["arguments"]) + self.assertEqual(calls[1]["name"], "get_weather") + self.assertIn("London", calls[1]["arguments"]) + + def test_feed_multiple_tool_calls(self): + """Streaming: multiple tool calls emitted sequentially.""" + + parser = ToolCallParser(QWEN_TOOL_FORMAT) + tool_calls = [] + + tokens = [ + "", + '{"name": "get_weather", "arguments": {"city": "Paris"}}', + "", + "", + '{"name": "get_weather", "arguments": {"city": "London"}}', + "", + ] + for token in tokens: + result = parser.feed(token) + if result is not None and result is not ToolCallParser.CONSUMED: + tool_calls.append(result) + + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertIn("Paris", tool_calls[0]["arguments"]) + self.assertEqual(tool_calls[1]["name"], "get_weather") + self.assertIn("London", tool_calls[1]["arguments"]) -@slow # server startup time is slow on our push CI -@require_openai -class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, unittest.TestCase): - """Tests the `continuous_batching` version of the Completions API.""" +@require_serve +class TestAppRoutes(unittest.TestCase): @classmethod def setUpClass(cls): - """Starts a server for tests to connect to.""" - cls.port = 8002 - cls.server = Serve( - port=cls.port, continuous_batching=True, attn_implementation="sdpa", default_seed=42, non_blocking=True + cls.model_manager = MagicMock(spec=ModelManager) + cls.model_manager.get_gen_models.return_value = [ + {"id": "test/model", "owned_by": "test", "object": "model", "created": 0} + ] + cls.chat_handler = MagicMock(spec=ChatCompletionHandler) + cls.response_handler = MagicMock(spec=ResponseHandler) + cls.transcription_handler = MagicMock(spec=TranscriptionHandler) + cls.app = build_server(cls.model_manager, cls.chat_handler, cls.response_handler, cls.transcription_handler) + cls.transport = httpx.ASGITransport(app=cls.app) + + async def _request(self, method: str, path: str, **kwargs) -> httpx.Response: + async with httpx.AsyncClient(transport=self.transport, base_url="http://test") as c: + return await c.request(method, path, **kwargs) + + def test_health(self): + resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/health")) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json(), {"status": "ok"}) + + def test_models_list(self): + resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/v1/models")) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["object"], "list") + self.assertEqual(len(data["data"]), 1) + + def test_request_id_generated(self): + resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/health")) + self.assertIn("x-request-id", resp.headers) + self.assertEqual(len(resp.headers["x-request-id"]), 36) # UUID length + + def test_request_id_passthrough(self): + resp = asyncio.get_event_loop().run_until_complete( + self._request("GET", "/health", headers={"x-request-id": "my-id"}) ) + self.assertEqual(resp.headers["x-request-id"], "my-id") + + +@slow +@require_serve +class TestChatCompletion(unittest.TestCase): + """Integration tests for /v1/chat/completions with a real model.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + @classmethod + def setUpClass(cls): + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") @classmethod def tearDownClass(cls): - cls.server.kill_server() + cls.serve.kill_server() - def test_full_request(self): - """Tests that an inference using the Responses API and Continuous Batching works""" + def test_non_streaming(self): + resp = self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}] + ) + self.assertIsNotNone(resp.choices[0].message.content) + self.assertIn(resp.choices[0].finish_reason, ("stop", "length")) + + def test_streaming(self): + text = "" + for chunk in self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}], stream=True + ): + if chunk.choices[0].delta.content: + text += chunk.choices[0].delta.content + self.assertTrue(len(text) > 0) - request = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "messages": [ - {"role": "system", "content": "You are a sports assistant designed to craft sports programs."}, - {"role": "user", "content": "Tell me what you can do."}, + def test_early_return_due_to_length(self): + """When max_tokens is hit, finish_reason should be 'length'.""" + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Hello, how are you?"}], + stream=True, + max_tokens=3, + ) + ) + last = chunks[-1] + self.assertEqual(last.choices[0].finish_reason, "length") + + def test_continues_until_stop(self): + """When model stops naturally, finish_reason should be 'stop'.""" + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": 'Please only answer with "Hi."'}], + stream=True, + max_tokens=30, + ) + ) + last = chunks[-1] + self.assertEqual(last.choices[0].finish_reason, "stop") + + def test_stop_strings(self): + resp = self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Count to 10"}], stop=["5"] + ) + self.assertNotIn("6", resp.choices[0].message.content) + + def test_multi_turn(self): + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[ + {"role": "user", "content": "My name is Alice"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "What is my name?"}, ], - "stream": True, - "max_tokens": 30, - } - all_payloads = self.run_server(request) + ) + self.assertIn("Alice", resp.choices[0].message.content) - full_text = "" - for token in all_payloads: - if isinstance(token, ChatCompletionStreamOutput) and token.choices and len(token.choices) > 0: - content = token.choices[0].delta.get("content", "") - full_text += content if content is not None else "" + def test_multiple_models_on_demand(self): + """Load two different models via separate requests — both should work.""" + model_a = "Qwen/Qwen2.5-0.5B-Instruct" + model_b = "HuggingFaceTB/SmolLM2-135M-Instruct" + prompt = [{"role": "user", "content": "Say hello"}] - # Verify that the system prompt went through. - self.assertTrue( - full_text.startswith( - "I can assist you with a wide range of tasks, from answering questions to providing information on various sports topics." + resp_a = self.client.chat.completions.create(model=model_a, messages=prompt) + self.assertIn(model_a, resp_a.model) + self.assertIsNotNone(resp_a.choices[0].message.content) + + resp_b = self.client.chat.completions.create(model=model_b, messages=prompt) + self.assertIn(model_b, resp_b.model) + self.assertIsNotNone(resp_b.choices[0].message.content) + + def test_non_streaming_usage(self): + resp = self.client.chat.completions.create( + model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}] + ) + self.assertIsNotNone(resp.usage) + self.assertGreater(resp.usage.prompt_tokens, 0) + self.assertGreater(resp.usage.completion_tokens, 0) + self.assertEqual(resp.usage.total_tokens, resp.usage.prompt_tokens + resp.usage.completion_tokens) + + def test_streaming_usage(self): + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hello"}], + stream=True, ) ) + # Last chunk should have usage + last = chunks[-1] + self.assertIsNotNone(last.usage) + self.assertGreater(last.usage.prompt_tokens, 0) + self.assertGreater(last.usage.completion_tokens, 0) + self.assertEqual(last.usage.total_tokens, last.usage.prompt_tokens + last.usage.completion_tokens) - def test_max_tokens_not_set_in_req(self): - request = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "messages": [ - {"role": "system", "content": "You are a sports assistant designed to craft sports programs."}, - {"role": "user", "content": "Tell me what you can do."}, - ], - "stream": True, + def test_tool_call(self): + """Tool calls should be parsed and emitted as ChoiceDeltaToolCall objects.""" + # Qwen2.5-0.5B-Instruct supports tools (Qwen family) + tool_def = { + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + "description": "Get the weather for a city.", + }, + "type": "function", } - all_payloads = self.run_server(request) + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "What is the weather in Paris?"}], + stream=True, + max_tokens=50, + temperature=0.0, + tools=[tool_def], + ) + ) - full_text = "" - for token in all_payloads: - if isinstance(token, ChatCompletionStreamOutput) and token.choices and len(token.choices) > 0: - content = token.choices[0].delta.get("content", "") - full_text += content if content is not None else "" + # First chunk should have role="assistant" + self.assertEqual(chunks[0].choices[0].delta.role, "assistant") - # Verify that the system prompt went through. - self.assertTrue( - full_text.startswith( - "I can assist you with a wide range of tasks, from answering questions to providing information on various sports topics." + # Model should make a tool call for this prompt + tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls] + self.assertGreater(len(tool_chunks), 0, "Model did not produce a tool call") + + # First tool call delta should have the function name + first_tool = tool_chunks[0].choices[0].delta.tool_calls[0] + self.assertEqual(first_tool.function.name, "get_weather") + + # finish_reason should be "tool_calls" + last = chunks[-1] + self.assertEqual(last.choices[0].finish_reason, "tool_calls") + + # Arguments should be valid JSON with no trailing brace + args_json = first_tool.function.arguments + import json as json_mod + + parsed_args = json_mod.loads(args_json) + self.assertIsInstance(parsed_args, dict) + + def test_tool_call_non_streaming(self): + """Non-streaming tool calls should return tool_calls in the message.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "What is the weather in Paris?"}], + stream=False, + max_tokens=50, + temperature=0.0, + tools=[tool_def], + ) + self.assertEqual(resp.choices[0].finish_reason, "tool_calls") + self.assertIsNotNone(resp.choices[0].message.tool_calls) + tc = resp.choices[0].message.tool_calls[0] + self.assertEqual(tc.function.name, "get_weather") + + import json as json_mod + + parsed_args = json_mod.loads(tc.function.arguments) + self.assertIsInstance(parsed_args, dict) + + def test_tool_call_multi(self): + """Model should be able to call multiple tools when asked.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + # Ask for two cities to encourage multiple tool calls + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "What is the weather in Paris and London?"}], + stream=True, + max_tokens=100, + temperature=0.0, + tools=[tool_def], ) ) + tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls] + # Should have two tool calls — one for Paris, one for London + self.assertEqual(len(tool_chunks), 2, f"Expected 2 tool calls, got {len(tool_chunks)}") + cities = {tc.choices[0].delta.tool_calls[0].function.name for tc in tool_chunks} + self.assertEqual(cities, {"get_weather"}) + last = chunks[-1] + self.assertEqual(last.choices[0].finish_reason, "tool_calls") + + def test_concurrent_non_streaming(self): + """Two concurrent non-streaming requests should both complete without interference.""" + import concurrent.futures - def test_request_cancellation(self): - """Tests that a request can be cancelled.""" + prompts = [ + [{"role": "user", "content": "Say hello"}], + [{"role": "user", "content": "Say goodbye"}], + ] + results = [None, None] - base_url = f"http://127.0.0.1:{self.port}" - request_id = "test-cancel" + def request_in_thread(index): + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + results[index] = client.chat.completions.create(model=self.MODEL, messages=prompts[index]) - # Ensure the server is up before sending a request - response = _call_healthcheck(base_url) - self.assertIsNotNone(response, "Failed to connect to the server health endpoint.") - self.assertEqual(response.status_code, 200) + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(request_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() # re-raise exceptions - _open_stream_and_cancel(base_url, request_id) + for i in range(2): + self.assertIsNotNone(results[i]) + self.assertIsNotNone(results[i].choices[0].message.content) + self.assertTrue(len(results[i].choices[0].message.content) > 0) - scheduler = _get_scheduler(self.server) + def test_concurrent_streaming(self): + """Two concurrent streaming requests should both produce complete, non-empty output.""" + import concurrent.futures - # Because cancellation is non-blocking, poll for a short, bounded time. - deadline = time.time() + 8.0 # generous but still CI-friendly - last_seen = None - while time.time() < deadline: - is_cancelled = scheduler.request_is_cancelled(request_id) - if is_cancelled: - break - last_seen = time.time() - time.sleep(0.1) # don't spin the CPU + prompts = [ + [{"role": "user", "content": "Say hello"}], + [{"role": "user", "content": "Say goodbye"}], + ] + results = [None, None] - is_cancelled = scheduler.request_is_cancelled(request_id) - self.assertTrue( - is_cancelled, - f"Request {request_id} still present in scheduler after cancellation " - f"(last seen at {last_seen}). Check cancellation propagation.", - ) + def stream_in_thread(index): + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + text = "" + for chunk in client.chat.completions.create(model=self.MODEL, messages=prompts[index], stream=True): + if chunk.choices[0].delta.content: + text += chunk.choices[0].delta.content + results[index] = text - def test_non_streaming_response_json_format(self): - """ - Tests that non-streaming continuous batching responses return proper JSON objects, - not double-encoded JSON strings (regression test for JSON serialization fix). - """ - client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="") - - # Make a non-streaming request - response = client.chat.completions.create( - model="Qwen/Qwen2.5-0.5B-Instruct", - messages=[{"role": "user", "content": "Say hello"}], - stream=False, - max_tokens=5, + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(stream_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + self.assertIsNotNone(results[i]) + self.assertTrue(len(results[i]) > 0, f"Request {i} produced empty output") + + def test_request_cancellation(self): + """Closing a stream early doesn't crash and the server stays healthy.""" + + with httpx.stream( + "POST", + f"{self.base_url}/v1/chat/completions", + json={ + "model": self.MODEL, + "stream": True, + "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}], + "max_tokens": 500, + }, + timeout=30, + ) as resp: + self.assertEqual(resp.status_code, 200) + chunks_read = 0 + for _ in resp.iter_lines(): + chunks_read += 1 + if chunks_read >= 3: + break + + # Server should still be healthy and serve subsequent requests + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hi"}], + max_tokens=10, + ) + self.assertIsNotNone(resp.choices[0].message.content) + + +@require_serve +class TestResponseInputConversion(unittest.TestCase): + def _make_handler(self): + return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) + + def test_string_input(self): + handler = self._make_handler() + msgs = handler._input_to_messages({"input": "Hello"}) + self.assertEqual(msgs, [{"role": "user", "content": "Hello"}]) + + def test_string_input_with_instructions(self): + handler = self._make_handler() + msgs = handler._input_to_messages({"input": "Hello", "instructions": "Be brief"}) + self.assertEqual(len(msgs), 2) + self.assertEqual(msgs[0], {"role": "system", "content": "Be brief"}) + self.assertEqual(msgs[1], {"role": "user", "content": "Hello"}) + + def test_list_input(self): + handler = self._make_handler() + msgs = handler._input_to_messages( + {"input": [{"role": "user", "content": "A"}, {"role": "assistant", "content": "B"}]} + ) + self.assertEqual(len(msgs), 2) + self.assertEqual(msgs[0]["content"], "A") + + def test_list_input_with_instructions_prepends_system(self): + handler = self._make_handler() + msgs = handler._input_to_messages({"input": [{"role": "user", "content": "Hi"}], "instructions": "Be helpful"}) + self.assertEqual(len(msgs), 2) + self.assertEqual(msgs[0]["role"], "system") + self.assertEqual(msgs[0]["content"], "Be helpful") + + def test_list_input_with_instructions_replaces_existing_system(self): + handler = self._make_handler() + msgs = handler._input_to_messages( + {"input": [{"role": "system", "content": "Old"}, {"role": "user", "content": "Hi"}], "instructions": "New"} ) + self.assertEqual(len(msgs), 2) + self.assertEqual(msgs[0]["content"], "New") - # Verify response is a proper ChatCompletion object (not a string) - self.assertIsNotNone(response) - self.assertIsNotNone(response.id) - self.assertIsNotNone(response.choices) - self.assertEqual(len(response.choices), 1) + def test_dict_input(self): + handler = self._make_handler() + msgs = handler._input_to_messages({"input": {"role": "user", "content": "Test"}}) + self.assertEqual(msgs, [{"role": "user", "content": "Test"}]) - # Verify the choice has proper structure - choice = response.choices[0] - self.assertIsNotNone(choice.message) - self.assertIsNotNone(choice.message.content) - self.assertEqual(choice.message.role, "assistant") - # Verify content is a string, not a serialized JSON - content = choice.message.content - self.assertIsInstance(content, str) - self.assertTrue(len(content) > 0) +@require_serve +class TestResponseValidation(unittest.TestCase): + def _make_handler(self): + return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) + def test_unsupported_fields_warns(self): + handler = self._make_handler() + with self.assertLogs("transformers", level="WARNING") as cm: + handler._validate_request({"model": "x", "input": "hi", "previous_response_id": "abc"}) + self.assertTrue(any("previous_response_id" in msg for msg in cm.output)) -@require_openai -class ServeResponsesMixin: - """ - Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API - (`generate` and `continuous_batching`). - """ + def test_valid_request_passes(self): + handler = self._make_handler() + # Should not raise + handler._validate_request({"model": "x", "input": "hi"}) - @retry - def run_server(self, request): - client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="") - stream = client.responses.create(**request) - all_payloads = [] - for payload in stream: - all_payloads.append(payload) +@require_serve +class TestResponseGenerationConfig(unittest.TestCase): + def _make_handler(self): + return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState()) - return all_payloads + def test_max_output_tokens(self): + from transformers import GenerationConfig - def test_request(self): - """Tests that an inference using the Responses API works""" + result = self._make_handler()._build_generation_config({"max_output_tokens": 42}, GenerationConfig()) + self.assertEqual(result.max_new_tokens, 42) - request = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "instructions": "You are a helpful assistant.", - "input": "Hello!", - "stream": True, - "max_output_tokens": 1, - } - all_payloads = self.run_server(request) + def test_default_bumps_short_max_new_tokens(self): + from transformers import GenerationConfig - # Allow variable number of delta events depending on tokenizer/streamer behavior - self.assertGreaterEqual(len(all_payloads), 8) + result = self._make_handler()._build_generation_config({}, GenerationConfig(max_new_tokens=20)) + self.assertEqual(result.max_new_tokens, 1024) - # Start markers - self.assertIsInstance(all_payloads[0], ResponseCreatedEvent) - self.assertIsInstance(all_payloads[1], ResponseInProgressEvent) - self.assertIsInstance(all_payloads[2], ResponseOutputItemAddedEvent) - self.assertIsInstance(all_payloads[3], ResponseContentPartAddedEvent) - # At least one delta event during streaming - self.assertTrue(any(isinstance(p, ResponseTextDeltaEvent) for p in all_payloads[4:-4])) +@require_serve +class TestResponseUsage(unittest.TestCase): + def testcompute_usage(self): + usage = compute_usage(input_tokens=100, output_tokens=50) + self.assertEqual(usage.input_tokens, 100) + self.assertEqual(usage.output_tokens, 50) + self.assertEqual(usage.total_tokens, 150) + self.assertEqual(usage.input_tokens_details.cached_tokens, 0) + self.assertEqual(usage.output_tokens_details.reasoning_tokens, 0) - # Closing markers - self.assertIsInstance(all_payloads[-4], ResponseTextDoneEvent) - self.assertIsInstance(all_payloads[-3], ResponseContentPartDoneEvent) - self.assertIsInstance(all_payloads[-2], ResponseOutputItemDoneEvent) - self.assertIsInstance(all_payloads[-1], ResponseCompletedEvent) + def test_usage_in_completed_response(self): + """Usage should serialize correctly inside a Response.""" + + usage = compute_usage(10, 5) + response = Response( + id="resp_test", + created_at=0, + status="completed", + model="test", + output=[], + object="response", + tools=[], + parallel_tool_calls=False, + tool_choice="auto", + usage=usage, + ) + dumped = response.model_dump(exclude_none=True) + self.assertEqual(dumped["usage"]["input_tokens"], 10) + self.assertEqual(dumped["usage"]["output_tokens"], 5) + self.assertEqual(dumped["usage"]["total_tokens"], 15) + + +@require_serve +class TestResponseSSEFormat(unittest.TestCase): + def test_sse_format(self): + event = ResponseCreatedEvent( + type="response.created", + sequence_number=0, + response=Response( + id="resp_test", + created_at=0, + status="queued", + model="test", + text={"format": {"type": "text"}}, + object="response", + tools=[], + output=[], + parallel_tool_calls=False, + tool_choice="auto", + ), + ) + result = BaseHandler.chunk_to_sse(event) + self.assertTrue(result.startswith("data: ")) + self.assertTrue(result.endswith("\n\n")) + parsed = json.loads(result[len("data: ") :].strip()) + self.assertEqual(parsed["type"], "response.created") + self.assertEqual(parsed["response"]["status"], "queued") - # TODO: one test for each request flag, to confirm it is working as expected - # TODO: speed-based test to confirm that KV cache is working across requests +@slow +@require_serve +class TestResponsesIntegration(unittest.TestCase): + """Integration tests for /v1/responses with a real model.""" -@slow # server startup time is slow on our push CI -@require_openai -class ServeResponsesIntegrationTest(ServeResponsesMixin, unittest.TestCase): - """Tests the Responses API.""" + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" @classmethod def setUpClass(cls): - """Starts a server for tests to connect to.""" - cls.port = 8003 - cls.server = Serve(port=cls.port, default_seed=42, non_blocking=True) + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") @classmethod def tearDownClass(cls): - cls.server.kill_server() - - @slow - def test_full_request(self): - """Tests that an inference using the Responses API works""" - - request = { - "model": "Qwen/Qwen2.5-0.5B-Instruct", - "instructions": "You are a sports assistant designed to craft sports programs.", - "input": "Tell me what you can do.", - "stream": True, - "max_output_tokens": 30, - # Disable sampling for deterministic output - "temperature": 0, + cls.serve.kill_server() + + def test_streaming(self): + events = list( + self.client.responses.create( + model=self.MODEL, + input="Say hello", + stream=True, + max_output_tokens=1, + ) + ) + # At least 8 events: created, in_progress, output_item_added, content_part_added, + # delta(s), text_done, content_part_done, output_item_done, completed + self.assertGreaterEqual(len(events), 8) + + # Start markers (fixed order) + self.assertEqual(events[0].type, "response.created") + self.assertEqual(events[1].type, "response.in_progress") + self.assertEqual(events[2].type, "response.output_item.added") + self.assertEqual(events[3].type, "response.content_part.added") + + # At least one delta + self.assertTrue(any(e.type == "response.output_text.delta" for e in events[4:-4])) + + # Closing markers (fixed order from the end) + self.assertEqual(events[-4].type, "response.output_text.done") + self.assertEqual(events[-3].type, "response.content_part.done") + self.assertEqual(events[-2].type, "response.output_item.done") + self.assertEqual(events[-1].type, "response.completed") + + def test_non_streaming(self): + resp = self.client.responses.create( + model=self.MODEL, + input="Say hello", + stream=False, + ) + self.assertEqual(resp.status, "completed") + self.assertTrue(len(resp.output) > 0) + self.assertTrue(len(resp.output[0].content[0].text) > 0) + + def test_non_streaming_usage(self): + resp = self.client.responses.create( + model=self.MODEL, + input="Say hello", + stream=False, + ) + self.assertIsNotNone(resp.usage) + self.assertGreater(resp.usage.input_tokens, 0) + self.assertGreater(resp.usage.output_tokens, 0) + self.assertEqual(resp.usage.total_tokens, resp.usage.input_tokens + resp.usage.output_tokens) + + def test_streaming_usage(self): + events = list( + self.client.responses.create( + model=self.MODEL, + input="Say hello", + stream=True, + max_output_tokens=5, + ) + ) + completed = events[-1] + self.assertEqual(completed.type, "response.completed") + usage = completed.response.usage + self.assertIsNotNone(usage) + self.assertGreater(usage.input_tokens, 0) + self.assertGreater(usage.output_tokens, 0) + self.assertEqual(usage.total_tokens, usage.input_tokens + usage.output_tokens) + + def test_tool_call_streaming(self): + """Streaming responses with tools should emit function_call events.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + events = list( + self.client.responses.create( + model=self.MODEL, + input="What is the weather in Paris?", + stream=True, + max_output_tokens=50, + tools=[tool_def], + ) + ) + types = [e.type for e in events] + self.assertIn("response.created", types) + self.assertIn("response.completed", types) + + # Should have function call events + self.assertIn("response.output_item.added", types) + self.assertIn("response.function_call_arguments.done", types) + + # Check the arguments done event + args_done = [e for e in events if e.type == "response.function_call_arguments.done"] + self.assertGreater(len(args_done), 0) + self.assertEqual(args_done[0].name, "get_weather") + + import json as json_mod + + parsed = json_mod.loads(args_done[0].arguments) + self.assertIsInstance(parsed, dict) + + def test_tool_call_non_streaming(self): + """Non-streaming responses with tools should include function_call output items.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", } - all_payloads = self.run_server(request) - - full_text = "" - for token in all_payloads: - if isinstance(token, ResponseTextDeltaEvent): - full_text += token.delta - - # Verify that the system prompt went through. - # With deterministic decoding, exact wording can still vary across versions. - # Assert non-empty output and that it references sports. - self.assertTrue(len(full_text) > 0) - self.assertIn("sports", full_text.lower()) - - @slow - def test_non_streaming_request(self): - """Tests that an inference using the Responses API with stream=False returns a single Response payload.""" - from openai import OpenAI - from openai.types.responses import Response as OpenAIResponse - - client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="") - resp = client.responses.create( - model="Qwen/Qwen2.5-0.5B-Instruct", - instructions="You are a helpful assistant.", - input="Hello!", + resp = self.client.responses.create( + model=self.MODEL, + input="What is the weather in Paris?", stream=False, - max_output_tokens=5, + max_output_tokens=50, + tools=[tool_def], ) + self.assertEqual(resp.status, "completed") + + # Should have at least message + function_call in output + self.assertGreater(len(resp.output), 1) + fc_items = [o for o in resp.output if o.type == "function_call"] + self.assertGreater(len(fc_items), 0) + self.assertEqual(fc_items[0].name, "get_weather") + + import json as json_mod - # Should be a single Response object with completed status and one output item containing text - self.assertIsInstance(resp, OpenAIResponse) + parsed = json_mod.loads(fc_items[0].arguments) + self.assertIsInstance(parsed, dict) + + def test_tool_call_multi(self): + """Model should produce multiple tool calls when asked about two cities.""" + tool_def = { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + "description": "Get the weather for a city.", + }, + "type": "function", + } + events = list( + self.client.responses.create( + model=self.MODEL, + input="What is the weather in Paris and London?", + stream=True, + max_output_tokens=100, + tools=[tool_def], + ) + ) + args_done = [e for e in events if e.type == "response.function_call_arguments.done"] + self.assertEqual(len(args_done), 2, f"Expected 2 tool calls, got {len(args_done)}") + self.assertEqual(events[-1].type, "response.completed") + + def test_multi_turn(self): + """Multi-turn conversation via list input.""" + resp = self.client.responses.create( + model=self.MODEL, + input=[ + {"role": "user", "content": "My name is Alice"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "What is my name?"}, + ], + stream=False, + ) self.assertEqual(resp.status, "completed") - self.assertTrue(len(resp.output) >= 1) - first_item = resp.output[0] - self.assertEqual(first_item.type, "message") - self.assertEqual(first_item.status, "completed") - self.assertTrue(len(first_item.content) >= 1) - first_part = first_item.content[0] - self.assertEqual(first_part.type, "output_text") - self.assertIsInstance(first_part.text, str) + self.assertIn("Alice", resp.output[0].content[0].text) + def test_concurrent_non_streaming(self): + """Two concurrent non-streaming responses requests should both complete.""" + import concurrent.futures -class ServeInfrastructureTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.port = 8042 - thread = Thread(target=Serve, kwargs={"port": cls.port}) - thread.daemon = True - thread.start() - - def test_healthcheck(self): - """Tests that the healthcheck endpoint works.""" - response = _call_healthcheck(f"http://localhost:{self.port}") - self.assertIsNotNone(response, "Failed to connect to the server health endpoint.") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json(), {"status": "ok"}) + inputs = ["Say hello", "Say goodbye"] + results = [None, None] + + def request_in_thread(index): + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(request_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + self.assertIsNotNone(results[i]) + self.assertEqual(results[i].status, "completed") + self.assertTrue(len(results[i].output[0].content[0].text) > 0) + + def test_concurrent_streaming(self): + """Two concurrent streaming responses requests should both produce complete event streams.""" + import concurrent.futures + + inputs = ["Say hello", "Say goodbye"] + results = [None, None] + def stream_in_thread(index): + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + results[index] = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True)) -def parse_sse_events(response): + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(stream_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + types = [e.type for e in results[i]] + self.assertIn("response.created", types, f"Request {i} missing created event") + self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events") + self.assertIn("response.completed", types, f"Request {i} missing completed event") + + +def _parse_sse_events(response): """Parse SSE lines from a streaming httpx response into a list of dicts.""" events = [] for line in response.iter_lines(): - if not line: + if not line or not line.startswith("data: "): continue - if line.startswith("data: "): - events.append(json.loads(line[6:])) + events.append(json.loads(line[len("data: ") :])) return events @slow -@require_openai -class ServeLoadModelIntegrationTest(unittest.TestCase): - """Tests the /load_model SSE endpoint.""" +@require_serve +class TestLoadModel(unittest.TestCase): + """Integration tests for POST /load_model SSE endpoint.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" @classmethod def setUpClass(cls): - cls.port = 8043 - cls.server = Serve(port=cls.port, non_blocking=True) - cls.base_url = f"http://localhost:{cls.port}" - # Wait for the server to be ready - response = _call_healthcheck(cls.base_url) - assert response is not None and response.status_code == 200 + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" @classmethod def tearDownClass(cls): - cls.server.kill_server() + cls.serve.kill_server() def setUp(self): - # Clear the in-memory model cache so each test starts fresh - self.server.reset_loaded_models() + # Clear model cache so each test starts fresh + self.serve.reset_loaded_models() def _load_model(self, model: str): - with httpx.Client(timeout=120) as client: - with client.stream("POST", f"{self.base_url}/load_model", json={"model": model}) as response: - events = parse_sse_events(response) - return response, events + with httpx.stream("POST", f"{self.base_url}/load_model", json={"model": model}, timeout=120) as resp: + events = _parse_sse_events(resp) + return resp, events def test_load_model_fresh(self): - """POST /load_model with a valid model returns SSE events ending with ready.""" - response, events = self._load_model("Qwen/Qwen2.5-0.5B-Instruct") + """POST /load_model returns SSE events ending with ready.""" + response, events = self._load_model(self.MODEL) self.assertEqual(response.status_code, 200) - self.assertIn("text/event-stream", response.headers.get("content-type", "")) - # Extract stages from loading events stages = [e["stage"] for e in events if e["status"] == "loading" and "stage" in e] self.assertIn("processor", stages) - self.assertIn("config", stages) self.assertIn("weights", stages) - # Stages must appear in the correct order - stage_indices = {stage: i for i, stage in enumerate(stages) if stage in ("processor", "config", "weights")} - self.assertLess(stage_indices["processor"], stage_indices["config"]) - self.assertLess(stage_indices["config"], stage_indices["weights"]) - - # Last event is ready with cached: false last = events[-1] self.assertEqual(last["status"], "ready") self.assertFalse(last["cached"]) - # Every event has status and model for event in events: self.assertIn("status", event) self.assertIn("model", event) def test_load_model_cached(self): - """Loading a model that is already in memory returns a single ready event with cached: true.""" - # First load to ensure the model is in memory - self._load_model("Qwen/Qwen2.5-0.5B-Instruct") + """Loading an already-loaded model returns a single ready event with cached: true.""" + self._load_model(self.MODEL) - # Second load should be cached - _, events = self._load_model("Qwen/Qwen2.5-0.5B-Instruct") + _, events = self._load_model(self.MODEL) ready_events = [e for e in events if e["status"] == "ready"] self.assertEqual(len(ready_events), 1) self.assertTrue(ready_events[0]["cached"]) - # No loading events should be present loading_events = [e for e in events if e["status"] == "loading"] self.assertEqual(len(loading_events), 0) @@ -987,18 +1311,18 @@ def test_load_model_error(self): _, events = self._load_model("nonexistent/model-that-does-not-exist") error_events = [e for e in events if e["status"] == "error"] - self.assertGreaterEqual(len(error_events), 1, "Expected at least one error event") + self.assertGreaterEqual(len(error_events), 1) self.assertIn("message", error_events[0]) def test_load_model_missing_field(self): """POST /load_model with no model field returns 422.""" - with httpx.Client(timeout=30) as client: - response = client.post(f"{self.base_url}/load_model", json={}) - self.assertEqual(response.status_code, 422) + + response = httpx.post(f"{self.base_url}/load_model", json={}, timeout=30) + self.assertEqual(response.status_code, 422) def test_load_model_event_schema(self): - """Every event in a load_model stream conforms to the expected schema.""" - _, events = self._load_model("Qwen/Qwen2.5-0.5B-Instruct") + """Every event conforms to the expected schema.""" + _, events = self._load_model(self.MODEL) for event in events: self.assertIsInstance(event["status"], str) @@ -1006,7 +1330,6 @@ def test_load_model_event_schema(self): if event["status"] == "loading": self.assertIn("stage", event) - if event["stage"] in ("download", "weights") and "progress" in event: progress = event["progress"] self.assertIn("current", progress) @@ -1018,12 +1341,10 @@ def test_load_model_event_schema(self): self.assertIsInstance(event["cached"], bool) def test_load_model_stage_ordering(self): - """Stages in loading events follow the expected order.""" - _, events = self._load_model("Qwen/Qwen2.5-0.5B-Instruct") + """Stages appear in the expected order.""" + _, events = self._load_model(self.MODEL) stages = [e["stage"] for e in events if e["status"] == "loading" and "stage" in e] - - # Deduplicate while preserving order (stages repeat for progress ticks) seen = set() unique_stages = [] for s in stages: @@ -1032,29 +1353,23 @@ def test_load_model_stage_ordering(self): unique_stages.append(s) expected_order = ["processor", "config", "download", "weights"] - # Filter expected_order to only stages that are actually present expected_present = [s for s in expected_order if s in unique_stages] - self.assertEqual(unique_stages, expected_present, "Stages appeared out of order") def test_concurrent_load_same_model(self): - """Two concurrent /load_model requests for the same model should both receive progress events - and a final ready event, but the model should only be loaded once.""" + """Two concurrent /load_model requests both get events and a ready event.""" import concurrent.futures - model = "Qwen/Qwen2.5-0.5B-Instruct" results = [None, None] def load_in_thread(index): - with httpx.Client(timeout=120) as client: - with client.stream("POST", f"{self.base_url}/load_model", json={"model": model}) as response: - events = parse_sse_events(response) - results[index] = (response.status_code, events) + with httpx.stream("POST", f"{self.base_url}/load_model", json={"model": self.MODEL}, timeout=120) as resp: + events = _parse_sse_events(resp) + results[index] = (resp.status_code, events) with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: futures = [pool.submit(load_in_thread, i) for i in range(2)] concurrent.futures.wait(futures) - # Re-raise any exceptions from threads for f in futures: f.result() @@ -1062,28 +1377,517 @@ def load_in_thread(index): status_code, events = results[i] self.assertEqual(status_code, 200, f"Caller {i} got non-200 status") self.assertTrue(len(events) > 0, f"Caller {i} received no events") - ready_events = [e for e in events if e["status"] == "ready"] self.assertEqual(len(ready_events), 1, f"Caller {i} should get exactly one ready event") - self.assertIn("model", ready_events[0]) - def test_concurrent_load_second_caller_gets_cached_if_first_finishes(self): - """If the first /load_model finishes before the second arrives, - the second caller should get a cached response.""" - model = "Qwen/Qwen2.5-0.5B-Instruct" - - # First load — blocks until complete - _, events1 = self._load_model(model) + def test_concurrent_load_second_caller_gets_cached(self): + """If the first /load_model finishes before the second, the second gets cached: true.""" + _, events1 = self._load_model(self.MODEL) ready1 = [e for e in events1 if e["status"] == "ready"] self.assertEqual(len(ready1), 1) self.assertFalse(ready1[0]["cached"]) - # Second load — model is now in memory - _, events2 = self._load_model(model) + _, events2 = self._load_model(self.MODEL) ready2 = [e for e in events2 if e["status"] == "ready"] self.assertEqual(len(ready2), 1) self.assertTrue(ready2[0]["cached"]) - # No loading events on the cached path loading2 = [e for e in events2 if e["status"] == "loading"] self.assertEqual(len(loading2), 0) + + def test_load_model_weights_progress_complete(self): + """Weights progress should go from 1 to total, with total matching across events.""" + _, events = self._load_model(self.MODEL) + + weights_events = [e for e in events if e.get("stage") == "weights" and "progress" in e] + self.assertGreater(len(weights_events), 0, "No weights progress events emitted") + + # All events should have the same total + totals = {e["progress"]["total"] for e in weights_events} + self.assertEqual(len(totals), 1, f"Inconsistent totals: {totals}") + total = totals.pop() + self.assertIsNotNone(total) + self.assertGreater(total, 0) + + # First should be 1, last should be total + self.assertEqual(weights_events[0]["progress"]["current"], 1) + self.assertEqual(weights_events[-1]["progress"]["current"], total) + + # Progress should be monotonically increasing + currents = [e["progress"]["current"] for e in weights_events] + self.assertEqual(currents, sorted(currents)) + + def test_load_model_exactly_one_ready(self): + """A fresh load should produce exactly one ready event as the last event.""" + _, events = self._load_model(self.MODEL) + + ready_events = [e for e in events if e["status"] == "ready"] + self.assertEqual(len(ready_events), 1) + self.assertEqual(events[-1]["status"], "ready") + + def test_load_model_usable_after_load(self): + """After /load_model completes, the model should be usable for inference.""" + self._load_model(self.MODEL) + + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + resp = client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hi"}], + max_tokens=5, + ) + self.assertIsNotNone(resp.choices[0].message.content) + self.assertTrue(len(resp.choices[0].message.content) > 0) + + def test_load_model_model_field_matches(self): + """The model field in every event should match the canonical model ID.""" + _, events = self._load_model(self.MODEL) + + for event in events: + self.assertTrue( + event["model"].startswith(self.MODEL), + f"Event model '{event['model']}' doesn't match '{self.MODEL}'", + ) + + def test_concurrent_non_streaming(self): + """Two concurrent non-streaming responses requests should both complete.""" + import concurrent.futures + + inputs = ["Say hello", "Say goodbye"] + results = [None, None] + + def request_in_thread(index): + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(request_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + self.assertIsNotNone(results[i]) + self.assertEqual(results[i].status, "completed") + self.assertTrue(len(results[i].output[0].content[0].text) > 0) + + def test_concurrent_streaming(self): + """Two concurrent streaming responses requests should both produce complete event streams.""" + import concurrent.futures + + inputs = ["Say hello", "Say goodbye"] + results = [None, None] + + def stream_in_thread(index): + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + events = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True)) + results[index] = events + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(stream_in_thread, i) for i in range(2)] + concurrent.futures.wait(futures) + for f in futures: + f.result() + + for i in range(2): + types = [e.type for e in results[i]] + self.assertIn("response.created", types, f"Request {i} missing created event") + self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events") + self.assertIn("response.completed", types, f"Request {i} missing completed event") + + +# Real image URL for VLM tests (person + dog on a beach) +_DOG_IMAGE_URL = "https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2-VL/demo_small.jpg" + + +@slow +@require_vision +@require_serve +class TestVLM(unittest.TestCase): + """Integration tests for VLM (vision-language model) support. Requires torchvision.""" + + MODEL = "HuggingFaceTB/SmolVLM-256M-Instruct" + + @classmethod + def setUpClass(cls): + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + def test_chat_completion_with_image(self): + """Chat completions should accept image_url content and produce a meaningful response.""" + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + {"type": "image_url", "image_url": {"url": _DOG_IMAGE_URL}}, + ], + } + ], + max_tokens=50, + ) + text = resp.choices[0].message.content + self.assertIsNotNone(text) + self.assertTrue( + any(word in text.lower() for word in ["dog", "beach", "person"]), + f"Expected dog/beach/person in response, got: {text}", + ) + + def test_responses_with_image(self): + """Responses API should accept image_url content and produce a meaningful response.""" + resp = self.client.responses.create( + model=self.MODEL, + input=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + {"type": "image_url", "image_url": {"url": _DOG_IMAGE_URL}}, + ], + } + ], + stream=False, + max_output_tokens=50, + ) + self.assertEqual(resp.status, "completed") + text = resp.output[0].content[0].text + self.assertTrue( + any(word in text.lower() for word in ["dog", "beach", "person"]), + f"Expected dog/beach/person in response, got: {text}", + ) + + +@slow +@require_librosa +@require_multipart +@require_serve +class TestTranscription(unittest.TestCase): + """Integration tests for POST /v1/audio/transcriptions with whisper-tiny.""" + + MODEL = "openai/whisper-tiny" + + @classmethod + def setUpClass(cls): + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + @classmethod + def _get_audio_bytes(cls): + """Download the MLK 'I have a dream' speech sample from HF Hub.""" + if not hasattr(cls, "_audio_bytes"): + from huggingface_hub import hf_hub_download + + path = hf_hub_download("Narsil/asr_dummy", "mlk.flac", repo_type="dataset") + with open(path, "rb") as f: + cls._audio_bytes = f.read() + return cls._audio_bytes + + def test_transcription_returns_text(self): + """POST /v1/audio/transcriptions with real speech returns meaningful transcription.""" + + audio_bytes = self._get_audio_bytes() + resp = httpx.post( + f"{self.base_url}/v1/audio/transcriptions", + files={"file": ("mlk.flac", audio_bytes, "audio/flac")}, + data={"model": self.MODEL}, + timeout=120, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertIn("text", data) + self.assertIsInstance(data["text"], str) + # Whisper-tiny should recognize at least "dream" from the MLK speech + self.assertIn("dream", data["text"].lower()) + + def test_transcription_openai_client(self): + """Transcription should work via the OpenAI Python client.""" + audio_bytes = self._get_audio_bytes() + client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused") + result = client.audio.transcriptions.create( + model=self.MODEL, + file=("mlk.flac", audio_bytes), + ) + self.assertIsInstance(result.text, str) + self.assertTrue(len(result.text) > 10) + + def test_transcription_streaming(self): + """Streaming transcription should yield text chunks via SSE.""" + + audio_bytes = self._get_audio_bytes() + with httpx.stream( + "POST", + f"{self.base_url}/v1/audio/transcriptions", + files={"file": ("mlk.flac", audio_bytes, "audio/flac")}, + data={"model": self.MODEL, "stream": "true"}, + timeout=120, + ) as resp: + self.assertEqual(resp.status_code, 200) + + chunks = [] + for line in resp.iter_lines(): + if line and line.startswith("data: "): + chunks.append(line[len("data: ") :]) + + self.assertGreater(len(chunks), 0, "No streaming chunks received") + full_text = "".join(chunks) + self.assertIn("dream", full_text.lower()) + + def test_transcription_missing_file(self): + """POST without a file should fail.""" + + resp = httpx.post( + f"{self.base_url}/v1/audio/transcriptions", + data={"model": self.MODEL}, + timeout=30, + ) + self.assertNotEqual(resp.status_code, 200) + + +@slow +@require_serve +@require_torch_accelerator +class TestContinuousBatchingChatCompletion(unittest.TestCase): + """Integration tests for /v1/chat/completions with continuous batching.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + @classmethod + def setUpClass(cls): + cls.serve, port = _start_serve( + force_model=cls.MODEL, + device="cuda:0", + continuous_batching=True, + attn_implementation="sdpa", + default_seed=42, + ) + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + def test_streaming(self): + """Streaming chat completion with CB produces text.""" + text = "" + for chunk in self.client.chat.completions.create( + model=self.MODEL, + messages=[ + {"role": "system", "content": "You are a sports assistant designed to craft sports programs."}, + {"role": "user", "content": "Tell me what you can do."}, + ], + stream=True, + max_tokens=30, + ): + if chunk.choices[0].delta.content: + text += chunk.choices[0].delta.content + self.assertTrue(len(text) > 0) + + def test_non_streaming(self): + """Non-streaming chat completion with CB returns a full response.""" + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hello"}], + max_tokens=20, + ) + self.assertIsNotNone(resp.choices[0].message.content) + self.assertTrue(len(resp.choices[0].message.content) > 0) + + def test_non_streaming_response_json_format(self): + """Non-streaming CB responses return proper JSON objects, not double-encoded strings.""" + response = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hello"}], + stream=False, + max_tokens=5, + ) + self.assertIsNotNone(response) + self.assertIsNotNone(response.id) + self.assertIsNotNone(response.choices) + self.assertEqual(len(response.choices), 1) + + choice = response.choices[0] + self.assertIsNotNone(choice.message) + self.assertIsNotNone(choice.message.content) + self.assertEqual(choice.message.role, "assistant") + self.assertIsInstance(choice.message.content, str) + + def test_multi_turn(self): + """Multi-turn conversation works with CB.""" + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[ + {"role": "user", "content": "My name is Alice"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "What is my name?"}, + ], + max_tokens=20, + ) + self.assertIn("Alice", resp.choices[0].message.content) + + def test_request_cancellation(self): + """Opening a stream and closing it early triggers CB cancellation.""" + + request_id = "test-cb-cancel" + + # Open a streaming request and close after a few chunks + with httpx.stream( + "POST", + f"{self.base_url}/v1/chat/completions", + headers={"X-Request-ID": request_id}, + json={ + "model": self.MODEL, + "stream": True, + "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}], + }, + timeout=30, + ) as resp: + self.assertEqual(resp.status_code, 200) + chunks_read = 0 + for _ in resp.iter_lines(): + chunks_read += 1 + if chunks_read >= 3: + break + + # Poll for cancellation in the CB scheduler + scheduler = self.serve._generation_state._cb_manager.scheduler + deadline = time.time() + 8.0 + while time.time() < deadline: + if scheduler.request_is_cancelled(request_id): + break + time.sleep(0.1) + + self.assertTrue( + scheduler.request_is_cancelled(request_id), + f"Request {request_id} not cancelled in scheduler after stream close.", + ) + + # Server should still be healthy and serve subsequent requests + resp = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": "Say hi"}], + max_tokens=10, + ) + self.assertIsNotNone(resp.choices[0].message.content) + + +@slow +@require_serve +@require_torch_accelerator +class TestContinuousBatchingResponses(unittest.TestCase): + """Integration tests for /v1/responses with continuous batching.""" + + MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + @classmethod + def setUpClass(cls): + cls.serve, port = _start_serve( + force_model=cls.MODEL, + device="cuda:0", + continuous_batching=True, + attn_implementation="sdpa", + default_seed=42, + ) + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + def test_streaming(self): + """Streaming response with CB produces text.""" + text = "" + stream = self.client.responses.create( + model=self.MODEL, + input="Say hello in one sentence.", + stream=True, + max_output_tokens=30, + ) + for event in stream: + if event.type == "response.output_text.delta": + text += event.delta + self.assertTrue(len(text) > 0) + + def test_non_streaming(self): + """Non-streaming response with CB returns text.""" + resp = self.client.responses.create( + model=self.MODEL, + input="Say hello in one sentence.", + stream=False, + max_output_tokens=30, + ) + content = resp.output[0].content[0].text + self.assertTrue(len(content) > 0) + + def test_multi_turn(self): + """Multi-turn conversation works with CB via Responses API.""" + resp = self.client.responses.create( + model=self.MODEL, + input=[ + {"role": "user", "content": "My name is Alice"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "What is my name?"}, + ], + stream=False, + max_output_tokens=20, + ) + content = resp.output[0].content[0].text + self.assertIn("Alice", content) + + def test_request_cancellation(self): + """Opening a stream and closing it early triggers CB cancellation.""" + + request_id = "test-cb-resp-cancel" + + with httpx.stream( + "POST", + f"{self.base_url}/v1/responses", + headers={"X-Request-ID": request_id}, + json={ + "model": self.MODEL, + "stream": True, + "input": "Count slowly so I can cancel you.", + "max_output_tokens": 500, + }, + timeout=30, + ) as resp: + self.assertEqual(resp.status_code, 200) + # Read enough data to ensure CB generation has started, then close. + received = b"" + for chunk in resp.iter_bytes(chunk_size=512): + received += chunk + if b"output_text.delta" in received: + break + + # Poll for cancellation in the CB scheduler + scheduler = self.serve._generation_state._cb_manager.scheduler + deadline = time.time() + 8.0 + while time.time() < deadline: + if scheduler.request_is_cancelled(request_id): + break + time.sleep(0.1) + + self.assertTrue( + scheduler.request_is_cancelled(request_id), + f"Request {request_id} not cancelled in scheduler after stream close.", + ) + + # Server should still serve subsequent requests + resp = self.client.responses.create( + model=self.MODEL, + input="Say hi", + stream=False, + max_output_tokens=10, + ) + self.assertTrue(len(resp.output[0].content[0].text) > 0)