diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile
index ab4bb8b12114..1380dc8a405c 100644
--- a/docker/transformers-all-latest-gpu/Dockerfile
+++ b/docker/transformers-all-latest-gpu/Dockerfile
@@ -105,6 +105,9 @@ RUN python3 -m pip install --no-cache-dir python-Levenshtein
# For `FastSpeech2ConformerTokenizer` tokenizer
RUN python3 -m pip install --no-cache-dir g2p-en
+# For serving tests (audio pipelines)
+RUN python3 -m pip install --no-cache-dir librosa python-multipart
+
# For Some bitsandbytes tests
RUN python3 -m pip install --no-cache-dir einops
diff --git a/examples/pytorch/transformers_serve_cb_eval_job.py b/examples/pytorch/transformers_serve_cb_eval_job.py
index c6355427b161..b30f71e4aa2e 100644
--- a/examples/pytorch/transformers_serve_cb_eval_job.py
+++ b/examples/pytorch/transformers_serve_cb_eval_job.py
@@ -16,10 +16,9 @@
from inspect_ai import eval
from inspect_ai.log import bundle_log_dir
-from inspect_evals.gpqa import gpqa_diamond
-def wait_for_server_up(server_process, timeout=600):
+def wait_for_server_up(server_process, port=8000, timeout=600):
start_time = time.time()
import urllib.error
@@ -27,7 +26,7 @@ def wait_for_server_up(server_process, timeout=600):
while time.time() - start_time < timeout:
try:
- req = urllib.request.urlopen("http://127.0.0.1:8000/health", timeout=2)
+ req = urllib.request.urlopen(f"http://127.0.0.1:{port}/health", timeout=2)
if req.status == 200:
elapsed = time.time() - start_time
print("\n" + "=" * 70)
@@ -69,17 +68,29 @@ def main():
action="store_true",
help="Disable continuous batching (enabled by default)",
)
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=8000,
+ help="Port for the transformers serve server (default: 8000)",
+ )
parser.add_argument(
"--limit",
type=int,
default=10,
- help="Number of evaluation samples to run (default: 5)",
+ help="Number of evaluation samples to run (default: 10)",
)
parser.add_argument(
"--max-connections",
type=int,
default=10,
- help="Maximum concurrent connections for evaluation (default: 2)",
+ help="Maximum concurrent connections for evaluation (default: 10)",
+ )
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ default=0,
+ help="Temperature for generation (default: 0)",
)
parser.add_argument(
"--log-dir",
@@ -121,7 +132,7 @@ def main():
)
parser.add_argument(
"--cb-use-cuda-graph",
- action="store_true",
+ action=argparse.BooleanOptionalAction,
help="Enable CUDA graphs for continuous batching performance",
)
@@ -133,6 +144,7 @@ def main():
serve_cmd = [
"transformers",
"serve",
+ args.model,
]
# Add continuous batching if not disabled
@@ -150,10 +162,14 @@ def main():
serve_cmd.extend(["--cb-max-memory-percent", str(args.cb_max_memory_percent)])
- serve_cmd.append("--cb-use-cuda-graph")
+ if args.cb_use_cuda_graph is True:
+ serve_cmd.append("--cb-use-cuda-graph")
+ elif args.cb_use_cuda_graph is False:
+ serve_cmd.append("--no-cb-use-cuda-graph")
# Always use sdpa attention implementation
serve_cmd.extend(["--attn-implementation", "kernels-community/flash-attn2"])
+ serve_cmd.extend(["--port", str(args.port)])
print("Starting transformers serve with continuous batching...")
print(f"Model: {args.model}")
@@ -163,6 +179,7 @@ def main():
print(f"CB Max Batch Tokens: {args.cb_max_batch_tokens if args.cb_max_batch_tokens else 'auto'}")
print(f"CB Max Memory: {args.cb_max_memory_percent * 100}%")
print(f"CB CUDA Graph: {args.cb_use_cuda_graph}")
+ print(f"Temperature: {args.temperature}")
print(f"Command: {' '.join(serve_cmd)}")
print("=" * 70)
print("SERVER OUTPUT:")
@@ -171,16 +188,17 @@ def main():
# Start server with output going directly to stdout/stderr
server_process = subprocess.Popen(serve_cmd, stdout=None, stderr=None)
- wait_for_server_up(server_process, timeout=600)
+ wait_for_server_up(server_process, port=args.port, timeout=600)
eval(
- gpqa_diamond,
+ "hf/Idavidrein/gpqa/diamond",
model=f"openai-api/transformers-serve/{args.model}",
log_dir=args.log_dir,
- model_base_url="http://localhost:8000/v1",
+ model_base_url=f"http://localhost:{args.port}/v1",
display="plain",
limit=args.limit,
model_args={"stream": False},
+ temperature=args.temperature,
max_connections=args.max_connections,
max_tokens=2048,
)
diff --git a/src/transformers/cli/chat.py b/src/transformers/cli/chat.py
index c6e434b07481..968fd290c75b 100644
--- a/src/transformers/cli/chat.py
+++ b/src/transformers/cli/chat.py
@@ -114,11 +114,17 @@ async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput])
self._console.print(f"[bold blue]<{self.model_id}>:")
with Live(console=self._console, refresh_per_second=4) as live:
text = ""
+ completion_tokens = 0
+ start_time = time.time()
finish_reason: str | None = None
async for token in await stream:
outputs = token.choices[0].delta.content
finish_reason = getattr(token.choices[0], "finish_reason", finish_reason)
+ usage = getattr(token, "usage", None)
+ if usage is not None:
+ completion_tokens = getattr(usage, "completion_tokens", completion_tokens)
+
if not outputs:
continue
@@ -154,6 +160,11 @@ async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput])
# Update the Live console output
live.update(markdown, refresh=True)
+ elapsed = time.time() - start_time
+ if elapsed > 0 and completion_tokens > 0:
+ tok_per_sec = completion_tokens / elapsed
+ self._console.print()
+ self._console.print(f"[dim]{completion_tokens} tokens in {elapsed:.1f}s ({tok_per_sec:.1f} tok/s)[/dim]")
self._console.print()
return text, finish_reason
@@ -544,14 +555,16 @@ async def _inner_run(self):
else:
chat.append({"role": "user", "content": user_input})
+ extra_body = {
+ "generation_config": config.to_json_string(),
+ "model": self.model_id,
+ }
+
stream = client.chat_completion(
chat,
stream=True,
model=self.model_id,
- extra_body={
- "generation_config": config.to_json_string(),
- "model": self.model_id,
- },
+ extra_body=extra_body,
)
model_output, finish_reason = await interface.stream_output(stream)
diff --git a/src/transformers/cli/serve.py b/src/transformers/cli/serve.py
index 7337eb305b61..4d9e6c712e7c 100644
--- a/src/transformers/cli/serve.py
+++ b/src/transformers/cli/serve.py
@@ -11,2340 +11,180 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+CLI entry point for `transformers serve`.
+"""
+
import asyncio
-import base64
-import copy
-import enum
-import gc
-import io
-import json
-import re
-import tempfile
import threading
-import time
-import uuid
-from collections.abc import Callable, Generator, Iterable
-from contextlib import asynccontextmanager
-from functools import lru_cache
-from io import BytesIO
-from threading import Thread
-from typing import TYPE_CHECKING, Annotated, Optional, TypedDict, Union
+from typing import Annotated
import typer
-from huggingface_hub import scan_cache_dir
-from tokenizers.decoders import DecodeStream
-from tqdm import tqdm
-from tqdm.auto import tqdm as base_tqdm
-import transformers
-from transformers import AutoTokenizer, BitsAndBytesConfig, GenerationConfig, PreTrainedTokenizerBase
-from transformers.generation import (
- LogitsProcessorList,
- TextIteratorStreamer,
-)
from transformers.utils import logging
-from transformers.utils.import_utils import (
- is_fastapi_available,
- is_librosa_available,
- is_openai_available,
- is_pydantic_available,
- is_uvicorn_available,
- is_vision_available,
-)
-
-
-if TYPE_CHECKING:
- from transformers import (
- PreTrainedModel,
- PreTrainedTokenizerFast,
- ProcessorMixin,
- )
-
- from ..generation.continuous_batching import ContinuousBatchingManager
-
-
-if is_librosa_available():
- import librosa
-
-if is_vision_available():
- from PIL import Image
-
-serve_dependencies_available = (
- is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available()
-)
-if serve_dependencies_available:
- import uvicorn
- from fastapi import FastAPI, HTTPException
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import JSONResponse, StreamingResponse
- from openai.types.audio.transcription import Transcription
- from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase
- from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
- from openai.types.chat.chat_completion import Choice
- from openai.types.chat.chat_completion_chunk import (
- ChatCompletionChunk,
- ChoiceDelta,
- ChoiceDeltaToolCall,
- ChoiceDeltaToolCallFunction,
- )
- from openai.types.chat.chat_completion_chunk import (
- Choice as ChoiceChunk,
- )
- from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming
- from openai.types.responses import (
- Response,
- ResponseCompletedEvent,
- ResponseContentPartAddedEvent,
- ResponseContentPartDoneEvent,
- ResponseCreatedEvent,
- ResponseError,
- ResponseErrorEvent,
- ResponseFailedEvent,
- ResponseInProgressEvent,
- ResponseOutputItemAddedEvent,
- ResponseOutputItemDoneEvent,
- ResponseOutputMessage,
- ResponseOutputText,
- ResponseTextDeltaEvent,
- ResponseTextDoneEvent,
- )
- from openai.types.responses.response_create_params import ResponseCreateParamsStreaming
- from pydantic import BaseModel, TypeAdapter, ValidationError
-
- # Expand OpenAI's request input types with an optional `generation_config` field
- class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False):
- """
- OpenAI's ResponseCreateParamsStreaming with an additional field for the generation config (as a json string).
- """
-
- generation_config: str
-
- class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False):
- """
- OpenAI's CompletionCreateParamsStreaming with additional fields for the generation config (as a json string) and passing the request_id
- """
+from transformers.utils.import_utils import is_serve_available
- generation_config: str
-
- class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False):
- """
- OpenAI's TranscriptionCreateParamsBase with an additional field for the generation config (as a json string).
- """
-
- file: bytes # Overwritten -- pydantic isn't happy with `typing.IO[bytes]`, present in the original type
- generation_config: str
- stream: bool = False
-
- # Contrarily to OpenAI's output types, input types are `TypedDict`, which don't have built-in validation.
- response_validator = TypeAdapter(TransformersResponseCreateParamsStreaming)
- completion_validator = TypeAdapter(TransformersCompletionCreateParamsStreaming)
- transcription_validator = TypeAdapter(TransformersTranscriptionCreateParams)
-
- # Define request fields that are not yet used in `transformers serve`. Receiving these fields will raise an
- # HTTPException.
- UNUSED_RESPONSE_FIELDS = {
- "background",
- "include",
- "max_tool_calls",
- "previous_response_id",
- "prompt",
- "reasoning",
- "service_tier",
- "store",
- "text",
- "tool_choice",
- "top_logprobs",
- "truncation",
- "user",
- }
-
- UNUSED_CHAT_COMPLETION_FIELDS = {
- "audio",
- "function_call",
- "functions",
- "logprobs",
- "max_completion_tokens",
- "metadata",
- "modalities",
- "n",
- "parallel_tool_calls",
- "prediction",
- "presence_penalty",
- "reasoning_effort",
- "response_format",
- "service_tier",
- "stop",
- "store",
- "stream_options",
- "tool_choice",
- "top_logprobs",
- "user",
- "web_search_options",
- }
- UNUSED_TRANSCRIPTION_FIELDS = {
- "chunking_strategy",
- "include",
- "language",
- "prompt",
- "response_format",
- "timestamp_granularities",
- }
+from .serving.utils import set_torch_seed
logger = logging.get_logger(__name__)
-# Possible tokens that indicate the start/end of a tool call
-# TODO (joao, matt): streamline tool token detection logic
-_TOOL_CALL_TOKENS = {
- "qwen": {
- "start": "",
- "end": "",
- },
-}
-_MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys())
-
-X_REQUEST_ID = "x-request-id"
-
-
-def set_torch_seed(_seed):
- import torch
-
- torch.manual_seed(_seed)
-
-
-def reset_torch_cache():
- import torch
-
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
-
-
-def torch_ones_like(_input_tensor):
- import torch
-
- return torch.ones_like(_input_tensor)
-
-
-class Modality(enum.Enum):
- LLM = "LLM"
- VLM = "VLM"
- STT = "STT"
- TTS = "TTS"
-
-
-def create_generation_config_from_req(
- req: dict,
- model_generation_config: GenerationConfig,
- **kwargs,
-) -> GenerationConfig:
- """
- Creates a generation config from the parameters of the request. If a generation config is passed in the request,
- it will be used as a baseline for parameterization. Otherwise, we will use the model's default generation config.
- Other parameters in the request will be applied on top of the baseline.
-
- Args:
- req (`dict`):
- The request which may optionally contain generation parameters.
- model_generation_config (`GenerationConfig`):
- The model's default generation config.
- kwargs (`dict`):
- Additional parameters to set in the generation config.
-
- Returns:
- The prepared `GenerationConfig` object.
- """
- # If there is a generation config in the request, it is a json string serialization from a `GenerationConfig`
- # object. For simplicity, flags set here take precedence over all other flags.
- if req.get("generation_config") is not None:
- generation_config = GenerationConfig(**json.loads(req["generation_config"]))
- else:
- generation_config = copy.deepcopy(model_generation_config)
-
- non_standard_kwargs = generation_config.update(**kwargs)
- # Set extra kwargs that are not in the `GenerationConfig` class (e.g. continuous batching flags)
- for k, v in non_standard_kwargs.items():
- if v is not None:
- setattr(generation_config, k, v)
-
- # Response-specific parameters
- if req.get("max_output_tokens") is not None:
- generation_config.max_new_tokens = int(req["max_output_tokens"])
-
- # Completion-specific parameters
- if req.get("max_tokens") is not None:
- generation_config.max_new_tokens = int(req["max_tokens"])
- if req.get("frequency_penalty") is not None:
- generation_config.repetition_penalty = float(req["frequency_penalty"])
- if req.get("logit_bias") is not None:
- generation_config.sequence_bias = req["logit_bias"]
- if req.get("stop") is not None:
- generation_config.stop_strings = req["stop"]
- if req.get("temperature") is not None:
- generation_config.temperature = float(req["temperature"])
- if float(req["temperature"]) == 0.0:
- generation_config.do_sample = False
- if req.get("top_p") is not None:
- generation_config.top_p = float(req["top_p"])
- if req.get("seed") is not None:
- set_torch_seed(req["seed"])
-
- return generation_config
-
-
-class DownloadAggregator:
- """Aggregates byte-progress across multiple concurrent download tqdm bars into a single SSE stream.
-
- huggingface_hub opens one tqdm bar per file shard. This class tracks them all and emits
- a single aggregate ``{"stage": "download", "progress": {"current": ..., "total": ...}}``
- event whenever any bar updates.
- """
-
- def __init__(self, enqueue: Callable[[dict], None], model_id_and_revision: str):
- self.enqueue = enqueue
- self.model = model_id_and_revision
- self.bars: dict[int, tuple[int, int | None]] = {} # id -> (current, total)
- self.last_emitted_current: int | None = None
-
- def register(self, bar_id: int, total: int | None):
- self.bars[bar_id] = (0, total)
- self._emit()
-
- def update(self, bar_id: int, current: int, total: int | None):
- self.bars[bar_id] = (current, total)
- self._emit()
-
- def close(self, bar_id: int):
- pass # keep the bar in _bars so totals remain correct
-
- def _emit(self):
- agg_current = sum(c for c, _ in self.bars.values())
- if agg_current == self.last_emitted_current:
- return
- self.last_emitted_current = agg_current
- totals = [t for _, t in self.bars.values() if t is not None]
- agg_total = sum(totals) if totals else None
- self.enqueue(
- {
- "status": "loading",
- "model": self.model,
- "stage": "download",
- "progress": {"current": agg_current, "total": agg_total},
- }
- )
-
-
-class DownloadProxy:
- """
- Leverages the DownloadAggregator in order to have a coherent tqdm wrapper.
- """
-
- def __init__(self, wrapped_bar, download_aggregator):
- self.wrapped_bar = wrapped_bar
- self.bar_id = id(wrapped_bar)
- self.download_aggregator = download_aggregator
-
- self.n = 0
- self.total = wrapped_bar.total
-
- def __getattr__(self, name):
- return getattr(self.wrapped_bar, name)
-
- def update(self, n=1):
- if n is None:
- n = 1
-
- self.n += n
- self.download_aggregator.update(self.bar_id, self.n, getattr(self.wrapped_bar, "total", self.total))
-
- return self.wrapped_bar.update(n)
-
- def close(self):
- self.download_aggregator.close(self.bar_id)
- return self.wrapped_bar.close()
-
- def __enter__(self):
- self.wrapped_bar.__enter__()
- return self
-
- def __exit__(self, *a):
- return self.wrapped_bar.__exit__(*a)
-
- def __iter__(self):
- count = 0
- for item in self.wrapped_bar:
- count += 1
- self.download_aggregator.update(self.bar_id, count, getattr(self.wrapped_bar, "total", self.total))
-
- yield item
-
-
-class WeightsProxy:
- """
- Wraps the weight-loading tqdm bar to have finer control over how we emit them to the clinet.
- """
-
- def __init__(self, wrapped_bar, _callable, model_id_and_revision):
- self.wrapped_bar = wrapped_bar
- self.last_emitted = -1
- self.callable = _callable
- self.model_id_and_revision = model_id_and_revision
-
- self.n = 0
- self.total = wrapped_bar.total
-
- def __getattr__(self, name):
- return getattr(self.wrapped_bar, name)
-
- def _emit(self):
- if self.n == self.last_emitted:
- return
- self.last_emitted = self.n
-
- self.callable(
- {
- "status": "loading",
- "model": self.model_id_and_revision,
- "stage": "weights",
- "progress": {
- "current": self.n,
- "total": getattr(self.wrapped_bar, "total", self.total),
- },
- }
- )
-
- def update(self, n=1):
- if n is None:
- n = 1
-
- self.n += n
- self._emit()
-
- return self.wrapped_bar.update(n)
-
- def close(self):
- return self.wrapped_bar.close()
-
- def __enter__(self):
- self.wrapped_bar.__enter__()
- return self
-
- def __exit__(self, *a):
- return self.wrapped_bar.__exit__(*a)
-
- def __iter__(self):
- for item in self.wrapped_bar:
- self.n += 1
- self._emit()
-
- yield item
-
-
-def set_tqdm_class(callback, mid):
- download_aggregator = DownloadAggregator(callback, mid)
-
- class ProgressTqdm(base_tqdm):
- """tqdm subclass that routes progress to the correct SSE stage.
-
- Bars with ``unit="B"`` are download bars (one per file shard) — they are
- aggregated into a single ``download`` stage stream via ``_DownloadAggregator``.
- All other bars are weight-loading bars emitted as ``weights`` stage events.
- """
-
- def __init__(self, *args, **kwargs):
- self.sse_unit = kwargs.get("unit") or "it"
- kwargs["disable"] = True # suppress server-side display
- super().__init__(*args, **kwargs)
- self.n = 0
- self.last_emitted = -1
- if self.sse_unit == "B":
- self._bar_id = id(self)
- download_aggregator.register(self._bar_id, self.total)
-
- def update(self, n=1):
- if n is None:
- n = 1
-
- self.n += n
-
- if self.sse_unit == "B":
- download_aggregator.update(self._bar_id, self.n, self.total)
- elif self.n != self.last_emitted:
- self.last_emitted = self.n
-
- callback(
- {
- "status": "loading",
- "model": mid,
- "stage": "weights",
- "progress": {"current": self.n, "total": self.total},
- }
- )
-
- def close(self):
- if self.sse_unit == "B":
- download_aggregator.close(self._bar_id)
- super().close()
-
- return ProgressTqdm
-
-
-class ToolState:
- """Lightweight class to keep track of the tool call state."""
-
- def __init__(self):
- self.reset()
-
- def reset(self):
- """Reset the tool call state (assumes we're outside a tool call)."""
- self.inside_tool_call = False
- self.has_tool_name_defined = False
- self.arg_nesting_level = 0
- self.buffer = ""
-
-
-class TimedModel:
- """
- A class that holds a PreTrainedModel instance and its associated processor.
- Automatically deletes the instances after a specified timeout.
- """
-
- def __init__(
- self,
- model: "PreTrainedModel",
- timeout_seconds: int,
- processor: Union["ProcessorMixin", "PreTrainedTokenizerFast"] | None = None,
- ):
- self.model = model
- self._name_or_path = str(model.name_or_path)
- self.processor = processor
- self.timeout_seconds = timeout_seconds
- self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached)
- self._timer.start()
-
- def reset_timer(self):
- """Reset the timer for the deletion of the instances."""
- self._timer.cancel()
- self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached)
- self._timer.start()
-
- def delete_model(self):
- """Delete the wrapped model and processor and clean up resources."""
- if hasattr(self, "model") and self.model is not None:
- del self.model
- del self.processor
- self.model = None
- self.processor = None
- gc.collect()
-
- # Clear CUDA cache if available
- reset_torch_cache()
-
- # XXX: in case we manually delete the model, like on server shutdown
- self._timer.cancel()
-
- def timeout_reached(self):
- if self.timeout_seconds > 0:
- self.delete_model()
- logger.warning(
- f"{self._name_or_path} was removed from memory after {self.timeout_seconds} seconds of inactivity"
- )
-
- def is_deleted(self):
- """Check if the instances have been deleted."""
- return not hasattr(self, "model") or self.model is None
-
class Serve:
- # Defining a class to help with internal state but in practice it's just a method to call
- # TODO: refactor into a proper module with helpers + 1 main method
def __init__(
self,
+ force_model: Annotated[str | None, typer.Argument(help="Model to preload and use for all requests.")] = None,
+ # Model options
continuous_batching: Annotated[
- bool | None,
- typer.Option(help="Whether to use continuous batching for chat completions. Configure with --cb-* flags."),
- ] = None,
+ bool,
+ typer.Option(help="Enable continuous batching with paged attention. Configure with --cb-* flags."),
+ ] = False,
cb_block_size: Annotated[
- int | None,
- typer.Option(help="Size of each KV cache block in tokens for continuous batching. Default: 256."),
+ int | None, typer.Option(help="KV cache block size in tokens for continuous batching.")
] = None,
cb_num_blocks: Annotated[
- int | None,
- typer.Option(
- help="Number of blocks in KV cache for continuous batching. Default: auto-inferred from GPU memory."
- ),
+ int | None, typer.Option(help="Number of KV cache blocks for continuous batching.")
] = None,
cb_max_batch_tokens: Annotated[
- int | None,
- typer.Option(
- help="Maximum number of tokens in a batch for continuous batching. Default: auto-inferred from GPU memory."
- ),
+ int | None, typer.Option(help="Maximum tokens per batch for continuous batching.")
] = None,
cb_max_memory_percent: Annotated[
- float | None,
- typer.Option(
- help="Maximum percentage of free GPU memory to use for KV cache in continuous batching (0.0-1.0). Default: 0.8."
- ),
+ float | None, typer.Option(help="Max GPU memory fraction for KV cache (0.0-1.0).")
] = None,
cb_use_cuda_graph: Annotated[
- bool | None,
- typer.Option(
- help="Enable CUDA graphs for continuous batching performance. Default: auto-inferred based on attention implementation."
- ),
+ bool | None, typer.Option(help="Enable CUDA graphs for continuous batching.")
] = None,
- device: Annotated[
- str,
- typer.Option(
- help="Device to use for inference; will default to `auto` and place the model on an accelerator if available."
- ),
- ] = "auto",
- dtype: Annotated[
- str | None,
- typer.Option(
- help="Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype will be automatically derived from the model's weights."
- ),
- ] = "auto",
- trust_remote_code: Annotated[
- bool, typer.Option(help="Whether to trust remote code when loading a model.")
- ] = False,
attn_implementation: Annotated[
- str | None,
- typer.Option(help="Which attention implementation to use."),
+ str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).")
] = None,
+ compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False,
quantization: Annotated[
- str | None,
- typer.Option(help="Which quantization method to use. choices: 'bnb-4bit', 'bnb-8bit'"),
+ str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.")
] = None,
- host: Annotated[str, typer.Option(help="Interface the server will listen to.")] = "localhost",
- port: Annotated[int, typer.Option(help="Port the server will listen to.")] = 8000,
+ device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto",
+ dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto",
+ trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False,
model_timeout: Annotated[
- int, typer.Option(help="Time in seconds after which a model will be removed from memory.")
+ int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.")
] = 300,
- log_level: Annotated[
- str, typer.Option(help="Logging level as a string. Example: 'info' or 'warning'.")
- ] = "warning",
- default_seed: Annotated[
- int | None, typer.Option(help="The default seed for torch, should be an integer.")
- ] = None,
- enable_cors: Annotated[
- bool,
- typer.Option(
- help="Whether to enable CORS. Some apps that make requests from external domains (e.g. Cursor) require CORS to be enabled."
- ),
- ] = False,
- input_validation: Annotated[bool, typer.Option(help="Whether to turn on strict input validation.")] = False,
- force_model: Annotated[
- str | None,
- typer.Option(
- help="Name of the model to be forced on all requests. This is useful for testing Apps that don't allow changing models in the request."
- ),
- ] = None,
+ # Server options
+ host: Annotated[str, typer.Option(help="Server listen address.")] = "localhost",
+ port: Annotated[int, typer.Option(help="Server listen port.")] = 8000,
+ enable_cors: Annotated[bool, typer.Option(help="Enable permissive CORS.")] = False,
+ log_level: Annotated[str, typer.Option(help="Logging level (e.g. 'info', 'warning').")] = "warning",
+ default_seed: Annotated[int | None, typer.Option(help="Default torch seed.")] = None,
non_blocking: Annotated[
- bool, typer.Option(hidden=True, help="Whether to run the server in a separate thread.")
+ bool, typer.Option(hidden=True, help="Run server in a background thread. Used by tests.")
] = False,
) -> None:
- if not serve_dependencies_available:
- raise ImportError(
- "Missing dependencies for the serving CLI. Please install with `pip install transformers[serving]`"
- )
-
- # Save input arguments
- self.attn_implementation = attn_implementation
- self.continuous_batching = continuous_batching
- self.device = device
- self.quantization = quantization
+ if not is_serve_available():
+ raise ImportError("Missing dependencies for serving. Install with `pip install transformers[serving]`")
- self.dtype = dtype
- self.trust_remote_code = trust_remote_code
- self.host = host
- self.port = port
- self.model_timeout = model_timeout
- self.log_level = log_level
- self.default_seed = default_seed
- self.enable_cors = enable_cors
- self.input_validation = input_validation
- self.force_model = force_model
- self.non_blocking = non_blocking
+ import uvicorn
- # Continuous batching configuration arguments
- self.cb_block_size = cb_block_size
- self.cb_num_blocks = cb_num_blocks
- self.cb_max_batch_tokens = cb_max_batch_tokens
- self.cb_max_memory_percent = cb_max_memory_percent
- self.cb_use_cuda_graph = cb_use_cuda_graph
+ from .serving.chat_completion import ChatCompletionHandler
+ from .serving.model_manager import ModelManager
+ from .serving.response import ResponseHandler
+ from .serving.server import build_server
+ from .serving.transcription import TranscriptionHandler
+ from .serving.utils import GenerationState
# Seed
if default_seed is not None:
set_torch_seed(default_seed)
- # Set up logging
+ # Logging
transformers_logger = logging.get_logger("transformers")
transformers_logger.setLevel(logging.log_levels[log_level.lower()])
- cb_logger = logging.get_logger("transformers.generation.continuous_batching")
- cb_logger.setLevel(logging.log_levels[log_level.lower()])
-
- # Internal state:
- # 1. Tracks models in memory, to prevent reloading the model unnecessarily
- self.loaded_models: dict[str, TimedModel] = {}
- self.running_continuous_batching_manager: ContinuousBatchingManager | None = None
-
- # Tracks in-flight model loads for fan-out to multiple SSE subscribers
- self.loading_subscribers: dict[str, list[asyncio.Queue[str | None]]] = {}
- self.loading_tasks: dict[str, asyncio.Task] = {}
-
- # Thread-safety for load_model_and_processor / load_audio_model_and_processor
- self.model_locks: dict[str, threading.Lock] = {}
- self.model_locks_guard = threading.Lock()
- # Thread-safety for continuous batching manager init/teardown
- self._cb_manager_lock = threading.Lock()
-
- # 2. preserves information about the last call and last KV cache, to determine whether we can reuse the KV
- # cache and avoid re-running prefill
- self.last_messages = None
- self.last_kv_cache = None
- self.last_model = None
-
- if self.model_timeout is None:
- self.model_timeout = -1 if self.force_model else 300
-
- if self.force_model:
- model_id_and_revision = self.process_model_name(self.force_model)
- self.last_model = model_id_and_revision
- self.load_model_and_processor(model_id_and_revision)
-
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- yield
- self.reset_loaded_models()
-
- app = FastAPI(lifespan=lifespan)
-
- # Some apps that make requests from external domains (e.g. Cursor) require CORS to be enabled. However, for
- # security purposes, it's disabled by default
- if self.enable_cors:
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- logger.warning_once(
- "CORS allow origin is set to `*`. This is not recommended for production environments."
- )
-
- from fastapi import Request
-
- @app.post("/v1/chat/completions")
- def chat_completion(request: Request, body: dict):
- self.validate_chat_completion_request(request=body)
-
- logger.warning(f"[Request received] Model: {body['model']}, CB: {self.continuous_batching}")
-
- if self.continuous_batching:
- return self.continuous_batching_chat_completion(body, request.state.request_id)
- else:
- return self.generate_chat_completion(body)
-
- @app.post("/v1/responses")
- def responses(request: dict):
- self.validate_response_request(request=request)
-
- logger.warning(f"[Request received] Model: {request['model']}, CB: {self.continuous_batching}")
-
- # Support non-streaming mode when `stream=false` is provided
- stream = request.get("stream", True)
- if not stream:
- response_obj = self.generate_response_non_streaming(request)
- return JSONResponse(response_obj)
-
- output = self.generate_response(request)
- return StreamingResponse(output, media_type="text/event-stream")
-
- @app.post("/v1/audio/transcriptions")
- async def audio_transcriptions(request: Request):
- # Parses the multipart/form-data request into the request format used by other endpoints
- async with request.form() as form:
- parsed_request = TransformersTranscriptionCreateParams(
- file=await form["file"].read(),
- model=form["model"],
- # TODO: add other fields
- )
- logger.debug(
- f"Received file: {form['file'].filename}; MIME type: {form['file'].content_type}; "
- f"size: {form['file'].size / 1024:.2f} KiB"
- )
- self.validate_transcription_request(request=parsed_request)
-
- output = self.generate_transcription(parsed_request)
- return StreamingResponse(output, media_type="text/event-stream")
-
- @app.options("/v1/models")
- @app.get("/v1/models")
- def get_all_models():
- return JSONResponse({"object": "list", "data": self.get_gen_models()})
-
- @app.get("/health")
- def healthcheck():
- return JSONResponse({"status": "ok"})
-
- @app.post("/load_model")
- async def load_model(body: dict):
- model = body.get("model")
- if model is None:
- raise HTTPException(status_code=422, detail="Missing `model` field in the request body.")
-
- model_id_and_revision = self.process_model_name(model)
-
- async def event_publisher():
- queue: asyncio.Queue[str | None] = asyncio.Queue()
- model_loaded = model_id_and_revision in self.loaded_models
- model_not_deleted = model_loaded and not self.loaded_models[model_id_and_revision].is_deleted()
-
- # Case 1: Model already cached
- if model_loaded and model_not_deleted:
- self.loaded_models[model_id_and_revision].reset_timer()
- yield f"data: {json.dumps({'status': 'ready', 'model': model_id_and_revision, 'cached': True})}\n\n"
- return
-
- # Case 2: Load already happening, join existing subscribers
- if model_id_and_revision in self.loading_tasks:
- self.loading_subscribers[model_id_and_revision].append(queue)
-
- while True:
- item = await queue.get()
- if item is None:
- break
- yield item
- return
-
- # Case 3: First request — start the load task
- self.loading_subscribers[model_id_and_revision] = [queue]
- loop = asyncio.get_running_loop()
-
- def enqueue(payload: dict):
- msg = f"data: {json.dumps(payload)}\n\n"
-
- def broadcast():
- for q in self.loading_subscribers.get(model_id_and_revision, []):
- q.put_nowait(msg)
-
- loop.call_soon_threadsafe(broadcast)
-
- download_aggregator = DownloadAggregator(enqueue, model_id_and_revision)
-
- def streaming_tqdm_hook(factory, args, kwargs):
- bar = factory(*args, **kwargs)
- desc = kwargs.get("desc") or getattr(bar, "desc", None) or ""
- unit = kwargs.get("unit") or getattr(bar, "unit", "it")
- total = getattr(bar, "total", kwargs.get("total"))
-
- # Only forward byte-progress bars (file downloads) — skip noise like "Fetching N files"
- if unit == "B":
- download_aggregator.register(id(bar), total)
- return DownloadProxy(bar, download_aggregator=download_aggregator)
-
- # "Loading weights" bar — emit as stage: "weights" with item-count progress
- if desc == "Loading weights":
- return WeightsProxy(bar, enqueue, model_id_and_revision)
-
- # Other non-byte, non-weights bars (noise) — return unmodified
- return bar
-
- async def run_load():
- previous_hook = logging.set_tqdm_hook(streaming_tqdm_hook)
- try:
- await asyncio.to_thread(self.load_model_and_processor, model_id_and_revision, enqueue)
- except Exception as e:
- logger.error(f"Failed to load {model_id_and_revision}: {e}", exc_info=True)
- enqueue({"status": "error", "model": model_id_and_revision, "message": str(e)})
- finally:
- logging.set_tqdm_hook(previous_hook)
-
- def _send_sentinel():
- for q in self.loading_subscribers.pop(model_id_and_revision, []):
- q.put_nowait(None)
- self.loading_tasks.pop(model_id_and_revision, None)
-
- loop.call_soon_threadsafe(_send_sentinel)
+ self._model_manager = ModelManager(
+ device=device,
+ dtype=dtype,
+ trust_remote_code=trust_remote_code,
+ attn_implementation=attn_implementation,
+ quantization=quantization,
+ model_timeout=model_timeout,
+ force_model=force_model,
+ )
+ from transformers import ContinuousBatchingConfig
+
+ cb_kwargs = {
+ k: v
+ for k, v in {
+ "block_size": cb_block_size,
+ "num_blocks": cb_num_blocks,
+ "max_batch_tokens": cb_max_batch_tokens,
+ "max_memory_percent": cb_max_memory_percent,
+ "use_cuda_graph": cb_use_cuda_graph,
+ }.items()
+ if v is not None
+ }
+ cb_config = ContinuousBatchingConfig(**cb_kwargs) if cb_kwargs else None
+ self._generation_state = GenerationState(
+ continuous_batching=continuous_batching,
+ compile=compile,
+ cb_config=cb_config,
+ )
- self.loading_tasks[model_id_and_revision] = asyncio.create_task(run_load())
+ self._chat_handler = ChatCompletionHandler(
+ model_manager=self._model_manager,
+ generation_state=self._generation_state,
+ )
- while True:
- item = await queue.get()
- if item is None:
- break
- yield item
+ self._response_handler = ResponseHandler(
+ model_manager=self._model_manager,
+ generation_state=self._generation_state,
+ )
- return StreamingResponse(event_publisher(), media_type="text/event-stream")
+ self._transcription_handler = TranscriptionHandler(self._model_manager, self._generation_state)
- @app.middleware("http")
- async def get_or_set_request_id(request: Request, call_next):
- request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
- request.state.request_id = request_id
- response = await call_next(request)
- response.headers[X_REQUEST_ID] = request_id
- return response
+ app = build_server(
+ self._model_manager,
+ self._chat_handler,
+ response_handler=self._response_handler,
+ transcription_handler=self._transcription_handler,
+ enable_cors=enable_cors,
+ )
- config = uvicorn.Config(app, host=self.host, port=self.port, log_level="info")
+ config = uvicorn.Config(app, host=host, port=port, log_level="info")
self.server = uvicorn.Server(config)
- if self.non_blocking:
+ if non_blocking:
self.start_server()
else:
self.server.run()
def start_server(self):
def _run():
- self._loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self._loop)
- # serve() is a coroutine; it exits when server.should_exit becomes True
- self._loop.run_until_complete(self.server.serve())
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ loop.run_until_complete(self.server.serve())
self._thread = threading.Thread(target=_run, name="uvicorn-thread", daemon=False)
self._thread.start()
- def kill_server(self):
- if not self._thread:
- raise ValueError("The server cannot be killed as it was not launched in a separate thread.")
-
- if not self._thread.is_alive():
- raise ValueError("The server is already killed.")
-
- self.server.should_exit = True
- if self._thread and self._thread.is_alive():
- self._thread.join(timeout=2)
-
def reset_loaded_models(self):
- """
- Resets all loaded models.
- """
- if self.running_continuous_batching_manager is not None:
- logger.warning("Resetting the continuous batching manager.")
- self.running_continuous_batching_manager.stop(block=True, timeout=2)
- self.running_continuous_batching_manager = None
- for model in list(self.loaded_models.values()):
- model.delete_model()
- self.last_model = None
-
- def _validate_request(
- self,
- request: dict,
- schema: TypedDict,
- validator: "TypeAdapter",
- unused_fields: set,
- ):
- """
- Validates the request against the schema, and checks for unexpected keys.
-
- Args:
- request (`dict`):
- The request to validate.
- schema (`TypedDict`):
- The schema of the request to validate. It is a `TypedDict` definition.
- validator (`TypeAdapter`):
- The validator to use to validate the request. Built from `schema`.
- unused_fields (`set`):
- Fields accepted by `schema`, but not used in `transformers serve`.
-
- Raises:
- HTTPException: If the request is invalid or contains unexpected or unused fields.
- """
- logger.debug(f"Validating request: {request}")
-
- # Validate unexpected keys -- Pydantic doesn't validate extra keys in the request.
- input_keys = set(request.keys())
- possible_keys = schema.__mutable_keys__
- unexpected_keys = input_keys - possible_keys
- if unexpected_keys:
- logger.error(f"Unexpected keys in the request: {unexpected_keys}")
- raise HTTPException(status_code=422, detail=f"Unexpected keys in the request: {unexpected_keys}")
-
- if self.input_validation:
- # Validate expected keys
- try:
- validator.validate_python(request)
- except ValidationError as e:
- logger.error(f"Validation error: {e.errors()}")
- raise HTTPException(status_code=422, detail=e.errors())
-
- # Validate unused fields
- unused_fields_in_request = input_keys & unused_fields
- if unused_fields_in_request:
- logger.error(f"Unused fields in the request: {unused_fields_in_request}")
- raise HTTPException(
- status_code=422, detail=f"Unused fields in the request: {unused_fields_in_request}"
- )
-
- def validate_response_request(self, request: dict):
- self._validate_request(
- request=request,
- schema=TransformersResponseCreateParamsStreaming,
- validator=response_validator,
- unused_fields=UNUSED_RESPONSE_FIELDS,
- )
-
- def validate_chat_completion_request(self, request: dict):
- self._validate_request(
- request=request,
- schema=TransformersCompletionCreateParamsStreaming,
- validator=completion_validator,
- unused_fields=UNUSED_CHAT_COMPLETION_FIELDS,
- )
-
- def validate_transcription_request(self, request: dict):
- self._validate_request(
- request=request,
- schema=TransformersTranscriptionCreateParams,
- validator=transcription_validator,
- unused_fields=UNUSED_TRANSCRIPTION_FIELDS,
- )
-
- def build_chat_completion_chunk(
- self,
- request_id: str = "",
- content: int | None = None,
- model: str | None = None,
- role: str | None = None,
- finish_reason: str | None = None,
- tool_calls: list["ChoiceDeltaToolCall"] | None = None,
- decode_stream: DecodeStream | None = None,
- tokenizer: Optional["PreTrainedTokenizerFast"] = None,
- ) -> "ChatCompletionChunk":
- """
- Builds a chunk of a streaming OpenAI Chat Completion response.
-
- IMPORTANT: The serialized chunk won't contain empty fields (fields with `None`). Some downstream apps,
- like Cursor, assume that when the field exists, it has data.
-
- Args:
- request_id (`str`):
- The request ID.
- content (`str`, *optional*):
- Content of the response from the model.
- model (`str`, *optional*):
- The model that generated the content.
- role (`str`, *optional*):
- The role of the next content, until a new role is defined.
- finish_reason (`str`, *optional*):
- The reason the generation by the model has finished.
- tool_calls (`list[ChoiceDeltaToolCall]`, *optional*):
- Data about the tool calls, when they are triggered.
-
- Returns:
- `str`: The built chunk, a string containing a JSON string with the payload.
- """
- if decode_stream is not None and content is not None and tokenizer is not None:
- content = decode_stream.step(tokenizer._tokenizer, content)
-
- chunk = ChatCompletionChunk(
- id=request_id,
- created=int(time.time()),
- model=model,
- choices=[
- ChoiceChunk(
- delta=ChoiceDelta(
- content=content,
- role=role,
- tool_calls=tool_calls,
- ),
- index=0,
- finish_reason=finish_reason,
- )
- ],
- system_fingerprint="",
- object="chat.completion.chunk",
- )
-
- return chunk
-
- @staticmethod
- def chunk_to_sse_element(chunk: "ChatCompletionChunk | BaseModel") -> str:
- """
- Builds an event of a streaming OpenAI Response model or a ChatCompletion chunk.
-
- IMPORTANT: The serialized chunk won't contain empty fields (fields with `None`). Some downstream apps,
- like Cursor, assume that when the field exists, it has data.
-
- Args:
- chunk (`BaseModel` or `ChatCompletionChunk`):
- The response to build an event from. One of the multiple OpenAI Response output types
-
- Returns:
- `str`: The built chunk, a string containing a JSON string with the payload.
- """
- if isinstance(chunk, str):
- # Error paths may yield pre-formatted strings — pass them through as-is.
- return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n"
- return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n"
-
- @staticmethod
- @lru_cache
- def get_gen_models(cache_dir: str | None = None) -> list[dict[str, any]]:
- """
- List LLMs and VLMs in the cache.
- """
- from transformers.models.auto.modeling_auto import (
- MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
- MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
- )
-
- generative_models = []
-
- logger.warning("Scanning the cache directory for LLMs and VLMs.")
- for repo in tqdm(scan_cache_dir(cache_dir).repos):
- if repo.repo_type != "model":
- continue
-
- refs = repo.refs
- for ref, revision_info in refs.items():
- files = revision_info.files
- config_path = next((f.file_path for f in files if f.file_name == "config.json"), None)
-
- if not config_path:
- continue
-
- config = json.loads(config_path.open().read())
-
- if not (isinstance(config, dict) and "architectures" in config):
- continue
-
- architectures = config["architectures"]
- llms = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()
- vlms = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
-
- if any(arch for arch in architectures if arch in [*llms, *vlms]):
- author = repo.repo_id.split("/") if "/" in repo.repo_id else ""
- repo_handle = repo.repo_id + (f"@{ref}" if ref != "main" else "")
- generative_models.append(
- {
- "owned_by": author,
- "id": repo_handle,
- "object": "model",
- "created": repo.last_modified,
- }
- )
-
- return generative_models
-
- def continuous_batching_chat_completion(self, req: dict, request_id: str) -> "StreamingResponse | JSONResponse":
- """
- Generates an OpenAI Chat Completion using continuous batching.
-
- Args:
- req (`dict`): The request to generate an OpenAI Chat Completion for.
-
- Returns:
- `Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
- """
-
- model_id_and_revision = self.process_model_name(req["model"])
+ """Clear all loaded models from memory."""
+ self._model_manager.shutdown()
- with self._cb_manager_lock:
- must_discard_cache = model_id_and_revision != self.last_model
- self.last_model = model_id_and_revision
-
- # When switching models, terminate a continuous batching manager if it is running.
- if must_discard_cache:
- if self.running_continuous_batching_manager is not None:
- self.running_continuous_batching_manager.stop(block=True, timeout=2)
- self.running_continuous_batching_manager = None
-
- model, processor = self.load_model_and_processor(model_id_and_revision)
-
- # Continuous batching only supports text-only models
- if self.get_model_modality(model, processor=processor) != Modality.LLM:
- logger.warning_once(
- "Continuous batching is not supported for non-text-only models. Falling back to regular generate."
- )
- return self.generate_chat_completion(req)
-
- tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
-
- generation_config = create_generation_config_from_req(
- req,
- model_generation_config=model.generation_config,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id,
- use_cache=False,
- do_sample=False,
- scheduler="fifo",
- )
-
- if self.running_continuous_batching_manager is None:
- from transformers import ContinuousBatchingConfig
-
- # Build continuous batching config from CLI arguments
- cb_config_kwargs = {}
- if self.cb_block_size is not None:
- cb_config_kwargs["block_size"] = self.cb_block_size
- if self.cb_num_blocks is not None:
- cb_config_kwargs["num_blocks"] = self.cb_num_blocks
- if self.cb_max_batch_tokens is not None:
- cb_config_kwargs["max_batch_tokens"] = self.cb_max_batch_tokens
- if self.cb_max_memory_percent is not None:
- cb_config_kwargs["max_memory_percent"] = self.cb_max_memory_percent
- if self.cb_use_cuda_graph is not None:
- cb_config_kwargs["use_cuda_graph"] = self.cb_use_cuda_graph
-
- cb_config = ContinuousBatchingConfig(**cb_config_kwargs) if cb_config_kwargs else None
-
- self.running_continuous_batching_manager = model.init_continuous_batching(
- generation_config=generation_config,
- continuous_batching_config=cb_config,
- )
-
- # TODO (Joao, Lysandre): the logits processors should be fixed in continuous batching and correctly applied in non-cb
- self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
- self.running_continuous_batching_manager.start()
-
- # TODO (Joao, Lysandre): this should also work with tool support
- modality = self.get_model_modality(model, processor=processor)
- processor_inputs = self.get_processor_inputs_from_inbound_messages(req["messages"], modality)
-
- inputs = processor.apply_chat_template(
- processor_inputs,
- add_generation_prompt=True,
- tools=req.get("tools"),
- return_tensors="pt",
- return_dict=True,
- tokenize=True,
- )
- inputs = inputs["input_ids"][0].to(model.device)
-
- def stream_chat_completion(request_id, decode_stream):
- from ..generation.continuous_batching import RequestStatus
-
- try:
- # Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit
- # they come from the assistant.
- yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)
-
- n_tokens_generated = 0
- for result in self.running_continuous_batching_manager.request_id_iter(request_id):
- n_tokens_generated += 1
-
- # Always yield the token content (even for the final FINISHED token)
- if result.generated_tokens:
- token_id = result.generated_tokens[-1]
- yield self.build_chat_completion_chunk(
- request_id=request_id,
- content=token_id,
- model=model_id_and_revision,
- decode_stream=decode_stream,
- tokenizer=tokenizer,
- )
-
- if result.status == RequestStatus.FINISHED:
- generated_all_tokens = (
- generation_config.max_new_tokens is not None
- and n_tokens_generated >= generation_config.max_new_tokens
- )
-
- # If the tokenizer has an eos_token, we can have a more robust check.
- if hasattr(tokenizer, "eos_token"):
- final_token_is_eos = result == tokenizer.eos_token
- generated_all_tokens = generated_all_tokens and not final_token_is_eos
-
- reason = "length" if generated_all_tokens else "stop"
-
- yield self.build_chat_completion_chunk(
- request_id,
- finish_reason=reason,
- model=model_id_and_revision,
- )
- break
-
- except Exception as e:
- logger.error(str(e))
- self.running_continuous_batching_manager.cancel_request(request_id)
- yield f'data: {{"error": "{str(e)}"}}'
-
- def buffer_chat_completion(_request_id):
- result = None
- while self.running_continuous_batching_manager.is_running() and result is None:
- result = self.running_continuous_batching_manager.get_result(request_id=_request_id, timeout=1)
-
- if result is None:
- raise RuntimeError(f"Request {_request_id} failed: generation loop stopped before producing a result.")
-
- content = tokenizer.decode(result.generated_tokens)
-
- chat_completion_result = ChatCompletion(
- id=_request_id,
- created=int(time.time()),
- object="chat.completion",
- model=model_id_and_revision,
- choices=[
- Choice(
- # TODO check the index
- index=0,
- message=ChatCompletionMessage(content=content, role="assistant"),
- finish_reason="stop",
- )
- ],
- # TODO implement function calling
- # TODO implement usage
- )
-
- return chat_completion_result
-
- async def cancellation_wrapper_stream(_request_id):
- # Enables cancellation in an async context
- try:
- decode_stream = DecodeStream(inputs.tolist(), False)
- for _chunk in stream_chat_completion(_request_id, decode_stream):
- yield self.chunk_to_sse_element(_chunk)
- await asyncio.sleep(0)
- except asyncio.CancelledError:
- self.running_continuous_batching_manager.cancel_request(_request_id)
- logger.warning(f"Request {_request_id} was cancelled.")
-
- def cancellation_wrapper_buffer(_request_id):
- # Enables cancellation in an async context
- try:
- return buffer_chat_completion(_request_id)
- except asyncio.CancelledError:
- self.running_continuous_batching_manager.cancel_request(_request_id)
- logger.warning(f"Request {_request_id} was cancelled.")
-
- request_id = self.running_continuous_batching_manager.add_request(
- inputs, request_id=request_id, max_new_tokens=generation_config.max_new_tokens, streaming=req.get("stream")
- )
-
- if req.get("stream"):
- return StreamingResponse(cancellation_wrapper_stream(request_id), media_type="text/event-stream")
- else:
- chunk = cancellation_wrapper_buffer(request_id)
- json_chunk = chunk.model_dump(exclude_none=True)
- return JSONResponse(json_chunk, media_type="application/json")
-
- @staticmethod
- def get_model_modality(model: "PreTrainedModel", processor=None) -> Modality:
- if processor is not None:
- if isinstance(processor, PreTrainedTokenizerBase):
- return Modality.LLM
-
- from transformers.models.auto.modeling_auto import (
- MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
- MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
- )
-
- model_classname = model.__class__.__name__
- if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
- modality = Modality.VLM
- elif model_classname in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
- modality = Modality.LLM
- else:
- raise ValueError(f"Unknown modality: {model_classname}")
-
- return modality
-
- @staticmethod
- def get_processor_inputs_from_inbound_messages(messages, modality: Modality):
- processor_inputs = []
-
- for message in messages:
- parsed_message = {"role": message["role"], "content": []}
-
- if modality == Modality.LLM:
- # Input: `content` is a string or a list of dictionaries with a "text" key.
- # Output: `content` is a string.
- if isinstance(message["content"], str):
- parsed_content = message["content"]
- elif isinstance(message["content"], list):
- parsed_content = []
- for content in message["content"]:
- if content["type"] == "text":
- parsed_content.append(content["text"])
- parsed_content = " ".join(parsed_content)
- parsed_message["content"] = parsed_content
-
- elif modality == Modality.VLM:
- # Input: `content` is a string or a list of dictionaries with a "type" key (possible types: "text",
- # "image_url").
- # Output: `content` is a list of dictionaries with a "type" key
- if isinstance(message["content"], str):
- parsed_message["content"].append({"type": "text", "text": message["content"]})
- else:
- for content in message["content"]:
- if content["type"] == "text":
- parsed_message["content"].append(content)
- elif content["type"] == "image_url":
- if "base64" in content["image_url"]["url"]:
- image_data = re.sub("^data:image/.+;base64,", "", content["image_url"]["url"])
- image = Image.open(BytesIO(base64.b64decode(image_data)))
-
- file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
- url = file.name
-
- image.save(file.name)
- else:
- url = content["image_url"]["url"]
-
- parsed_message["content"].append({"type": "image", "url": url})
- processor_inputs.append(parsed_message)
- return processor_inputs
-
- def generate_chat_completion(self, req: dict) -> "StreamingResponse | JSONResponse":
- """
- Generates an OpenAI Chat Completion using `generate`.
-
- Args:
- req (`dict`): The request to generate an OpenAI Chat Completion for.
-
- Returns:
- `Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
- """
-
- # TODO: This should throw an error in case the specified model in the request is different to the forced model.
- if self.force_model is not None:
- req["model"] = self.force_model
-
- messages: Iterable[ChatCompletionMessageParam] = req["messages"]
-
- # HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
- # request whose last message is from the assistant.
- if messages[-1]["role"] == "assistant":
+ def kill_server(self):
+ self._generation_state.shutdown()
+ self._model_manager.shutdown()
+ if not self._thread or not self._thread.is_alive():
return
-
- model_id_and_revision = self.process_model_name(req["model"])
- must_discard_cache = model_id_and_revision != self.last_model
-
- self.last_model = model_id_and_revision
- model, processor = self.load_model_and_processor(model_id_and_revision)
-
- modality = self.get_model_modality(model, processor=processor)
- processor_inputs = self.get_processor_inputs_from_inbound_messages(messages, modality)
-
- # ====== TOOL PREPROCESSING LOGIC ======
- tool_model_family = None
- for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
- if supported_model_families in model.config.architectures[0].lower():
- tool_model_family = supported_model_families
- break
- # TODO: trigger 2 constrained generations after the tool call start token is emitted:
- # 1. force generation to pick from the tool names
- # 2. force generation to pick from that tool's arguments
- # ====== END OF TOOL PREPROCESSING LOGIC ======
-
- inputs = processor.apply_chat_template(
- processor_inputs,
- add_generation_prompt=True,
- tools=req.get("tools"),
- return_tensors="pt",
- return_dict=True,
- tokenize=True,
- )
- inputs = inputs.to(model.device)
- request_id = req.get("request_id", "req_0")
-
- # Temporary hack for GPTOSS 1: don't filter special tokens
- skip_special_tokens = True
- if "gptoss" in model.config.architectures[0].lower():
- skip_special_tokens = False
-
- generation_streamer = TextIteratorStreamer(
- processor,
- skip_special_tokens=skip_special_tokens,
- skip_prompt=True,
- )
-
- if self.is_continuation(req) and not must_discard_cache:
- seq_len = self.last_kv_cache.get_seq_length()
- if inputs["input_ids"].shape[-1] > seq_len:
- last_kv_cache = self.last_kv_cache
- else:
- last_kv_cache = None
- else:
- seq_len = inputs["input_ids"].shape[-1]
- last_kv_cache = None
-
- generation_config = create_generation_config_from_req(
- req,
- model_generation_config=model.generation_config,
- )
-
- generation_kwargs = {
- **inputs,
- "streamer": generation_streamer,
- "generation_config": generation_config,
- "return_dict_in_generate": True,
- "past_key_values": last_kv_cache,
- }
-
- def stream_chat_completion(streamer, _request_id):
- # Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output
- # classes and piping the reasoning trace into a new field
- filter_cot = False
- cot_trace_end = None
- if "gptoss" in model.config.architectures[0].lower():
- filter_cot = True
- cot_trace_end = "<|channel|>final<|message|>"
-
- # Thin wrapper to save the KV cache after generation
- def generate_with_cache(**kwargs):
- generate_output = model.generate(**kwargs)
- self.last_kv_cache = generate_output.past_key_values
-
- thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
- results = ""
-
- try:
- thread.start()
- tool_state = ToolState()
-
- # Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit
- # they come from the assistant.
- yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)
-
- result = ""
- n_tokens_generated = 0
-
- for result in streamer:
- n_tokens_generated += 1
-
- # Temporary hack for GPT-OSS 3: don't emit the final "<|return|>"
- if "gptoss" in model.config.architectures[0].lower():
- result = result.removesuffix("<|return|>")
- results += result
-
- # (related to temporary hack 2)
- if filter_cot:
- if cot_trace_end in results: # end of reasoning trace observed -> stop filtering
- filter_cot = False
- continue
- else:
- continue
-
- # ====== TOOL CALL LOGIC ======
- if tool_model_family is not None:
- # Start of a tool call: reset state variables, set `inside_tool_call`
- if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]:
- tool_state.inside_tool_call = True
- continue
-
- # End of tool call: reset `inside_tool_call`, emit a `finish_reason`
- if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
- tool_state.reset()
- yield self.build_chat_completion_chunk(
- request_id=_request_id,
- role=None,
- finish_reason="tool_calls",
- model=model_id_and_revision,
- )
-
- continue
- # Inside a tool call
- if tool_state.inside_tool_call:
- tool_state.buffer += result
-
- # First step: extract the tool name (may need several tokens, and we can't emit a delta
- # until we have the full name)
- if not tool_state.has_tool_name_defined:
- tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer)
- if tool_name is None:
- continue
- else:
- tool_name = tool_name.group(1)
- tool_state.has_tool_name_defined = True
- tool = ChoiceDeltaToolCall(
- function=ChoiceDeltaToolCallFunction(name=tool_name),
- index=0,
- type="function",
- id=_request_id + "_tool_call", # Only the first tool call delta has an id
- )
-
- # Second step: extract tool arguments. The tool arguments can be seen as a json string
- # within the tool json string. We emit a delta for the arguments.
- else:
- # Empty text: skip
- if result == "":
- continue
- # Until we see the `"arguments": {` in the buffer, we skip
- # TODO: other models will likely need more elaborate processing here
- if '"arguments": {' not in tool_state.buffer:
- continue
-
- # Handle nesting. We want to exclude the last } from the emitted arguments (it's
- # closing the outermost nesting level, outside the arguments block)
- tool_state.arg_nesting_level += result.count("{")
- tool_state.arg_nesting_level -= result.count("}")
- if tool_state.arg_nesting_level < 0:
- result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}"
-
- tool = ChoiceDeltaToolCall(
- function=ChoiceDeltaToolCallFunction(arguments=result),
- index=0,
- type="function",
- )
-
- yield self.build_chat_completion_chunk(
- request_id=_request_id,
- role=None,
- tool_calls=[tool],
- model=model_id_and_revision,
- )
- continue
- # ====== END OF TOOL CALL LOGIC ======
-
- # All non-tool related tokens are emitted as assistant messages. Empty text is skipped.
- if result != "":
- yield self.build_chat_completion_chunk(
- _request_id, content=result, model=model_id_and_revision
- )
-
- generated_all_tokens = (
- generation_config.max_new_tokens is not None
- and n_tokens_generated >= generation_config.max_new_tokens
- )
-
- # If the tokenizer has an eos_token, we can have a more robust check.
- if hasattr(streamer.tokenizer, "eos_token"):
- final_token_is_eos = result == streamer.tokenizer.eos_token
- generated_all_tokens = generated_all_tokens and not final_token_is_eos
-
- reason = "length" if generated_all_tokens else "stop"
-
- yield self.build_chat_completion_chunk(_request_id, finish_reason=reason, model=model_id_and_revision)
-
- thread.join()
- except Exception as e:
- logger.error(str(e))
- yield f'data: {{"error": "{str(e)}"}}'
-
- finally:
- thread.join()
-
- if req.get("stream"):
- return StreamingResponse(
- map(self.chunk_to_sse_element, stream_chat_completion(generation_streamer, request_id)),
- media_type="text/event-stream",
- )
- else:
- content = []
- finish_reason = "stop"
-
- generator = stream_chat_completion(generation_streamer, request_id)
- usage = None
-
- for chunk in generator:
- choice = chunk.choices[0]
- if getattr(choice.delta, "content", None):
- content.append(choice.delta.content)
- if choice.finish_reason:
- finish_reason = choice.finish_reason
- if getattr(chunk, "usage", None):
- usage = chunk.usage
-
- chat_completion_result = ChatCompletion(
- id=request_id,
- created=int(time.time()),
- object="chat.completion",
- model=model_id_and_revision,
- choices=[
- Choice(
- # TODO check the index
- index=0,
- message=ChatCompletionMessage(content="".join(content), role="assistant"),
- finish_reason=finish_reason,
- )
- ],
- # TODO implement function calling
- usage=usage,
- )
-
- result = chat_completion_result.model_dump(exclude_none=True)
-
- return JSONResponse(result, media_type="application/json")
-
- def generate_response(self, req: dict) -> Generator[str, None, None]:
- """
- Generates an OpenAI Response using `generate`.
-
- Args:
- req (`dict`): The request to generate an OpenAI Response for.
-
- Returns:
- `Generator[str, None, None]`: A generator that yields the OpenAI Response events.
- """
- # TODO -- Implement non-streaming mode
- model_id_and_revision = self.process_model_name(req["model"])
- must_discard_cache = model_id_and_revision != self.last_model
- self.last_model = model_id_and_revision
- model, processor = self.load_model_and_processor(model_id_and_revision)
-
- if isinstance(req["input"], str):
- inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
- inputs.append({"role": "user", "content": req["input"]})
- elif isinstance(req["input"], list):
- if "instructions" in req:
- if req["input"][0]["role"] != "system":
- inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]]
- else:
- inputs = req["input"]
- inputs[0]["content"] = req["instructions"]
- else:
- inputs = req["input"]
- elif isinstance(req["input"], dict):
- inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
- inputs.append(req["input"])
- else:
- raise TypeError("inputs should be a list, dict, or str")
-
- inputs = processor.apply_chat_template(inputs, tokenize=True, add_generation_prompt=True, return_tensors="pt")
- inputs = inputs.to(model.device)
- request_id = req.get("previous_response_id", "req_0")
-
- # Temporary hack for GPT-OSS 1: don't filter special tokens
- skip_special_tokens = True
- if "gptoss" in model.config.architectures[0].lower():
- skip_special_tokens = False
-
- generation_streamer = TextIteratorStreamer(
- processor,
- skip_special_tokens=skip_special_tokens,
- skip_prompt=True,
- )
- generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
-
- last_kv_cache = None
- if self.is_continuation(req) and not must_discard_cache:
- seq_len = self.last_kv_cache.get_seq_length()
- if inputs.shape[-1] > seq_len:
- last_kv_cache = self.last_kv_cache
-
- generation_kwargs = {
- "inputs": inputs,
- "attention_mask": torch_ones_like(inputs),
- "streamer": generation_streamer,
- "generation_config": generation_config,
- "return_dict_in_generate": True,
- "past_key_values": last_kv_cache,
- }
-
- def stream_response(streamer, _request_id):
- # Temporary hack for GPT-OSS 2: filter out the CoT tokens. Full solution here implies defining new output
- # classes and piping the reasoning trace into a new field
- filter_cot = False
- cot_trace_end = None
- if "gptoss" in model.config.architectures[0].lower():
- filter_cot = True
- cot_trace_end = "<|channel|>final<|message|>"
-
- # Thin wrapper to save the KV cache after generation
- def generate_with_cache(**kwargs):
- generate_output = model.generate(**kwargs)
- self.last_kv_cache = generate_output.past_key_values
-
- thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
- sequence_number = 0
- output_index = 0
- content_index = 0
-
- try:
- thread.start()
- created_at = time.time() # the spec expects a unix timestamp in seconds
-
- # We start by acknowledging the request (the request has `status="queued"`), and then by moving it to
- # in progress (`status="in_progress"`)
- response_created = ResponseCreatedEvent(
- type="response.created",
- sequence_number=sequence_number,
- response=Response(
- id=f"resp_{request_id}",
- created_at=created_at,
- status="queued",
- model=model_id_and_revision,
- instructions=req.get("instructions"),
- text={"format": {"type": "text"}},
- object="response",
- tools=[],
- output=[],
- parallel_tool_calls=req.get("parallel_tool_calls", False),
- tool_choice="auto",
- metadata=req.get("metadata"),
- ),
- )
- sequence_number += 1
- yield self.chunk_to_sse_element(response_created)
-
- response_in_progress = ResponseInProgressEvent(
- type="response.in_progress",
- sequence_number=sequence_number,
- response=Response(
- id=f"resp_{request_id}",
- created_at=created_at,
- status="in_progress",
- model=model_id_and_revision,
- instructions=req.get("instructions"),
- text={"format": {"type": "text"}},
- object="response",
- tools=[],
- output=[],
- parallel_tool_calls=req.get("parallel_tool_calls", False),
- tool_choice="auto",
- metadata=req.get("metadata"),
- ),
- )
- sequence_number += 1
- yield self.chunk_to_sse_element(response_in_progress)
-
- # Start the output item. Emit the assistant role to start the stream. Other chunks won't have a role,
- # as it is implicit
- response_output_item_added = ResponseOutputItemAddedEvent(
- type="response.output_item.added",
- sequence_number=sequence_number,
- output_index=output_index,
- item=ResponseOutputMessage(
- id=f"msg_{request_id}", type="message", status="in_progress", role="assistant", content=[]
- ),
- )
- sequence_number += 1
- yield self.chunk_to_sse_element(response_output_item_added)
-
- # Start the content part of the event
- response_content_part_added = ResponseContentPartAddedEvent(
- type="response.content_part.added",
- item_id=f"msg_{request_id}",
- sequence_number=sequence_number,
- output_index=output_index,
- content_index=content_index,
- part=ResponseOutputText(type="output_text", text="", annotations=[]),
- )
- sequence_number += 1
- yield self.chunk_to_sse_element(response_content_part_added)
-
- # Stream the actual generated text
- results = ""
- for result in streamer:
- # Temporary hack for GPTOS 3: don't emit the final "<|return|>"
- if "gptoss" in model.config.architectures[0].lower():
- result = result.removesuffix("<|return|>")
- results += result
-
- # (related to temporary hack 2)
- if filter_cot:
- if cot_trace_end in results: # end of reasoning trace observed -> stop filtering
- filter_cot = False
- results = "" # reset the results -> results will now track the final response
- continue
- else:
- response_output_text_delta = ResponseTextDeltaEvent(
- type="response.output_text.delta",
- item_id=f"msg_{request_id}",
- sequence_number=sequence_number,
- output_index=output_index,
- content_index=content_index,
- delta=result,
- logprobs=[],
- )
- sequence_number += 1
- yield self.chunk_to_sse_element(response_output_text_delta)
- else:
- # Normal path: emit token deltas when not filtering CoT
- if result:
- response_output_text_delta = ResponseTextDeltaEvent(
- type="response.output_text.delta",
- item_id=f"msg_{request_id}",
- sequence_number=sequence_number,
- output_index=output_index,
- content_index=content_index,
- delta=result,
- logprobs=[],
- )
- sequence_number += 1
- yield self.chunk_to_sse_element(response_output_text_delta)
-
- # Signal the end of the text generation
- response_output_text_done = ResponseTextDoneEvent(
- type="response.output_text.done",
- item_id=f"msg_{request_id}",
- sequence_number=sequence_number,
- output_index=output_index,
- content_index=0,
- text=results,
- logprobs=[],
- )
- sequence_number += 1
- yield self.chunk_to_sse_element(response_output_text_done)
-
- # Complete the content part
- response_content_part_done = ResponseContentPartDoneEvent(
- type="response.content_part.done",
- item_id=f"msg_{request_id}",
- sequence_number=sequence_number,
- output_index=output_index,
- content_index=content_index,
- part=ResponseOutputText(type="output_text", text=response_output_text_done.text, annotations=[]),
- )
- sequence_number += 1
- content_index += 1
- yield self.chunk_to_sse_element(response_content_part_done)
-
- # Complete the output item
- response_output_item_done = ResponseOutputItemDoneEvent(
- type="response.output_item.done",
- sequence_number=sequence_number,
- output_index=output_index,
- item=ResponseOutputMessage(
- id=f"msg_{request_id}",
- type="message",
- status="completed",
- role="assistant",
- content=[response_content_part_done.part],
- annotations=[],
- ),
- )
- sequence_number += 1
- output_index += 1
- yield self.chunk_to_sse_element(response_output_item_done)
-
- # Finally, Complete the event
- response_completed = ResponseCompletedEvent(
- type="response.completed",
- sequence_number=sequence_number,
- response=Response(
- id=f"resp_{request_id}",
- created_at=created_at,
- status="completed",
- model=model_id_and_revision,
- instructions=req.get("instructions"),
- text={"format": {"type": "text"}},
- output=[response_output_item_done.item],
- object="response",
- tools=[],
- parallel_tool_calls=req.get("parallel_tool_calls", False),
- tool_choice="auto",
- metadata=req.get("metadata"),
- ),
- )
- sequence_number += 1
- yield self.chunk_to_sse_element(response_completed)
-
- thread.join()
- except Exception as e:
- logger.error(f"Exception in response generation: {str(e)}")
- error_event = ResponseErrorEvent(
- type="error",
- sequence_number=sequence_number,
- message=str(e),
- )
- sequence_number += 1
- yield self.chunk_to_sse_element(error_event)
-
- response_failed = ResponseFailedEvent(
- type="response.failed",
- sequence_number=sequence_number,
- response=Response(
- id=f"resp_{request_id}",
- created_at=created_at,
- status="failed",
- model=model_id_and_revision,
- instructions=req.get("instructions"),
- text={"format": {"type": "text"}},
- output=[],
- object="response",
- tools=[],
- parallel_tool_calls=False,
- tool_choice="auto",
- metadata=req.get("metadata"),
- error=ResponseError(
- code="server_error",
- message=str(e),
- ),
- ),
- )
- sequence_number += 1
- yield self.chunk_to_sse_element(response_failed)
-
- finally:
- thread.join()
-
- return stream_response(generation_streamer, request_id)
-
- def generate_response_non_streaming(self, req: dict) -> dict:
- """
- Generates an OpenAI Response in non-streaming mode (single JSON payload).
-
- Args:
- req (`dict`): The request to generate an OpenAI Response for.
-
- Returns:
- `dict`: The OpenAI `Response` serialized as a dict.
- """
- model_id_and_revision = self.process_model_name(req["model"])
- must_discard_cache = model_id_and_revision != self.last_model
- self.last_model = model_id_and_revision
- model, processor = self.load_model_and_processor(model_id_and_revision)
-
- if isinstance(req["input"], str):
- inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
- inputs.append({"role": "user", "content": req["input"]})
- elif isinstance(req["input"], list):
- if "instructions" in req:
- if req["input"][0]["role"] != "system":
- inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]]
- else:
- inputs = req["input"]
- inputs[0]["content"] = req["instructions"]
- else:
- inputs = req["input"]
- elif isinstance(req["input"], dict):
- inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
- inputs.append(req["input"])
- else:
- raise ValueError("inputs should be a list, dict, or str")
-
- inputs = processor.apply_chat_template(
- inputs, add_generation_prompt=True, return_tensors="pt", return_dict=True
- )["input_ids"]
- inputs = inputs.to(model.device)
- request_id = req.get("previous_response_id", "req_0")
-
- # Temporary hack for GPTOSS 1: don't filter special tokens
- skip_special_tokens = True
- if "gptoss" in model.config.architectures[0].lower():
- skip_special_tokens = False
-
- generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
-
- last_kv_cache = None
- if self.is_continuation(req) and not must_discard_cache:
- seq_len = self.last_kv_cache.get_seq_length()
- if inputs.shape[-1] > seq_len:
- last_kv_cache = self.last_kv_cache
-
- generate_output = model.generate(
- inputs=inputs,
- attention_mask=torch_ones_like(inputs),
- generation_config=generation_config,
- return_dict_in_generate=True,
- past_key_values=last_kv_cache,
- )
- # save KV cache
- self.last_kv_cache = generate_output.past_key_values
-
- # Decode full text
- full_text = processor.batch_decode(generate_output.sequences, skip_special_tokens=skip_special_tokens)[0]
-
- created_at = time.time()
- response_output_item = ResponseOutputMessage(
- id=f"msg_{request_id}",
- type="message",
- status="completed",
- role="assistant",
- content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])],
- annotations=[],
- )
- response_completed = Response(
- id=f"resp_{request_id}",
- created_at=created_at,
- status="completed",
- model=model_id_and_revision,
- instructions=req.get("instructions"),
- text={"format": {"type": "text"}},
- output=[response_output_item],
- object="response",
- tools=[],
- parallel_tool_calls=req.get("parallel_tool_calls", False),
- tool_choice="auto",
- metadata=req.get("metadata"),
- )
- return response_completed.model_dump(exclude_none=True)
-
- def generate_transcription(self, req: dict) -> Generator[str, None, None]:
- """
- Generates an OpenAI Transcription using the audio file.
-
- Args:
- req (`dict`): The request containing the audio file and model information.
-
- Returns:
- `Generator[str, None, None]`: A generator that yields the transcription result.
- """
- # TODO: implement streaming transcription (currently, it's not streaming)
- if not is_librosa_available():
- raise ImportError(
- "Missing librosa dependency for audio transcription. Please install with `pip install librosa`"
- )
- model_id_and_revision = self.process_model_name(req["model"])
- audio_model, audio_processor = self.load_audio_model_and_processor(model_id_and_revision)
-
- generation_streamer = TextIteratorStreamer(
- audio_processor.tokenizer, skip_special_tokens=True, skip_prompt=True
- )
- generation_config = create_generation_config_from_req(
- req, model_generation_config=audio_model.generation_config
- )
-
- # Read the binary audio file using librosa
- model_sampling_rate = audio_processor.feature_extractor.sampling_rate
- audio_bytes = io.BytesIO(req["file"])
- audio_array, _ = librosa.load(audio_bytes, sr=model_sampling_rate, mono=True)
- audio_inputs = audio_processor(audio_array, sampling_rate=model_sampling_rate, return_tensors="pt").to(
- audio_model.device
- )
- audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype)
-
- generation_kwargs = {
- "streamer": generation_streamer,
- "generation_config": generation_config,
- "return_dict_in_generate": True,
- }
-
- def _generate_transcription():
- generated_ids = audio_model.generate(**audio_inputs, **generation_kwargs)
- transcription_text = audio_processor.batch_decode(generated_ids.sequences, skip_special_tokens=True)[0]
- transcription = Transcription(text=transcription_text)
- yield f"{transcription.model_dump_json(exclude_none=True)}"
-
- return _generate_transcription()
-
- def is_continuation(self, req: dict) -> bool:
- """
- Determines whether the current request is a continuation of the last request. In other words, if it is the
- same chat session.
-
- Args:
- req (`dict`): The request to check.
-
- Returns:
- `True` if the request is a continuation of the last request, `False` otherwise.
- """
- messages = req.get("messages") or req.get("input") # ChatCompletion and Response have different fields
- req_continues_last_messages = True
-
- # No cached messages: this is a new request
- if self.last_messages is None:
- req_continues_last_messages = False
- # The new request has no new rounds of conversation: this is a new request
- elif len(self.last_messages) >= len(messages):
- req_continues_last_messages = False
- # Otherwise, check that the last messages are a subset of the new request
- else:
- for i in range(len(self.last_messages)):
- if self.last_messages[i] != messages[i]:
- req_continues_last_messages = False
- break
-
- self.last_messages = messages
- return req_continues_last_messages
-
- def get_quantization_config(self) -> BitsAndBytesConfig | None:
- """
- Returns the quantization config for the given CLI arguments.
-
- Returns:
- `Optional[BitsAndBytesConfig]`: The quantization config.
- """
- if self.quantization == "bnb-4bit":
- quantization_config = BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_quant_type="nf4",
- bnb_4bit_use_double_quant=True,
- )
- elif self.quantization == "bnb-8bit":
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
- else:
- quantization_config = None
-
- if quantization_config is not None:
- logger.warning(f"Quantization applied with the following config: {quantization_config}")
-
- return quantization_config
-
- def process_model_name(self, model_id: str) -> str:
- """
- Applies the `force_model` CLI argument and canonicalizes the model name to the format "model_id@revision".
- If the model_id DOESN'T contain an @, it defaults to "model_id@main".
-
- Args:
- model_id (`str`): The model ID.
-
- Returns:
- `str`: The canonicalized model name to be used
- """
- if self.force_model is not None:
- model_id = self.force_model
- if "@" in model_id:
- return model_id
- return f"{model_id}@main"
-
- def _load_model_and_data_processor(
- self, model_id_and_revision: str, progress_callback: Callable[[dict], None] | None = None
- ):
- """
- Generic method to load a model and a data processor from a model ID and revision, making use of the serve CLI
- arguments.
-
- Args:
- model_id_and_revision (`str`):
- The model ID and revision to load.
- model_cls (`type[PreTrainedModel]`):
- The model class to load.
-
- Returns:
- `tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]`: The loaded model and
- data processor (tokenizer, audio processor, etc.).
- """
- import torch
-
- from transformers import AutoConfig, AutoProcessor
-
- callback = progress_callback
- mid = model_id_and_revision
-
- tqdm_class = set_tqdm_class(callback, mid) if callback is not None else None
-
- def emit_progress(stage: str):
- if progress_callback is None:
- return
- progress_callback({"status": "loading", "model": model_id_and_revision, "stage": stage})
-
- logger.warning(f"Loading {model_id_and_revision}")
- emit_progress("processor")
-
- if "@" in model_id_and_revision:
- model_id, revision = model_id_and_revision.split("@", 1)
- else:
- model_id, revision = model_id_and_revision, "main"
-
- try:
- data_processor = AutoProcessor.from_pretrained(
- model_id,
- revision=revision,
- trust_remote_code=self.trust_remote_code,
- )
- except OSError:
- try:
- data_processor = AutoTokenizer.from_pretrained(
- model_id,
- revision=revision,
- trust_remote_code=self.trust_remote_code,
- )
- except OSError:
- raise OSError("Failed to load processor with `AutoProcessor` and `AutoTokenizer`.")
-
- # processor done — move on to config
- dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype)
- quantization_config = self.get_quantization_config()
-
- model_kwargs = {
- "revision": revision,
- "attn_implementation": self.attn_implementation,
- "dtype": dtype,
- "device_map": self.device,
- "trust_remote_code": self.trust_remote_code,
- "quantization_config": quantization_config,
- }
-
- emit_progress("config")
- config = AutoConfig.from_pretrained(model_id, **model_kwargs)
- architecture = getattr(transformers, config.architectures[0])
- # weights stage events are emitted by _WeightsTqdm (and download by _DownloadAggregator)
- model = architecture.from_pretrained(model_id, tqdm_class=tqdm_class, **model_kwargs)
-
- has_default_max_length = (
- model.generation_config.max_new_tokens is None and model.generation_config.max_length == 20
- )
- has_short_max_new_tokens = (
- model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 1024
- )
- if has_default_max_length or has_short_max_new_tokens:
- model.generation_config.max_new_tokens = 1024
-
- return model, data_processor
-
- def load_model_and_processor(
- self, model_id_and_revision: str, progress_callback: Callable[[dict], None] | None = None
- ) -> tuple["PreTrainedModel", "PreTrainedTokenizerFast"]:
- """
- Loads the text model and processor from the given model ID and revision into the ServeCommand instance.
-
- Args:
- model_id_and_revision (`str`):
- The model ID and revision to load.
-
- Returns:
- `tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and processor.
- """
- with self.model_locks_guard:
- lock = self.model_locks.setdefault(model_id_and_revision, threading.Lock())
-
- with lock:
- if (
- model_id_and_revision not in self.loaded_models
- or self.loaded_models[model_id_and_revision].is_deleted()
- ):
- model, processor = self._load_model_and_data_processor(
- model_id_and_revision, progress_callback=progress_callback
- )
- self.loaded_models[model_id_and_revision] = TimedModel(
- model,
- timeout_seconds=self.model_timeout,
- processor=processor,
- )
- if progress_callback is not None:
- progress_callback({"status": "ready", "model": model_id_and_revision, "cached": False})
- else:
- self.loaded_models[model_id_and_revision].reset_timer()
- model = self.loaded_models[model_id_and_revision].model
- processor = self.loaded_models[model_id_and_revision].processor
- if progress_callback is not None:
- progress_callback({"status": "ready", "model": model_id_and_revision, "cached": True})
-
- return model, processor
-
- def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple["PreTrainedModel", "ProcessorMixin"]:
- """
- Loads the audio model and processor from the given model ID and revision into the ServeCommand instance.
-
- Args:
- model_id_and_revision (`str`):
- The model ID and revision to load.
-
- Returns:
- `tuple[PreTrainedModel, ProcessorMixin]`: The loaded audio model and processor.
- """
- with self.model_locks_guard:
- lock = self.model_locks.setdefault(model_id_and_revision, threading.Lock())
-
- with lock:
- if (
- model_id_and_revision not in self.loaded_models
- or self.loaded_models[model_id_and_revision].is_deleted()
- ):
- logger.warning(f"Loading model into cache: {model_id_and_revision}")
- audio_model, audio_processor = self._load_model_and_data_processor(model_id_and_revision)
- self.loaded_models[model_id_and_revision] = TimedModel(
- audio_model,
- timeout_seconds=self.model_timeout,
- processor=audio_processor,
- )
- else:
- self.loaded_models[model_id_and_revision].reset_timer()
- audio_model = self.loaded_models[model_id_and_revision].model
- audio_processor = self.loaded_models[model_id_and_revision].processor
-
- return audio_model, audio_processor
+ self.server.should_exit = True
+ self._thread.join(timeout=2)
-# set docstring separately to make it look nice (Typer doesn't play well with the class command)
Serve.__doc__ = """
Run a FastAPI server to serve models on-demand with an OpenAI compatible API.
-
Models will be loaded and unloaded automatically based on usage and a timeout.
\b
-The server will expose the following endpoints:
- - POST /v1/chat/completions: Generates chat completions.
- - POST /v1/responses: Generates responses.
- - POST /v1/audio/transcriptions: Generates transcriptions from audio.
- - GET /v1/models: Lists available models for 3rd party tools.
+Endpoints:
+ POST /v1/chat/completions — Chat completions (streaming + non-streaming).
+ GET /v1/models — Lists available models.
+ GET /health — Health check.
-Requires FastAPI and Uvicorn to be installed.
+Requires FastAPI and Uvicorn: pip install transformers[serving]
"""
-
-if __name__ == "__main__":
- serve = Serve()
diff --git a/src/transformers/cli/serving/__init__.py b/src/transformers/cli/serving/__init__.py
new file mode 100644
index 000000000000..118d3a9c2012
--- /dev/null
+++ b/src/transformers/cli/serving/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .model_manager import ModelManager
+from .server import build_server
+from .utils import Modality
diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py
new file mode 100644
index 000000000000..520eab7abbf4
--- /dev/null
+++ b/src/transformers/cli/serving/chat_completion.py
@@ -0,0 +1,410 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Handler for the /v1/chat/completions endpoint.
+
+Supports streaming (SSE via DirectStreamer) and non-streaming (JSON) responses.
+"""
+
+import asyncio
+import time
+from collections.abc import AsyncGenerator
+from typing import TYPE_CHECKING
+
+from ...utils import logging
+from ...utils.import_utils import is_serve_available
+
+
+if is_serve_available():
+ from fastapi.responses import JSONResponse, StreamingResponse
+ from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall
+ from openai.types.chat.chat_completion import Choice
+ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall
+ from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
+ from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming
+ from openai.types.completion_usage import CompletionUsage
+
+from .utils import (
+ BaseGenerateManager,
+ BaseHandler,
+ ToolCallParser,
+ _StreamError,
+ detect_tool_format,
+)
+
+
+if TYPE_CHECKING:
+ from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin
+
+
+class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False):
+ generation_config: str
+ seed: int
+
+
+# Fields accepted by the OpenAI schema but not yet supported.
+# Receiving these raises an error to avoid silent misbehaviour.
+# NOTE: "stop" is NOT in this set — we map it to stop_strings.
+UNUSED_CHAT_COMPLETION_FIELDS = {
+ "audio",
+ "function_call",
+ "functions",
+ "logprobs",
+ "max_completion_tokens",
+ "metadata",
+ "modalities",
+ "n",
+ "parallel_tool_calls",
+ "prediction",
+ "presence_penalty",
+ "reasoning_effort",
+ "response_format",
+ "service_tier",
+ "store",
+ "stream_options",
+ "tool_choice",
+ "top_logprobs",
+ "user",
+ "web_search_options",
+}
+
+
+logger = logging.get_logger(__name__)
+
+
+class ChatCompletionHandler(BaseHandler):
+ """Handler for the `/v1/chat/completions` endpoint.
+
+ Supports both streaming (SSE) and non-streaming (JSON) responses.
+ """
+
+ _valid_params_class = TransformersCompletionCreateParamsStreaming
+ _unused_fields = UNUSED_CHAT_COMPLETION_FIELDS
+
+ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse:
+ """Validate the request, load the model, and dispatch to streaming or non-streaming.
+
+ Args:
+ body (`dict`): The raw JSON request body (OpenAI chat completion format).
+ request_id (`str`): Unique request identifier (from header or auto-generated).
+
+ Returns:
+ `StreamingResponse | JSONResponse`: SSE stream or JSON depending on ``body["stream"]``.
+ """
+ self._validate_request(body)
+
+ model_id, model, processor = self._resolve_model(body)
+ modality = self.model_manager.get_model_modality(model, processor=processor)
+ use_cb = self.generation_state.use_continuous_batching(model, modality)
+ logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}")
+ gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb)
+ processor_inputs = self.get_processor_inputs_from_messages(body["messages"], modality)
+
+ inputs = processor.apply_chat_template(
+ processor_inputs,
+ add_generation_prompt=True,
+ tools=body.get("tools"),
+ return_tensors=None if use_cb else "pt",
+ return_dict=True,
+ tokenize=True,
+ )
+ if not use_cb:
+ inputs = inputs.to(model.device)
+
+ gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb)
+ # TODO: remove when CB supports per-request generation config
+ if use_cb:
+ gen_manager.init_cb(model, gen_config)
+
+ # Detect tool support for the loaded model
+ # TODO: after tool_call start token, use constrained generation to:
+ # 1. force generation to pick from the available tool names
+ # 2. force generation to produce valid JSON matching the tool's parameter schema
+ tool_format = detect_tool_format(model) if body.get("tools") else None
+
+ streaming = body.get("stream")
+ if streaming:
+ return self._streaming(
+ request_id,
+ model,
+ processor,
+ model_id,
+ inputs,
+ gen_config,
+ gen_manager=gen_manager,
+ tool_format=tool_format,
+ )
+ else:
+ return await self._non_streaming(
+ request_id,
+ model,
+ processor,
+ model_id,
+ inputs,
+ gen_config,
+ gen_manager=gen_manager,
+ tool_format=tool_format,
+ )
+
+ # ----- streaming -----
+
+ def _streaming(
+ self,
+ request_id: str,
+ model: "PreTrainedModel",
+ processor: "ProcessorMixin | PreTrainedTokenizerFast",
+ model_id: str,
+ inputs: dict,
+ gen_config: "GenerationConfig",
+ gen_manager: BaseGenerateManager,
+ tool_format: dict | None = None,
+ ) -> StreamingResponse:
+ """Stream tokens as SSE via DirectStreamer."""
+ queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id)
+ input_ids = inputs["input_ids"]
+ # CB returns plain lists, regular path returns tensors
+ input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1]
+ parser = ToolCallParser(tool_format) if tool_format else None
+
+ async def sse_gen() -> AsyncGenerator[str, None]:
+ has_tool_calls = False
+ try:
+ yield self._build_chunk_sse(request_id, role="assistant", model=model_id)
+
+ done = False
+ while not done:
+ text = await queue.get()
+ batch = [text]
+ try:
+ while True:
+ batch.append(queue.get_nowait())
+ except asyncio.QueueEmpty:
+ pass
+
+ sse_parts: list[str] = []
+ for text in batch:
+ if text is None:
+ done = True
+ break
+ if isinstance(text, _StreamError):
+ sse_parts.append(f'data: {{"error": "{text.msg}"}}\n\n')
+ yield "".join(sse_parts)
+ return
+
+ # Tool call parsing: None = normal text, CONSUMED = buffering, else = tool call dict
+ chunk_kwargs = {"content": text}
+ if parser is not None and (result := parser.feed(text)) is not None:
+ if result is ToolCallParser.CONSUMED:
+ continue
+ has_tool_calls = True
+ chunk_kwargs = {
+ "tool_calls": [
+ ChoiceDeltaToolCall(
+ index=0,
+ type="function",
+ id=f"{request_id}_tool_call",
+ function={"name": result["name"], "arguments": result["arguments"]},
+ )
+ ]
+ }
+
+ sse_parts.append(self._build_chunk_sse(request_id, model=model_id, **chunk_kwargs))
+
+ if sse_parts:
+ yield "".join(sse_parts)
+
+ hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens
+ if has_tool_calls:
+ finish_reason = "tool_calls"
+ elif hit_max:
+ finish_reason = "length"
+ else:
+ finish_reason = "stop"
+ usage = CompletionUsage(
+ prompt_tokens=input_len,
+ completion_tokens=streamer.total_tokens,
+ total_tokens=input_len + streamer.total_tokens,
+ )
+ yield self._build_chunk_sse(
+ request_id,
+ finish_reason=finish_reason,
+ model=model_id,
+ usage=usage,
+ )
+ except (GeneratorExit, asyncio.CancelledError):
+ # Client disconnected — abort generation to free GPU.
+ # Re-raise is mandatory: Python raises RuntimeError if GeneratorExit is swallowed.
+ streamer.cancel()
+ raise
+
+ return StreamingResponse(sse_gen(), media_type="text/event-stream")
+
+ # ----- non-streaming -----
+
+ async def _non_streaming(
+ self,
+ request_id: str,
+ model: "PreTrainedModel",
+ processor: "ProcessorMixin | PreTrainedTokenizerFast",
+ model_id: str,
+ inputs: dict,
+ gen_config: "GenerationConfig",
+ gen_manager: BaseGenerateManager,
+ tool_format: dict | None = None,
+ ) -> JSONResponse:
+ """Run generation and return a JSONResponse."""
+ content, input_len, generated_ids = await gen_manager.generate_non_streaming(
+ model, processor, inputs, gen_config, request_id=request_id
+ )
+
+ hit_max = gen_config.max_new_tokens is not None and len(generated_ids) >= gen_config.max_new_tokens
+ completion_tokens = len(generated_ids)
+ usage = CompletionUsage(
+ prompt_tokens=input_len,
+ completion_tokens=completion_tokens,
+ total_tokens=input_len + completion_tokens,
+ )
+
+ # Parse tool calls from the generated text
+ tool_calls = None
+ if tool_format is not None:
+ parsed = ToolCallParser.parse(content, tool_format)
+ if parsed is not None:
+ tool_calls = [
+ ChatCompletionMessageToolCall(
+ id=f"{request_id}_tool_call",
+ type="function",
+ function={"name": tc["name"], "arguments": tc["arguments"]},
+ )
+ for tc in parsed
+ ]
+
+ if tool_calls is not None:
+ finish_reason = "tool_calls"
+ elif hit_max:
+ finish_reason = "length"
+ else:
+ finish_reason = "stop"
+
+ return JSONResponse(
+ self._build_completion(
+ request_id,
+ content,
+ model_id,
+ finish_reason=finish_reason,
+ usage=usage,
+ tool_calls=tool_calls,
+ ),
+ media_type="application/json",
+ )
+
+ # ----- helpers -----
+
+ def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False):
+ """Apply Chat Completions params (``max_tokens``, ``frequency_penalty``, ``logit_bias``,
+ ``stop``) on top of the base generation config."""
+ generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb)
+
+ if body.get("max_tokens") is not None:
+ generation_config.max_new_tokens = int(body["max_tokens"])
+ if body.get("frequency_penalty") is not None:
+ generation_config.repetition_penalty = 1.0 + float(body["frequency_penalty"])
+ if body.get("logit_bias") is not None:
+ generation_config.sequence_bias = {(int(k),): v for k, v in body["logit_bias"].items()}
+ if body.get("stop") is not None:
+ generation_config.stop_strings = body["stop"]
+
+ return generation_config
+
+ # ----- response builders -----
+
+ def _build_completion(
+ self,
+ request_id: str,
+ content: str,
+ model_id: str,
+ finish_reason: str,
+ usage: CompletionUsage | None = None,
+ tool_calls: list[dict] | None = None,
+ ) -> dict:
+ """Build a non-streaming ChatCompletion response dict.
+
+ Args:
+ request_id (`str`): Unique request identifier.
+ content (`str`): The generated text.
+ model_id (`str`): Model ID to include in the response.
+ finish_reason (`str`): Why generation stopped (``"stop"``, ``"length"``, ``"tool_calls"``).
+ usage (`CompletionUsage`, *optional*): Token usage statistics.
+ tool_calls (`list[dict]`, *optional*): Parsed tool calls, if any.
+
+ Returns:
+ `dict`: Serialized ``ChatCompletion`` ready for JSON response.
+ """
+ message = ChatCompletionMessage(content=content, role="assistant", tool_calls=tool_calls)
+ result = ChatCompletion(
+ id=request_id,
+ created=int(time.time()),
+ object="chat.completion",
+ model=model_id,
+ choices=[
+ Choice(
+ index=0,
+ message=message,
+ finish_reason=finish_reason,
+ )
+ ],
+ usage=usage,
+ )
+ return result.model_dump(exclude_none=True)
+
+ def _build_chunk_sse(
+ self,
+ request_id: str = "",
+ content: str | None = None,
+ model: str | None = None,
+ role: str | None = None,
+ finish_reason: str | None = None,
+ tool_calls: list | None = None,
+ usage: CompletionUsage | None = None,
+ ) -> str:
+ """Build a streaming ``ChatCompletionChunk`` and format it as an SSE ``data:`` line.
+
+ Args:
+ request_id (`str`): Unique request identifier.
+ content (`str`, *optional*): Text content delta.
+ model (`str`, *optional*): Model ID.
+ role (`str`, *optional*): Role (only sent in the first chunk).
+ finish_reason (`str`, *optional*): Set on the final chunk.
+ tool_calls (`list`, *optional*): Tool call deltas.
+ usage (`CompletionUsage`, *optional*): Token usage (sent with the final chunk).
+
+ Returns:
+ `str`: A formatted SSE event string.
+ """
+ chunk = ChatCompletionChunk(
+ id=request_id,
+ created=int(time.time()),
+ model=model,
+ choices=[
+ ChoiceChunk(
+ delta=ChoiceDelta(content=content, role=role, tool_calls=tool_calls),
+ index=0,
+ finish_reason=finish_reason,
+ )
+ ],
+ usage=usage,
+ system_fingerprint="",
+ object="chat.completion.chunk",
+ )
+ return self.chunk_to_sse(chunk)
diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py
new file mode 100644
index 000000000000..39b477302c58
--- /dev/null
+++ b/src/transformers/cli/serving/model_manager.py
@@ -0,0 +1,457 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Model loading, caching, and lifecycle management.
+"""
+
+import asyncio
+import gc
+import json
+import threading
+from collections.abc import Callable
+from functools import lru_cache
+from typing import TYPE_CHECKING
+
+from huggingface_hub import scan_cache_dir
+from tqdm import tqdm
+
+import transformers
+from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
+
+from ...utils import logging
+from .utils import Modality, make_progress_tqdm_class, reset_torch_cache
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin
+
+
+logger = logging.get_logger(__name__)
+
+
+class TimedModel:
+ """Wraps a model + processor and auto-unloads them after a period of inactivity.
+
+ Args:
+ model: The loaded model.
+ timeout_seconds: Seconds of inactivity before auto-unload. Use -1 to disable.
+ processor: The associated processor or tokenizer.
+ on_unload: Optional callback invoked after the model is unloaded from memory.
+ """
+
+ def __init__(
+ self,
+ model: "PreTrainedModel",
+ timeout_seconds: int,
+ processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None,
+ on_unload: "Callable | None" = None,
+ ):
+ self.model = model
+ self._name_or_path = str(model.name_or_path)
+ self.processor = processor
+ self.timeout_seconds = timeout_seconds
+ self._on_unload = on_unload
+ self._timer = threading.Timer(self.timeout_seconds, self._timeout_reached)
+ self._timer.start()
+
+ def reset_timer(self) -> None:
+ """Reset the inactivity timer (called on each request)."""
+ self._timer.cancel()
+ self._timer = threading.Timer(self.timeout_seconds, self._timeout_reached)
+ self._timer.start()
+
+ def delete_model(self) -> None:
+ """Delete the model and processor, free GPU memory."""
+ if hasattr(self, "model") and self.model is not None:
+ del self.model
+ del self.processor
+ self.model = None
+ self.processor = None
+ gc.collect()
+ reset_torch_cache()
+ self._timer.cancel()
+ if self._on_unload is not None:
+ self._on_unload()
+
+ def _timeout_reached(self) -> None:
+ if self.timeout_seconds > 0:
+ self.delete_model()
+ logger.info(f"{self._name_or_path} was removed from memory after {self.timeout_seconds}s of inactivity")
+
+
+class ModelManager:
+ """Loads, caches, and manages the lifecycle of models.
+
+ Handlers receive a reference to this and call `load_model_and_processor()`
+ to get a model ready for inference.
+
+ Args:
+ device: Device to place models on (e.g. "auto", "cuda", "cpu").
+ dtype: Torch dtype override. "auto" derives from model weights.
+ trust_remote_code: Whether to trust remote code when loading models.
+ attn_implementation: Attention implementation override (e.g. "flash_attention_2").
+ quantization: Quantization method ("bnb-4bit" or "bnb-8bit").
+ model_timeout: Seconds before an idle model is unloaded. -1 disables.
+ force_model: If set, preload this model at init time.
+ """
+
+ def __init__(
+ self,
+ device: str = "auto",
+ dtype: str | None = "auto",
+ trust_remote_code: bool = False,
+ attn_implementation: str | None = None,
+ quantization: str | None = None,
+ model_timeout: int = 300,
+ force_model: str | None = None,
+ ):
+ self.loaded_models: dict[str, TimedModel] = {}
+
+ # Thread-safety for concurrent load_model_and_processor calls
+ self._model_locks: dict[str, threading.Lock] = {}
+ self._model_locks_guard = threading.Lock()
+
+ # Tracks in-flight loads for fan-out to multiple SSE subscribers (used by load_model_streaming)
+ self._loading_subscribers: dict[str, list[asyncio.Queue[str | None]]] = {}
+ self._loading_tasks: dict[str, asyncio.Task] = {}
+
+ # Convert numeric device strings (e.g. "0") to int so device_map works correctly
+ self.device = int(device) if device.isdigit() else device
+ self.dtype = self._resolve_dtype(dtype)
+ self.trust_remote_code = trust_remote_code
+ self.attn_implementation = attn_implementation
+ self.quantization = quantization
+ self.model_timeout = model_timeout
+ self.force_model = force_model
+
+ self._validate_args()
+
+ # Preloaded models should never be auto-unloaded
+ if force_model is not None:
+ self.model_timeout = -1
+
+ # Preload the forced model after all state is initialized
+ if force_model is not None:
+ self.load_model_and_processor(self.process_model_name(force_model))
+
+ @staticmethod
+ def _resolve_dtype(dtype: str | None):
+ import torch
+
+ if dtype in ("auto", None):
+ return dtype
+ resolved = getattr(torch, dtype, None)
+ if not isinstance(resolved, torch.dtype):
+ raise ValueError(
+ f"Unsupported dtype: '{dtype}'. Must be 'auto' or a valid torch dtype (e.g. 'float16', 'bfloat16')."
+ )
+ return resolved
+
+ def _validate_args(self):
+ if self.quantization is not None and self.quantization not in ("bnb-4bit", "bnb-8bit"):
+ raise ValueError(
+ f"Unsupported quantization method: '{self.quantization}'. Must be 'bnb-4bit' or 'bnb-8bit'."
+ )
+ VALID_ATTN_IMPLEMENTATIONS = {"eager", "sdpa", "flash_attention_2", "flash_attention_3", "flex_attention"}
+ is_kernels_community = self.attn_implementation is not None and self.attn_implementation.startswith(
+ "kernels-community/"
+ )
+ if (
+ self.attn_implementation is not None
+ and not is_kernels_community
+ and self.attn_implementation not in VALID_ATTN_IMPLEMENTATIONS
+ ):
+ raise ValueError(
+ f"Unsupported attention implementation: '{self.attn_implementation}'. "
+ f"Must be one of {VALID_ATTN_IMPLEMENTATIONS} or a kernels-community kernel (e.g. 'kernels-community/flash-attn2')."
+ )
+
+ @staticmethod
+ def process_model_name(model_id: str) -> str:
+ """Canonicalize to `'model_id@revision'` format. Defaults to `@main`."""
+ if "@" in model_id:
+ return model_id
+ return f"{model_id}@main"
+
+ def get_quantization_config(self) -> BitsAndBytesConfig | None:
+ """Return a BitsAndBytesConfig based on the `quantization` setting, or None."""
+ if self.quantization == "bnb-4bit":
+ return BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_use_double_quant=True,
+ )
+ elif self.quantization == "bnb-8bit":
+ return BitsAndBytesConfig(load_in_8bit=True)
+ return None
+
+ def _load_processor(self, model_id_and_revision: str) -> "ProcessorMixin | PreTrainedTokenizerFast":
+ """Load a processor for the given model.
+
+ Args:
+ model_id_and_revision: Model ID in ``'model_id@revision'`` format.
+ """
+ from transformers import AutoProcessor
+
+ model_id, revision = model_id_and_revision.split("@", 1)
+ return AutoProcessor.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code)
+
+ def _load_model(
+ self, model_id_and_revision: str, tqdm_class: type | None = None, progress_callback: Callable | None = None
+ ) -> "PreTrainedModel":
+ """Load a model.
+
+ Args:
+ model_id_and_revision (`str`): Model ID in ``'model_id@revision'`` format.
+ tqdm_class (*optional*): tqdm subclass for progress bars during ``from_pretrained``.
+ progress_callback (`Callable`, *optional*): Called with progress dicts during loading.
+
+ Returns:
+ `PreTrainedModel`: The loaded model.
+ """
+ from transformers import AutoConfig
+
+ model_id, revision = model_id_and_revision.split("@", 1)
+
+ model_kwargs = {
+ "revision": revision,
+ "attn_implementation": self.attn_implementation,
+ "dtype": self.dtype,
+ "device_map": self.device,
+ "trust_remote_code": self.trust_remote_code,
+ "quantization_config": self.get_quantization_config(),
+ "tqdm_class": tqdm_class,
+ }
+
+ if progress_callback is not None:
+ progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "config"})
+ config = AutoConfig.from_pretrained(model_id, **model_kwargs)
+ architecture = getattr(transformers, config.architectures[0])
+
+ return architecture.from_pretrained(model_id, **model_kwargs)
+
+ def load_model_and_processor(
+ self,
+ model_id_and_revision: str,
+ progress_callback: Callable | None = None,
+ tqdm_class: type | None = None,
+ ) -> "tuple[PreTrainedModel, ProcessorMixin | PreTrainedTokenizerFast]":
+ """Load a model (or return it from cache), resetting its inactivity timer.
+
+ Args:
+ model_id_and_revision: Model ID in ``'model_id@revision'`` format.
+ progress_callback: If provided, called with dicts like
+ ``{"status": "loading", "model": ..., "stage": ...}`` during loading.
+ tqdm_class: Optional tqdm subclass for progress bars during ``from_pretrained``.
+ """
+ # Per-model lock prevents duplicate loads when concurrent requests arrive
+ with self._model_locks_guard:
+ lock = self._model_locks.setdefault(model_id_and_revision, threading.Lock())
+
+ with lock:
+ if model_id_and_revision not in self.loaded_models:
+ logger.warning(f"Loading {model_id_and_revision}")
+ if progress_callback is not None:
+ progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "processor"})
+ processor = self._load_processor(model_id_and_revision)
+ model = self._load_model(
+ model_id_and_revision, tqdm_class=tqdm_class, progress_callback=progress_callback
+ )
+ self.loaded_models[model_id_and_revision] = TimedModel(
+ model,
+ timeout_seconds=self.model_timeout,
+ processor=processor,
+ on_unload=lambda key=model_id_and_revision: self.loaded_models.pop(key, None),
+ )
+ if progress_callback is not None:
+ progress_callback({"status": "ready", "model": model_id_and_revision, "cached": False})
+ else:
+ self.loaded_models[model_id_and_revision].reset_timer()
+ model = self.loaded_models[model_id_and_revision].model
+ processor = self.loaded_models[model_id_and_revision].processor
+ if progress_callback is not None:
+ progress_callback({"status": "ready", "model": model_id_and_revision, "cached": True})
+ return model, processor
+
+ async def load_model_streaming(self, model_id_and_revision: str):
+ """Load a model and stream progress as SSE events.
+
+ Handles three cases:
+ 1. Model already cached → single ``ready`` event
+ 2. Load already in progress → join existing subscriber stream
+ 3. First request → start loading, broadcast to all subscribers
+
+ Args:
+ model_id_and_revision (`str`): Model ID in ``'model_id@revision'`` format.
+
+ Yields:
+ `str`: SSE ``data: ...`` lines with progress updates.
+ """
+ mid = model_id_and_revision
+ queue: asyncio.Queue[str | None] = asyncio.Queue()
+
+ # Case 1: already cached
+ if mid in self.loaded_models:
+ self.loaded_models[mid].reset_timer()
+ yield f"data: {json.dumps({'status': 'ready', 'model': mid, 'cached': True})}\n\n"
+ return
+
+ # Case 2: load in progress — join existing subscribers
+ if mid in self._loading_tasks:
+ self._loading_subscribers[mid].append(queue)
+ while True:
+ item = await queue.get()
+ if item is None:
+ break
+ yield item
+ return
+
+ # Case 3: first request — start the load
+ self._loading_subscribers[mid] = [queue]
+ loop = asyncio.get_running_loop()
+
+ def enqueue(payload: dict):
+ msg = f"data: {json.dumps(payload)}\n\n"
+
+ def broadcast():
+ for q in self._loading_subscribers.get(mid, []):
+ q.put_nowait(msg)
+
+ loop.call_soon_threadsafe(broadcast)
+
+ tqdm_class = make_progress_tqdm_class(enqueue, mid)
+
+ def _tqdm_hook(factory, args, kwargs):
+ return tqdm_class(*args, **kwargs)
+
+ async def run_load():
+ try:
+ # Install a global tqdm hook so the "Loading weights" bar in
+ # core_model_loading.py (which uses logging.tqdm) routes through
+ # our ProgressTqdm. The tqdm_class kwarg only covers download bars.
+ previous_hook = logging.set_tqdm_hook(_tqdm_hook)
+ try:
+ await asyncio.to_thread(
+ self.load_model_and_processor,
+ mid,
+ progress_callback=enqueue,
+ tqdm_class=tqdm_class,
+ )
+ finally:
+ logging.set_tqdm_hook(previous_hook)
+ except Exception as e:
+ logger.error(f"Failed to load {mid}: {e}", exc_info=True)
+ enqueue({"status": "error", "model": mid, "message": str(e)})
+ finally:
+
+ def _send_sentinel():
+ for q in self._loading_subscribers.pop(mid, []):
+ q.put_nowait(None)
+ self._loading_tasks.pop(mid, None)
+
+ loop.call_soon_threadsafe(_send_sentinel)
+
+ self._loading_tasks[mid] = asyncio.create_task(run_load())
+
+ while True:
+ item = await queue.get()
+ if item is None:
+ break
+ yield item
+
+ def shutdown(self) -> None:
+ """Delete all loaded models and free resources."""
+ for timed in list(self.loaded_models.values()):
+ timed.delete_model()
+
+ @staticmethod
+ def get_model_modality(
+ model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None
+ ) -> Modality:
+ """Detect whether a model is an LLM or VLM based on its architecture.
+
+ Args:
+ model (`PreTrainedModel`): The loaded model.
+ processor (`ProcessorMixin | PreTrainedTokenizerFast`, *optional*):
+ If a plain tokenizer (not a multi-modal processor), short-circuits to LLM.
+
+ Returns:
+ `Modality`: The detected modality (``Modality.LLM`` or ``Modality.VLM``).
+ """
+ if processor is not None and isinstance(processor, PreTrainedTokenizerBase):
+ return Modality.LLM
+
+ from transformers.models.auto.modeling_auto import (
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
+ )
+
+ model_classname = model.__class__.__name__
+ if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
+ return Modality.VLM
+ elif model_classname in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
+ return Modality.LLM
+ else:
+ raise ValueError(f"Unknown modality for: {model_classname}")
+
+ @staticmethod
+ @lru_cache
+ def get_gen_models(cache_dir: str | None = None) -> list[dict]:
+ """List generative models (LLMs and VLMs) available in the HuggingFace cache.
+
+ Args:
+ cache_dir (`str`, *optional*): Path to the HuggingFace cache directory.
+ Defaults to the standard cache location.
+
+ Returns:
+ `list[dict]`: OpenAI-compatible model list entries with ``id``, ``object``, etc.
+ """
+ from transformers.models.auto.modeling_auto import (
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
+ )
+
+ generative_models = []
+ logger.warning("Scanning the cache directory for LLMs and VLMs.")
+
+ for repo in tqdm(scan_cache_dir(cache_dir).repos):
+ if repo.repo_type != "model":
+ continue
+
+ for ref, revision_info in repo.refs.items():
+ config_path = next((f.file_path for f in revision_info.files if f.file_name == "config.json"), None)
+ if not config_path:
+ continue
+
+ config = json.loads(config_path.open().read())
+ if not (isinstance(config, dict) and "architectures" in config):
+ continue
+
+ architectures = config["architectures"]
+ llms = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()
+ vlms = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
+
+ if any(arch for arch in architectures if arch in [*llms, *vlms]):
+ author = repo.repo_id.split("/") if "/" in repo.repo_id else ""
+ repo_handle = repo.repo_id + (f"@{ref}" if ref != "main" else "")
+ generative_models.append(
+ {
+ "owned_by": author,
+ "id": repo_handle,
+ "object": "model",
+ "created": repo.last_modified,
+ }
+ )
+
+ return generative_models
diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py
new file mode 100644
index 000000000000..23d49a480d33
--- /dev/null
+++ b/src/transformers/cli/serving/response.py
@@ -0,0 +1,564 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Handler for the /v1/responses endpoint (OpenAI Responses API).
+
+Supports streaming (SSE) and non-streaming (JSON) responses.
+"""
+
+import asyncio
+import time
+from collections.abc import AsyncGenerator
+from typing import TYPE_CHECKING
+
+from ...utils import logging
+from ...utils.import_utils import is_serve_available
+
+
+if is_serve_available():
+ from fastapi import HTTPException
+ from fastapi.responses import JSONResponse, StreamingResponse
+ from openai.types.responses import (
+ Response,
+ ResponseCompletedEvent,
+ ResponseContentPartAddedEvent,
+ ResponseContentPartDoneEvent,
+ ResponseCreatedEvent,
+ ResponseError,
+ ResponseErrorEvent,
+ ResponseFailedEvent,
+ ResponseFunctionCallArgumentsDoneEvent,
+ ResponseFunctionToolCall,
+ ResponseInProgressEvent,
+ ResponseOutputItemAddedEvent,
+ ResponseOutputItemDoneEvent,
+ ResponseOutputMessage,
+ ResponseOutputText,
+ ResponseTextDeltaEvent,
+ ResponseTextDoneEvent,
+ )
+ from openai.types.responses.response_create_params import ResponseCreateParamsStreaming
+ from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage
+
+from .utils import (
+ BaseGenerateManager,
+ BaseHandler,
+ ToolCallParser,
+ _StreamError,
+ detect_tool_format,
+)
+
+
+if TYPE_CHECKING:
+ from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin
+
+
+logger = logging.get_logger(__name__)
+
+
+class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False):
+ generation_config: str
+ seed: int
+
+
+UNUSED_RESPONSE_FIELDS = {
+ "background",
+ "include",
+ "max_tool_calls",
+ "previous_response_id",
+ "prompt",
+ "reasoning",
+ "service_tier",
+ "store",
+ "text",
+ "tool_choice",
+ "top_logprobs",
+ "truncation",
+ "user",
+}
+
+
+class ResponseHandler(BaseHandler):
+ """Handler for the ``/v1/responses`` endpoint."""
+
+ _valid_params_class = TransformersResponseCreateParamsStreaming
+ _unused_fields = UNUSED_RESPONSE_FIELDS
+
+ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse:
+ """Validate, load model, dispatch to streaming or non-streaming.
+
+ Args:
+ body (`dict`): The raw JSON request body (OpenAI Responses API format).
+ request_id (`str`): Unique request identifier (from header or auto-generated).
+
+ Returns:
+ `StreamingResponse | JSONResponse`: SSE stream or JSON depending on ``body["stream"]``.
+ """
+ self._validate_request(body)
+
+ model_id, model, processor = self._resolve_model(body)
+ modality = self.model_manager.get_model_modality(model, processor=processor)
+ use_cb = self.generation_state.use_continuous_batching(model, modality)
+ logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}")
+ gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb)
+
+ # Two-step input conversion (chat completions skips step 1 since messages are already standard):
+ # 1. Normalize Responses API input (string/list/dict + instructions) → standard messages list
+ # 2. Transform message content for the HF processor (VLM image handling, text joining, etc.)
+ messages = self._input_to_messages(body)
+ processor_inputs = self.get_processor_inputs_from_messages(messages, modality)
+
+ inputs = processor.apply_chat_template(
+ processor_inputs,
+ add_generation_prompt=True,
+ tools=body.get("tools"),
+ return_tensors=None if use_cb else "pt",
+ return_dict=True,
+ tokenize=True,
+ )
+ if not use_cb:
+ inputs = inputs.to(model.device)
+
+ gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb)
+ # TODO: remove when CB supports per-request generation config
+ if use_cb:
+ gen_manager.init_cb(model, gen_config)
+ tool_format = detect_tool_format(model) if body.get("tools") else None
+
+ streaming = body.get("stream", True)
+ if streaming:
+ return self._streaming(
+ request_id,
+ model,
+ processor,
+ model_id,
+ body,
+ inputs,
+ gen_config,
+ gen_manager=gen_manager,
+ tool_format=tool_format,
+ )
+ else:
+ return await self._non_streaming(
+ request_id,
+ model,
+ processor,
+ model_id,
+ body,
+ inputs,
+ gen_config,
+ gen_manager=gen_manager,
+ tool_format=tool_format,
+ )
+
+ # ----- input conversion -----
+
+ @staticmethod
+ def _input_to_messages(body: dict) -> list[dict]:
+ """Convert the Responses API ``input`` field to a list of chat messages.
+
+ Handles string, list, and dict inputs. If ``instructions`` is provided, it is
+ prepended as a system message (or replaces an existing one).
+
+ Args:
+ body (`dict`): The raw request body containing ``input`` and optionally ``instructions``.
+
+ Returns:
+ `list[dict]`: Standard OpenAI-format chat messages.
+ """
+ inp = body["input"]
+ instructions = body.get("instructions")
+
+ if isinstance(inp, str):
+ messages = [{"role": "system", "content": instructions}] if instructions else []
+ messages.append({"role": "user", "content": inp})
+ elif isinstance(inp, list):
+ if instructions:
+ if inp[0]["role"] != "system":
+ messages = [{"role": "system", "content": instructions}, *inp]
+ else:
+ messages = list(inp)
+ messages[0]["content"] = instructions
+ else:
+ messages = inp
+ elif isinstance(inp, dict):
+ messages = [{"role": "system", "content": instructions}] if instructions else []
+ messages.append(inp)
+ else:
+ raise HTTPException(status_code=422, detail="'input' must be a string, list, or dict")
+
+ return messages
+
+ # ----- streaming -----
+
+ def _streaming(
+ self,
+ request_id: str,
+ model: "PreTrainedModel",
+ processor: "ProcessorMixin | PreTrainedTokenizerFast",
+ model_id: str,
+ body: dict,
+ inputs: dict,
+ gen_config: "GenerationConfig",
+ gen_manager: BaseGenerateManager,
+ tool_format: dict | None = None,
+ ) -> StreamingResponse:
+ """Generate a streaming Responses API reply (SSE) using DirectStreamer."""
+ queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id)
+ input_ids = inputs["input_ids"]
+ # CB returns plain lists, regular path returns tensors
+ input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1]
+ parser = ToolCallParser(tool_format) if tool_format else None
+
+ seq = 0
+ output_index = 0
+ created_at = time.time()
+ resp_id = f"resp_{request_id}"
+ msg_id = f"msg_{request_id}"
+
+ response_base = {
+ "id": resp_id,
+ "created_at": created_at,
+ "model": model_id,
+ "object": "response",
+ # Required by pydantic but not used — echo request config back
+ "tools": [],
+ "parallel_tool_calls": body.get("parallel_tool_calls", False),
+ "tool_choice": "auto",
+ }
+
+ async def event_stream() -> AsyncGenerator[str, None]:
+ nonlocal seq, output_index
+
+ try:
+ # 1. Created + In progress
+ yield self.chunk_to_sse(
+ ResponseCreatedEvent(
+ type="response.created",
+ sequence_number=seq,
+ response=Response(**response_base, status="queued", output=[]),
+ )
+ )
+ seq += 1
+ yield self.chunk_to_sse(
+ ResponseInProgressEvent(
+ type="response.in_progress",
+ sequence_number=seq,
+ response=Response(**response_base, status="in_progress", output=[]),
+ )
+ )
+ seq += 1
+
+ # 2. Output item added (message)
+ yield self.chunk_to_sse(
+ ResponseOutputItemAddedEvent(
+ type="response.output_item.added",
+ sequence_number=seq,
+ output_index=output_index,
+ item=ResponseOutputMessage(
+ id=msg_id,
+ type="message",
+ status="in_progress",
+ role="assistant",
+ content=[],
+ ),
+ )
+ )
+ seq += 1
+
+ # 3. Content part added
+ yield self.chunk_to_sse(
+ ResponseContentPartAddedEvent(
+ type="response.content_part.added",
+ item_id=msg_id,
+ sequence_number=seq,
+ output_index=output_index,
+ content_index=0,
+ part=ResponseOutputText(type="output_text", text="", annotations=[]),
+ )
+ )
+ seq += 1
+
+ # 4. Stream tokens — drain queue to batch HTTP writes
+ full_text = ""
+ tool_calls = []
+ done = False
+
+ while not done:
+ text = await queue.get()
+ # Drain all available tokens for one batched HTTP write
+ batch = [text]
+ try:
+ while True:
+ batch.append(queue.get_nowait())
+ except asyncio.QueueEmpty:
+ pass
+
+ sse_parts: list[str] = []
+ for text in batch:
+ if text is None:
+ done = True
+ break
+ if isinstance(text, _StreamError):
+ logger.error(f"Exception in response generation: {text.msg}")
+ sse_parts.append(
+ self.chunk_to_sse(
+ ResponseErrorEvent(type="error", sequence_number=seq, message=text.msg)
+ )
+ )
+ seq += 1
+ sse_parts.append(
+ self.chunk_to_sse(
+ ResponseFailedEvent(
+ type="response.failed",
+ sequence_number=seq,
+ response=Response(
+ **response_base,
+ status="failed",
+ output=[],
+ error=ResponseError(code="server_error", message=text.msg),
+ ),
+ )
+ )
+ )
+ yield "".join(sse_parts)
+ return
+
+ # Tool call parsing
+ if parser is not None and (result := parser.feed(text)) is not None:
+ if result is not ToolCallParser.CONSUMED:
+ tc_id = f"{request_id}_tool_call"
+ name = result["name"]
+ arguments = result["arguments"]
+ tc_item = ResponseFunctionToolCall(
+ id=tc_id,
+ call_id=tc_id,
+ type="function_call",
+ name=name,
+ arguments=arguments,
+ status="completed",
+ )
+ tool_calls.append(tc_item)
+ output_index += 1
+ sse_parts.append(
+ self.chunk_to_sse(
+ ResponseOutputItemAddedEvent(
+ type="response.output_item.added",
+ sequence_number=seq,
+ output_index=output_index,
+ item=tc_item,
+ )
+ )
+ )
+ seq += 1
+ sse_parts.append(
+ self.chunk_to_sse(
+ ResponseFunctionCallArgumentsDoneEvent(
+ type="response.function_call_arguments.done",
+ sequence_number=seq,
+ item_id=tc_id,
+ output_index=output_index,
+ arguments=arguments,
+ name=name,
+ )
+ )
+ )
+ seq += 1
+ sse_parts.append(
+ self.chunk_to_sse(
+ ResponseOutputItemDoneEvent(
+ type="response.output_item.done",
+ sequence_number=seq,
+ output_index=output_index,
+ item=tc_item,
+ )
+ )
+ )
+ seq += 1
+ continue
+
+ full_text += text
+ sse_parts.append(
+ self.chunk_to_sse(
+ ResponseTextDeltaEvent(
+ type="response.output_text.delta",
+ item_id=msg_id,
+ sequence_number=seq,
+ output_index=0,
+ content_index=0,
+ delta=text,
+ logprobs=[],
+ )
+ )
+ )
+ seq += 1
+
+ if sse_parts:
+ yield "".join(sse_parts)
+
+ # 5. Close text output
+ output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[])
+ yield self.chunk_to_sse(
+ ResponseTextDoneEvent(
+ type="response.output_text.done",
+ item_id=msg_id,
+ sequence_number=seq,
+ output_index=0,
+ content_index=0,
+ text=full_text,
+ logprobs=[],
+ )
+ )
+ seq += 1
+ yield self.chunk_to_sse(
+ ResponseContentPartDoneEvent(
+ type="response.content_part.done",
+ item_id=msg_id,
+ sequence_number=seq,
+ output_index=0,
+ content_index=0,
+ part=output_text_part,
+ )
+ )
+ seq += 1
+
+ msg_item = ResponseOutputMessage(
+ id=msg_id,
+ type="message",
+ status="completed",
+ role="assistant",
+ content=[output_text_part],
+ annotations=[],
+ )
+ yield self.chunk_to_sse(
+ ResponseOutputItemDoneEvent(
+ type="response.output_item.done",
+ sequence_number=seq,
+ output_index=0,
+ item=msg_item,
+ )
+ )
+ seq += 1
+
+ # 6. Completed
+ all_output = [msg_item] + list(tool_calls)
+ usage = compute_usage(input_len, streamer.total_tokens)
+ yield self.chunk_to_sse(
+ ResponseCompletedEvent(
+ type="response.completed",
+ sequence_number=seq,
+ response=Response(**response_base, status="completed", output=all_output, usage=usage),
+ )
+ )
+ seq += 1
+ except (GeneratorExit, asyncio.CancelledError):
+ # Client disconnected — abort generation to free GPU.
+ # Re-raise is mandatory: Python raises RuntimeError if GeneratorExit is swallowed.
+ streamer.cancel()
+ raise
+
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
+
+ # ----- non-streaming -----
+
+ async def _non_streaming(
+ self,
+ request_id: str,
+ model: "PreTrainedModel",
+ processor: "ProcessorMixin | PreTrainedTokenizerFast",
+ model_id: str,
+ body: dict,
+ inputs: dict,
+ gen_config: "GenerationConfig",
+ gen_manager: BaseGenerateManager,
+ tool_format: dict | None = None,
+ ) -> JSONResponse:
+ """Generate a non-streaming Responses API reply (single JSON)."""
+ full_text, input_len, generated_ids = await gen_manager.generate_non_streaming(
+ model, processor, inputs, gen_config, request_id=request_id
+ )
+
+ output_items = [
+ ResponseOutputMessage(
+ id=f"msg_{request_id}",
+ type="message",
+ status="completed",
+ role="assistant",
+ content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])],
+ annotations=[],
+ )
+ ]
+
+ # Parse tool calls from the generated text
+ if tool_format is not None:
+ parsed_calls = ToolCallParser.parse(full_text, tool_format)
+ if parsed_calls is not None:
+ for i, tc in enumerate(parsed_calls):
+ tc_id = f"{request_id}_tool_call"
+ output_items.append(
+ ResponseFunctionToolCall(
+ id=tc_id,
+ call_id=tc_id,
+ type="function_call",
+ name=tc["name"],
+ arguments=tc["arguments"],
+ status="completed",
+ )
+ )
+
+ usage = compute_usage(input_len, len(generated_ids))
+ response = Response(
+ id=f"resp_{request_id}",
+ created_at=time.time(),
+ status="completed",
+ model=model_id,
+ output=output_items,
+ object="response",
+ usage=usage,
+ # Required by pydantic but not used — echo request config back
+ tools=[],
+ parallel_tool_calls=body.get("parallel_tool_calls", False),
+ tool_choice="auto",
+ )
+ return JSONResponse(response.model_dump(exclude_none=True))
+
+ # ----- helpers -----
+
+ def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False):
+ """Apply Responses API params (``max_output_tokens``) on top of the base generation config."""
+ generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb)
+
+ if body.get("max_output_tokens") is not None:
+ generation_config.max_new_tokens = int(body["max_output_tokens"])
+
+ return generation_config
+
+
+def compute_usage(input_tokens: int, output_tokens: int) -> ResponseUsage:
+ """Build a ``ResponseUsage`` object for a Responses API reply.
+
+ Args:
+ input_tokens (`int`): Number of prompt tokens.
+ output_tokens (`int`): Number of generated tokens.
+
+ Returns:
+ `ResponseUsage`: Usage statistics with zero-filled detail fields.
+ """
+ return ResponseUsage(
+ input_tokens=input_tokens,
+ output_tokens=output_tokens,
+ total_tokens=input_tokens + output_tokens,
+ input_tokens_details=InputTokensDetails(cached_tokens=0),
+ output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
+ )
diff --git a/src/transformers/cli/serving/server.py b/src/transformers/cli/serving/server.py
new file mode 100644
index 000000000000..ec0f287b36ee
--- /dev/null
+++ b/src/transformers/cli/serving/server.py
@@ -0,0 +1,127 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+FastAPI app factory.
+"""
+
+import uuid
+from contextlib import asynccontextmanager
+
+from ...utils import logging
+from ...utils.import_utils import is_serve_available
+
+
+if is_serve_available():
+ from fastapi import FastAPI, Request
+ from fastapi.middleware.cors import CORSMiddleware
+ from fastapi.responses import JSONResponse, StreamingResponse
+
+from .chat_completion import ChatCompletionHandler
+from .model_manager import ModelManager
+from .response import ResponseHandler
+from .transcription import TranscriptionHandler
+from .utils import X_REQUEST_ID
+
+
+logger = logging.get_logger(__name__)
+
+
+def build_server(
+ model_manager: ModelManager,
+ chat_handler: ChatCompletionHandler,
+ response_handler: ResponseHandler,
+ transcription_handler: TranscriptionHandler,
+ enable_cors: bool = False,
+) -> FastAPI:
+ """Build and return a configured FastAPI application.
+
+ Args:
+ model_manager: Handles model loading, caching, and cleanup.
+ chat_handler: Handles `/v1/chat/completions` requests.
+ response_handler: Handles `/v1/responses` requests.
+ enable_cors: If `True`, adds permissive CORS middleware (allow all origins).
+
+ Returns:
+ A FastAPI app ready to be passed to uvicorn.
+ """
+
+ @asynccontextmanager
+ async def lifespan(app: FastAPI):
+ yield
+ model_manager.shutdown()
+
+ app = FastAPI(lifespan=lifespan)
+
+ if enable_cors:
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+ logger.warning_once("CORS allow origin is set to `*`. Not recommended for production.")
+
+ # ---- Middleware ----
+
+ @app.middleware("http")
+ async def request_id_middleware(request: Request, call_next):
+ """Get or set the request ID in the header."""
+ request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
+ request.state.request_id = request_id
+ response = await call_next(request)
+ response.headers[X_REQUEST_ID] = request_id
+ return response
+
+ # ---- Routes ----
+
+ @app.post("/v1/chat/completions")
+ async def chat_completions(request: Request, body: dict):
+ return await chat_handler.handle_request(body, request.state.request_id)
+
+ @app.post("/v1/responses")
+ async def responses(request: Request, body: dict):
+ return await response_handler.handle_request(body, request.state.request_id)
+
+ @app.post("/v1/audio/transcriptions")
+ async def audio_transcriptions(request: Request):
+ return await transcription_handler.handle_request(request)
+
+ @app.post("/load_model")
+ async def load_model(body: dict):
+ from fastapi import HTTPException
+
+ model = body.get("model")
+ if model is None:
+ raise HTTPException(status_code=422, detail="Missing `model` field in the request body.")
+ model_id_and_revision = model_manager.process_model_name(model)
+ return StreamingResponse(
+ model_manager.load_model_streaming(model_id_and_revision), media_type="text/event-stream"
+ )
+
+ @app.post("/reset")
+ def reset():
+ model_manager.shutdown()
+ return JSONResponse({"status": "ok"})
+
+ @app.get("/v1/models")
+ @app.options("/v1/models")
+ def list_models():
+ return JSONResponse({"object": "list", "data": model_manager.get_gen_models()})
+
+ @app.get("/health")
+ def health():
+ return JSONResponse({"status": "ok"})
+
+ return app
diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py
new file mode 100644
index 000000000000..b63add7e5ed6
--- /dev/null
+++ b/src/transformers/cli/serving/transcription.py
@@ -0,0 +1,185 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Handler for the /v1/audio/transcriptions endpoint.
+"""
+
+import io
+from typing import TYPE_CHECKING
+
+from ...utils import logging
+from ...utils.import_utils import is_serve_available
+
+
+if is_serve_available():
+ from fastapi import HTTPException, Request
+ from fastapi.responses import JSONResponse, StreamingResponse
+ from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase
+
+from .model_manager import ModelManager
+from .utils import DirectStreamer, GenerateManager, GenerationState, _StreamError
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel, ProcessorMixin
+
+
+logger = logging.get_logger(__name__)
+
+
+class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False):
+ stream: bool
+
+
+UNUSED_TRANSCRIPTION_FIELDS = {
+ "chunking_strategy",
+ "include",
+ "language",
+ "prompt",
+ "response_format",
+ "temperature",
+ "timestamp_granularities",
+}
+
+
+class TranscriptionHandler:
+ """Handler for ``POST /v1/audio/transcriptions``.
+
+ Accepts a multipart/form-data request with an audio file and model name,
+ runs speech-to-text, and returns an OpenAI-compatible Transcription response.
+
+ Standalone (does not extend :class:`BaseHandler`) because audio requests use
+ multipart form data, not JSON bodies, and don't need generation config or
+ validation. Shares the :class:`GenerationState` for thread safety.
+ """
+
+ def __init__(self, model_manager: ModelManager, generation_state: GenerationState):
+ """
+ Args:
+ model_manager (`ModelManager`): Handles model loading, caching, and lifecycle.
+ generation_state (`GenerationState`): Shared generation state for thread safety.
+ """
+ self.model_manager = model_manager
+ self.generation_state = generation_state
+
+ def _validate_request(self, form_keys: set[str]) -> None:
+ """Validate transcription request fields."""
+ unexpected = form_keys - TransformersTranscriptionCreateParams.__mutable_keys__
+ if unexpected:
+ raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}")
+ unused = form_keys & UNUSED_TRANSCRIPTION_FIELDS
+ if unused:
+ logger.warning_once(f"Ignoring unsupported fields in the request: {unused}")
+
+ async def handle_request(self, request: Request) -> JSONResponse | StreamingResponse:
+ """Parse multipart form, run transcription, return result.
+
+ Args:
+ request (`Request`): FastAPI request containing multipart form data with
+ ``file`` (audio bytes), ``model`` (model ID), and optional ``stream`` flag.
+
+ Returns:
+ `JSONResponse | StreamingResponse`: Transcription result or SSE stream.
+ """
+ from transformers.utils.import_utils import is_librosa_available, is_multipart_available
+
+ if not is_librosa_available():
+ raise ImportError("Missing librosa dependency for audio transcription. Install with `pip install librosa`")
+ if not is_multipart_available():
+ raise ImportError(
+ "Missing python-multipart dependency for file uploads. Install with `pip install python-multipart`"
+ )
+
+ async with request.form() as form:
+ self._validate_request(set(form.keys()))
+ file_bytes = await form["file"].read()
+ model = form["model"]
+ stream = str(form.get("stream", "false")).lower() == "true"
+
+ model_id_and_revision = self.model_manager.process_model_name(model)
+ audio_model, audio_processor = self.model_manager.load_model_and_processor(model_id_and_revision)
+ gen_manager = self.generation_state.get_manager(model_id_and_revision)
+ audio_inputs = self._prepare_audio_inputs(file_bytes, audio_processor, audio_model)
+
+ if stream:
+ return self._streaming(gen_manager, audio_model, audio_processor, audio_inputs)
+ return await self._non_streaming(gen_manager, audio_model, audio_processor, audio_inputs)
+
+ @staticmethod
+ def _prepare_audio_inputs(
+ file_bytes: bytes, audio_processor: "ProcessorMixin", audio_model: "PreTrainedModel"
+ ) -> dict:
+ """Load audio bytes and convert to model inputs."""
+ import librosa
+
+ sampling_rate = audio_processor.feature_extractor.sampling_rate
+ audio_array, _ = librosa.load(io.BytesIO(file_bytes), sr=sampling_rate, mono=True)
+ audio_inputs = audio_processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").to(
+ audio_model.device
+ )
+ audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype)
+ return audio_inputs
+
+ async def _non_streaming(
+ self,
+ gen_manager: GenerateManager,
+ audio_model: "PreTrainedModel",
+ audio_processor: "ProcessorMixin",
+ audio_inputs: dict,
+ ) -> JSONResponse:
+ # Audio models have different inputs (input_features) and decode (batch_decode)
+ # than text models, so we use async_submit() directly instead of
+ # generate_non_streaming()
+ from openai.types.audio import Transcription
+
+ generated_ids = await gen_manager.async_submit(audio_model.generate, **audio_inputs)
+ text = audio_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ return JSONResponse(Transcription(text=text).model_dump(exclude_none=True))
+
+ def _streaming(
+ self,
+ gen_manager: GenerateManager,
+ audio_model: "PreTrainedModel",
+ audio_processor: "ProcessorMixin",
+ audio_inputs: dict,
+ ) -> StreamingResponse:
+ # Same as _non_streaming — uses submit() directly because audio inputs
+ # differ from text.
+ import asyncio
+
+ tokenizer = audio_processor.tokenizer if hasattr(audio_processor, "tokenizer") else audio_processor
+ loop = asyncio.get_running_loop()
+ queue: asyncio.Queue = asyncio.Queue()
+ streamer = DirectStreamer(tokenizer._tokenizer, loop, queue, skip_special_tokens=True)
+ gen_kwargs = {**audio_inputs, "streamer": streamer}
+
+ def _run():
+ try:
+ audio_model.generate(**gen_kwargs)
+ except Exception as e:
+ loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e)))
+
+ gen_manager.submit(_run)
+
+ async def sse_gen():
+ while True:
+ text = await queue.get()
+ if text is None:
+ break
+ if isinstance(text, _StreamError):
+ yield f'data: {{"error": "{text.msg}"}}\n\n'
+ return
+ yield f"data: {text}\n\n"
+
+ return StreamingResponse(sse_gen(), media_type="text/event-stream")
diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py
new file mode 100644
index 000000000000..d9828d123b12
--- /dev/null
+++ b/src/transformers/cli/serving/utils.py
@@ -0,0 +1,956 @@
+# Copyright 2026 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Shared types, constants, and utilities for the serving layer.
+"""
+
+import asyncio
+import base64
+import copy
+import enum
+import json
+import re
+import tempfile
+import threading
+from abc import ABC, abstractmethod
+from collections.abc import Callable
+from concurrent.futures import Future
+from io import BytesIO
+from queue import Queue
+from typing import TYPE_CHECKING
+
+from transformers.utils import logging
+
+
+if TYPE_CHECKING:
+ import pydantic
+ import tokenizers
+ import torch
+
+ from transformers import (
+ ContinuousBatchingConfig,
+ GenerationConfig,
+ PreTrainedModel,
+ PreTrainedTokenizerFast,
+ ProcessorMixin,
+ )
+ from transformers.generation.continuous_batching.continuous_api import ContinuousBatchingManager
+ from transformers.generation.continuous_batching.requests import GenerationOutput
+ from transformers.generation.continuous_batching.scheduler import Scheduler
+
+ from .model_manager import ModelManager
+
+
+logger = logging.get_logger(__name__)
+
+
+X_REQUEST_ID = "x-request-id"
+
+
+class Modality(enum.Enum):
+ LLM = "LLM"
+ VLM = "VLM"
+ STT = "STT"
+ TTS = "TTS"
+
+
+class _StreamError:
+ """Sentinel to signal an error from the generate thread."""
+
+ def __init__(self, msg: str):
+ self.msg = msg
+
+
+class _GenerationCancelled(Exception):
+ """Raised inside ``DirectStreamer.put()`` to abort ``model.generate()``."""
+
+
+# Model-specific tokens that mark the start/end of a tool call block.
+# TODO: extract these from the chat template at runtime instead of hardcoding.
+# Qwen/Hermes use /, Mistral uses [TOOL_CALLS], etc.
+# The markers are defined in each model's Jinja chat template.
+_TOOL_CALL_TOKENS = {
+ "qwen": {
+ "start": "",
+ "end": "",
+ },
+}
+
+
+def detect_tool_format(model: "PreTrainedModel") -> dict | None:
+ """Return the tool call token format for a model, if supported.
+
+ Args:
+ model (`PreTrainedModel`): The loaded model.
+
+ Returns:
+ `dict | None`: A dict ``{"start": str, "end": str}`` with the model's tool call
+ delimiters, or ``None`` if the model family is not recognized.
+ """
+ architecture = model.config.architectures[0].lower()
+ for family in _TOOL_CALL_TOKENS:
+ if family in architecture:
+ return _TOOL_CALL_TOKENS[family]
+ return None
+
+
+class ToolCallParser:
+ """Parses tool calls from model output.
+
+ The model emits tool calls as structured text between start/end tokens
+ (e.g. ``{"name": "fn", "arguments": {...}}``).
+
+ **Streaming** (``feed``): buffers tokens between start/end markers, parses
+ the complete block when the end marker is seen, returns a ``ChoiceDeltaToolCall``.
+
+ **Non-streaming** (``parse``): extracts all tool call blocks from complete text.
+
+ Usage::
+
+ parser = ToolCallParser(tool_format={"start": ..., "end": ...})
+ for text_chunk in streamer:
+ result = parser.feed(text_chunk)
+ if result is None:
+ # Normal text — emit as content
+ elif result is ToolCallParser.CONSUMED:
+ # Buffering — skip
+ else:
+ # result is a ChoiceDeltaToolCall — emit it
+ """
+
+ def __init__(self, tool_format: dict):
+ self._tokens = tool_format
+ self._inside = False
+ self._buffer = ""
+
+ # Sentinel: token was consumed by the parser but produced no output.
+ CONSUMED = object()
+
+ def feed(self, text: str) -> object | dict | None:
+ """Feed a text chunk (streaming).
+
+ Returns:
+ - ``None`` — normal text, not a tool token. Emit as content.
+ - ``CONSUMED`` — token consumed internally (buffering/markers). Skip.
+ - A ``ChoiceDeltaToolCall`` — emit as a tool call delta.
+ """
+ if text.strip() == self._tokens["start"]:
+ self._inside = True
+ self._buffer = ""
+ return self.CONSUMED
+
+ if text.strip() == self._tokens["end"]:
+ self._inside = False
+ block = self._buffer.strip()
+ self._buffer = ""
+ return self._parse_block(block) or self.CONSUMED
+
+ if self._inside:
+ self._buffer += text
+ return self.CONSUMED
+
+ return None
+
+ @staticmethod
+ def _extract_name_and_args(block: str) -> tuple[str, str] | None:
+ """Extract (name, arguments_json) from a tool call block, or None if invalid."""
+ if not block:
+ return None
+ parsed = json.loads(block)
+ name = parsed.get("name")
+ if name is None:
+ return None
+ arguments = parsed.get("arguments", {})
+ return name, json.dumps(arguments)
+
+ @staticmethod
+ def parse(text: str, tool_format: dict) -> list[dict] | None:
+ """Parse tool calls from complete text.
+
+ Returns a list of ``{"name": str, "arguments": str}`` dicts, or ``None`` if none found.
+ """
+ start, end = tool_format["start"], tool_format["end"]
+ tool_calls = []
+ pos = 0
+ while True:
+ s = text.find(start, pos)
+ if s < 0:
+ break
+ e = text.find(end, s + len(start))
+ if e < 0:
+ break
+ result = ToolCallParser._extract_name_and_args(text[s + len(start) : e].strip())
+ if result is not None:
+ tool_calls.append({"name": result[0], "arguments": result[1]})
+ pos = e + len(end)
+ return tool_calls if tool_calls else None
+
+ def _parse_block(self, block: str) -> dict | None:
+ """Parse a buffered tool call block. Returns ``{"name": str, "arguments": str}`` or None."""
+ result = self._extract_name_and_args(block)
+ if result is None:
+ return None
+ return {"name": result[0], "arguments": result[1]}
+
+
+class DownloadAggregator:
+ """Aggregates byte-progress across multiple concurrent download tqdm bars.
+
+ huggingface_hub opens one tqdm bar per file shard. This class tracks them all and emits
+ a single aggregate ``{"stage": "download", "progress": {...}}`` event whenever any updates.
+ """
+
+ def __init__(self, enqueue: Callable, model_id: str):
+ self.enqueue = enqueue
+ self.model = model_id
+ self.bars: dict[int, tuple[int, int | None]] = {}
+ self.last_emitted_current: int | None = None
+
+ def register(self, bar_id: int, total: int | None) -> None:
+ """Register a new download bar with its total byte count."""
+ self.bars[bar_id] = (0, total)
+ self._emit()
+
+ def update(self, bar_id: int, current: int, total: int | None) -> None:
+ """Update a bar's current byte count and emit aggregate progress."""
+ self.bars[bar_id] = (current, total)
+ self._emit()
+
+ def close(self, bar_id: int) -> None:
+ pass # keep the bar so totals remain correct
+
+ def _emit(self) -> None:
+ agg_current = sum(c for c, _ in self.bars.values())
+ if agg_current == self.last_emitted_current:
+ return
+ self.last_emitted_current = agg_current
+ totals = [t for _, t in self.bars.values() if t is not None]
+ agg_total = sum(totals) if totals else None
+ self.enqueue(
+ {
+ "status": "loading",
+ "model": self.model,
+ "stage": "download",
+ "progress": {"current": agg_current, "total": agg_total},
+ }
+ )
+
+
+def make_progress_tqdm_class(callback: Callable, model_id: str) -> type:
+ """Create a tqdm subclass that routes progress to a callback.
+
+ Bars with ``unit="B"`` are download bars — aggregated via ``DownloadAggregator``.
+ Other bars (e.g. "Loading weights") emit ``weights`` stage events.
+
+ Args:
+ callback (`callable`): Called with a dict payload
+ ``{"status": "loading", "model": ..., "stage": ..., "progress": ...}``.
+ model_id (`str`): The model ID (included in progress payloads).
+
+ Returns:
+ A tqdm subclass that forwards progress to *callback*.
+ """
+ from tqdm.auto import tqdm as base_tqdm
+
+ download_aggregator = DownloadAggregator(callback, model_id)
+
+ class ProgressTqdm(base_tqdm):
+ def __init__(self, *args, **kwargs):
+ self.sse_unit = kwargs.get("unit") or "it"
+ kwargs["disable"] = True
+ super().__init__(*args, **kwargs)
+ self.n = 0
+ self.last_emitted = -1
+ if self.sse_unit == "B":
+ self._bar_id = id(self)
+ download_aggregator.register(self._bar_id, self.total)
+
+ def update(self, n=1):
+ if n is None:
+ n = 1
+ self.n += n
+ if self.sse_unit == "B":
+ download_aggregator.update(self._bar_id, self.n, self.total)
+ elif self.n != self.last_emitted:
+ self.last_emitted = self.n
+ callback(
+ {
+ "status": "loading",
+ "model": model_id,
+ "stage": "weights",
+ "progress": {"current": self.n, "total": self.total},
+ }
+ )
+
+ def __iter__(self):
+ for item in self.iterable:
+ self.n += 1
+ if self.sse_unit == "B":
+ download_aggregator.update(self._bar_id, self.n, self.total)
+ elif self.n != self.last_emitted:
+ self.last_emitted = self.n
+ callback(
+ {
+ "status": "loading",
+ "model": model_id,
+ "stage": "weights",
+ "progress": {"current": self.n, "total": self.total},
+ }
+ )
+ yield item
+
+ def close(self):
+ if self.sse_unit == "B":
+ download_aggregator.close(self._bar_id)
+ super().close()
+
+ return ProgressTqdm
+
+
+class DirectStreamer:
+ """Streamer for ``model.generate()`` (used by :class:`GenerateManager`).
+
+ Implements the ``put``/``end`` protocol that ``model.generate()`` expects:
+ generate calls ``put(token_tensor)`` after each decode step, and ``end()``
+ when generation is complete. Tokens are decoded incrementally via the Rust
+ ``DecodeStream`` (O(1) per token) and pushed as text to an asyncio.Queue.
+ """
+
+ def __init__(
+ self,
+ tokenizer: "tokenizers.Tokenizer",
+ loop: asyncio.AbstractEventLoop,
+ queue: asyncio.Queue,
+ skip_special_tokens: bool = True,
+ ):
+ """
+ Args:
+ tokenizer: The Rust tokenizer (``tokenizer._tokenizer``).
+ loop (`asyncio.AbstractEventLoop`): The event loop to push decoded text to.
+ queue (`asyncio.Queue`): The queue that receives decoded text chunks.
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether to strip special tokens during decoding.
+ """
+ from tokenizers.decoders import DecodeStream
+
+ self._tokenizer = tokenizer
+ self._loop = loop
+ self._queue = queue
+ self._decode_stream = DecodeStream([], skip_special_tokens)
+ self._first = True
+ self._cancelled = threading.Event()
+ self.total_tokens = 0
+
+ def put(self, value: "torch.Tensor") -> None:
+ """Called by ``model.generate()`` after each decode step with new token(s)."""
+ if self._cancelled.is_set():
+ raise _GenerationCancelled()
+ # The first put() contains the prompt tokens — skip since we only stream generated tokens.
+ if self._first:
+ self._first = False
+ return
+ for token_id in value.tolist():
+ self.total_tokens += 1
+ text = self._decode_stream.step(self._tokenizer, token_id)
+ if text is not None:
+ self._loop.call_soon_threadsafe(self._queue.put_nowait, text)
+
+ def end(self) -> None:
+ """Called by ``model.generate()`` when generation is complete."""
+ self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
+
+ def cancel(self) -> None:
+ """Signal cancellation. The next ``put()`` call will raise and abort ``model.generate()``."""
+ self._cancelled.set()
+
+
+class CBStreamer:
+ """Streamer for continuous batching (used by :class:`CBGenerateManager`).
+
+ Same ``put``/``end`` protocol as :class:`DirectStreamer`, but called manually
+ by :class:`CBGenerateManager` instead of by ``model.generate()``:
+ ``put(output)`` receives a CB ``GenerationOutput``, decodes new tokens, and
+ pushes text to the asyncio.Queue. ``end()`` signals the stream is complete.
+ """
+
+ def __init__(
+ self,
+ cb_manager: "ContinuousBatchingManager",
+ request_id: str,
+ tokenizer: "tokenizers.Tokenizer",
+ loop: asyncio.AbstractEventLoop,
+ queue: asyncio.Queue,
+ ):
+ """
+ Args:
+ cb_manager (`ContinuousBatchingManager`): The CB manager instance.
+ request_id (`str`): The request ID to track in the CB scheduler.
+ tokenizer: The Rust tokenizer (``tokenizer._tokenizer``).
+ loop (`asyncio.AbstractEventLoop`): The event loop to push decoded text to.
+ queue (`asyncio.Queue`): The queue that receives decoded text chunks.
+ """
+ from tokenizers.decoders import DecodeStream
+
+ self._cb = cb_manager
+ self._request_id = request_id
+ self._loop = loop
+ self._queue = queue
+ self._tokenizer = tokenizer
+ self._decode_stream = DecodeStream([], True)
+ self._prev_len = 0
+ self.total_tokens = 0
+
+ def put(self, output: "GenerationOutput") -> None:
+ """Decode new tokens from a CB ``GenerationOutput`` and push text to the queue."""
+ new_tokens = output.generated_tokens[self._prev_len :]
+ self._prev_len = len(output.generated_tokens)
+ for token_id in new_tokens:
+ self.total_tokens += 1
+ text = self._decode_stream.step(self._tokenizer, token_id)
+ if text is not None:
+ self._queue.put_nowait(text)
+
+ def end(self) -> None:
+ """Signal end of stream."""
+ self._queue.put_nowait(None)
+
+ def cancel(self) -> None:
+ """Cancel the CB request."""
+ self._cb.cancel_request(self._request_id)
+
+
+def set_torch_seed(seed: int) -> None:
+ """Set the PyTorch random seed for reproducible generation."""
+ import torch
+
+ torch.manual_seed(seed)
+
+
+def reset_torch_cache() -> None:
+ """Empty the CUDA cache if a GPU is available."""
+ import torch
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+
+class InferenceThread:
+ """Persistent thread for ``model.generate()`` calls.
+
+ ``torch.compile`` with CUDA graphs stores state in thread-local storage.
+ All inference must run on the same thread to avoid corrupted graph state.
+ """
+
+ def __init__(self):
+ self._queue: Queue = Queue()
+ self._thread = threading.Thread(target=self._run, daemon=True)
+ self._thread.start()
+
+ def _run(self) -> None:
+ while True:
+ fn, args, kwargs, future, loop = self._queue.get()
+ try:
+ result = fn(*args, **kwargs)
+ if loop is not None:
+ loop.call_soon_threadsafe(future.set_result, result)
+ else:
+ future.set_result(result)
+ except Exception as e:
+ if loop is not None:
+ loop.call_soon_threadsafe(future.set_exception, e)
+ else:
+ future.set_exception(e)
+
+ def submit(self, fn, *args, **kwargs) -> Future:
+ """Submit a callable to the inference thread. Returns a blocking Future."""
+ future: Future = Future()
+ self._queue.put((fn, args, kwargs, future, None))
+ return future
+
+ def async_submit(self, fn, *args, **kwargs) -> asyncio.Future:
+ """Submit a callable to the inference thread. Returns an awaitable asyncio.Future."""
+ loop = asyncio.get_running_loop()
+ future = loop.create_future()
+ self._queue.put((fn, args, kwargs, future, loop))
+ return future
+
+
+class BaseGenerateManager(ABC):
+ """Base class for generation managers.
+
+ Subclasses:
+ - :class:`GenerateManager` — sequential ``model.generate()`` on a persistent thread.
+ - :class:`CBGenerateManager` — continuous batching with paged attention.
+ """
+
+ @abstractmethod
+ def generate_streaming(
+ self,
+ model: "PreTrainedModel",
+ processor: "ProcessorMixin | PreTrainedTokenizerFast",
+ inputs: dict,
+ gen_config: "GenerationConfig",
+ request_id: str,
+ ) -> tuple[asyncio.Queue, "DirectStreamer | CBStreamer"]:
+ """Start streaming generation.
+
+ Args:
+ model (`PreTrainedModel`): The loaded model.
+ processor: The processor or tokenizer for decoding.
+ inputs (`dict`): Tokenized inputs (tensors for sequential, lists for CB).
+ gen_config (`GenerationConfig`): Generation parameters.
+ request_id (`str`): Unique request identifier.
+
+ Returns:
+ `tuple[asyncio.Queue, DirectStreamer | CBStreamer]`: A ``(queue, streamer)`` pair
+ where *queue* yields ``str | _StreamError | None`` and *streamer* exposes
+ ``.total_tokens`` and ``.cancel()``.
+ """
+
+ @abstractmethod
+ def generate_non_streaming(
+ self,
+ model: "PreTrainedModel",
+ processor: "ProcessorMixin | PreTrainedTokenizerFast",
+ inputs: dict,
+ gen_config: "GenerationConfig",
+ request_id: str,
+ ) -> tuple[str, int, list[int]]:
+ """Run generation to completion.
+
+ Args:
+ model (`PreTrainedModel`): The loaded model.
+ processor: The processor or tokenizer for decoding.
+ inputs (`dict`): Tokenized inputs (tensors for sequential, lists for CB).
+ gen_config (`GenerationConfig`): Generation parameters.
+ request_id (`str`): Unique request identifier.
+
+ Returns:
+ `tuple[str, int, list[int]]`: ``(text, input_len, generated_ids)``.
+ """
+
+ @abstractmethod
+ def stop(self) -> None:
+ """Stop the generation manager and free resources."""
+
+
+class GenerateManager(BaseGenerateManager):
+ """Sequential generation via ``model.generate()`` on a persistent thread."""
+
+ def __init__(self):
+ self._thread = InferenceThread()
+
+ def generate_streaming(
+ self,
+ model: "PreTrainedModel",
+ processor: "ProcessorMixin | PreTrainedTokenizerFast",
+ inputs: dict,
+ gen_config: "GenerationConfig",
+ request_id: str,
+ ) -> tuple[asyncio.Queue, DirectStreamer]:
+ """Start streaming generation via ``model.generate()`` on the inference thread."""
+ loop = asyncio.get_running_loop()
+ queue: asyncio.Queue = asyncio.Queue()
+ streamer = DirectStreamer(processor._tokenizer, loop, queue, skip_special_tokens=True)
+ gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor}
+
+ def _run() -> None:
+ try:
+ model.generate(**gen_kwargs)
+ except _GenerationCancelled:
+ loop.call_soon_threadsafe(queue.put_nowait, None)
+ except Exception as e:
+ loop.call_soon_threadsafe(queue.put_nowait, _StreamError(str(e)))
+
+ self.submit(_run)
+ return queue, streamer
+
+ async def generate_non_streaming(
+ self,
+ model: "PreTrainedModel",
+ processor: "ProcessorMixin | PreTrainedTokenizerFast",
+ inputs: dict,
+ gen_config: "GenerationConfig",
+ request_id: str,
+ ) -> tuple[str, int, "torch.Tensor"]:
+ """Run generation to completion via ``model.generate()`` on the inference thread."""
+ sequences = await self.async_submit(
+ model.generate, **inputs, generation_config=gen_config, tokenizer=processor
+ )
+ input_len = inputs["input_ids"].shape[-1]
+ generated_ids = sequences[0, input_len:]
+ text = processor.decode(generated_ids, skip_special_tokens=True)
+ return text, input_len, generated_ids
+
+ def submit(self, fn: Callable, *args, **kwargs) -> Future:
+ """Submit a callable to the inference thread. Returns a blocking Future."""
+ return self._thread.submit(fn, *args, **kwargs)
+
+ def async_submit(self, fn: Callable, *args, **kwargs) -> asyncio.Future:
+ """Submit a callable to the inference thread. Returns an awaitable asyncio.Future."""
+ return self._thread.async_submit(fn, *args, **kwargs)
+
+ def stop(self) -> None:
+ pass # inference thread is a daemon
+
+
+class CBGenerateManager(BaseGenerateManager):
+ """Continuous batching generation via paged attention.
+
+ Translates between the handler's text-level asyncio.Queue and CB's
+ token-level interface. Per-request: ``max_new_tokens``, ``eos_token_id``.
+
+ The CB manager is initialized lazily on the first request via
+ :meth:`ensure_initialized`, using that request's ``gen_config`` for shared
+ sampling params (temperature, top_p, do_sample).
+
+ .. todo:: Remove :meth:`init_cb` when CB supports per-request
+ generation config. At that point, ``gen_config`` can be passed directly
+ to ``add_request`` and the CB manager no longer needs a shared config.
+ """
+
+ def __init__(self, cb_config: "ContinuousBatchingConfig | None" = None):
+ self._cb = None
+ self._cb_config = cb_config
+
+ def init_cb(self, model: "PreTrainedModel", gen_config: "GenerationConfig") -> None:
+ """Initialize the CB manager on first call with the request's generation config.
+
+ .. todo:: Remove when CB supports per-request generation config.
+
+ Args:
+ model (`PreTrainedModel`): The loaded model (must support ``init_continuous_batching``).
+ gen_config (`GenerationConfig`): Generation config used for shared sampling params.
+ """
+ if self._cb is not None:
+ return
+
+ self._cb = model.init_continuous_batching(
+ generation_config=gen_config, continuous_batching_config=self._cb_config
+ )
+ self._cb.start()
+
+ def generate_streaming(
+ self,
+ model: "PreTrainedModel",
+ processor: "ProcessorMixin | PreTrainedTokenizerFast",
+ inputs: dict,
+ gen_config: "GenerationConfig",
+ request_id: str,
+ ) -> tuple[asyncio.Queue, CBStreamer]:
+ """Start streaming CB generation. Registers a per-request output handler."""
+ loop = asyncio.get_running_loop()
+ text_queue: asyncio.Queue = asyncio.Queue()
+
+ input_ids = inputs["input_ids"]
+ request_id = self._cb.add_request(
+ input_ids,
+ request_id=request_id,
+ streaming=True,
+ max_new_tokens=gen_config.max_new_tokens,
+ eos_token_id=gen_config.eos_token_id,
+ )
+ streamer = CBStreamer(self._cb, request_id, processor._tokenizer, loop, text_queue)
+
+ # Register a direct callback: the dispatcher calls this on the event loop with each GenerationOutput.
+ # This decodes tokens and pushes text straight to the SSE text_queue
+ def _on_output(output):
+ try:
+ streamer.put(output)
+ if output.is_finished():
+ streamer.end()
+ except Exception as e:
+ text_queue.put_nowait(_StreamError(str(e)))
+
+ self._cb.register_result_handler(request_id, _on_output)
+ return text_queue, streamer
+
+ async def generate_non_streaming(
+ self,
+ model: "PreTrainedModel",
+ processor: "ProcessorMixin | PreTrainedTokenizerFast",
+ inputs: dict,
+ gen_config: "GenerationConfig",
+ request_id: str,
+ ) -> tuple[str, int, list[int]]:
+ """Run non-streaming CB generation. Registers a handler that resolves an asyncio.Future on completion."""
+ input_ids = inputs["input_ids"]
+ input_len = len(input_ids)
+
+ # Register future BEFORE add_request to avoid race with fast completion
+ loop = asyncio.get_running_loop()
+ future = loop.create_future()
+
+ def _on_result(result):
+ if not future.done():
+ future.set_result(result)
+
+ self._cb.register_result_handler(request_id, _on_result)
+
+ self._cb.add_request(
+ input_ids,
+ request_id=request_id,
+ max_new_tokens=gen_config.max_new_tokens,
+ streaming=False,
+ eos_token_id=gen_config.eos_token_id,
+ )
+ result = await future
+ if result is None:
+ raise RuntimeError(f"CB manager stopped before producing a result for {request_id}")
+ generated_ids = result.generated_tokens
+ text = processor.decode(generated_ids, skip_special_tokens=True)
+ return text, input_len, generated_ids
+
+ @property
+ def scheduler(self) -> "Scheduler":
+ """The CB scheduler (for testing/monitoring)."""
+ return self._cb.batch_processor.scheduler
+
+ def stop(self) -> None:
+ if self._cb is not None:
+ self._cb.stop(block=True, timeout=2)
+
+
+class GenerationState:
+ """Shared generation state across all handlers.
+
+ Manages per-model :class:`GenerateManager` instances (each with its own
+ :class:`InferenceThread` so different models can run concurrently while
+ ``torch.compile`` / CUDA graphs require same-model-same-thread) and a
+ single :class:`CBGenerateManager` for continuous batching.
+
+ Args:
+ continuous_batching (`bool`, *optional*, defaults to `False`):
+ Whether to use continuous batching with paged attention instead of
+ sequential ``model.generate()`` calls.
+ """
+
+ def __init__(
+ self,
+ continuous_batching: bool = False,
+ compile: bool = False,
+ cb_config: "ContinuousBatchingConfig | None" = None,
+ ):
+ self._continuous_batching = continuous_batching
+ self._compile = compile
+ self._cb_config = cb_config
+ self._generate_managers: dict[str, GenerateManager] = {}
+ self._cb_manager: CBGenerateManager | None = None
+ self._cb_model_id: str | None = None
+
+ def use_continuous_batching(self, model: "PreTrainedModel", modality: Modality) -> bool:
+ """Check if continuous batching can be used for this model and modality.
+
+ Args:
+ model (`PreTrainedModel`): The loaded model.
+ modality (`Modality`): The detected model modality (LLM, VLM, etc.).
+
+ Returns:
+ `bool`: ``True`` if CB is enabled and the model supports it, ``False`` otherwise.
+ """
+ if not self._continuous_batching:
+ return False
+ can = hasattr(model, "init_continuous_batching") and modality == Modality.LLM
+ if not can:
+ logger.warning_once(
+ f"{model.__class__.__name__} does not support continuous batching. "
+ "Falling back to sequential generation."
+ )
+ return can
+
+ def get_manager(self, model_id: str, use_cb: bool = False) -> BaseGenerateManager:
+ """Return a per-model generation manager, lazily created on first request.
+
+ Args:
+ model_id (`str`): The model ID in ``'model_id@revision'`` format.
+ use_cb (`bool`): Whether to return a CB manager or a sequential one.
+
+ Returns:
+ `BaseGenerateManager`: Either a `GenerateManager` or `CBGenerateManager`.
+ """
+ if use_cb:
+ if self._cb_model_id != model_id:
+ if self._cb_manager is not None:
+ self._cb_manager.stop()
+ self._cb_manager = None
+ if self._cb_manager is None:
+ self._cb_manager = CBGenerateManager(cb_config=self._cb_config)
+ self._cb_model_id = model_id
+ return self._cb_manager
+ if model_id not in self._generate_managers:
+ self._generate_managers[model_id] = GenerateManager()
+ return self._generate_managers[model_id]
+
+ def shutdown(self) -> None:
+ """Stop any active generation managers."""
+ if self._cb_manager is not None:
+ self._cb_manager.stop()
+ self._cb_manager = None
+
+
+class BaseHandler:
+ """Shared logic for chat completion and responses handlers.
+
+ Provides model resolution, generation config building, and SSE formatting.
+ Generation is delegated to the shared :class:`GenerationState`.
+
+ Args:
+ model_manager (`ModelManager`):
+ Handles model loading, caching, and lifecycle.
+ generation_state (`GenerationState`):
+ Shared state managing per-model generation managers.
+ """
+
+ _valid_params_class: type | None = None
+ _unused_fields: set[str] = set()
+
+ def __init__(
+ self,
+ model_manager: "ModelManager",
+ generation_state: GenerationState,
+ ):
+ self.model_manager = model_manager
+ self.generation_state = generation_state
+
+ def _validate_request(self, body: dict) -> None:
+ """Validate request fields against the handler's params class and unused fields."""
+ from fastapi import HTTPException
+
+ input_keys = set(body.keys())
+ if self._valid_params_class is not None:
+ unexpected = input_keys - self._valid_params_class.__mutable_keys__
+ if unexpected:
+ raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}")
+ unused = input_keys & self._unused_fields
+ if unused:
+ logger.warning_once(f"Ignoring unsupported fields in the request: {unused}")
+
+ @staticmethod
+ def chunk_to_sse(chunk: "str | pydantic.BaseModel") -> str:
+ """Format a pydantic model or string as an SSE ``data:`` line."""
+ if isinstance(chunk, str):
+ return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n"
+ return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n"
+
+ def _resolve_model(self, body: dict) -> tuple[str, "PreTrainedModel", "ProcessorMixin | PreTrainedTokenizerFast"]:
+ """Apply force_model, load model + processor.
+
+ Returns ``(model_id, model, processor)``.
+ """
+ if self.model_manager.force_model is not None:
+ body["model"] = self.model_manager.force_model
+
+ model_id = self.model_manager.process_model_name(body["model"])
+ model, processor = self.model_manager.load_model_and_processor(model_id)
+
+ return model_id, model, processor
+
+ def _build_generation_config(
+ self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False
+ ) -> "GenerationConfig":
+ """Build a GenerationConfig from shared params (temperature, top_p, seed, generation_config JSON).
+
+ Subclasses should call ``super()._build_generation_config(...)`` then apply
+ endpoint-specific params (``max_tokens``, ``max_output_tokens``, etc.).
+
+ Args:
+ body (`dict`):
+ The raw request body.
+ model_generation_config (`GenerationConfig`):
+ The model's default generation config (will be deep-copied).
+ use_cb (`bool`, *optional*, defaults to `False`):
+ Whether continuous batching is active. If ``True``, disables the model's
+ internal KV cache (CB manages its own paged cache).
+
+ Returns:
+ `GenerationConfig`: A new config with request-specific overrides applied.
+ """
+ from transformers import GenerationConfig
+
+ if body.get("generation_config") is not None:
+ generation_config = GenerationConfig(**json.loads(body["generation_config"]))
+ else:
+ generation_config = copy.deepcopy(model_generation_config)
+ if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 1024:
+ generation_config.max_new_tokens = 1024
+
+ if body.get("temperature") is not None:
+ generation_config.temperature = float(body["temperature"])
+ if float(body["temperature"]) == 0.0:
+ generation_config.do_sample = False
+ if body.get("top_p") is not None:
+ generation_config.top_p = float(body["top_p"])
+ if body.get("seed") is not None:
+ set_torch_seed(body["seed"])
+
+ # --compile flag: use static cache + torch.compile for faster decode
+ if self.generation_state._compile and generation_config.cache_implementation is None:
+ generation_config.cache_implementation = "static"
+
+ # CB manages its own paged KV cache
+ if use_cb:
+ generation_config.use_cache = False
+
+ # TODO: add prefix caching for the non-CB path (reuse KV cache across multi-turn conversations)
+
+ return generation_config
+
+ @staticmethod
+ def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) -> list[dict]:
+ """Convert OpenAI-format messages to the format expected by HF processors.
+
+ For LLMs, collapses list content blocks into plain text. For VLMs, converts
+ ``image_url`` content parts (including base64) into ``{"type": "image", "url": ...}``
+ entries that HF processors understand.
+
+ Args:
+ messages (`list[dict]`): OpenAI-format chat messages.
+ modality (`Modality`): Whether the model is an LLM or VLM.
+
+ Returns:
+ `list[dict]`: Processor-compatible messages.
+ """
+ processor_inputs = []
+
+ for message in messages:
+ parsed = {"role": message["role"], "content": []}
+
+ if modality == Modality.LLM:
+ if isinstance(message["content"], str):
+ parsed["content"] = message["content"]
+ elif isinstance(message["content"], list):
+ texts = [c["text"] for c in message["content"] if c["type"] == "text"]
+ parsed["content"] = " ".join(texts)
+
+ elif modality == Modality.VLM:
+ if isinstance(message["content"], str):
+ parsed["content"].append({"type": "text", "text": message["content"]})
+ else:
+ for content in message["content"]:
+ if content["type"] == "text":
+ parsed["content"].append(content)
+ elif content["type"] == "image_url":
+ from PIL import Image
+
+ url = content["image_url"]["url"]
+ if "base64" in url:
+ image_data = re.sub("^data:image/.+;base64,", "", url)
+ image = Image.open(BytesIO(base64.b64decode(image_data)))
+ file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
+ image.save(file.name)
+ url = file.name
+ parsed["content"].append({"type": "image", "url": url})
+
+ processor_inputs.append(parsed)
+ return processor_inputs
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index bdbf213412fe..fdc730df55c8 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -115,6 +115,7 @@
is_liger_kernel_available,
is_lomo_available,
is_mistral_common_available,
+ is_multipart_available,
is_natten_available,
is_nltk_available,
is_numba_available,
@@ -140,6 +141,7 @@
is_scipy_available,
is_sentencepiece_available,
is_seqio_available,
+ is_serve_available,
is_spacy_available,
is_speech_available,
is_spqr_available,
@@ -1414,6 +1416,13 @@ def require_librosa(test_case):
return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
+def require_multipart(test_case):
+ """
+ Decorator marking a test that requires python-multipart
+ """
+ return unittest.skipUnless(is_multipart_available(), "test requires python-multipart")(test_case)
+
+
def require_liger_kernel(test_case):
"""
Decorator marking a test that requires liger_kernel
@@ -1497,6 +1506,13 @@ def require_openai(test_case):
return unittest.skipUnless(is_openai_available(), "test requires openai")(test_case)
+def require_serve(test_case):
+ """
+ Decorator marking a test that requires the serving dependencies (fastapi, uvicorn, pydantic, openai).
+ """
+ return unittest.skipUnless(is_serve_available(), "test requires serving dependencies")(test_case)
+
+
def require_mistral_common(test_case):
"""
Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available.
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index 7b8bfb80ec19..3f5c7cac386b 100644
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -164,6 +164,7 @@
is_matplotlib_available,
is_mistral_common_available,
is_mlx_available,
+ is_multipart_available,
is_natten_available,
is_ninja_available,
is_nltk_available,
@@ -198,6 +199,7 @@
is_scipy_available,
is_sentencepiece_available,
is_seqio_available,
+ is_serve_available,
is_sinq_available,
is_sklearn_available,
is_soundfile_available,
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index 5d17f9a01771..1e1ac2545f05 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -741,6 +741,11 @@ def is_librosa_available() -> bool:
return _is_package_available("librosa")[0]
+@lru_cache
+def is_multipart_available() -> bool:
+ return _is_package_available("multipart")[0]
+
+
@lru_cache
def is_essentia_available() -> bool:
return _is_package_available("essentia")[0]
@@ -766,6 +771,11 @@ def is_openai_available() -> bool:
return _is_package_available("openai")[0]
+@lru_cache
+def is_serve_available() -> bool:
+ return is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available()
+
+
@lru_cache
def is_pretty_midi_available() -> bool:
return _is_package_available("pretty_midi")[0]
diff --git a/tests/cli/test_serve.py b/tests/cli/test_serve.py
index 33f33cce7aab..9be3dbeb99ff 100644
--- a/tests/cli/test_serve.py
+++ b/tests/cli/test_serve.py
@@ -11,61 +11,80 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Tests for the serving layer.
+
+"""
+
+import asyncio
import json
import os
-import tempfile
+import socket
import time
import unittest
-from threading import Thread
-from unittest.mock import Mock, patch
+from unittest.mock import MagicMock
import httpx
-from huggingface_hub import ChatCompletionStreamOutput, InferenceClient, hf_hub_download
-from parameterized import parameterized
-
-from transformers import GenerationConfig
-from transformers.cli.serve import Modality, Serve
-from transformers.testing_utils import require_openai, slow
-from transformers.utils.import_utils import (
- is_fastapi_available,
- is_openai_available,
- is_pydantic_available,
- is_uvicorn_available,
+
+from transformers.cli.serve import Serve
+from transformers.cli.serving.chat_completion import ChatCompletionHandler
+from transformers.cli.serving.model_manager import ModelManager, TimedModel
+from transformers.cli.serving.response import ResponseHandler, compute_usage
+from transformers.cli.serving.server import build_server
+from transformers.cli.serving.transcription import TranscriptionHandler
+from transformers.cli.serving.utils import (
+ BaseHandler,
+ GenerationState,
+ Modality,
+ ToolCallParser,
+ detect_tool_format,
+)
+from transformers.testing_utils import (
+ require_librosa,
+ require_multipart,
+ require_serve,
+ require_torch_accelerator,
+ require_vision,
+ slow,
)
+from transformers.utils.import_utils import is_serve_available
-serve_dependencies_available = (
- is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available()
-)
+if is_serve_available():
+ from fastapi import HTTPException
+ from openai import OpenAI
+ from openai.types.responses import Response, ResponseCreatedEvent
+
+
+def _find_free_port() -> int:
+ """Return a free TCP port on localhost."""
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("localhost", 0))
+ return s.getsockname()[1]
+
-if serve_dependencies_available:
- from openai import APIConnectionError, OpenAI
- from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
- from openai.types.responses import (
- Response,
- ResponseCompletedEvent,
- ResponseContentPartAddedEvent,
- ResponseContentPartDoneEvent,
- ResponseCreatedEvent,
- ResponseInProgressEvent,
- ResponseOutputItemAddedEvent,
- ResponseOutputItemDoneEvent,
- ResponseTextDeltaEvent,
- ResponseTextDoneEvent,
- )
-
-
-@require_openai
-def test_help(cli):
- """Minimal test: we can invoke the help command."""
- output = cli("serve", "--help")
- assert output.exit_code == 0
- assert "serve" in output.output
-
-
-@require_openai
+def _start_serve(**kwargs) -> tuple["Serve", int]:
+ """Start a non-blocking Serve instance on a free port and wait until healthy.
+
+ Returns ``(serve, port)``.
+ """
+ port = _find_free_port()
+ serve = Serve(port=port, non_blocking=True, **kwargs)
+ for _ in range(30):
+ try:
+ if httpx.get(f"http://localhost:{port}/health", timeout=2).status_code == 200:
+ return serve, port
+ except Exception: # noqa: S110
+ pass
+ time.sleep(1)
+ raise RuntimeError(f"Server on port {port} did not become healthy in time")
+
+
+@require_serve
def test_host_port_blocking(cli):
- """Minimal test: we can set arguments through the CLI - blocking"""
+ """CLI args --host and --port are passed to uvicorn.Config, and server.run() is called."""
+ from unittest.mock import Mock, patch
+
with (
patch("uvicorn.Config") as ConfigMock,
patch("uvicorn.Server") as ServerMock,
@@ -73,912 +92,1217 @@ def test_host_port_blocking(cli):
server_instance = Mock()
ServerMock.return_value = server_instance
- # Call the serve CLI with host/port
out = cli("serve", "--host", "0.0.0.0", "--port", "9000")
_, kwargs = ConfigMock.call_args
assert out.exit_code == 0
assert kwargs["host"] == "0.0.0.0"
assert kwargs["port"] == 9000
-
ServerMock.assert_called_once_with(ConfigMock.return_value)
server_instance.run.assert_called_once()
-@require_openai
-def test_host_port_non_blocking(cli, caplog):
- """Minimal test: we can set arguments through the CLI - non-blocking"""
- caplog.set_level(100000)
- # ^ hack to avoid an issue happening only in CI. We don't check logs anyway so it's fine.
- # Source: https://github.com/pallets/click/issues/824#issuecomment-562581313
+class TestProcessorInputsFromMessages(unittest.TestCase):
+ def test_llm_string_content(self):
+ get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
- with (
- patch("uvicorn.Config") as ConfigMock,
- patch("uvicorn.Server") as ServerMock,
- patch.object(Serve, "start_server") as start_mock,
- ):
- server_instance = Mock()
- ServerMock.return_value = server_instance
+ messages = [{"role": "user", "content": "Hello"}]
+ result = get_processor_inputs_from_messages(messages, Modality.LLM)
+ self.assertEqual(result, [{"role": "user", "content": "Hello"}])
- out = cli("serve", "--host", "0.5.0.0", "--port", "9002", "--non-blocking")
- assert out.exit_code == 0
+ def test_llm_list_content_text_only(self):
+ get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
- # Config got the CLI args
- _, kwargs = ConfigMock.call_args
- assert kwargs["host"] == "0.5.0.0"
- assert kwargs["port"] == 9002
+ messages = [{"role": "user", "content": [{"type": "text", "text": "A"}, {"type": "text", "text": "B"}]}]
+ result = get_processor_inputs_from_messages(messages, Modality.LLM)
+ self.assertEqual(result, [{"role": "user", "content": "A B"}])
- # Non-blocking path uses start_server(), not server.run()
- start_mock.assert_called_once()
- server_instance.run.assert_not_called()
+ def test_vlm_string_content_wrapped(self):
+ get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
+ messages = [{"role": "user", "content": "Hello"}]
+ result = get_processor_inputs_from_messages(messages, Modality.VLM)
+ self.assertEqual(result, [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}])
-@require_openai
-def test_build_chat_completion_chunk():
- """
- Tests that the chunks are correctly built for the Chat Completion API. The `choices` checks implicitly
- confirm that empty fields are not emitted.
- """
- dummy = Serve.__new__(Serve)
-
- # The keys for these fields must be present in every chunk
- MANDATORY_FIELDS = ["data", "id", "choices", "created", "model", "object", "system_fingerprint"]
-
- # Case 1: most fields are provided
- chunk = dummy.build_chat_completion_chunk(
- request_id="req0", content="hello", finish_reason="stop", role="user", model="dummy_model@main"
- )
- chunk = dummy.chunk_to_sse_element(chunk)
- for field in MANDATORY_FIELDS:
- assert field in chunk
- assert '"choices":[{"delta":{"content":"hello","role":"user"},"finish_reason":"stop","index":0}]' in chunk
-
- # Case 2: only the role is provided -- other fields in 'choices' are omitted
- chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user", model="dummy_model@main")
- chunk = dummy.chunk_to_sse_element(chunk)
- for field in MANDATORY_FIELDS:
- assert field in chunk
- assert '"choices":[{"delta":{"role":"user"},"index":0}]' in chunk
-
- # Case 3: only the content is provided -- other fields in 'choices' are omitted
- chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello", model="dummy_model@main")
- chunk = dummy.chunk_to_sse_element(chunk)
- for field in MANDATORY_FIELDS:
- assert field in chunk
- assert '"choices":[{"delta":{"content":"hello"},"index":0}]' in chunk
-
- # Case 4: tool calls support a list of ChoiceDeltaToolCall objects
- tool_call = ChoiceDeltaToolCall(
- index=0,
- function=ChoiceDeltaToolCallFunction(name="foo_bar", arguments='{"foo1": "bar1", "foo2": "bar2"}'),
- type="function",
- )
- chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call], model="dummy_model@main")
- chunk = dummy.chunk_to_sse_element(chunk)
- for field in MANDATORY_FIELDS:
- assert field in chunk
- expected_choices_content = (
- 'choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"foo1\\": \\"bar1\\", '
- '\\"foo2\\": \\"bar2\\"}","name":"foo_bar"},"type":"function"}]},"index":0}]'
- )
- assert expected_choices_content in chunk
-
-
-def test_generative_model_list():
- with tempfile.TemporaryDirectory() as cache_dir:
- # "download" a few models, including some non-generative models
- hf_hub_download("Menlo/Jan-nano", "config.json", cache_dir=cache_dir)
- hf_hub_download("Menlo/Jan-nano-128k", "config.json", cache_dir=cache_dir)
- hf_hub_download("Qwen/Qwen2.5-0.5B-Instruct", "config.json", cache_dir=cache_dir)
- hf_hub_download("HuggingFaceTB/SmolVLM-Instruct", "config.json", cache_dir=cache_dir)
- hf_hub_download("google-bert/bert-base-cased", "config.json", cache_dir=cache_dir)
-
- expected_results = {
- "HuggingFaceTB/SmolVLM-Instruct": ["HuggingFaceTB", "SmolVLM-Instruct"],
- "Qwen/Qwen2.5-0.5B-Instruct": ["Qwen", "Qwen2.5-0.5B-Instruct"],
- "Menlo/Jan-nano": ["Menlo", "Jan-nano"],
- "Menlo/Jan-nano-128k": ["Menlo", "Jan-nano-128k"],
- }
+ def test_vlm_text_and_image_url(self):
+ get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
- # list models
- result = Serve.get_gen_models(cache_dir)
- assert len(expected_results) == len(result)
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What is this?"},
+ {"type": "image_url", "image_url": {"url": "https://example.com/img.png"}},
+ ],
+ }
+ ]
+ result = get_processor_inputs_from_messages(messages, Modality.VLM)
+ self.assertEqual(len(result[0]["content"]), 2)
+ self.assertEqual(result[0]["content"][0]["type"], "text")
+ self.assertEqual(result[0]["content"][1], {"type": "image", "url": "https://example.com/img.png"})
- local_repos = {repo["id"]: repo["owned_by"] for repo in result}
+ def test_llm_multi_turn_conversation(self):
+ """Multi-turn conversation with string content should pass through as-is."""
- for key, value in expected_results.items():
- assert key in local_repos
- assert local_repos[key] == value
+ get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
+ messages = [
+ {"role": "user", "content": "How are you?"},
+ {"role": "assistant", "content": "I'm great!"},
+ {"role": "user", "content": "Help me write tests?"},
+ ]
+ result = get_processor_inputs_from_messages(messages, Modality.LLM)
+ self.assertEqual(len(result), 3)
+ self.assertEqual(result[0]["content"], "How are you?")
+ self.assertEqual(result[1]["role"], "assistant")
+ self.assertEqual(result[2]["content"], "Help me write tests?")
-@require_openai
-def test_build_response_event():
- """
- Tests that the events are correctly built for the Response API.
+ def test_llm_list_content_with_type(self):
+ """LLM messages with typed content list should extract text and join."""
- Contrarily to the Chat Completion API, the Response API has a wide set of possible output objects. This test
- only checks a few basic assumptions -- we rely on OpenAI's pydantic models to enforce the correct schema.
- """
- dummy = Serve.__new__(Serve)
-
- response_created = ResponseCreatedEvent(
- type="response.created",
- sequence_number=0,
- response=Response(
- id="resp_0",
- created_at=time.time(),
- status="queued",
- model="dummy_model@main",
- instructions=None, # <--- is set to None = should NOT be in the output.
- text={"format": {"type": "text"}},
- object="response",
- tools=[], # <--- empty lists should be in the output (they are often mandatory fields)
- output=[],
- parallel_tool_calls=False,
- tool_choice="auto",
- metadata=None,
- ),
- )
-
- event = dummy.chunk_to_sse_element(response_created)
- assert event.startswith("data: ") # Sanity check: event formatting
- assert '"model":"dummy_model@main"' in event # Sanity check: set field
- assert '"status":"queued"' in event
- assert "tools" in event # empty lists should be in the output
- assert "output" in event
- assert "instructions" not in event # None fields should NOT be in the output
- assert "metadata" not in event
- assert "error" not in event # Unset optional fields should NOT be in the output
- assert "top_p" not in event
-
-
-def retry(fn, max_attempts=5, delay=2):
- """
- Retry a function up to `max_attempts` times with a `delay` between attempts.
- Useful for testing functions that may fail due to server not being ready.
- """
+ get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
- def wrapper(*args, **kwargs):
- nb_attempts = 0
- while True:
- nb_attempts += 1
- try:
- return fn(*args, **kwargs)
- except (httpx.HTTPError, APIConnectionError):
- if nb_attempts >= max_attempts:
- raise
- time.sleep(delay)
+ messages = [
+ {"role": "user", "content": [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]}
+ ]
+ result = get_processor_inputs_from_messages(messages, Modality.LLM)
+ self.assertEqual(result[0]["content"], "Hello world")
- return wrapper
+ @require_vision
+ def test_vlm_base64_image_creates_temp_file(self):
+ """Base64 image URLs should be decoded and saved to a temp file."""
+ get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
-class ServeCompletionsMixin:
- """
- Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API
- (`generate` and `continuous_batching`).
- """
+ # Minimal valid 1x1 PNG as base64
+ base64_url = (
+ "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4"
+ "2mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
+ )
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What is this?"},
+ {"type": "image_url", "image_url": {"url": base64_url}},
+ ],
+ }
+ ]
+ result = get_processor_inputs_from_messages(messages, Modality.VLM)
+ image_item = result[0]["content"][1]
+ self.assertEqual(image_item["type"], "image")
+ self.assertTrue(os.path.exists(image_item["url"])) # temp file was created
- @retry
- def run_server(self, request):
- with InferenceClient(f"http://localhost:{self.port}") as client:
- return list(client.chat_completion(**request))
-
- @parameterized.expand(
- [
- ("default_request", {}),
- ("one_token", {"max_tokens": 1}),
- ("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}),
- (
- "tool_call",
- {
- "tools": [
- {
- "function": {
- "name": "foo_bar",
- "parameters": {"type": "object"},
- "description": "Foo bar",
- },
- "type": "function",
- }
- ]
- },
- ),
+ def test_vlm_multi_turn(self):
+ """VLM multi-turn: string content should be wrapped in text type."""
+
+ get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
+
+ messages = [
+ {"role": "user", "content": "Describe the image"},
+ {"role": "assistant", "content": "It shows a cat"},
+ {"role": "user", "content": "What color?"},
]
- )
- def test_requests(self, test_name: str, request_flags: dict):
- """Tests that the completions app gracefully handles GOOD requests, producing the expected output payloads."""
-
- request = {
- "model": "Qwen/Qwen3-0.6B",
- "messages": [{"role": "user", "content": "Hello, how are you?"}],
- "stream": True, # We don't support "stream": False yet
- "max_tokens": 5, # Small generation by default
- }
- request.update(request_flags)
- all_payloads = self.run_server(request)
-
- # If a request is successful, the returned payload needs to follow the schema, which we test here.
- # NOTE: the output of our server is wrapped by `InferenceClient`, which sends fields even when they
- # are empty.
-
- # Finish reason: the last payload should have a finish reason of "length" or "stop", all others should be empty
- finish_reasons = [payload.choices[0].finish_reason for payload in all_payloads]
- self.assertTrue(finish_reasons[-1] in ["length", "stop"])
- self.assertTrue(all(reason is None for reason in finish_reasons[:-1]))
-
- # Role: the first payload should have a role of "assistant", all others should be empty
- roles = [payload.choices[0].delta.role for payload in all_payloads]
- self.assertEqual(roles[0], "assistant")
- self.assertTrue(all(role is None for role in roles[1:]))
-
- # Content: the first and the last payload shouldn't have content (role and finish reason). It may be empty
- # in some other payload positions, e.g. tool calls.
- contents = [payload.choices[0].delta.content for payload in all_payloads]
- self.assertTrue(contents[0] is None and contents[-1] is None)
- self.assertTrue(any(content is not None for content in contents[1:-1]))
- # TODO: add "usage" field to output and test it
-
- def test_generation_config_in_request(self):
- """Tests that the generation config is correctly passed into the generation call."""
- generation_config = GenerationConfig(do_sample=False, temperature=0.0)
- request = {
- "model": "Qwen/Qwen3-0.6B",
- "messages": [{"role": "user", "content": "Hello, how are you?"}],
- "stream": True,
- "max_tokens": 10,
- "extra_body": {
- "generation_config": generation_config.to_json_string(),
- },
- }
- all_payloads = self.run_server(request)
- contents = [payload.choices[0].delta.content for payload in all_payloads]
- output_text = "".join([text for text in contents if text is not None])
- # The generation config sets greedy decoding, so the output is reproducible. By default, `Qwen/Qwen3-0.6B`
- # sets `do_sample=True`
- self.assertEqual(output_text, '\nOkay, the user just asked, "')
+ result = get_processor_inputs_from_messages(messages, Modality.VLM)
+ self.assertEqual(len(result), 3)
+ for msg in result:
+ self.assertIsInstance(msg["content"], list)
+ self.assertEqual(msg["content"][0]["type"], "text")
- def test_early_return_due_to_length(self):
- request = {
- "model": "Qwen/Qwen2.5-0.5B-Instruct",
- "messages": [{"role": "user", "content": "Hello, how are you?"}],
- "stream": True,
- "max_tokens": 3,
- }
- all_payloads = self.run_server(request)
- last_payload = all_payloads[-1]
- self.assertTrue(last_payload.choices[0]["finish_reason"] == "length")
+class TestGenerativeModelList(unittest.TestCase):
+ def test_lists_only_generative_models(self):
+ """Should list LLMs and VLMs but not non-generative models like BERT."""
+ import tempfile
- def test_continues_until_stop(self):
- request = {
- "model": "Qwen/Qwen2.5-0.5B-Instruct",
- "messages": [{"role": "user", "content": 'Please only answer with "Hi."'}],
- "stream": True,
- "max_tokens": 30,
- }
+ from huggingface_hub import hf_hub_download
- all_payloads = self.run_server(request)
- last_payload = all_payloads[-1]
- self.assertTrue(last_payload.choices[0]["finish_reason"] == "stop")
+ with tempfile.TemporaryDirectory() as cache_dir:
+ # Download config.json for a few models
+ hf_hub_download("Qwen/Qwen2.5-0.5B-Instruct", "config.json", cache_dir=cache_dir)
+ hf_hub_download("google-bert/bert-base-cased", "config.json", cache_dir=cache_dir)
+ result = ModelManager.get_gen_models(cache_dir)
+ model_ids = {r["id"] for r in result}
-class ServeCompletionsGenerateMockTests(unittest.TestCase):
- def test_processor_inputs_from_inbound_messages_llm(self):
- modality = Modality.LLM
- messages = expected_outputs = [
- {"role": "user", "content": "How are you doing?"},
- {"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"},
- {"role": "user", "content": "Can you help me write tests?"},
- ]
- outputs = Serve.get_processor_inputs_from_inbound_messages(messages, modality)
- self.assertListEqual(expected_outputs, outputs)
+ self.assertIn("Qwen/Qwen2.5-0.5B-Instruct", model_ids)
+ self.assertNotIn("google-bert/bert-base-cased", model_ids)
- messages_with_type = [
- {"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]},
- {
- "role": "assistant",
- "content": [
- {"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"}
- ],
- },
- {"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]},
- ]
- outputs = Serve.get_processor_inputs_from_inbound_messages(messages_with_type, modality)
- self.assertListEqual(expected_outputs, outputs)
- messages_multiple_text = [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "How are you doing?"},
- {"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"},
- ],
- },
- ]
- expected_outputs_multiple_text = [
- {
- "role": "user",
- "content": "How are you doing? I'm doing great, thank you for asking! How can I assist you today?",
- },
- ]
- outputs = Serve.get_processor_inputs_from_inbound_messages(messages_multiple_text, modality)
- self.assertListEqual(expected_outputs_multiple_text, outputs)
+@require_serve
+class TestBuildGenerationConfig(unittest.TestCase):
+ def _make_handler(self):
+ return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState())
- def test_processor_inputs_from_inbound_messages_vlm_text_only(self):
- modality = Modality.VLM
- messages = [
- {"role": "user", "content": "How are you doing?"},
- {"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"},
- {"role": "user", "content": "Can you help me write tests?"},
- ]
+ def test_max_tokens(self):
+ from transformers import GenerationConfig
- expected_outputs = [
- {"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]},
- {
- "role": "assistant",
- "content": [
- {"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"}
- ],
- },
- {"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]},
- ]
+ result = self._make_handler()._build_generation_config({"max_tokens": 7}, GenerationConfig())
+ self.assertEqual(result.max_new_tokens, 7)
- outputs = Serve.get_processor_inputs_from_inbound_messages(messages, modality)
- self.assertListEqual(expected_outputs, outputs)
+ def test_temperature_zero_disables_sampling(self):
+ from transformers import GenerationConfig
- def test_processor_inputs_from_inbound_messages_vlm_text_and_image_in_base_64(self):
- modality = Modality.VLM
- messages = [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "How many pixels are in the image?"},
- {
- "type": "image_url",
- "image_url": {
- "url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAASABIAAD/4QBARXhpZgAATU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAABaADAAQAAAABAAAABQAAAAD/7QA4UGhvdG9zaG9wIDMuMAA4QklNBAQAAAAAAAA4QklNBCUAAAAAABDUHYzZjwCyBOmACZjs+EJ+/8AAEQgABQAFAwEiAAIRAQMRAf/EAB8AAAEFAQEBAQEBAAAAAAAAAAABAgMEBQYHCAkKC//EALUQAAIBAwMCBAMFBQQEAAABfQECAwAEEQUSITFBBhNRYQcicRQygZGhCCNCscEVUtHwJDNicoIJChYXGBkaJSYnKCkqNDU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6g4SFhoeIiYqSk5SVlpeYmZqio6Slpqeoqaqys7S1tre4ubrCw8TFxsfIycrS09TV1tfY2drh4uPk5ebn6Onq8fLz9PX29/j5+v/EAB8BAAMBAQEBAQEBAQEAAAAAAAABAgMEBQYHCAkKC//EALURAAIBAgQEAwQHBQQEAAECdwABAgMRBAUhMQYSQVEHYXETIjKBCBRCkaGxwQkjM1LwFWJy0QoWJDThJfEXGBkaJicoKSo1Njc4OTpDREVGR0hJSlNUVVZXWFlaY2RlZmdoaWpzdHV2d3h5eoKDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uLj5OXm5+jp6vLz9PX29/j5+v/bAEMAAQEBAQEBAgEBAgICAgICAwICAgIDBAMDAwMDBAUEBAQEBAQFBQUFBQUFBQYGBgYGBgcHBwcHCAgICAgICAgICP/bAEMBAQEBAgICAwICAwgFBAUICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICP/dAAQAAf/aAAwDAQACEQMRAD8A/v4ooooA/9k="
- },
- },
- ],
- },
- {
- "role": "assistant",
- "content": "The number of pixels in the image cannot be determined from the provided information.",
- },
- {"role": "user", "content": "Alright"},
- ]
+ result = self._make_handler()._build_generation_config({"temperature": 0.0}, GenerationConfig(do_sample=True))
+ self.assertFalse(result.do_sample)
- expected_outputs = [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "How many pixels are in the image?"},
- {"type": "image", "url": "/var/folders/4v/64sxdhsd3gz3r8vhhnyc0mqw0000gn/T/tmp50oyghk6.png"},
- ],
- },
- {
- "role": "assistant",
- "content": [
- {
- "type": "text",
- "text": "The number of pixels in the image cannot be determined from the provided information.",
- }
- ],
- },
- {"role": "user", "content": [{"type": "text", "text": "Alright"}]},
- ]
+ def test_frequency_penalty(self):
+ from transformers import GenerationConfig
- outputs = Serve.get_processor_inputs_from_inbound_messages(messages, modality)
+ result = self._make_handler()._build_generation_config({"frequency_penalty": 0.5}, GenerationConfig())
+ self.assertAlmostEqual(result.repetition_penalty, 1.5)
- for expected_output, output in zip(expected_outputs, outputs):
- expected_output_content = expected_output["content"]
- output_content = output["content"]
+ def test_logit_bias_tuple_keys(self):
+ from transformers import GenerationConfig
- self.assertEqual(type(expected_output_content), type(output_content))
+ result = self._make_handler()._build_generation_config({"logit_bias": {"42": 1.0}}, GenerationConfig())
+ self.assertEqual(result.sequence_bias, {(42,): 1.0})
- if isinstance(expected_output_content, list):
- for expected_output_content_item, output_content_item in zip(expected_output_content, output_content):
- self.assertIn("type", expected_output_content_item)
- self.assertIn("type", output_content_item)
- self.assertTrue(expected_output_content_item["type"] == output_content_item["type"])
+ def test_stop_strings(self):
+ from transformers import GenerationConfig
- if expected_output_content_item["type"] == "text":
- self.assertEqual(expected_output_content_item["text"], output_content_item["text"])
+ result = self._make_handler()._build_generation_config({"stop": [""]}, GenerationConfig())
+ self.assertEqual(result.stop_strings, [""])
- if expected_output_content_item["type"] == "image":
- self.assertTrue(os.path.exists(output_content_item["url"]))
- else:
- raise ValueError("VLMs should only receive content as lists.")
+ def test_generation_config_json_overrides(self):
+ from transformers import GenerationConfig
+ custom = GenerationConfig(max_new_tokens=5, do_sample=False)
+ result = self._make_handler()._build_generation_config(
+ {"generation_config": custom.to_json_string()}, GenerationConfig(max_new_tokens=100)
+ )
+ self.assertEqual(result.max_new_tokens, 5)
+ self.assertFalse(result.do_sample)
-@slow # server startup time is slow on our push CI
-@require_openai
-class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
- """Tests the `generate` version of the Completions API."""
+ def test_generation_config_json_no_defaults_applied(self):
+ """When generation_config JSON is passed, serving defaults should NOT be applied."""
+ from transformers import GenerationConfig
- @classmethod
- def setUpClass(cls):
- """Starts a server for tests to connect to."""
- cls.port = 8001
- cls.server = Serve(port=cls.port, non_blocking=True)
+ custom = GenerationConfig(max_new_tokens=10)
+ result = self._make_handler()._build_generation_config(
+ {"generation_config": custom.to_json_string()}, GenerationConfig()
+ )
+ # Should keep 10, not bump to 1024
+ self.assertEqual(result.max_new_tokens, 10)
- @classmethod
- def tearDownClass(cls):
- cls.server.kill_server()
+ def test_default_bumps_short_max_new_tokens(self):
+ from transformers import GenerationConfig
- @slow
- def test_tool_call(self):
- """Tests that the tool call is correctly handled and that the payloads are correctly structured."""
- # TODO: move to the mixin when CB also supports tool calls
-
- request = {
- # This model is a small model that's very eager to call tools
- # TODO: this is a 4B model. Find a smaller model that's eager to call tools
- "model": "Menlo/Jan-nano",
- # The request should produce a tool call
- "messages": [{"role": "user", "content": "Generate an image of a cat."}],
- "stream": True,
- "max_tokens": 50,
- # Reproducibility
- "temperature": 0.0,
- # This tool is a copy from the tool in the original tiny-agents demo
- "tools": [
- {
- "function": {
- "name": "flux1_schnell_infer",
- "parameters": {
- "type": "object",
- "properties": {
- "prompt": {"type": "string"},
- "seed": {"type": "number", "description": "numeric value between 0 and 2147483647"},
- "randomize_seed": {"type": "boolean", "default": True},
- "width": {
- "type": "number",
- "description": "numeric value between 256 and 2048",
- "default": 1024,
- },
- "height": {
- "type": "number",
- "description": "numeric value between 256 and 2048",
- "default": 1024,
- },
- "num_inference_steps": {
- "type": "number",
- "description": "numeric value between 1 and 16",
- "default": 4,
- },
- },
- },
- "description": "Generate an image using the Flux 1 Schnell Image Generator.",
- },
- "type": "function",
- }
- ],
- }
- all_payloads = self.run_server(request)
-
- # The first payload should contain the role
- roles = [payload.choices[0].delta.role for payload in all_payloads]
- self.assertEqual(roles[0], "assistant")
- self.assertTrue(all(role is None for role in roles[1:]))
-
- # All other payloads (except the last one) should be tool call related, for this specific request
- contents = [payload.choices[0].delta.content for payload in all_payloads]
- self.assertTrue(all(content is None for content in contents))
-
- # The first tool call delta should contain the tool name. The other tool call deltas should contain the tool
- # arguments.
- tool_calls = [payload.choices[0].delta.tool_calls[0] for payload in all_payloads[1:-1]]
- first_tool_call = tool_calls[0]
- self.assertEqual(first_tool_call["function"]["name"], "flux1_schnell_infer")
- self.assertEqual(first_tool_call["function"]["arguments"], None)
- other_tool_calls = tool_calls[1:]
- self.assertTrue(all(tool_call["function"]["name"] is None for tool_call in other_tool_calls))
- self.assertTrue(all(tool_call["function"]["arguments"] is not None for tool_call in other_tool_calls))
-
- # Finally, the last payload should contain a finish reason
- finish_reasons = [payload.choices[0].finish_reason for payload in all_payloads]
- # TODO: I think the finish reason for a tool call is different? double check this
- self.assertTrue(finish_reasons[-1] in ["stop", "length"])
- self.assertTrue(all(reason is None for reason in finish_reasons[:-1]))
-
-
-def _get_scheduler(serve_command):
- # Defensive navigation in case any layer is renamed in the future
- cbm = getattr(serve_command, "running_continuous_batching_manager", None)
- assert cbm is not None, "ServeCommand has no running_continuous_batching_manager"
- bp = getattr(cbm, "batch_processor", None)
- assert bp is not None, "running_continuous_batching_manager has no batch_processor"
- sched = getattr(bp, "scheduler", None)
- assert sched is not None, "batch_processor has no scheduler"
- return sched
-
-
-def _call_healthcheck(base_url: str):
- response = None
- retries = 10
- while retries > 0:
- try:
- response = httpx.get(f"{base_url}/health")
- break
- except httpx.NetworkError:
- time.sleep(0.1)
- retries -= 1
- return response
+ result = self._make_handler()._build_generation_config({}, GenerationConfig(max_new_tokens=20))
+ self.assertEqual(result.max_new_tokens, 1024)
+ def test_user_max_tokens_overrides_default(self):
+ """User's max_tokens should win over the serving default."""
+ from transformers import GenerationConfig
-def _open_stream_and_cancel(base_url: str, request_id: str):
- with httpx.Client() as s:
- with s.stream(
- "POST",
- f"{base_url}/v1/chat/completions",
- headers={"X-Request-ID": request_id},
- json={
- "model": "Qwen/Qwen2.5-0.5B-Instruct",
- "stream": True,
- "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}],
- },
- timeout=30,
- ) as resp:
- assert resp.status_code == 200
+ result = self._make_handler()._build_generation_config({"max_tokens": 50}, GenerationConfig(max_new_tokens=20))
+ self.assertEqual(result.max_new_tokens, 50)
+
+
+@require_serve
+class TestValidation(unittest.TestCase):
+ def _make_handler(self):
+ return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState())
+
+ def test_valid_request_passes(self):
+ handler = self._make_handler()
+ # Should not raise
+ handler._validate_request({"model": "x", "messages": [{"role": "user", "content": "hi"}], "stream": True})
+
+ def test_unexpected_keys_rejected(self):
+ handler = self._make_handler()
+ with self.assertRaises(HTTPException) as ctx:
+ handler._validate_request({"model": "x", "messages": [], "bogus_field": True})
+ self.assertEqual(ctx.exception.status_code, 422)
+ self.assertIn("bogus_field", ctx.exception.detail)
+
+ def test_unsupported_fields_warns(self):
+ handler = self._make_handler()
+ with self.assertLogs("transformers", level="WARNING") as cm:
+ handler._validate_request({"model": "x", "messages": [], "audio": {}})
+ self.assertTrue(any("audio" in msg for msg in cm.output))
+
+
+class TestModelManager(unittest.TestCase):
+ def test_process_model_name_adds_main(self):
+ self.assertEqual(ModelManager.process_model_name("org/model"), "org/model@main")
+
+ def test_process_model_name_preserves_revision(self):
+ self.assertEqual(ModelManager.process_model_name("org/model@dev"), "org/model@dev")
+
+ def test_quantization_config_4bit(self):
+ mm = ModelManager(quantization="bnb-4bit")
+ cfg = mm.get_quantization_config()
+ self.assertTrue(cfg.load_in_4bit)
+
+ def test_quantization_config_8bit(self):
+ mm = ModelManager(quantization="bnb-8bit")
+ cfg = mm.get_quantization_config()
+ self.assertTrue(cfg.load_in_8bit)
+
+ def test_quantization_config_none(self):
+ mm = ModelManager()
+ self.assertIsNone(mm.get_quantization_config())
- wait_for_n_chunks = 3
- for i, _ in enumerate(resp.iter_bytes(chunk_size=None)):
- if i >= wait_for_n_chunks:
- resp.close()
- break
+class TestTimedModel(unittest.TestCase):
+ def test_delete_model(self):
+ mock_model = MagicMock()
+ deleted = []
+ timed = TimedModel(
+ mock_model, timeout_seconds=9999, processor=MagicMock(), on_unload=lambda: deleted.append(True)
+ )
+ self.assertIsNotNone(timed.model)
+ timed.delete_model()
+ self.assertIsNone(timed.model)
+ self.assertEqual(len(deleted), 1)
+
+ def test_timeout_zero_no_delete(self):
+ mock_model = MagicMock()
+ timed = TimedModel(mock_model, timeout_seconds=0, processor=MagicMock())
+ timed._timeout_reached()
+ self.assertIsNotNone(timed.model)
+ timed._timer.cancel()
+
+
+@require_serve
+class TestChunkSSE(unittest.TestCase):
+ def _make_handler(self):
+ return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState())
+
+ def test_build_chunk_sse_content(self):
+ handler = self._make_handler()
+ sse = handler._build_chunk_sse(request_id="req1", content="hi", model="m")
+ self.assertTrue(sse.startswith("data: "))
+ self.assertTrue(sse.endswith("\n\n"))
+ parsed = json.loads(sse[len("data: ") :].strip())
+ self.assertEqual(parsed["choices"][0]["delta"]["content"], "hi")
+
+ def test_build_chunk_sse_role(self):
+ handler = self._make_handler()
+ sse = handler._build_chunk_sse(request_id="req1", role="assistant", model="m")
+ parsed = json.loads(sse[len("data: ") :].strip())
+ self.assertEqual(parsed["choices"][0]["delta"]["role"], "assistant")
+ self.assertNotIn("content", parsed["choices"][0]["delta"])
+
+ def test_build_chunk_sse_finish_reason(self):
+ handler = self._make_handler()
+ sse = handler._build_chunk_sse(request_id="req1", finish_reason="stop", model="m")
+ parsed = json.loads(sse[len("data: ") :].strip())
+ self.assertEqual(parsed["choices"][0]["finish_reason"], "stop")
+
+ def test_chunk_to_sse_string_passthrough(self):
+ result = BaseHandler.chunk_to_sse("data: already formatted\n\n")
+ self.assertEqual(result, "data: already formatted\n\n")
+
+ def test_chunk_to_sse_wraps_plain_string(self):
+ result = BaseHandler.chunk_to_sse("hello")
+ self.assertEqual(result, "data: hello\n\n")
+
+
+QWEN_TOOL_FORMAT = {"start": "", "end": ""}
+
+
+@require_serve
+class TestToolParser(unittest.TestCase):
+ def test_detect_tool_format_qwen(self):
+ model = MagicMock()
+ model.config.architectures = ["Qwen2ForCausalLM"]
+ fmt = detect_tool_format(model)
+ self.assertEqual(fmt, QWEN_TOOL_FORMAT)
+
+ def test_detect_tool_format_unsupported(self):
+ model = MagicMock()
+ model.config.architectures = ["LlamaForCausalLM"]
+ self.assertIsNone(detect_tool_format(model))
+
+ def test_parser_start_token(self):
+ parser = ToolCallParser(QWEN_TOOL_FORMAT)
+ result = parser.feed("")
+ self.assertIs(result, ToolCallParser.CONSUMED)
+
+ def test_parser_end_token(self):
+ parser = ToolCallParser(QWEN_TOOL_FORMAT)
+ parser.feed("")
+ result = parser.feed("")
+ self.assertIs(result, ToolCallParser.CONSUMED)
+
+ def test_parser_buffers_until_end(self):
+ parser = ToolCallParser(QWEN_TOOL_FORMAT)
+ parser.feed("")
+ # Intermediate tokens are buffered
+ result = parser.feed('{"name": "my_tool", "arguments": {"x": 1}}')
+ self.assertIs(result, ToolCallParser.CONSUMED)
+ # Tool call is emitted on end token
+ result = parser.feed("")
+ self.assertIsNot(result, ToolCallParser.CONSUMED)
+ self.assertEqual(result["name"], "my_tool")
+
+ def test_parser_normal_text_returns_none(self):
+ parser = ToolCallParser(QWEN_TOOL_FORMAT)
+ result = parser.feed("Hello world")
+ self.assertIsNone(result)
+
+ def test_parser_full_flow(self):
+ """Simulate a complete tool call token sequence."""
+
+ parser = ToolCallParser(QWEN_TOOL_FORMAT)
+ tool_calls = []
+
+ for token in [
+ "",
+ '{"name": "get_weather",',
+ ' "arguments": {',
+ '"city": "Paris"',
+ "}}",
+ "\n",
+ "",
+ ]:
+ result = parser.feed(token)
+ if result is not None and result is not ToolCallParser.CONSUMED:
+ tool_calls.append(result)
+
+ # Single tool call emitted on with both name and arguments
+ self.assertEqual(len(tool_calls), 1)
+ self.assertEqual(tool_calls[0]["name"], "get_weather")
+ self.assertIn("Paris", tool_calls[0]["arguments"])
+
+ def test_parse_tool_calls_from_text(self):
+ """Non-streaming tool call parsing from complete text."""
+
+ text = '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n'
+ calls = ToolCallParser.parse(text, QWEN_TOOL_FORMAT)
+ self.assertIsNotNone(calls)
+ self.assertEqual(len(calls), 1)
+ self.assertEqual(calls[0]["name"], "get_weather")
+ self.assertIn("Paris", calls[0]["arguments"])
+
+ def test_parse_tool_calls_no_tool_call(self):
+ """Non-streaming: normal text returns None."""
+
+ calls = ToolCallParser.parse("Hello, how can I help?", QWEN_TOOL_FORMAT)
+ self.assertIsNone(calls)
+
+ def test_parse_multiple_tool_calls(self):
+ """Non-streaming: multiple tool calls in one response."""
+
+ text = (
+ '\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n\n'
+ '\n{"name": "get_weather", "arguments": {"city": "London"}}\n'
+ )
+ calls = ToolCallParser.parse(text, QWEN_TOOL_FORMAT)
+ self.assertIsNotNone(calls)
+ self.assertEqual(len(calls), 2)
+ self.assertEqual(calls[0]["name"], "get_weather")
+ self.assertIn("Paris", calls[0]["arguments"])
+ self.assertEqual(calls[1]["name"], "get_weather")
+ self.assertIn("London", calls[1]["arguments"])
+
+ def test_feed_multiple_tool_calls(self):
+ """Streaming: multiple tool calls emitted sequentially."""
+
+ parser = ToolCallParser(QWEN_TOOL_FORMAT)
+ tool_calls = []
+
+ tokens = [
+ "",
+ '{"name": "get_weather", "arguments": {"city": "Paris"}}',
+ "",
+ "",
+ '{"name": "get_weather", "arguments": {"city": "London"}}',
+ "",
+ ]
+ for token in tokens:
+ result = parser.feed(token)
+ if result is not None and result is not ToolCallParser.CONSUMED:
+ tool_calls.append(result)
+
+ self.assertEqual(len(tool_calls), 2)
+ self.assertEqual(tool_calls[0]["name"], "get_weather")
+ self.assertIn("Paris", tool_calls[0]["arguments"])
+ self.assertEqual(tool_calls[1]["name"], "get_weather")
+ self.assertIn("London", tool_calls[1]["arguments"])
-@slow # server startup time is slow on our push CI
-@require_openai
-class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
- """Tests the `continuous_batching` version of the Completions API."""
+@require_serve
+class TestAppRoutes(unittest.TestCase):
@classmethod
def setUpClass(cls):
- """Starts a server for tests to connect to."""
- cls.port = 8002
- cls.server = Serve(
- port=cls.port, continuous_batching=True, attn_implementation="sdpa", default_seed=42, non_blocking=True
+ cls.model_manager = MagicMock(spec=ModelManager)
+ cls.model_manager.get_gen_models.return_value = [
+ {"id": "test/model", "owned_by": "test", "object": "model", "created": 0}
+ ]
+ cls.chat_handler = MagicMock(spec=ChatCompletionHandler)
+ cls.response_handler = MagicMock(spec=ResponseHandler)
+ cls.transcription_handler = MagicMock(spec=TranscriptionHandler)
+ cls.app = build_server(cls.model_manager, cls.chat_handler, cls.response_handler, cls.transcription_handler)
+ cls.transport = httpx.ASGITransport(app=cls.app)
+
+ async def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
+ async with httpx.AsyncClient(transport=self.transport, base_url="http://test") as c:
+ return await c.request(method, path, **kwargs)
+
+ def test_health(self):
+ resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/health"))
+ self.assertEqual(resp.status_code, 200)
+ self.assertEqual(resp.json(), {"status": "ok"})
+
+ def test_models_list(self):
+ resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/v1/models"))
+ self.assertEqual(resp.status_code, 200)
+ data = resp.json()
+ self.assertEqual(data["object"], "list")
+ self.assertEqual(len(data["data"]), 1)
+
+ def test_request_id_generated(self):
+ resp = asyncio.get_event_loop().run_until_complete(self._request("GET", "/health"))
+ self.assertIn("x-request-id", resp.headers)
+ self.assertEqual(len(resp.headers["x-request-id"]), 36) # UUID length
+
+ def test_request_id_passthrough(self):
+ resp = asyncio.get_event_loop().run_until_complete(
+ self._request("GET", "/health", headers={"x-request-id": "my-id"})
)
+ self.assertEqual(resp.headers["x-request-id"], "my-id")
+
+
+@slow
+@require_serve
+class TestChatCompletion(unittest.TestCase):
+ """Integration tests for /v1/chat/completions with a real model."""
+
+ MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
+
+ @classmethod
+ def setUpClass(cls):
+ cls.serve, port = _start_serve()
+ cls.base_url = f"http://localhost:{port}"
+ cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
@classmethod
def tearDownClass(cls):
- cls.server.kill_server()
+ cls.serve.kill_server()
- def test_full_request(self):
- """Tests that an inference using the Responses API and Continuous Batching works"""
+ def test_non_streaming(self):
+ resp = self.client.chat.completions.create(
+ model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}]
+ )
+ self.assertIsNotNone(resp.choices[0].message.content)
+ self.assertIn(resp.choices[0].finish_reason, ("stop", "length"))
+
+ def test_streaming(self):
+ text = ""
+ for chunk in self.client.chat.completions.create(
+ model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}], stream=True
+ ):
+ if chunk.choices[0].delta.content:
+ text += chunk.choices[0].delta.content
+ self.assertTrue(len(text) > 0)
- request = {
- "model": "Qwen/Qwen2.5-0.5B-Instruct",
- "messages": [
- {"role": "system", "content": "You are a sports assistant designed to craft sports programs."},
- {"role": "user", "content": "Tell me what you can do."},
+ def test_early_return_due_to_length(self):
+ """When max_tokens is hit, finish_reason should be 'length'."""
+ chunks = list(
+ self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": "Hello, how are you?"}],
+ stream=True,
+ max_tokens=3,
+ )
+ )
+ last = chunks[-1]
+ self.assertEqual(last.choices[0].finish_reason, "length")
+
+ def test_continues_until_stop(self):
+ """When model stops naturally, finish_reason should be 'stop'."""
+ chunks = list(
+ self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": 'Please only answer with "Hi."'}],
+ stream=True,
+ max_tokens=30,
+ )
+ )
+ last = chunks[-1]
+ self.assertEqual(last.choices[0].finish_reason, "stop")
+
+ def test_stop_strings(self):
+ resp = self.client.chat.completions.create(
+ model=self.MODEL, messages=[{"role": "user", "content": "Count to 10"}], stop=["5"]
+ )
+ self.assertNotIn("6", resp.choices[0].message.content)
+
+ def test_multi_turn(self):
+ resp = self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[
+ {"role": "user", "content": "My name is Alice"},
+ {"role": "assistant", "content": "Nice to meet you!"},
+ {"role": "user", "content": "What is my name?"},
],
- "stream": True,
- "max_tokens": 30,
- }
- all_payloads = self.run_server(request)
+ )
+ self.assertIn("Alice", resp.choices[0].message.content)
- full_text = ""
- for token in all_payloads:
- if isinstance(token, ChatCompletionStreamOutput) and token.choices and len(token.choices) > 0:
- content = token.choices[0].delta.get("content", "")
- full_text += content if content is not None else ""
+ def test_multiple_models_on_demand(self):
+ """Load two different models via separate requests — both should work."""
+ model_a = "Qwen/Qwen2.5-0.5B-Instruct"
+ model_b = "HuggingFaceTB/SmolLM2-135M-Instruct"
+ prompt = [{"role": "user", "content": "Say hello"}]
- # Verify that the system prompt went through.
- self.assertTrue(
- full_text.startswith(
- "I can assist you with a wide range of tasks, from answering questions to providing information on various sports topics."
+ resp_a = self.client.chat.completions.create(model=model_a, messages=prompt)
+ self.assertIn(model_a, resp_a.model)
+ self.assertIsNotNone(resp_a.choices[0].message.content)
+
+ resp_b = self.client.chat.completions.create(model=model_b, messages=prompt)
+ self.assertIn(model_b, resp_b.model)
+ self.assertIsNotNone(resp_b.choices[0].message.content)
+
+ def test_non_streaming_usage(self):
+ resp = self.client.chat.completions.create(
+ model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}]
+ )
+ self.assertIsNotNone(resp.usage)
+ self.assertGreater(resp.usage.prompt_tokens, 0)
+ self.assertGreater(resp.usage.completion_tokens, 0)
+ self.assertEqual(resp.usage.total_tokens, resp.usage.prompt_tokens + resp.usage.completion_tokens)
+
+ def test_streaming_usage(self):
+ chunks = list(
+ self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": "Say hello"}],
+ stream=True,
)
)
+ # Last chunk should have usage
+ last = chunks[-1]
+ self.assertIsNotNone(last.usage)
+ self.assertGreater(last.usage.prompt_tokens, 0)
+ self.assertGreater(last.usage.completion_tokens, 0)
+ self.assertEqual(last.usage.total_tokens, last.usage.prompt_tokens + last.usage.completion_tokens)
- def test_max_tokens_not_set_in_req(self):
- request = {
- "model": "Qwen/Qwen2.5-0.5B-Instruct",
- "messages": [
- {"role": "system", "content": "You are a sports assistant designed to craft sports programs."},
- {"role": "user", "content": "Tell me what you can do."},
- ],
- "stream": True,
+ def test_tool_call(self):
+ """Tool calls should be parsed and emitted as ChoiceDeltaToolCall objects."""
+ # Qwen2.5-0.5B-Instruct supports tools (Qwen family)
+ tool_def = {
+ "function": {
+ "name": "get_weather",
+ "parameters": {
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ },
+ "description": "Get the weather for a city.",
+ },
+ "type": "function",
}
- all_payloads = self.run_server(request)
+ chunks = list(
+ self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": "What is the weather in Paris?"}],
+ stream=True,
+ max_tokens=50,
+ temperature=0.0,
+ tools=[tool_def],
+ )
+ )
- full_text = ""
- for token in all_payloads:
- if isinstance(token, ChatCompletionStreamOutput) and token.choices and len(token.choices) > 0:
- content = token.choices[0].delta.get("content", "")
- full_text += content if content is not None else ""
+ # First chunk should have role="assistant"
+ self.assertEqual(chunks[0].choices[0].delta.role, "assistant")
- # Verify that the system prompt went through.
- self.assertTrue(
- full_text.startswith(
- "I can assist you with a wide range of tasks, from answering questions to providing information on various sports topics."
+ # Model should make a tool call for this prompt
+ tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls]
+ self.assertGreater(len(tool_chunks), 0, "Model did not produce a tool call")
+
+ # First tool call delta should have the function name
+ first_tool = tool_chunks[0].choices[0].delta.tool_calls[0]
+ self.assertEqual(first_tool.function.name, "get_weather")
+
+ # finish_reason should be "tool_calls"
+ last = chunks[-1]
+ self.assertEqual(last.choices[0].finish_reason, "tool_calls")
+
+ # Arguments should be valid JSON with no trailing brace
+ args_json = first_tool.function.arguments
+ import json as json_mod
+
+ parsed_args = json_mod.loads(args_json)
+ self.assertIsInstance(parsed_args, dict)
+
+ def test_tool_call_non_streaming(self):
+ """Non-streaming tool calls should return tool_calls in the message."""
+ tool_def = {
+ "function": {
+ "name": "get_weather",
+ "parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
+ "description": "Get the weather for a city.",
+ },
+ "type": "function",
+ }
+ resp = self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": "What is the weather in Paris?"}],
+ stream=False,
+ max_tokens=50,
+ temperature=0.0,
+ tools=[tool_def],
+ )
+ self.assertEqual(resp.choices[0].finish_reason, "tool_calls")
+ self.assertIsNotNone(resp.choices[0].message.tool_calls)
+ tc = resp.choices[0].message.tool_calls[0]
+ self.assertEqual(tc.function.name, "get_weather")
+
+ import json as json_mod
+
+ parsed_args = json_mod.loads(tc.function.arguments)
+ self.assertIsInstance(parsed_args, dict)
+
+ def test_tool_call_multi(self):
+ """Model should be able to call multiple tools when asked."""
+ tool_def = {
+ "function": {
+ "name": "get_weather",
+ "parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
+ "description": "Get the weather for a city.",
+ },
+ "type": "function",
+ }
+ # Ask for two cities to encourage multiple tool calls
+ chunks = list(
+ self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": "What is the weather in Paris and London?"}],
+ stream=True,
+ max_tokens=100,
+ temperature=0.0,
+ tools=[tool_def],
)
)
+ tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls]
+ # Should have two tool calls — one for Paris, one for London
+ self.assertEqual(len(tool_chunks), 2, f"Expected 2 tool calls, got {len(tool_chunks)}")
+ cities = {tc.choices[0].delta.tool_calls[0].function.name for tc in tool_chunks}
+ self.assertEqual(cities, {"get_weather"})
+ last = chunks[-1]
+ self.assertEqual(last.choices[0].finish_reason, "tool_calls")
+
+ def test_concurrent_non_streaming(self):
+ """Two concurrent non-streaming requests should both complete without interference."""
+ import concurrent.futures
- def test_request_cancellation(self):
- """Tests that a request can be cancelled."""
+ prompts = [
+ [{"role": "user", "content": "Say hello"}],
+ [{"role": "user", "content": "Say goodbye"}],
+ ]
+ results = [None, None]
- base_url = f"http://127.0.0.1:{self.port}"
- request_id = "test-cancel"
+ def request_in_thread(index):
+ client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
+ results[index] = client.chat.completions.create(model=self.MODEL, messages=prompts[index])
- # Ensure the server is up before sending a request
- response = _call_healthcheck(base_url)
- self.assertIsNotNone(response, "Failed to connect to the server health endpoint.")
- self.assertEqual(response.status_code, 200)
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
+ futures = [pool.submit(request_in_thread, i) for i in range(2)]
+ concurrent.futures.wait(futures)
+ for f in futures:
+ f.result() # re-raise exceptions
- _open_stream_and_cancel(base_url, request_id)
+ for i in range(2):
+ self.assertIsNotNone(results[i])
+ self.assertIsNotNone(results[i].choices[0].message.content)
+ self.assertTrue(len(results[i].choices[0].message.content) > 0)
- scheduler = _get_scheduler(self.server)
+ def test_concurrent_streaming(self):
+ """Two concurrent streaming requests should both produce complete, non-empty output."""
+ import concurrent.futures
- # Because cancellation is non-blocking, poll for a short, bounded time.
- deadline = time.time() + 8.0 # generous but still CI-friendly
- last_seen = None
- while time.time() < deadline:
- is_cancelled = scheduler.request_is_cancelled(request_id)
- if is_cancelled:
- break
- last_seen = time.time()
- time.sleep(0.1) # don't spin the CPU
+ prompts = [
+ [{"role": "user", "content": "Say hello"}],
+ [{"role": "user", "content": "Say goodbye"}],
+ ]
+ results = [None, None]
- is_cancelled = scheduler.request_is_cancelled(request_id)
- self.assertTrue(
- is_cancelled,
- f"Request {request_id} still present in scheduler after cancellation "
- f"(last seen at {last_seen}). Check cancellation propagation.",
- )
+ def stream_in_thread(index):
+ client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
+ text = ""
+ for chunk in client.chat.completions.create(model=self.MODEL, messages=prompts[index], stream=True):
+ if chunk.choices[0].delta.content:
+ text += chunk.choices[0].delta.content
+ results[index] = text
- def test_non_streaming_response_json_format(self):
- """
- Tests that non-streaming continuous batching responses return proper JSON objects,
- not double-encoded JSON strings (regression test for JSON serialization fix).
- """
- client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="")
-
- # Make a non-streaming request
- response = client.chat.completions.create(
- model="Qwen/Qwen2.5-0.5B-Instruct",
- messages=[{"role": "user", "content": "Say hello"}],
- stream=False,
- max_tokens=5,
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
+ futures = [pool.submit(stream_in_thread, i) for i in range(2)]
+ concurrent.futures.wait(futures)
+ for f in futures:
+ f.result()
+
+ for i in range(2):
+ self.assertIsNotNone(results[i])
+ self.assertTrue(len(results[i]) > 0, f"Request {i} produced empty output")
+
+ def test_request_cancellation(self):
+ """Closing a stream early doesn't crash and the server stays healthy."""
+
+ with httpx.stream(
+ "POST",
+ f"{self.base_url}/v1/chat/completions",
+ json={
+ "model": self.MODEL,
+ "stream": True,
+ "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}],
+ "max_tokens": 500,
+ },
+ timeout=30,
+ ) as resp:
+ self.assertEqual(resp.status_code, 200)
+ chunks_read = 0
+ for _ in resp.iter_lines():
+ chunks_read += 1
+ if chunks_read >= 3:
+ break
+
+ # Server should still be healthy and serve subsequent requests
+ resp = self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": "Say hi"}],
+ max_tokens=10,
+ )
+ self.assertIsNotNone(resp.choices[0].message.content)
+
+
+@require_serve
+class TestResponseInputConversion(unittest.TestCase):
+ def _make_handler(self):
+ return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState())
+
+ def test_string_input(self):
+ handler = self._make_handler()
+ msgs = handler._input_to_messages({"input": "Hello"})
+ self.assertEqual(msgs, [{"role": "user", "content": "Hello"}])
+
+ def test_string_input_with_instructions(self):
+ handler = self._make_handler()
+ msgs = handler._input_to_messages({"input": "Hello", "instructions": "Be brief"})
+ self.assertEqual(len(msgs), 2)
+ self.assertEqual(msgs[0], {"role": "system", "content": "Be brief"})
+ self.assertEqual(msgs[1], {"role": "user", "content": "Hello"})
+
+ def test_list_input(self):
+ handler = self._make_handler()
+ msgs = handler._input_to_messages(
+ {"input": [{"role": "user", "content": "A"}, {"role": "assistant", "content": "B"}]}
+ )
+ self.assertEqual(len(msgs), 2)
+ self.assertEqual(msgs[0]["content"], "A")
+
+ def test_list_input_with_instructions_prepends_system(self):
+ handler = self._make_handler()
+ msgs = handler._input_to_messages({"input": [{"role": "user", "content": "Hi"}], "instructions": "Be helpful"})
+ self.assertEqual(len(msgs), 2)
+ self.assertEqual(msgs[0]["role"], "system")
+ self.assertEqual(msgs[0]["content"], "Be helpful")
+
+ def test_list_input_with_instructions_replaces_existing_system(self):
+ handler = self._make_handler()
+ msgs = handler._input_to_messages(
+ {"input": [{"role": "system", "content": "Old"}, {"role": "user", "content": "Hi"}], "instructions": "New"}
)
+ self.assertEqual(len(msgs), 2)
+ self.assertEqual(msgs[0]["content"], "New")
- # Verify response is a proper ChatCompletion object (not a string)
- self.assertIsNotNone(response)
- self.assertIsNotNone(response.id)
- self.assertIsNotNone(response.choices)
- self.assertEqual(len(response.choices), 1)
+ def test_dict_input(self):
+ handler = self._make_handler()
+ msgs = handler._input_to_messages({"input": {"role": "user", "content": "Test"}})
+ self.assertEqual(msgs, [{"role": "user", "content": "Test"}])
- # Verify the choice has proper structure
- choice = response.choices[0]
- self.assertIsNotNone(choice.message)
- self.assertIsNotNone(choice.message.content)
- self.assertEqual(choice.message.role, "assistant")
- # Verify content is a string, not a serialized JSON
- content = choice.message.content
- self.assertIsInstance(content, str)
- self.assertTrue(len(content) > 0)
+@require_serve
+class TestResponseValidation(unittest.TestCase):
+ def _make_handler(self):
+ return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState())
+ def test_unsupported_fields_warns(self):
+ handler = self._make_handler()
+ with self.assertLogs("transformers", level="WARNING") as cm:
+ handler._validate_request({"model": "x", "input": "hi", "previous_response_id": "abc"})
+ self.assertTrue(any("previous_response_id" in msg for msg in cm.output))
-@require_openai
-class ServeResponsesMixin:
- """
- Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API
- (`generate` and `continuous_batching`).
- """
+ def test_valid_request_passes(self):
+ handler = self._make_handler()
+ # Should not raise
+ handler._validate_request({"model": "x", "input": "hi"})
- @retry
- def run_server(self, request):
- client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="")
- stream = client.responses.create(**request)
- all_payloads = []
- for payload in stream:
- all_payloads.append(payload)
+@require_serve
+class TestResponseGenerationConfig(unittest.TestCase):
+ def _make_handler(self):
+ return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState())
- return all_payloads
+ def test_max_output_tokens(self):
+ from transformers import GenerationConfig
- def test_request(self):
- """Tests that an inference using the Responses API works"""
+ result = self._make_handler()._build_generation_config({"max_output_tokens": 42}, GenerationConfig())
+ self.assertEqual(result.max_new_tokens, 42)
- request = {
- "model": "Qwen/Qwen2.5-0.5B-Instruct",
- "instructions": "You are a helpful assistant.",
- "input": "Hello!",
- "stream": True,
- "max_output_tokens": 1,
- }
- all_payloads = self.run_server(request)
+ def test_default_bumps_short_max_new_tokens(self):
+ from transformers import GenerationConfig
- # Allow variable number of delta events depending on tokenizer/streamer behavior
- self.assertGreaterEqual(len(all_payloads), 8)
+ result = self._make_handler()._build_generation_config({}, GenerationConfig(max_new_tokens=20))
+ self.assertEqual(result.max_new_tokens, 1024)
- # Start markers
- self.assertIsInstance(all_payloads[0], ResponseCreatedEvent)
- self.assertIsInstance(all_payloads[1], ResponseInProgressEvent)
- self.assertIsInstance(all_payloads[2], ResponseOutputItemAddedEvent)
- self.assertIsInstance(all_payloads[3], ResponseContentPartAddedEvent)
- # At least one delta event during streaming
- self.assertTrue(any(isinstance(p, ResponseTextDeltaEvent) for p in all_payloads[4:-4]))
+@require_serve
+class TestResponseUsage(unittest.TestCase):
+ def testcompute_usage(self):
+ usage = compute_usage(input_tokens=100, output_tokens=50)
+ self.assertEqual(usage.input_tokens, 100)
+ self.assertEqual(usage.output_tokens, 50)
+ self.assertEqual(usage.total_tokens, 150)
+ self.assertEqual(usage.input_tokens_details.cached_tokens, 0)
+ self.assertEqual(usage.output_tokens_details.reasoning_tokens, 0)
- # Closing markers
- self.assertIsInstance(all_payloads[-4], ResponseTextDoneEvent)
- self.assertIsInstance(all_payloads[-3], ResponseContentPartDoneEvent)
- self.assertIsInstance(all_payloads[-2], ResponseOutputItemDoneEvent)
- self.assertIsInstance(all_payloads[-1], ResponseCompletedEvent)
+ def test_usage_in_completed_response(self):
+ """Usage should serialize correctly inside a Response."""
+
+ usage = compute_usage(10, 5)
+ response = Response(
+ id="resp_test",
+ created_at=0,
+ status="completed",
+ model="test",
+ output=[],
+ object="response",
+ tools=[],
+ parallel_tool_calls=False,
+ tool_choice="auto",
+ usage=usage,
+ )
+ dumped = response.model_dump(exclude_none=True)
+ self.assertEqual(dumped["usage"]["input_tokens"], 10)
+ self.assertEqual(dumped["usage"]["output_tokens"], 5)
+ self.assertEqual(dumped["usage"]["total_tokens"], 15)
+
+
+@require_serve
+class TestResponseSSEFormat(unittest.TestCase):
+ def test_sse_format(self):
+ event = ResponseCreatedEvent(
+ type="response.created",
+ sequence_number=0,
+ response=Response(
+ id="resp_test",
+ created_at=0,
+ status="queued",
+ model="test",
+ text={"format": {"type": "text"}},
+ object="response",
+ tools=[],
+ output=[],
+ parallel_tool_calls=False,
+ tool_choice="auto",
+ ),
+ )
+ result = BaseHandler.chunk_to_sse(event)
+ self.assertTrue(result.startswith("data: "))
+ self.assertTrue(result.endswith("\n\n"))
+ parsed = json.loads(result[len("data: ") :].strip())
+ self.assertEqual(parsed["type"], "response.created")
+ self.assertEqual(parsed["response"]["status"], "queued")
- # TODO: one test for each request flag, to confirm it is working as expected
- # TODO: speed-based test to confirm that KV cache is working across requests
+@slow
+@require_serve
+class TestResponsesIntegration(unittest.TestCase):
+ """Integration tests for /v1/responses with a real model."""
-@slow # server startup time is slow on our push CI
-@require_openai
-class ServeResponsesIntegrationTest(ServeResponsesMixin, unittest.TestCase):
- """Tests the Responses API."""
+ MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
@classmethod
def setUpClass(cls):
- """Starts a server for tests to connect to."""
- cls.port = 8003
- cls.server = Serve(port=cls.port, default_seed=42, non_blocking=True)
+ cls.serve, port = _start_serve()
+ cls.base_url = f"http://localhost:{port}"
+ cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
@classmethod
def tearDownClass(cls):
- cls.server.kill_server()
-
- @slow
- def test_full_request(self):
- """Tests that an inference using the Responses API works"""
-
- request = {
- "model": "Qwen/Qwen2.5-0.5B-Instruct",
- "instructions": "You are a sports assistant designed to craft sports programs.",
- "input": "Tell me what you can do.",
- "stream": True,
- "max_output_tokens": 30,
- # Disable sampling for deterministic output
- "temperature": 0,
+ cls.serve.kill_server()
+
+ def test_streaming(self):
+ events = list(
+ self.client.responses.create(
+ model=self.MODEL,
+ input="Say hello",
+ stream=True,
+ max_output_tokens=1,
+ )
+ )
+ # At least 8 events: created, in_progress, output_item_added, content_part_added,
+ # delta(s), text_done, content_part_done, output_item_done, completed
+ self.assertGreaterEqual(len(events), 8)
+
+ # Start markers (fixed order)
+ self.assertEqual(events[0].type, "response.created")
+ self.assertEqual(events[1].type, "response.in_progress")
+ self.assertEqual(events[2].type, "response.output_item.added")
+ self.assertEqual(events[3].type, "response.content_part.added")
+
+ # At least one delta
+ self.assertTrue(any(e.type == "response.output_text.delta" for e in events[4:-4]))
+
+ # Closing markers (fixed order from the end)
+ self.assertEqual(events[-4].type, "response.output_text.done")
+ self.assertEqual(events[-3].type, "response.content_part.done")
+ self.assertEqual(events[-2].type, "response.output_item.done")
+ self.assertEqual(events[-1].type, "response.completed")
+
+ def test_non_streaming(self):
+ resp = self.client.responses.create(
+ model=self.MODEL,
+ input="Say hello",
+ stream=False,
+ )
+ self.assertEqual(resp.status, "completed")
+ self.assertTrue(len(resp.output) > 0)
+ self.assertTrue(len(resp.output[0].content[0].text) > 0)
+
+ def test_non_streaming_usage(self):
+ resp = self.client.responses.create(
+ model=self.MODEL,
+ input="Say hello",
+ stream=False,
+ )
+ self.assertIsNotNone(resp.usage)
+ self.assertGreater(resp.usage.input_tokens, 0)
+ self.assertGreater(resp.usage.output_tokens, 0)
+ self.assertEqual(resp.usage.total_tokens, resp.usage.input_tokens + resp.usage.output_tokens)
+
+ def test_streaming_usage(self):
+ events = list(
+ self.client.responses.create(
+ model=self.MODEL,
+ input="Say hello",
+ stream=True,
+ max_output_tokens=5,
+ )
+ )
+ completed = events[-1]
+ self.assertEqual(completed.type, "response.completed")
+ usage = completed.response.usage
+ self.assertIsNotNone(usage)
+ self.assertGreater(usage.input_tokens, 0)
+ self.assertGreater(usage.output_tokens, 0)
+ self.assertEqual(usage.total_tokens, usage.input_tokens + usage.output_tokens)
+
+ def test_tool_call_streaming(self):
+ """Streaming responses with tools should emit function_call events."""
+ tool_def = {
+ "function": {
+ "name": "get_weather",
+ "parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
+ "description": "Get the weather for a city.",
+ },
+ "type": "function",
+ }
+ events = list(
+ self.client.responses.create(
+ model=self.MODEL,
+ input="What is the weather in Paris?",
+ stream=True,
+ max_output_tokens=50,
+ tools=[tool_def],
+ )
+ )
+ types = [e.type for e in events]
+ self.assertIn("response.created", types)
+ self.assertIn("response.completed", types)
+
+ # Should have function call events
+ self.assertIn("response.output_item.added", types)
+ self.assertIn("response.function_call_arguments.done", types)
+
+ # Check the arguments done event
+ args_done = [e for e in events if e.type == "response.function_call_arguments.done"]
+ self.assertGreater(len(args_done), 0)
+ self.assertEqual(args_done[0].name, "get_weather")
+
+ import json as json_mod
+
+ parsed = json_mod.loads(args_done[0].arguments)
+ self.assertIsInstance(parsed, dict)
+
+ def test_tool_call_non_streaming(self):
+ """Non-streaming responses with tools should include function_call output items."""
+ tool_def = {
+ "function": {
+ "name": "get_weather",
+ "parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
+ "description": "Get the weather for a city.",
+ },
+ "type": "function",
}
- all_payloads = self.run_server(request)
-
- full_text = ""
- for token in all_payloads:
- if isinstance(token, ResponseTextDeltaEvent):
- full_text += token.delta
-
- # Verify that the system prompt went through.
- # With deterministic decoding, exact wording can still vary across versions.
- # Assert non-empty output and that it references sports.
- self.assertTrue(len(full_text) > 0)
- self.assertIn("sports", full_text.lower())
-
- @slow
- def test_non_streaming_request(self):
- """Tests that an inference using the Responses API with stream=False returns a single Response payload."""
- from openai import OpenAI
- from openai.types.responses import Response as OpenAIResponse
-
- client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="")
- resp = client.responses.create(
- model="Qwen/Qwen2.5-0.5B-Instruct",
- instructions="You are a helpful assistant.",
- input="Hello!",
+ resp = self.client.responses.create(
+ model=self.MODEL,
+ input="What is the weather in Paris?",
stream=False,
- max_output_tokens=5,
+ max_output_tokens=50,
+ tools=[tool_def],
)
+ self.assertEqual(resp.status, "completed")
+
+ # Should have at least message + function_call in output
+ self.assertGreater(len(resp.output), 1)
+ fc_items = [o for o in resp.output if o.type == "function_call"]
+ self.assertGreater(len(fc_items), 0)
+ self.assertEqual(fc_items[0].name, "get_weather")
+
+ import json as json_mod
- # Should be a single Response object with completed status and one output item containing text
- self.assertIsInstance(resp, OpenAIResponse)
+ parsed = json_mod.loads(fc_items[0].arguments)
+ self.assertIsInstance(parsed, dict)
+
+ def test_tool_call_multi(self):
+ """Model should produce multiple tool calls when asked about two cities."""
+ tool_def = {
+ "function": {
+ "name": "get_weather",
+ "parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
+ "description": "Get the weather for a city.",
+ },
+ "type": "function",
+ }
+ events = list(
+ self.client.responses.create(
+ model=self.MODEL,
+ input="What is the weather in Paris and London?",
+ stream=True,
+ max_output_tokens=100,
+ tools=[tool_def],
+ )
+ )
+ args_done = [e for e in events if e.type == "response.function_call_arguments.done"]
+ self.assertEqual(len(args_done), 2, f"Expected 2 tool calls, got {len(args_done)}")
+ self.assertEqual(events[-1].type, "response.completed")
+
+ def test_multi_turn(self):
+ """Multi-turn conversation via list input."""
+ resp = self.client.responses.create(
+ model=self.MODEL,
+ input=[
+ {"role": "user", "content": "My name is Alice"},
+ {"role": "assistant", "content": "Nice to meet you!"},
+ {"role": "user", "content": "What is my name?"},
+ ],
+ stream=False,
+ )
self.assertEqual(resp.status, "completed")
- self.assertTrue(len(resp.output) >= 1)
- first_item = resp.output[0]
- self.assertEqual(first_item.type, "message")
- self.assertEqual(first_item.status, "completed")
- self.assertTrue(len(first_item.content) >= 1)
- first_part = first_item.content[0]
- self.assertEqual(first_part.type, "output_text")
- self.assertIsInstance(first_part.text, str)
+ self.assertIn("Alice", resp.output[0].content[0].text)
+ def test_concurrent_non_streaming(self):
+ """Two concurrent non-streaming responses requests should both complete."""
+ import concurrent.futures
-class ServeInfrastructureTest(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- cls.port = 8042
- thread = Thread(target=Serve, kwargs={"port": cls.port})
- thread.daemon = True
- thread.start()
-
- def test_healthcheck(self):
- """Tests that the healthcheck endpoint works."""
- response = _call_healthcheck(f"http://localhost:{self.port}")
- self.assertIsNotNone(response, "Failed to connect to the server health endpoint.")
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.json(), {"status": "ok"})
+ inputs = ["Say hello", "Say goodbye"]
+ results = [None, None]
+
+ def request_in_thread(index):
+ client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
+ results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
+ futures = [pool.submit(request_in_thread, i) for i in range(2)]
+ concurrent.futures.wait(futures)
+ for f in futures:
+ f.result()
+
+ for i in range(2):
+ self.assertIsNotNone(results[i])
+ self.assertEqual(results[i].status, "completed")
+ self.assertTrue(len(results[i].output[0].content[0].text) > 0)
+
+ def test_concurrent_streaming(self):
+ """Two concurrent streaming responses requests should both produce complete event streams."""
+ import concurrent.futures
+
+ inputs = ["Say hello", "Say goodbye"]
+ results = [None, None]
+ def stream_in_thread(index):
+ client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
+ results[index] = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True))
-def parse_sse_events(response):
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
+ futures = [pool.submit(stream_in_thread, i) for i in range(2)]
+ concurrent.futures.wait(futures)
+ for f in futures:
+ f.result()
+
+ for i in range(2):
+ types = [e.type for e in results[i]]
+ self.assertIn("response.created", types, f"Request {i} missing created event")
+ self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events")
+ self.assertIn("response.completed", types, f"Request {i} missing completed event")
+
+
+def _parse_sse_events(response):
"""Parse SSE lines from a streaming httpx response into a list of dicts."""
events = []
for line in response.iter_lines():
- if not line:
+ if not line or not line.startswith("data: "):
continue
- if line.startswith("data: "):
- events.append(json.loads(line[6:]))
+ events.append(json.loads(line[len("data: ") :]))
return events
@slow
-@require_openai
-class ServeLoadModelIntegrationTest(unittest.TestCase):
- """Tests the /load_model SSE endpoint."""
+@require_serve
+class TestLoadModel(unittest.TestCase):
+ """Integration tests for POST /load_model SSE endpoint."""
+
+ MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
@classmethod
def setUpClass(cls):
- cls.port = 8043
- cls.server = Serve(port=cls.port, non_blocking=True)
- cls.base_url = f"http://localhost:{cls.port}"
- # Wait for the server to be ready
- response = _call_healthcheck(cls.base_url)
- assert response is not None and response.status_code == 200
+ cls.serve, port = _start_serve()
+ cls.base_url = f"http://localhost:{port}"
@classmethod
def tearDownClass(cls):
- cls.server.kill_server()
+ cls.serve.kill_server()
def setUp(self):
- # Clear the in-memory model cache so each test starts fresh
- self.server.reset_loaded_models()
+ # Clear model cache so each test starts fresh
+ self.serve.reset_loaded_models()
def _load_model(self, model: str):
- with httpx.Client(timeout=120) as client:
- with client.stream("POST", f"{self.base_url}/load_model", json={"model": model}) as response:
- events = parse_sse_events(response)
- return response, events
+ with httpx.stream("POST", f"{self.base_url}/load_model", json={"model": model}, timeout=120) as resp:
+ events = _parse_sse_events(resp)
+ return resp, events
def test_load_model_fresh(self):
- """POST /load_model with a valid model returns SSE events ending with ready."""
- response, events = self._load_model("Qwen/Qwen2.5-0.5B-Instruct")
+ """POST /load_model returns SSE events ending with ready."""
+ response, events = self._load_model(self.MODEL)
self.assertEqual(response.status_code, 200)
- self.assertIn("text/event-stream", response.headers.get("content-type", ""))
- # Extract stages from loading events
stages = [e["stage"] for e in events if e["status"] == "loading" and "stage" in e]
self.assertIn("processor", stages)
- self.assertIn("config", stages)
self.assertIn("weights", stages)
- # Stages must appear in the correct order
- stage_indices = {stage: i for i, stage in enumerate(stages) if stage in ("processor", "config", "weights")}
- self.assertLess(stage_indices["processor"], stage_indices["config"])
- self.assertLess(stage_indices["config"], stage_indices["weights"])
-
- # Last event is ready with cached: false
last = events[-1]
self.assertEqual(last["status"], "ready")
self.assertFalse(last["cached"])
- # Every event has status and model
for event in events:
self.assertIn("status", event)
self.assertIn("model", event)
def test_load_model_cached(self):
- """Loading a model that is already in memory returns a single ready event with cached: true."""
- # First load to ensure the model is in memory
- self._load_model("Qwen/Qwen2.5-0.5B-Instruct")
+ """Loading an already-loaded model returns a single ready event with cached: true."""
+ self._load_model(self.MODEL)
- # Second load should be cached
- _, events = self._load_model("Qwen/Qwen2.5-0.5B-Instruct")
+ _, events = self._load_model(self.MODEL)
ready_events = [e for e in events if e["status"] == "ready"]
self.assertEqual(len(ready_events), 1)
self.assertTrue(ready_events[0]["cached"])
- # No loading events should be present
loading_events = [e for e in events if e["status"] == "loading"]
self.assertEqual(len(loading_events), 0)
@@ -987,18 +1311,18 @@ def test_load_model_error(self):
_, events = self._load_model("nonexistent/model-that-does-not-exist")
error_events = [e for e in events if e["status"] == "error"]
- self.assertGreaterEqual(len(error_events), 1, "Expected at least one error event")
+ self.assertGreaterEqual(len(error_events), 1)
self.assertIn("message", error_events[0])
def test_load_model_missing_field(self):
"""POST /load_model with no model field returns 422."""
- with httpx.Client(timeout=30) as client:
- response = client.post(f"{self.base_url}/load_model", json={})
- self.assertEqual(response.status_code, 422)
+
+ response = httpx.post(f"{self.base_url}/load_model", json={}, timeout=30)
+ self.assertEqual(response.status_code, 422)
def test_load_model_event_schema(self):
- """Every event in a load_model stream conforms to the expected schema."""
- _, events = self._load_model("Qwen/Qwen2.5-0.5B-Instruct")
+ """Every event conforms to the expected schema."""
+ _, events = self._load_model(self.MODEL)
for event in events:
self.assertIsInstance(event["status"], str)
@@ -1006,7 +1330,6 @@ def test_load_model_event_schema(self):
if event["status"] == "loading":
self.assertIn("stage", event)
-
if event["stage"] in ("download", "weights") and "progress" in event:
progress = event["progress"]
self.assertIn("current", progress)
@@ -1018,12 +1341,10 @@ def test_load_model_event_schema(self):
self.assertIsInstance(event["cached"], bool)
def test_load_model_stage_ordering(self):
- """Stages in loading events follow the expected order."""
- _, events = self._load_model("Qwen/Qwen2.5-0.5B-Instruct")
+ """Stages appear in the expected order."""
+ _, events = self._load_model(self.MODEL)
stages = [e["stage"] for e in events if e["status"] == "loading" and "stage" in e]
-
- # Deduplicate while preserving order (stages repeat for progress ticks)
seen = set()
unique_stages = []
for s in stages:
@@ -1032,29 +1353,23 @@ def test_load_model_stage_ordering(self):
unique_stages.append(s)
expected_order = ["processor", "config", "download", "weights"]
- # Filter expected_order to only stages that are actually present
expected_present = [s for s in expected_order if s in unique_stages]
-
self.assertEqual(unique_stages, expected_present, "Stages appeared out of order")
def test_concurrent_load_same_model(self):
- """Two concurrent /load_model requests for the same model should both receive progress events
- and a final ready event, but the model should only be loaded once."""
+ """Two concurrent /load_model requests both get events and a ready event."""
import concurrent.futures
- model = "Qwen/Qwen2.5-0.5B-Instruct"
results = [None, None]
def load_in_thread(index):
- with httpx.Client(timeout=120) as client:
- with client.stream("POST", f"{self.base_url}/load_model", json={"model": model}) as response:
- events = parse_sse_events(response)
- results[index] = (response.status_code, events)
+ with httpx.stream("POST", f"{self.base_url}/load_model", json={"model": self.MODEL}, timeout=120) as resp:
+ events = _parse_sse_events(resp)
+ results[index] = (resp.status_code, events)
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
futures = [pool.submit(load_in_thread, i) for i in range(2)]
concurrent.futures.wait(futures)
- # Re-raise any exceptions from threads
for f in futures:
f.result()
@@ -1062,28 +1377,517 @@ def load_in_thread(index):
status_code, events = results[i]
self.assertEqual(status_code, 200, f"Caller {i} got non-200 status")
self.assertTrue(len(events) > 0, f"Caller {i} received no events")
-
ready_events = [e for e in events if e["status"] == "ready"]
self.assertEqual(len(ready_events), 1, f"Caller {i} should get exactly one ready event")
- self.assertIn("model", ready_events[0])
- def test_concurrent_load_second_caller_gets_cached_if_first_finishes(self):
- """If the first /load_model finishes before the second arrives,
- the second caller should get a cached response."""
- model = "Qwen/Qwen2.5-0.5B-Instruct"
-
- # First load — blocks until complete
- _, events1 = self._load_model(model)
+ def test_concurrent_load_second_caller_gets_cached(self):
+ """If the first /load_model finishes before the second, the second gets cached: true."""
+ _, events1 = self._load_model(self.MODEL)
ready1 = [e for e in events1 if e["status"] == "ready"]
self.assertEqual(len(ready1), 1)
self.assertFalse(ready1[0]["cached"])
- # Second load — model is now in memory
- _, events2 = self._load_model(model)
+ _, events2 = self._load_model(self.MODEL)
ready2 = [e for e in events2 if e["status"] == "ready"]
self.assertEqual(len(ready2), 1)
self.assertTrue(ready2[0]["cached"])
- # No loading events on the cached path
loading2 = [e for e in events2 if e["status"] == "loading"]
self.assertEqual(len(loading2), 0)
+
+ def test_load_model_weights_progress_complete(self):
+ """Weights progress should go from 1 to total, with total matching across events."""
+ _, events = self._load_model(self.MODEL)
+
+ weights_events = [e for e in events if e.get("stage") == "weights" and "progress" in e]
+ self.assertGreater(len(weights_events), 0, "No weights progress events emitted")
+
+ # All events should have the same total
+ totals = {e["progress"]["total"] for e in weights_events}
+ self.assertEqual(len(totals), 1, f"Inconsistent totals: {totals}")
+ total = totals.pop()
+ self.assertIsNotNone(total)
+ self.assertGreater(total, 0)
+
+ # First should be 1, last should be total
+ self.assertEqual(weights_events[0]["progress"]["current"], 1)
+ self.assertEqual(weights_events[-1]["progress"]["current"], total)
+
+ # Progress should be monotonically increasing
+ currents = [e["progress"]["current"] for e in weights_events]
+ self.assertEqual(currents, sorted(currents))
+
+ def test_load_model_exactly_one_ready(self):
+ """A fresh load should produce exactly one ready event as the last event."""
+ _, events = self._load_model(self.MODEL)
+
+ ready_events = [e for e in events if e["status"] == "ready"]
+ self.assertEqual(len(ready_events), 1)
+ self.assertEqual(events[-1]["status"], "ready")
+
+ def test_load_model_usable_after_load(self):
+ """After /load_model completes, the model should be usable for inference."""
+ self._load_model(self.MODEL)
+
+ client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
+ resp = client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": "Say hi"}],
+ max_tokens=5,
+ )
+ self.assertIsNotNone(resp.choices[0].message.content)
+ self.assertTrue(len(resp.choices[0].message.content) > 0)
+
+ def test_load_model_model_field_matches(self):
+ """The model field in every event should match the canonical model ID."""
+ _, events = self._load_model(self.MODEL)
+
+ for event in events:
+ self.assertTrue(
+ event["model"].startswith(self.MODEL),
+ f"Event model '{event['model']}' doesn't match '{self.MODEL}'",
+ )
+
+ def test_concurrent_non_streaming(self):
+ """Two concurrent non-streaming responses requests should both complete."""
+ import concurrent.futures
+
+ inputs = ["Say hello", "Say goodbye"]
+ results = [None, None]
+
+ def request_in_thread(index):
+ client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
+ results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
+ futures = [pool.submit(request_in_thread, i) for i in range(2)]
+ concurrent.futures.wait(futures)
+ for f in futures:
+ f.result()
+
+ for i in range(2):
+ self.assertIsNotNone(results[i])
+ self.assertEqual(results[i].status, "completed")
+ self.assertTrue(len(results[i].output[0].content[0].text) > 0)
+
+ def test_concurrent_streaming(self):
+ """Two concurrent streaming responses requests should both produce complete event streams."""
+ import concurrent.futures
+
+ inputs = ["Say hello", "Say goodbye"]
+ results = [None, None]
+
+ def stream_in_thread(index):
+ client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
+ events = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True))
+ results[index] = events
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
+ futures = [pool.submit(stream_in_thread, i) for i in range(2)]
+ concurrent.futures.wait(futures)
+ for f in futures:
+ f.result()
+
+ for i in range(2):
+ types = [e.type for e in results[i]]
+ self.assertIn("response.created", types, f"Request {i} missing created event")
+ self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events")
+ self.assertIn("response.completed", types, f"Request {i} missing completed event")
+
+
+# Real image URL for VLM tests (person + dog on a beach)
+_DOG_IMAGE_URL = "https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2-VL/demo_small.jpg"
+
+
+@slow
+@require_vision
+@require_serve
+class TestVLM(unittest.TestCase):
+ """Integration tests for VLM (vision-language model) support. Requires torchvision."""
+
+ MODEL = "HuggingFaceTB/SmolVLM-256M-Instruct"
+
+ @classmethod
+ def setUpClass(cls):
+ cls.serve, port = _start_serve()
+ cls.base_url = f"http://localhost:{port}"
+ cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.serve.kill_server()
+
+ def test_chat_completion_with_image(self):
+ """Chat completions should accept image_url content and produce a meaningful response."""
+ resp = self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What do you see in this image?"},
+ {"type": "image_url", "image_url": {"url": _DOG_IMAGE_URL}},
+ ],
+ }
+ ],
+ max_tokens=50,
+ )
+ text = resp.choices[0].message.content
+ self.assertIsNotNone(text)
+ self.assertTrue(
+ any(word in text.lower() for word in ["dog", "beach", "person"]),
+ f"Expected dog/beach/person in response, got: {text}",
+ )
+
+ def test_responses_with_image(self):
+ """Responses API should accept image_url content and produce a meaningful response."""
+ resp = self.client.responses.create(
+ model=self.MODEL,
+ input=[
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What do you see in this image?"},
+ {"type": "image_url", "image_url": {"url": _DOG_IMAGE_URL}},
+ ],
+ }
+ ],
+ stream=False,
+ max_output_tokens=50,
+ )
+ self.assertEqual(resp.status, "completed")
+ text = resp.output[0].content[0].text
+ self.assertTrue(
+ any(word in text.lower() for word in ["dog", "beach", "person"]),
+ f"Expected dog/beach/person in response, got: {text}",
+ )
+
+
+@slow
+@require_librosa
+@require_multipart
+@require_serve
+class TestTranscription(unittest.TestCase):
+ """Integration tests for POST /v1/audio/transcriptions with whisper-tiny."""
+
+ MODEL = "openai/whisper-tiny"
+
+ @classmethod
+ def setUpClass(cls):
+ cls.serve, port = _start_serve()
+ cls.base_url = f"http://localhost:{port}"
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.serve.kill_server()
+
+ @classmethod
+ def _get_audio_bytes(cls):
+ """Download the MLK 'I have a dream' speech sample from HF Hub."""
+ if not hasattr(cls, "_audio_bytes"):
+ from huggingface_hub import hf_hub_download
+
+ path = hf_hub_download("Narsil/asr_dummy", "mlk.flac", repo_type="dataset")
+ with open(path, "rb") as f:
+ cls._audio_bytes = f.read()
+ return cls._audio_bytes
+
+ def test_transcription_returns_text(self):
+ """POST /v1/audio/transcriptions with real speech returns meaningful transcription."""
+
+ audio_bytes = self._get_audio_bytes()
+ resp = httpx.post(
+ f"{self.base_url}/v1/audio/transcriptions",
+ files={"file": ("mlk.flac", audio_bytes, "audio/flac")},
+ data={"model": self.MODEL},
+ timeout=120,
+ )
+ self.assertEqual(resp.status_code, 200)
+ data = resp.json()
+ self.assertIn("text", data)
+ self.assertIsInstance(data["text"], str)
+ # Whisper-tiny should recognize at least "dream" from the MLK speech
+ self.assertIn("dream", data["text"].lower())
+
+ def test_transcription_openai_client(self):
+ """Transcription should work via the OpenAI Python client."""
+ audio_bytes = self._get_audio_bytes()
+ client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
+ result = client.audio.transcriptions.create(
+ model=self.MODEL,
+ file=("mlk.flac", audio_bytes),
+ )
+ self.assertIsInstance(result.text, str)
+ self.assertTrue(len(result.text) > 10)
+
+ def test_transcription_streaming(self):
+ """Streaming transcription should yield text chunks via SSE."""
+
+ audio_bytes = self._get_audio_bytes()
+ with httpx.stream(
+ "POST",
+ f"{self.base_url}/v1/audio/transcriptions",
+ files={"file": ("mlk.flac", audio_bytes, "audio/flac")},
+ data={"model": self.MODEL, "stream": "true"},
+ timeout=120,
+ ) as resp:
+ self.assertEqual(resp.status_code, 200)
+
+ chunks = []
+ for line in resp.iter_lines():
+ if line and line.startswith("data: "):
+ chunks.append(line[len("data: ") :])
+
+ self.assertGreater(len(chunks), 0, "No streaming chunks received")
+ full_text = "".join(chunks)
+ self.assertIn("dream", full_text.lower())
+
+ def test_transcription_missing_file(self):
+ """POST without a file should fail."""
+
+ resp = httpx.post(
+ f"{self.base_url}/v1/audio/transcriptions",
+ data={"model": self.MODEL},
+ timeout=30,
+ )
+ self.assertNotEqual(resp.status_code, 200)
+
+
+@slow
+@require_serve
+@require_torch_accelerator
+class TestContinuousBatchingChatCompletion(unittest.TestCase):
+ """Integration tests for /v1/chat/completions with continuous batching."""
+
+ MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
+
+ @classmethod
+ def setUpClass(cls):
+ cls.serve, port = _start_serve(
+ force_model=cls.MODEL,
+ device="cuda:0",
+ continuous_batching=True,
+ attn_implementation="sdpa",
+ default_seed=42,
+ )
+ cls.base_url = f"http://localhost:{port}"
+ cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.serve.kill_server()
+
+ def test_streaming(self):
+ """Streaming chat completion with CB produces text."""
+ text = ""
+ for chunk in self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[
+ {"role": "system", "content": "You are a sports assistant designed to craft sports programs."},
+ {"role": "user", "content": "Tell me what you can do."},
+ ],
+ stream=True,
+ max_tokens=30,
+ ):
+ if chunk.choices[0].delta.content:
+ text += chunk.choices[0].delta.content
+ self.assertTrue(len(text) > 0)
+
+ def test_non_streaming(self):
+ """Non-streaming chat completion with CB returns a full response."""
+ resp = self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": "Say hello"}],
+ max_tokens=20,
+ )
+ self.assertIsNotNone(resp.choices[0].message.content)
+ self.assertTrue(len(resp.choices[0].message.content) > 0)
+
+ def test_non_streaming_response_json_format(self):
+ """Non-streaming CB responses return proper JSON objects, not double-encoded strings."""
+ response = self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": "Say hello"}],
+ stream=False,
+ max_tokens=5,
+ )
+ self.assertIsNotNone(response)
+ self.assertIsNotNone(response.id)
+ self.assertIsNotNone(response.choices)
+ self.assertEqual(len(response.choices), 1)
+
+ choice = response.choices[0]
+ self.assertIsNotNone(choice.message)
+ self.assertIsNotNone(choice.message.content)
+ self.assertEqual(choice.message.role, "assistant")
+ self.assertIsInstance(choice.message.content, str)
+
+ def test_multi_turn(self):
+ """Multi-turn conversation works with CB."""
+ resp = self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[
+ {"role": "user", "content": "My name is Alice"},
+ {"role": "assistant", "content": "Nice to meet you!"},
+ {"role": "user", "content": "What is my name?"},
+ ],
+ max_tokens=20,
+ )
+ self.assertIn("Alice", resp.choices[0].message.content)
+
+ def test_request_cancellation(self):
+ """Opening a stream and closing it early triggers CB cancellation."""
+
+ request_id = "test-cb-cancel"
+
+ # Open a streaming request and close after a few chunks
+ with httpx.stream(
+ "POST",
+ f"{self.base_url}/v1/chat/completions",
+ headers={"X-Request-ID": request_id},
+ json={
+ "model": self.MODEL,
+ "stream": True,
+ "messages": [{"role": "user", "content": "Count slowly so I can cancel you."}],
+ },
+ timeout=30,
+ ) as resp:
+ self.assertEqual(resp.status_code, 200)
+ chunks_read = 0
+ for _ in resp.iter_lines():
+ chunks_read += 1
+ if chunks_read >= 3:
+ break
+
+ # Poll for cancellation in the CB scheduler
+ scheduler = self.serve._generation_state._cb_manager.scheduler
+ deadline = time.time() + 8.0
+ while time.time() < deadline:
+ if scheduler.request_is_cancelled(request_id):
+ break
+ time.sleep(0.1)
+
+ self.assertTrue(
+ scheduler.request_is_cancelled(request_id),
+ f"Request {request_id} not cancelled in scheduler after stream close.",
+ )
+
+ # Server should still be healthy and serve subsequent requests
+ resp = self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": "Say hi"}],
+ max_tokens=10,
+ )
+ self.assertIsNotNone(resp.choices[0].message.content)
+
+
+@slow
+@require_serve
+@require_torch_accelerator
+class TestContinuousBatchingResponses(unittest.TestCase):
+ """Integration tests for /v1/responses with continuous batching."""
+
+ MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
+
+ @classmethod
+ def setUpClass(cls):
+ cls.serve, port = _start_serve(
+ force_model=cls.MODEL,
+ device="cuda:0",
+ continuous_batching=True,
+ attn_implementation="sdpa",
+ default_seed=42,
+ )
+ cls.base_url = f"http://localhost:{port}"
+ cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.serve.kill_server()
+
+ def test_streaming(self):
+ """Streaming response with CB produces text."""
+ text = ""
+ stream = self.client.responses.create(
+ model=self.MODEL,
+ input="Say hello in one sentence.",
+ stream=True,
+ max_output_tokens=30,
+ )
+ for event in stream:
+ if event.type == "response.output_text.delta":
+ text += event.delta
+ self.assertTrue(len(text) > 0)
+
+ def test_non_streaming(self):
+ """Non-streaming response with CB returns text."""
+ resp = self.client.responses.create(
+ model=self.MODEL,
+ input="Say hello in one sentence.",
+ stream=False,
+ max_output_tokens=30,
+ )
+ content = resp.output[0].content[0].text
+ self.assertTrue(len(content) > 0)
+
+ def test_multi_turn(self):
+ """Multi-turn conversation works with CB via Responses API."""
+ resp = self.client.responses.create(
+ model=self.MODEL,
+ input=[
+ {"role": "user", "content": "My name is Alice"},
+ {"role": "assistant", "content": "Nice to meet you!"},
+ {"role": "user", "content": "What is my name?"},
+ ],
+ stream=False,
+ max_output_tokens=20,
+ )
+ content = resp.output[0].content[0].text
+ self.assertIn("Alice", content)
+
+ def test_request_cancellation(self):
+ """Opening a stream and closing it early triggers CB cancellation."""
+
+ request_id = "test-cb-resp-cancel"
+
+ with httpx.stream(
+ "POST",
+ f"{self.base_url}/v1/responses",
+ headers={"X-Request-ID": request_id},
+ json={
+ "model": self.MODEL,
+ "stream": True,
+ "input": "Count slowly so I can cancel you.",
+ "max_output_tokens": 500,
+ },
+ timeout=30,
+ ) as resp:
+ self.assertEqual(resp.status_code, 200)
+ # Read enough data to ensure CB generation has started, then close.
+ received = b""
+ for chunk in resp.iter_bytes(chunk_size=512):
+ received += chunk
+ if b"output_text.delta" in received:
+ break
+
+ # Poll for cancellation in the CB scheduler
+ scheduler = self.serve._generation_state._cb_manager.scheduler
+ deadline = time.time() + 8.0
+ while time.time() < deadline:
+ if scheduler.request_is_cancelled(request_id):
+ break
+ time.sleep(0.1)
+
+ self.assertTrue(
+ scheduler.request_is_cancelled(request_id),
+ f"Request {request_id} not cancelled in scheduler after stream close.",
+ )
+
+ # Server should still serve subsequent requests
+ resp = self.client.responses.create(
+ model=self.MODEL,
+ input="Say hi",
+ stream=False,
+ max_output_tokens=10,
+ )
+ self.assertTrue(len(resp.output[0].content[0].text) > 0)