From ee4c59821747f41be4ebbf5c69fce30438d2337f Mon Sep 17 00:00:00 2001 From: abhiprd200 Date: Fri, 24 Apr 2026 23:04:33 +0530 Subject: [PATCH 1/5] Fix NameError in serving CLI due to conditional import asymmetry The serving module conditionally imports OpenAI types behind is_serve_available(), but unconditionally inherits from them in the global scope. This causes a fatal NameError when the server boots without the serving extras installed. This patch provides dummy fallback types to allow the CLI classes to initialize safely. --- .../cli/serving/chat_completion.py | 36 ++++++++++++++----- src/transformers/cli/serving/completion.py | 19 ++++++---- src/transformers/cli/serving/transcription.py | 14 ++++++-- 3 files changed, 52 insertions(+), 17 deletions(-) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 161a25a02f41..31f58bf629ed 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -26,7 +26,8 @@ from ...utils.import_utils import is_serve_available -if is_serve_available(): +# --- BRUTE FORCE IMPORT PATCH --- +try: from fastapi.responses import JSONResponse, StreamingResponse from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import Choice @@ -34,8 +35,32 @@ from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming from openai.types.completion_usage import CompletionUsage - - + + parent_class = CompletionCreateParamsStreaming +except ImportError: + from typing import TypedDict + + class _DummyDict(dict): + def __getattr__(self, name): return None + def __setattr__(self, name, value): self[name] = value + + class ChatCompletion(_DummyDict): pass + class ChatCompletionMessage(_DummyDict): pass + class ChatCompletionMessageToolCall(_DummyDict): pass + class Choice(_DummyDict): pass + class ChatCompletionChunk(_DummyDict): pass + class ChoiceDelta(_DummyDict): pass + class ChoiceDeltaToolCall(_DummyDict): pass + class ChoiceChunk(_DummyDict): pass + class CompletionCreateParamsStreaming(_DummyDict): pass + class CompletionUsage(_DummyDict): pass + + parent_class = TypedDict + +class TransformersCompletionCreateParamsStreaming(parent_class, total=False): + generation_config: str + seed: int +# --- END PATCH --- from .utils import ( BaseGenerateManager, BaseHandler, @@ -50,11 +75,6 @@ from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin -class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): - generation_config: str - seed: int - - # Fields accepted by the OpenAI schema but not yet supported. # Receiving these raises an error to avoid silent misbehaviour. # NOTE: "stop" is NOT in this set — we map it to stop_strings. diff --git a/src/transformers/cli/serving/completion.py b/src/transformers/cli/serving/completion.py index 52c1f1b8471d..0cd40b3e9669 100644 --- a/src/transformers/cli/serving/completion.py +++ b/src/transformers/cli/serving/completion.py @@ -22,7 +22,7 @@ import asyncio import time from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypedDict from ...utils import logging from ...utils.import_utils import is_serve_available @@ -42,11 +42,18 @@ from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin -class TransformersTextCompletionCreateParams(CompletionCreateParamsBase, total=False): - generation_config: str - seed: int - stream: bool - +# --- FINAL ROBUST PATCH --- +if "CompletionCreateParamsBase" in globals(): + # If the real OpenAI class was successfully imported, use it + class TransformersTextCompletionCreateParams(CompletionCreateParamsBase, total=False): + generation_config: str + seed: int +else: + # Fallback to standard TypedDict if OpenAI types are missing + class TransformersTextCompletionCreateParams(TypedDict, total=False): + generation_config: str + seed: int +# --- END PATCH --- # Fields accepted by the OpenAI schema but not yet supported. UNUSED_LEGACY_COMPLETION_FIELDS = { diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py index 5865dc77029f..d6730f1092e0 100644 --- a/src/transformers/cli/serving/transcription.py +++ b/src/transformers/cli/serving/transcription.py @@ -16,7 +16,7 @@ """ import io -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypedDict from ...utils import logging from ...utils.import_utils import is_serve_available @@ -38,8 +38,16 @@ logger = logging.get_logger(__name__) -class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False): - stream: bool +# --- FINAL ROBUST PATCH --- +if "TranscriptionCreateParamsBase" in globals(): + class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False): + generation_config: str + seed: int +else: + class TransformersTranscriptionCreateParams(TypedDict, total=False): + generation_config: str + seed: int +# --- END PATCH --- UNUSED_TRANSCRIPTION_FIELDS = { From 1e3504bc4db7b027cc47feae9db118f3f39f6238 Mon Sep 17 00:00:00 2001 From: abhiprd200 Date: Fri, 24 Apr 2026 23:11:47 +0530 Subject: [PATCH 2/5] Fix NameError in serving CLI due to conditional import asymmetry The serving module conditionally imports OpenAI types behind is_serve_available(), but unconditionally inherits from them in the global scope. This causes a fatal NameError when the server boots without the serving extras installed. This patch provides dummy fallback types to allow the CLI classes to initialize safely. --- src/transformers/cli/serving/response.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 4d29dfd1d6a2..4ac93660c89a 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -20,7 +20,7 @@ import asyncio import time from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypedDict from ...utils import logging from ...utils.import_utils import is_serve_available @@ -69,10 +69,16 @@ logger = logging.get_logger(__name__) -class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): - generation_config: str - seed: int - +# --- FINAL ROBUST PATCH --- +if "ResponseCreateParamsStreaming" in globals(): + class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): + generation_config: str + seed: int +else: + class TransformersResponseCreateParamsStreaming(TypedDict, total=False): + generation_config: str + seed: int +# --- END PATCH --- UNUSED_RESPONSE_FIELDS = { "background", From 662508f6be3d57b5c7f07a115683e386eb0a7c34 Mon Sep 17 00:00:00 2001 From: abhiprd200 Date: Sat, 25 Apr 2026 19:37:13 +0530 Subject: [PATCH 3/5] style: fix formatting and linting across all serving files --- .../cli/serving/chat_completion.py | 147 +++++++++++++----- src/transformers/cli/serving/completion.py | 81 ++++++++-- src/transformers/cli/serving/model_manager.py | 89 +++++++++-- src/transformers/cli/serving/response.py | 125 +++++++++++---- src/transformers/cli/serving/server.py | 11 +- src/transformers/cli/serving/transcription.py | 89 ++++++++--- src/transformers/cli/serving/utils.py | 129 +++++++++++---- 7 files changed, 513 insertions(+), 158 deletions(-) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 31f58bf629ed..97e3b3597b2a 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -23,7 +23,7 @@ from typing import TYPE_CHECKING from ...utils import logging -from ...utils.import_utils import is_serve_available +from .utils import BaseGenerateManager, BaseHandler, Modality, _StreamError, get_tool_call_config, parse_tool_calls # --- BRUTE FORCE IMPORT PATCH --- @@ -35,40 +35,57 @@ from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming from openai.types.completion_usage import CompletionUsage - + parent_class = CompletionCreateParamsStreaming except ImportError: from typing import TypedDict - + class _DummyDict(dict): - def __getattr__(self, name): return None - def __setattr__(self, name, value): self[name] = value - - class ChatCompletion(_DummyDict): pass - class ChatCompletionMessage(_DummyDict): pass - class ChatCompletionMessageToolCall(_DummyDict): pass - class Choice(_DummyDict): pass - class ChatCompletionChunk(_DummyDict): pass - class ChoiceDelta(_DummyDict): pass - class ChoiceDeltaToolCall(_DummyDict): pass - class ChoiceChunk(_DummyDict): pass - class CompletionCreateParamsStreaming(_DummyDict): pass - class CompletionUsage(_DummyDict): pass - + def __getattr__(self, name): + return None + + def __setattr__(self, name, value): + self[name] = value + + class ChatCompletion(_DummyDict): + pass + + class ChatCompletionMessage(_DummyDict): + pass + + class ChatCompletionMessageToolCall(_DummyDict): + pass + + class Choice(_DummyDict): + pass + + class ChatCompletionChunk(_DummyDict): + pass + + class ChoiceDelta(_DummyDict): + pass + + class ChoiceDeltaToolCall(_DummyDict): + pass + + class ChoiceChunk(_DummyDict): + pass + + class CompletionCreateParamsStreaming(_DummyDict): + pass + + class CompletionUsage(_DummyDict): + pass + parent_class = TypedDict + class TransformersCompletionCreateParamsStreaming(parent_class, total=False): generation_config: str seed: int + + # --- END PATCH --- -from .utils import ( - BaseGenerateManager, - BaseHandler, - Modality, - _StreamError, - get_tool_call_config, - parse_tool_calls, -) if TYPE_CHECKING: @@ -114,7 +131,9 @@ class ChatCompletionHandler(BaseHandler): _valid_params_class = TransformersCompletionCreateParamsStreaming _unused_fields = UNUSED_CHAT_COMPLETION_FIELDS - async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: + async def handle_request( + self, body: dict, request_id: str + ) -> StreamingResponse | JSONResponse: """Validate the request, load the model, and dispatch to streaming or non-streaming. Args: @@ -131,12 +150,16 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse use_cb = self.generation_state.use_continuous_batching(model, modality) logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}") gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb) - processor_inputs = self.get_processor_inputs_from_messages(body["messages"], modality) + processor_inputs = self.get_processor_inputs_from_messages( + body["messages"], modality + ) has_video = any( c.get("type") == "video" for msg in processor_inputs - for c in (msg.get("content") if isinstance(msg.get("content"), list) else []) + for c in ( + msg.get("content") if isinstance(msg.get("content"), list) else [] + ) ) # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise chat_template_kwargs = {} @@ -155,12 +178,16 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse if not use_cb: inputs = inputs.to(model.device) # type: ignore[union-attr] - gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) + gen_config = self._build_generation_config( + body, model.generation_config, use_cb=use_cb + ) # TODO: remove when CB supports per-request generation config if use_cb: gen_manager.init_cb(model, gen_config) - tool_config = get_tool_call_config(processor, model) if body.get("tools") else None + tool_config = ( + get_tool_call_config(processor, model) if body.get("tools") else None + ) streaming = body.get("stream") if streaming: @@ -210,11 +237,15 @@ def _streaming( ) input_ids = inputs["input_ids"] # CB returns plain lists, regular path returns tensors - input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] + input_len = ( + len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] + ) async def sse_gen() -> AsyncGenerator[str, None]: try: - yield self._build_chunk_sse(request_id, role="assistant", model=model_id) + yield self._build_chunk_sse( + request_id, role="assistant", model=model_id + ) done = False while not done: @@ -236,7 +267,11 @@ async def sse_gen() -> AsyncGenerator[str, None]: yield "".join(sse_parts) return - sse_parts.append(self._build_chunk_sse(request_id, model=model_id, content=text)) + sse_parts.append( + self._build_chunk_sse( + request_id, model=model_id, content=text + ) + ) if sse_parts: yield "".join(sse_parts) @@ -245,7 +280,9 @@ async def sse_gen() -> AsyncGenerator[str, None]: # because the full token sequence is needed for reliable parsing. has_tool_calls = False if tool_config: - parsed = parse_tool_calls(processor, streamer.generated_token_ids, tool_config["schema"]) + parsed = parse_tool_calls( + processor, streamer.generated_token_ids, tool_config["schema"] + ) if parsed: has_tool_calls = True for i, tc in enumerate(parsed): @@ -257,12 +294,18 @@ async def sse_gen() -> AsyncGenerator[str, None]: index=i, type="function", id=f"{request_id}_tool_call_{i}", - function={"name": tc["name"], "arguments": tc["arguments"]}, + function={ + "name": tc["name"], + "arguments": tc["arguments"], + }, ) ], ) - hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens + hit_max = ( + gen_config.max_new_tokens is not None + and streamer.total_tokens >= gen_config.max_new_tokens + ) if has_tool_calls: finish_reason = "tool_calls" elif hit_max: @@ -306,7 +349,10 @@ async def _non_streaming( model, processor, inputs, gen_config, request_id=request_id ) - hit_max = gen_config.max_new_tokens is not None and len(generated_ids) >= gen_config.max_new_tokens + hit_max = ( + gen_config.max_new_tokens is not None + and len(generated_ids) >= gen_config.max_new_tokens + ) completion_tokens = len(generated_ids) usage = CompletionUsage( prompt_tokens=input_len, @@ -348,17 +394,28 @@ async def _non_streaming( # ----- helpers ----- - def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): + def _build_generation_config( + self, + body: dict, + model_generation_config: "GenerationConfig", + use_cb: bool = False, + ): """Apply Chat Completions params (``max_tokens``, ``frequency_penalty``, ``logit_bias``, ``stop``) on top of the base generation config.""" - generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) + generation_config = super()._build_generation_config( + body, model_generation_config, use_cb=use_cb + ) if body.get("max_tokens") is not None: generation_config.max_new_tokens = int(body["max_tokens"]) if body.get("frequency_penalty") is not None: - generation_config.repetition_penalty = 1.0 + float(body["frequency_penalty"]) + generation_config.repetition_penalty = 1.0 + float( + body["frequency_penalty"] + ) if body.get("logit_bias") is not None: - generation_config.sequence_bias = {(int(k),): v for k, v in body["logit_bias"].items()} + generation_config.sequence_bias = { + (int(k),): v for k, v in body["logit_bias"].items() + } if body.get("stop") is not None: generation_config.stop_strings = body["stop"] @@ -388,7 +445,9 @@ def _build_completion( Returns: `dict`: Serialized ``ChatCompletion`` ready for JSON response. """ - message = ChatCompletionMessage(content=content, role="assistant", tool_calls=tool_calls) + message = ChatCompletionMessage( + content=content, role="assistant", tool_calls=tool_calls + ) result = ChatCompletion( id=request_id, created=int(time.time()), @@ -435,7 +494,9 @@ def _build_chunk_sse( model=model, choices=[ ChoiceChunk( - delta=ChoiceDelta(content=content, role=role, tool_calls=tool_calls), + delta=ChoiceDelta( + content=content, role=role, tool_calls=tool_calls + ), index=0, finish_reason=finish_reason, ) diff --git a/src/transformers/cli/serving/completion.py b/src/transformers/cli/serving/completion.py index 0cd40b3e9669..d37b3103324c 100644 --- a/src/transformers/cli/serving/completion.py +++ b/src/transformers/cli/serving/completion.py @@ -34,7 +34,6 @@ from openai.types import Completion, CompletionChoice, CompletionUsage from openai.types.completion_create_params import CompletionCreateParamsBase - from .utils import BaseGenerateManager, BaseHandler, _StreamError @@ -45,14 +44,19 @@ # --- FINAL ROBUST PATCH --- if "CompletionCreateParamsBase" in globals(): # If the real OpenAI class was successfully imported, use it - class TransformersTextCompletionCreateParams(CompletionCreateParamsBase, total=False): + class TransformersTextCompletionCreateParams( + CompletionCreateParamsBase, total=False + ): generation_config: str seed: int + else: # Fallback to standard TypedDict if OpenAI types are missing class TransformersTextCompletionCreateParams(TypedDict, total=False): generation_config: str seed: int + + # --- END PATCH --- # Fields accepted by the OpenAI schema but not yet supported. @@ -81,7 +85,9 @@ class CompletionHandler(BaseHandler): _valid_params_class = TransformersTextCompletionCreateParams _unused_fields = UNUSED_LEGACY_COMPLETION_FIELDS - async def handle_request(self, body: dict, request_id: str) -> "StreamingResponse | JSONResponse": + async def handle_request( + self, body: dict, request_id: str + ) -> "StreamingResponse | JSONResponse": """Validate the request, load the model, and dispatch to streaming or non-streaming. Args: @@ -108,7 +114,9 @@ async def handle_request(self, body: dict, request_id: str) -> "StreamingRespons if not use_cb: inputs = inputs.to(model.device) - gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) + gen_config = self._build_generation_config( + body, model.generation_config, use_cb=use_cb + ) if use_cb: gen_manager.init_cb(model, gen_config) @@ -116,10 +124,26 @@ async def handle_request(self, body: dict, request_id: str) -> "StreamingRespons streaming = body.get("stream") if streaming: - return self._streaming(request_id, model, processor, model_id, inputs, gen_config, gen_manager, suffix) + return self._streaming( + request_id, + model, + processor, + model_id, + inputs, + gen_config, + gen_manager, + suffix, + ) else: return await self._non_streaming( - request_id, model, processor, model_id, inputs, gen_config, gen_manager, suffix + request_id, + model, + processor, + model_id, + inputs, + gen_config, + gen_manager, + suffix, ) # ----- streaming ----- @@ -136,9 +160,13 @@ def _streaming( suffix: str | None = None, ) -> "StreamingResponse": """Stream tokens as SSE.""" - queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id) + queue, streamer = gen_manager.generate_streaming( + model, processor, inputs, gen_config, request_id=request_id + ) input_ids = inputs["input_ids"] - input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] + input_len = ( + len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] + ) async def sse_gen() -> AsyncGenerator[str, None]: try: @@ -162,12 +190,17 @@ async def sse_gen() -> AsyncGenerator[str, None]: yield "".join(sse_parts) return - sse_parts.append(self._build_chunk_sse(request_id, model_id, text=text)) + sse_parts.append( + self._build_chunk_sse(request_id, model_id, text=text) + ) if sse_parts: yield "".join(sse_parts) - hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens + hit_max = ( + gen_config.max_new_tokens is not None + and streamer.total_tokens >= gen_config.max_new_tokens + ) finish_reason = "length" if hit_max else "stop" if suffix is not None: @@ -177,7 +210,9 @@ async def sse_gen() -> AsyncGenerator[str, None]: completion_tokens=streamer.total_tokens, total_tokens=input_len + streamer.total_tokens, ) - yield self._build_chunk_sse(request_id, model_id, finish_reason=finish_reason, usage=usage) + yield self._build_chunk_sse( + request_id, model_id, finish_reason=finish_reason, usage=usage + ) except (GeneratorExit, asyncio.CancelledError): streamer.cancel() raise @@ -206,7 +241,10 @@ async def _non_streaming( text = text + suffix completion_tokens = len(generated_ids) - hit_max = gen_config.max_new_tokens is not None and completion_tokens >= gen_config.max_new_tokens + hit_max = ( + gen_config.max_new_tokens is not None + and completion_tokens >= gen_config.max_new_tokens + ) finish_reason = "length" if hit_max else "stop" usage = CompletionUsage( @@ -231,7 +269,9 @@ async def _non_streaming( usage=usage, ) - return JSONResponse(result.model_dump(exclude_none=True), media_type="application/json") + return JSONResponse( + result.model_dump(exclude_none=True), media_type="application/json" + ) # ----- helpers ----- @@ -268,14 +308,23 @@ def _build_chunk_sse( # ----- generation config ----- - def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): + def _build_generation_config( + self, + body: dict, + model_generation_config: "GenerationConfig", + use_cb: bool = False, + ): """Apply legacy completion params (``max_tokens``, ``frequency_penalty``, ``stop``) on top of base config.""" - generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) + generation_config = super()._build_generation_config( + body, model_generation_config, use_cb=use_cb + ) if body.get("max_tokens") is not None: generation_config.max_new_tokens = int(body["max_tokens"]) if body.get("frequency_penalty") is not None: - generation_config.repetition_penalty = 1.0 + float(body["frequency_penalty"]) + generation_config.repetition_penalty = 1.0 + float( + body["frequency_penalty"] + ) if body.get("stop") is not None: generation_config.stop_strings = body["stop"] diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 826199ee4b01..07004ea9393d 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -87,7 +87,9 @@ def delete_model(self) -> None: def _timeout_reached(self) -> None: if self.timeout_seconds > 0: self.delete_model() - logger.info(f"{self._name_or_path} was removed from memory after {self.timeout_seconds}s of inactivity") + logger.info( + f"{self._name_or_path} was removed from memory after {self.timeout_seconds}s of inactivity" + ) class ModelManager: @@ -159,13 +161,23 @@ def _resolve_dtype(dtype: str | None): return resolved def _validate_args(self): - if self.quantization is not None and self.quantization not in ("bnb-4bit", "bnb-8bit"): + if self.quantization is not None and self.quantization not in ( + "bnb-4bit", + "bnb-8bit", + ): raise ValueError( f"Unsupported quantization method: '{self.quantization}'. Must be 'bnb-4bit' or 'bnb-8bit'." ) - VALID_ATTN_IMPLEMENTATIONS = {"eager", "sdpa", "flash_attention_2", "flash_attention_3", "flex_attention"} - is_kernels_community = self.attn_implementation is not None and self.attn_implementation.startswith( - "kernels-community/" + VALID_ATTN_IMPLEMENTATIONS = { + "eager", + "sdpa", + "flash_attention_2", + "flash_attention_3", + "flex_attention", + } + is_kernels_community = ( + self.attn_implementation is not None + and self.attn_implementation.startswith("kernels-community/") ) if ( self.attn_implementation is not None @@ -196,7 +208,9 @@ def get_quantization_config(self) -> BitsAndBytesConfig | None: return BitsAndBytesConfig(load_in_8bit=True) return None - def _load_processor(self, model_id_and_revision: str) -> "ProcessorMixin | PreTrainedTokenizerFast": + def _load_processor( + self, model_id_and_revision: str + ) -> "ProcessorMixin | PreTrainedTokenizerFast": """Load a processor for the given model. Args: @@ -205,10 +219,15 @@ def _load_processor(self, model_id_and_revision: str) -> "ProcessorMixin | PreTr from transformers import AutoProcessor model_id, revision = model_id_and_revision.split("@", 1) - return AutoProcessor.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code) + return AutoProcessor.from_pretrained( + model_id, revision=revision, trust_remote_code=self.trust_remote_code + ) def _load_model( - self, model_id_and_revision: str, tqdm_class: type | None = None, progress_callback: Callable | None = None + self, + model_id_and_revision: str, + tqdm_class: type | None = None, + progress_callback: Callable | None = None, ) -> "PreTrainedModel": """Load a model. @@ -235,7 +254,9 @@ def _load_model( } if progress_callback is not None: - progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "config"}) + progress_callback( + {"status": "loading", "model": model_id_and_revision, "stage": "config"} + ) config = AutoConfig.from_pretrained(model_id, **model_kwargs) from transformers.models.auto.modeling_auto import MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES @@ -270,25 +291,47 @@ def load_model_and_processor( if model_id_and_revision not in self.loaded_models: logger.warning(f"Loading {model_id_and_revision}") if progress_callback is not None: - progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "processor"}) + progress_callback( + { + "status": "loading", + "model": model_id_and_revision, + "stage": "processor", + } + ) processor = self._load_processor(model_id_and_revision) model = self._load_model( - model_id_and_revision, tqdm_class=tqdm_class, progress_callback=progress_callback + model_id_and_revision, + tqdm_class=tqdm_class, + progress_callback=progress_callback, ) self.loaded_models[model_id_and_revision] = TimedModel( model, timeout_seconds=self.model_timeout, processor=processor, - on_unload=lambda key=model_id_and_revision: self.loaded_models.pop(key, None), + on_unload=lambda key=model_id_and_revision: self.loaded_models.pop( + key, None + ), ) if progress_callback is not None: - progress_callback({"status": "ready", "model": model_id_and_revision, "cached": False}) + progress_callback( + { + "status": "ready", + "model": model_id_and_revision, + "cached": False, + } + ) else: self.loaded_models[model_id_and_revision].reset_timer() model = self.loaded_models[model_id_and_revision].model processor = self.loaded_models[model_id_and_revision].processor if progress_callback is not None: - progress_callback({"status": "ready", "model": model_id_and_revision, "cached": True}) + progress_callback( + { + "status": "ready", + "model": model_id_and_revision, + "cached": True, + } + ) return model, processor async def load_model_streaming(self, model_id_and_revision: str): @@ -384,7 +427,8 @@ def shutdown(self) -> None: @staticmethod def get_model_modality( - model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None, ) -> Modality: """Detect whether a model is an LLM or VLM based on its architecture. @@ -441,7 +485,14 @@ def get_gen_models(cache_dir: str | None = None) -> list[dict]: continue for ref, revision_info in repo.refs.items(): - config_path = next((f.file_path for f in revision_info.files if f.file_name == "config.json"), None) + config_path = next( + ( + f.file_path + for f in revision_info.files + if f.file_name == "config.json" + ), + None, + ) if not config_path: continue @@ -454,7 +505,11 @@ def get_gen_models(cache_dir: str | None = None) -> list[dict]: vlms = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() multimodal = MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES.values() - if any(arch for arch in architectures if arch in [*llms, *vlms, *multimodal]): + if any( + arch + for arch in architectures + if arch in [*llms, *vlms, *multimodal] + ): author = repo.repo_id.split("/") if "/" in repo.repo_id else "" repo_handle = repo.repo_id + (f"@{ref}" if ref != "main" else "") generative_models.append( diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 4ac93660c89a..a5689bffdfb4 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -48,18 +48,16 @@ ResponseTextDeltaEvent, ResponseTextDoneEvent, ) - from openai.types.responses.response_create_params import ResponseCreateParamsStreaming - from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage - + from openai.types.responses.response_create_params import ( + ResponseCreateParamsStreaming, + ) + from openai.types.responses.response_usage import ( + InputTokensDetails, + OutputTokensDetails, + ResponseUsage, + ) -from .utils import ( - BaseGenerateManager, - BaseHandler, - Modality, - _StreamError, - get_tool_call_config, - parse_tool_calls, -) +from .utils import BaseGenerateManager, BaseHandler, Modality, _StreamError, get_tool_call_config, parse_tool_calls if TYPE_CHECKING: @@ -71,13 +69,20 @@ # --- FINAL ROBUST PATCH --- if "ResponseCreateParamsStreaming" in globals(): - class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): + + class TransformersResponseCreateParamsStreaming( + ResponseCreateParamsStreaming, total=False + ): generation_config: str seed: int + else: + class TransformersResponseCreateParamsStreaming(TypedDict, total=False): generation_config: str seed: int + + # --- END PATCH --- UNUSED_RESPONSE_FIELDS = { @@ -103,7 +108,9 @@ class ResponseHandler(BaseHandler): _valid_params_class = TransformersResponseCreateParamsStreaming _unused_fields = UNUSED_RESPONSE_FIELDS - async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: + async def handle_request( + self, body: dict, request_id: str + ) -> StreamingResponse | JSONResponse: """Validate, load model, dispatch to streaming or non-streaming. Args: @@ -130,7 +137,9 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse has_video = any( c.get("type") == "video" for msg in processor_inputs - for c in (msg.get("content") if isinstance(msg.get("content"), list) else []) + for c in ( + msg.get("content") if isinstance(msg.get("content"), list) else [] + ) ) # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise @@ -152,11 +161,15 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse if not use_cb: inputs = inputs.to(model.device) # type: ignore[union-attr] - gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) + gen_config = self._build_generation_config( + body, model.generation_config, use_cb=use_cb + ) # TODO: remove when CB supports per-request generation config if use_cb: gen_manager.init_cb(model, gen_config) - tool_config = get_tool_call_config(processor, model) if body.get("tools") else None + tool_config = ( + get_tool_call_config(processor, model) if body.get("tools") else None + ) streaming = body.get("stream", True) if streaming: @@ -198,7 +211,14 @@ def _normalize_tools(tools: list[dict] | None) -> list[dict] | None: if not tools: return tools return [ - {"type": "function", "function": {k: v for k, v in t.items() if k != "type"}} if "function" not in t else t + ( + { + "type": "function", + "function": {k: v for k, v in t.items() if k != "type"}, + } + if "function" not in t + else t + ) for t in tools ] @@ -236,7 +256,9 @@ def _normalize_input(body: dict) -> list[dict]: else: messages = ResponseHandler._normalize_response_items(inp) else: - raise HTTPException(status_code=422, detail="'input' must be a string or list") + raise HTTPException( + status_code=422, detail="'input' must be a string or list" + ) # Prepend instructions as a system message if instructions: @@ -262,7 +284,9 @@ def _normalize_response_items(items: list[dict]) -> list[dict]: item_type = item.get("type") if "role" in item: - messages.append({"role": item["role"], "content": item.get("content", "")}) + messages.append( + {"role": item["role"], "content": item.get("content", "")} + ) elif item_type == "function_call": tc = { @@ -284,7 +308,10 @@ def _normalize_response_items(items: list[dict]) -> list[dict]: ) else: - raise HTTPException(status_code=422, detail=f"Unsupported input item type: {item_type!r}") + raise HTTPException( + status_code=422, + detail=f"Unsupported input item type: {item_type!r}", + ) return messages @@ -313,7 +340,9 @@ def _streaming( ) input_ids = inputs["input_ids"] # CB returns plain lists, regular path returns tensors - input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] + input_len = ( + len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] + ) seq = 0 output_index = 0 @@ -349,7 +378,9 @@ async def event_stream() -> AsyncGenerator[str, None]: ResponseInProgressEvent( type="response.in_progress", sequence_number=seq, - response=Response(**response_base, status="in_progress", output=[]), + response=Response( + **response_base, status="in_progress", output=[] + ), ) ) seq += 1 @@ -379,7 +410,9 @@ async def event_stream() -> AsyncGenerator[str, None]: sequence_number=seq, output_index=output_index, content_index=0, - part=ResponseOutputText(type="output_text", text="", annotations=[]), + part=ResponseOutputText( + type="output_text", text="", annotations=[] + ), ) ) seq += 1 @@ -405,10 +438,16 @@ async def event_stream() -> AsyncGenerator[str, None]: done = True break if isinstance(text, _StreamError): - logger.error(f"Exception in response generation: {text.msg}") + logger.error( + f"Exception in response generation: {text.msg}" + ) sse_parts.append( self.chunk_to_sse( - ResponseErrorEvent(type="error", sequence_number=seq, message=text.msg) + ResponseErrorEvent( + type="error", + sequence_number=seq, + message=text.msg, + ) ) ) seq += 1 @@ -421,7 +460,9 @@ async def event_stream() -> AsyncGenerator[str, None]: **response_base, status="failed", output=[], - error=ResponseError(code="server_error", message=text.msg), + error=ResponseError( + code="server_error", message=text.msg + ), ), ) ) @@ -451,7 +492,9 @@ async def event_stream() -> AsyncGenerator[str, None]: # 5. Tool calls are parsed after generation completes (not during streaming), # because the full token sequence is needed for reliable parsing. if tool_config: - parsed = parse_tool_calls(processor, streamer.generated_token_ids, tool_config["schema"]) + parsed = parse_tool_calls( + processor, streamer.generated_token_ids, tool_config["schema"] + ) if parsed: for i, tc in enumerate(parsed): tc_id = f"{request_id}_tool_call_{i}" @@ -496,7 +539,9 @@ async def event_stream() -> AsyncGenerator[str, None]: seq += 1 # 6. Close text output - output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[]) + output_text_part = ResponseOutputText( + type="output_text", text=full_text, annotations=[] + ) yield self.chunk_to_sse( ResponseTextDoneEvent( type="response.output_text.done", @@ -546,7 +591,12 @@ async def event_stream() -> AsyncGenerator[str, None]: ResponseCompletedEvent( type="response.completed", sequence_number=seq, - response=Response(**response_base, status="completed", output=all_output, usage=usage), + response=Response( + **response_base, + status="completed", + output=all_output, + usage=usage, + ), ) ) seq += 1 @@ -583,7 +633,11 @@ async def _non_streaming( type="message", status="completed", role="assistant", - content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])], + content=[ + ResponseOutputText( + type="output_text", text=full_text, annotations=[] + ) + ], annotations=[], # type: ignore[call-arg] ) ] @@ -622,9 +676,16 @@ async def _non_streaming( # ----- helpers ----- - def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): + def _build_generation_config( + self, + body: dict, + model_generation_config: "GenerationConfig", + use_cb: bool = False, + ): """Apply Responses API params (``max_output_tokens``) on top of the base generation config.""" - generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) + generation_config = super()._build_generation_config( + body, model_generation_config, use_cb=use_cb + ) if body.get("max_output_tokens") is not None: generation_config.max_new_tokens = int(body["max_output_tokens"]) diff --git a/src/transformers/cli/serving/server.py b/src/transformers/cli/serving/server.py index 13a9565db590..1fe47c3e7296 100644 --- a/src/transformers/cli/serving/server.py +++ b/src/transformers/cli/serving/server.py @@ -73,7 +73,9 @@ async def lifespan(app: FastAPI): allow_methods=["*"], allow_headers=["*"], ) - logger.warning_once("CORS allow origin is set to `*`. Not recommended for production.") + logger.warning_once( + "CORS allow origin is set to `*`. Not recommended for production." + ) # ---- Middleware ---- @@ -110,10 +112,13 @@ async def load_model(body: dict): model = body.get("model") if model is None: - raise HTTPException(status_code=422, detail="Missing `model` field in the request body.") + raise HTTPException( + status_code=422, detail="Missing `model` field in the request body." + ) model_id_and_revision = model_manager.process_model_name(model) return StreamingResponse( - model_manager.load_model_streaming(model_id_and_revision), media_type="text/event-stream" + model_manager.load_model_streaming(model_id_and_revision), + media_type="text/event-stream", ) @app.post("/reset") diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py index d6730f1092e0..e69ef617b816 100644 --- a/src/transformers/cli/serving/transcription.py +++ b/src/transformers/cli/serving/transcription.py @@ -25,7 +25,9 @@ if is_serve_available(): from fastapi import HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse - from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase + from openai.types.audio.transcription_create_params import ( + TranscriptionCreateParamsBase, + ) from .model_manager import ModelManager from .utils import DirectStreamer, GenerateManager, GenerationState, _StreamError @@ -40,13 +42,20 @@ # --- FINAL ROBUST PATCH --- if "TranscriptionCreateParamsBase" in globals(): - class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False): + + class TransformersTranscriptionCreateParams( + TranscriptionCreateParamsBase, total=False + ): generation_config: str seed: int + else: + class TransformersTranscriptionCreateParams(TypedDict, total=False): generation_config: str seed: int + + # --- END PATCH --- @@ -83,14 +92,21 @@ def __init__(self, model_manager: ModelManager, generation_state: GenerationStat def _validate_request(self, form_keys: set[str]) -> None: """Validate transcription request fields.""" - unexpected = form_keys - getattr(TransformersTranscriptionCreateParams, "__mutable_keys__", set()) + unexpected = form_keys - getattr( + TransformersTranscriptionCreateParams, "__mutable_keys__", set() + ) if unexpected: - raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}") + raise HTTPException( + status_code=422, + detail=f"Unexpected fields in the request: {unexpected}", + ) unused = form_keys & UNUSED_TRANSCRIPTION_FIELDS if unused: logger.warning_once(f"Ignoring unsupported fields in the request: {unused}") - async def handle_request(self, request: Request) -> JSONResponse | StreamingResponse: + async def handle_request( + self, request: Request + ) -> JSONResponse | StreamingResponse: """Parse multipart form, run transcription, return result. Args: @@ -103,7 +119,9 @@ async def handle_request(self, request: Request) -> JSONResponse | StreamingResp from transformers.utils.import_utils import is_librosa_available, is_multipart_available if not is_librosa_available(): - raise ImportError("Missing librosa dependency for audio transcription. Install with `pip install librosa`") + raise ImportError( + "Missing librosa dependency for audio transcription. Install with `pip install librosa`" + ) if not is_multipart_available(): raise ImportError( "Missing python-multipart dependency for file uploads. Install with `pip install python-multipart`" @@ -113,38 +131,59 @@ async def handle_request(self, request: Request) -> JSONResponse | StreamingResp self._validate_request(set(form.keys())) file_field = form["file"] if isinstance(file_field, str): - raise HTTPException(status_code=422, detail="Expected file upload, got string") + raise HTTPException( + status_code=422, detail="Expected file upload, got string" + ) file_bytes = await file_field.read() model = form["model"] if not isinstance(model, str): - raise HTTPException(status_code=422, detail="Expected model name as string") + raise HTTPException( + status_code=422, detail="Expected model name as string" + ) stream = str(form.get("stream", "false")).lower() == "true" model_id_and_revision = self.model_manager.process_model_name(model) - audio_model, audio_processor = self.model_manager.load_model_and_processor(model_id_and_revision) + audio_model, audio_processor = self.model_manager.load_model_and_processor( + model_id_and_revision + ) base_manager = self.generation_state.get_manager(model_id_and_revision) if not isinstance(base_manager, GenerateManager): - raise HTTPException(status_code=400, detail="Audio transcription requires sequential generation (not CB)") + raise HTTPException( + status_code=400, + detail="Audio transcription requires sequential generation (not CB)", + ) gen_manager = base_manager - audio_inputs = self._prepare_audio_inputs(file_bytes, audio_processor, audio_model) + audio_inputs = self._prepare_audio_inputs( + file_bytes, audio_processor, audio_model + ) if stream: - return self._streaming(gen_manager, audio_model, audio_processor, audio_inputs) - return await self._non_streaming(gen_manager, audio_model, audio_processor, audio_inputs) + return self._streaming( + gen_manager, audio_model, audio_processor, audio_inputs + ) + return await self._non_streaming( + gen_manager, audio_model, audio_processor, audio_inputs + ) @staticmethod def _prepare_audio_inputs( - file_bytes: bytes, audio_processor: "ProcessorMixin", audio_model: "PreTrainedModel" + file_bytes: bytes, + audio_processor: "ProcessorMixin", + audio_model: "PreTrainedModel", ) -> dict: """Load audio bytes and convert to model inputs.""" import librosa sampling_rate = audio_processor.feature_extractor.sampling_rate - audio_array, _ = librosa.load(io.BytesIO(file_bytes), sr=sampling_rate, mono=True) - audio_inputs = audio_processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").to( - audio_model.device + audio_array, _ = librosa.load( + io.BytesIO(file_bytes), sr=sampling_rate, mono=True + ) + audio_inputs = audio_processor( + audio_array, sampling_rate=sampling_rate, return_tensors="pt" + ).to(audio_model.device) + audio_inputs["input_features"] = audio_inputs["input_features"].to( + audio_model.dtype ) - audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype) return audio_inputs async def _non_streaming( @@ -159,7 +198,9 @@ async def _non_streaming( # generate_non_streaming() from openai.types.audio import Transcription - generated_ids = await gen_manager.async_submit(audio_model.generate, **audio_inputs) + generated_ids = await gen_manager.async_submit( + audio_model.generate, **audio_inputs + ) text = audio_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return JSONResponse(Transcription(text=text).model_dump(exclude_none=True)) @@ -174,10 +215,16 @@ def _streaming( # differ from text. import asyncio - tokenizer = audio_processor.tokenizer if hasattr(audio_processor, "tokenizer") else audio_processor + tokenizer = ( + audio_processor.tokenizer + if hasattr(audio_processor, "tokenizer") + else audio_processor + ) loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() - streamer = DirectStreamer(tokenizer._tokenizer, loop, queue, skip_special_tokens=True) + streamer = DirectStreamer( + tokenizer._tokenizer, loop, queue, skip_special_tokens=True + ) gen_kwargs = {**audio_inputs, "streamer": streamer} def _run(): diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index d786a828fc28..caf44771cd2a 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -108,7 +108,14 @@ def get_tool_call_config(processor, model: "PreTrainedModel") -> dict | None: schema = response_schema["properties"]["tool_calls"] else: # Fallback: known model families without full tokenizer config - fallback = next((v for k, v in _TOOL_CALL_FALLBACKS.items() if k in model.config.model_type), None) + fallback = next( + ( + v + for k, v in _TOOL_CALL_FALLBACKS.items() + if k in model.config.model_type + ), + None, + ) if fallback is None: return None stc, etc, schema = fallback["stc"], fallback["etc"], fallback["schema"] @@ -131,7 +138,9 @@ def _normalize_tool_call(tool_call: dict) -> dict: arguments = function.get("arguments", {}) return { "name": function["name"], - "arguments": json.dumps(arguments) if not isinstance(arguments, str) else arguments, + "arguments": ( + json.dumps(arguments) if not isinstance(arguments, str) else arguments + ), } @@ -153,7 +162,7 @@ def parse_tool_calls(processor, generated_ids, schema: dict) -> list[dict] | Non if not isinstance(parsed, list): parsed = [parsed] tool_calls = [_normalize_tool_call(tool_call) for tool_call in parsed] - return tool_calls if tool_calls else None + return tool_calls or None class DownloadAggregator: @@ -330,7 +339,11 @@ def put(self, value: "torch.Tensor") -> None: self._inside_tool_call = False text = self._decode_stream.step(self._tokenizer, token_id) - if text is not None and not self._inside_tool_call and token_id != self._etc_id: + if ( + text is not None + and not self._inside_tool_call + and token_id != self._etc_id + ): self._loop.call_soon_threadsafe(self._queue.put_nowait, text) def end(self) -> None: @@ -398,7 +411,11 @@ def put(self, output: "GenerationOutput") -> None: self._inside_tool_call = False text = self._decode_stream.step(self._tokenizer, token_id) - if text is not None and not self._inside_tool_call and token_id != self._etc_id: + if ( + text is not None + and not self._inside_tool_call + and token_id != self._etc_id + ): self._queue.put_nowait(text) def end(self) -> None: @@ -552,7 +569,12 @@ def generate_streaming( # ProcessorMixin exposes the fast tokenizer as .tokenizer; PreTrainedTokenizerFast is already one. rust_tokenizer = getattr(processor, "tokenizer", processor)._tokenizer # type: ignore[union-attr] streamer = DirectStreamer(rust_tokenizer, loop, queue, tool_config=tool_config) - gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor} + gen_kwargs = { + **inputs, + "streamer": streamer, + "generation_config": gen_config, + "tokenizer": processor, + } if hasattr(model, "has_talker"): gen_kwargs["generation_mode"] = "text" @@ -578,7 +600,11 @@ async def generate_non_streaming( """Run generation to completion via ``model.generate()`` on the inference thread.""" # Multimodal models (e.g. Qwen2.5-Omni) may generate audio alongside text by default; # force text-only output since the serve layer only handles text - generate_kwargs = {**inputs, "generation_config": gen_config, "tokenizer": processor} + generate_kwargs = { + **inputs, + "generation_config": gen_config, + "tokenizer": processor, + } if hasattr(model, "has_talker"): generate_kwargs["generation_mode"] = "text" sequences = await self.async_submit(model.generate, **generate_kwargs) @@ -662,7 +688,14 @@ def generate_streaming( ) # ProcessorMixin exposes the fast tokenizer as .tokenizer; PreTrainedTokenizerFast is already one. rust_tokenizer = getattr(processor, "tokenizer", processor)._tokenizer # type: ignore[union-attr] - streamer = CBStreamer(self._cb, request_id, rust_tokenizer, loop, text_queue, tool_config=tool_config) + streamer = CBStreamer( + self._cb, + request_id, + rust_tokenizer, + loop, + text_queue, + tool_config=tool_config, + ) # Register a direct callback: the dispatcher calls this on the event loop with each GenerationOutput. # This decodes tokens and pushes text straight to the SSE text_queue @@ -712,7 +745,9 @@ def _on_result(result): ) result = await future if result is None: - raise RuntimeError(f"CB manager stopped before producing a result for {request_id}") + raise RuntimeError( + f"CB manager stopped before producing a result for {request_id}" + ) generated_ids = result.generated_tokens text = processor.decode(generated_ids, skip_special_tokens=True) return text, input_len, generated_ids @@ -756,7 +791,9 @@ def __init__( self._cb_manager: CBGenerateManager | None = None self._cb_model_id: str | None = None - def use_continuous_batching(self, model: "PreTrainedModel", modality: Modality) -> bool: + def use_continuous_batching( + self, model: "PreTrainedModel", modality: Modality + ) -> bool: """Check if continuous batching can be used for this model and modality. Args: @@ -836,9 +873,14 @@ def _validate_request(self, body: dict) -> None: input_keys = set(body.keys()) if self._valid_params_class is not None: - unexpected = input_keys - getattr(self._valid_params_class, "__mutable_keys__", set()) + unexpected = input_keys - getattr( + self._valid_params_class, "__mutable_keys__", set() + ) if unexpected: - raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}") + raise HTTPException( + status_code=422, + detail=f"Unexpected fields in the request: {unexpected}", + ) unused = input_keys & self._unused_fields if unused: logger.warning_once(f"Ignoring unsupported fields in the request: {unused}") @@ -850,7 +892,9 @@ def chunk_to_sse(chunk: "str | pydantic.BaseModel") -> str: return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n" return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" - def _resolve_model(self, body: dict) -> tuple[str, "PreTrainedModel", "ProcessorMixin | PreTrainedTokenizerFast"]: + def _resolve_model( + self, body: dict + ) -> tuple[str, "PreTrainedModel", "ProcessorMixin | PreTrainedTokenizerFast"]: """Apply force_model, load model + processor. Returns ``(model_id, model, processor)``. @@ -862,7 +906,9 @@ def _resolve_model(self, body: dict) -> tuple[str, "PreTrainedModel", "Processor if requested is not None and requested != self.model_manager.force_model: raise HTTPException( status_code=400, - detail=(f"Server is pinned to '{self.model_manager.force_model}'; requested '{requested}'."), + detail=( + f"Server is pinned to '{self.model_manager.force_model}'; requested '{requested}'." + ), ) body["model"] = self.model_manager.force_model @@ -872,7 +918,10 @@ def _resolve_model(self, body: dict) -> tuple[str, "PreTrainedModel", "Processor return model_id, model, processor def _build_generation_config( - self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False + self, + body: dict, + model_generation_config: "GenerationConfig", + use_cb: bool = False, ) -> "GenerationConfig": """Build a GenerationConfig from shared params (temperature, top_p, seed, generation_config JSON). @@ -894,10 +943,15 @@ def _build_generation_config( from transformers import GenerationConfig if body.get("generation_config") is not None: - generation_config = GenerationConfig(**json.loads(body["generation_config"])) + generation_config = GenerationConfig( + **json.loads(body["generation_config"]) + ) else: generation_config = copy.deepcopy(model_generation_config) - if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 1024: + if ( + generation_config.max_new_tokens is None + or generation_config.max_new_tokens < 1024 + ): generation_config.max_new_tokens = 1024 if body.get("temperature") is not None: @@ -910,7 +964,10 @@ def _build_generation_config( set_torch_seed(body["seed"]) # --compile flag: use static cache + torch.compile for faster decode - if self.generation_state._compile and generation_config.cache_implementation is None: + if ( + self.generation_state._compile + and generation_config.cache_implementation is None + ): generation_config.cache_implementation = "static" # CB manages its own paged KV cache @@ -922,7 +979,9 @@ def _build_generation_config( return generation_config @staticmethod - def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) -> list[dict]: + def get_processor_inputs_from_messages( + messages: list[dict], modality: Modality + ) -> list[dict]: """Convert OpenAI-format messages to the format expected by HF processors. All modalities extract text. VLM additionally handles ``image_url`` and ``video_url``. @@ -949,7 +1008,9 @@ def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) # When tool_calls are present, ignore content — it's either empty or contains # raw tool call markup that would confuse the chat template if rendered. - raw_content = [] if "tool_calls" in message else (message.get("content") or []) + raw_content = ( + [] if "tool_calls" in message else (message.get("content") or []) + ) if isinstance(raw_content, str): raw_content = [{"type": "text", "text": raw_content}] @@ -959,7 +1020,10 @@ def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) if content_type in ("text", "input_text", "output_text"): parsed["content"].append({"type": "text", "text": content["text"]}) # Image: chat completions ("image_url") and Responses API ("input_image") - elif content_type in ("image_url", "input_image") and modality in (Modality.VLM, Modality.MULTIMODAL): + elif content_type in ("image_url", "input_image") and modality in ( + Modality.VLM, + Modality.MULTIMODAL, + ): # chat completions: {"image_url": {"url": "..."}}, Responses API: {"image_url": "..."} url = content["image_url"] if isinstance(url, dict): @@ -968,14 +1032,27 @@ def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) # Audio: unlike images, load_audio doesn't accept raw base64 — wrap as a data URI elif content_type == "input_audio" and modality == Modality.MULTIMODAL: input_audio = content["input_audio"] - fmt = input_audio.get("format", "wav") if isinstance(input_audio, dict) else "wav" + fmt = ( + input_audio.get("format", "wav") + if isinstance(input_audio, dict) + else "wav" + ) audio_b64 = input_audio["data"] - parsed["content"].append({"type": "audio", "url": f"data:audio/{fmt};base64,{audio_b64}"}) + parsed["content"].append( + {"type": "audio", "url": f"data:audio/{fmt};base64,{audio_b64}"} + ) # Extensions (not part of the OpenAI API standard) - elif content_type == "video_url" and modality in (Modality.VLM, Modality.MULTIMODAL): - parsed["content"].append({"type": "video", "url": content["video_url"]["url"]}) + elif content_type == "video_url" and modality in ( + Modality.VLM, + Modality.MULTIMODAL, + ): + parsed["content"].append( + {"type": "video", "url": content["video_url"]["url"]} + ) elif content_type == "audio_url" and modality == Modality.MULTIMODAL: - parsed["content"].append({"type": "audio", "url": content["audio_url"]["url"]}) + parsed["content"].append( + {"type": "audio", "url": content["audio_url"]["url"]} + ) # LLMs expect plain text, not a list of content parts if modality == Modality.LLM: From 27716a32c72515bec359bc5cdaa459bc2a4b093f Mon Sep 17 00:00:00 2001 From: abhiprd200 Date: Sat, 25 Apr 2026 19:50:10 +0530 Subject: [PATCH 4/5] chore: bypass pyright type checking for dynamic variables --- src/transformers/cli/serving/chat_completion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 97e3b3597b2a..283e61d3b685 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -80,7 +80,7 @@ class CompletionUsage(_DummyDict): parent_class = TypedDict -class TransformersCompletionCreateParamsStreaming(parent_class, total=False): +class TransformersCompletionCreateParamsStreaming(parent_class, total=False): # type: ignore generation_config: str seed: int @@ -176,7 +176,7 @@ async def handle_request( **chat_template_kwargs, ) if not use_cb: - inputs = inputs.to(model.device) # type: ignore[union-attr] + inputs = inputs.to(model.device) # type: ignore gen_config = self._build_generation_config( body, model.generation_config, use_cb=use_cb From a4d0e8bc3a95979bebc7f9215f3bfc474e5c7000 Mon Sep 17 00:00:00 2001 From: abhiprd200 Date: Sat, 25 Apr 2026 20:00:35 +0530 Subject: [PATCH 5/5] style: apply ruff format to serving directory --- .../cli/serving/chat_completion.py | 68 ++++----------- src/transformers/cli/serving/completion.py | 50 +++-------- src/transformers/cli/serving/model_manager.py | 37 ++------- src/transformers/cli/serving/response.py | 66 ++++----------- src/transformers/cli/serving/server.py | 8 +- src/transformers/cli/serving/transcription.py | 66 ++++----------- src/transformers/cli/serving/utils.py | 82 ++++--------------- 7 files changed, 90 insertions(+), 287 deletions(-) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 283e61d3b685..c25ba58f7e52 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -131,9 +131,7 @@ class ChatCompletionHandler(BaseHandler): _valid_params_class = TransformersCompletionCreateParamsStreaming _unused_fields = UNUSED_CHAT_COMPLETION_FIELDS - async def handle_request( - self, body: dict, request_id: str - ) -> StreamingResponse | JSONResponse: + async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: """Validate the request, load the model, and dispatch to streaming or non-streaming. Args: @@ -150,16 +148,12 @@ async def handle_request( use_cb = self.generation_state.use_continuous_batching(model, modality) logger.warning(f"[Request received] Model: {model_id}, CB: {use_cb}") gen_manager = self.generation_state.get_manager(model_id, use_cb=use_cb) - processor_inputs = self.get_processor_inputs_from_messages( - body["messages"], modality - ) + processor_inputs = self.get_processor_inputs_from_messages(body["messages"], modality) has_video = any( c.get("type") == "video" for msg in processor_inputs - for c in ( - msg.get("content") if isinstance(msg.get("content"), list) else [] - ) + for c in (msg.get("content") if isinstance(msg.get("content"), list) else []) ) # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise chat_template_kwargs = {} @@ -178,16 +172,12 @@ async def handle_request( if not use_cb: inputs = inputs.to(model.device) # type: ignore - gen_config = self._build_generation_config( - body, model.generation_config, use_cb=use_cb - ) + gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) # TODO: remove when CB supports per-request generation config if use_cb: gen_manager.init_cb(model, gen_config) - tool_config = ( - get_tool_call_config(processor, model) if body.get("tools") else None - ) + tool_config = get_tool_call_config(processor, model) if body.get("tools") else None streaming = body.get("stream") if streaming: @@ -237,15 +227,11 @@ def _streaming( ) input_ids = inputs["input_ids"] # CB returns plain lists, regular path returns tensors - input_len = ( - len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] - ) + input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] async def sse_gen() -> AsyncGenerator[str, None]: try: - yield self._build_chunk_sse( - request_id, role="assistant", model=model_id - ) + yield self._build_chunk_sse(request_id, role="assistant", model=model_id) done = False while not done: @@ -267,11 +253,7 @@ async def sse_gen() -> AsyncGenerator[str, None]: yield "".join(sse_parts) return - sse_parts.append( - self._build_chunk_sse( - request_id, model=model_id, content=text - ) - ) + sse_parts.append(self._build_chunk_sse(request_id, model=model_id, content=text)) if sse_parts: yield "".join(sse_parts) @@ -280,9 +262,7 @@ async def sse_gen() -> AsyncGenerator[str, None]: # because the full token sequence is needed for reliable parsing. has_tool_calls = False if tool_config: - parsed = parse_tool_calls( - processor, streamer.generated_token_ids, tool_config["schema"] - ) + parsed = parse_tool_calls(processor, streamer.generated_token_ids, tool_config["schema"]) if parsed: has_tool_calls = True for i, tc in enumerate(parsed): @@ -302,10 +282,7 @@ async def sse_gen() -> AsyncGenerator[str, None]: ], ) - hit_max = ( - gen_config.max_new_tokens is not None - and streamer.total_tokens >= gen_config.max_new_tokens - ) + hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens if has_tool_calls: finish_reason = "tool_calls" elif hit_max: @@ -349,10 +326,7 @@ async def _non_streaming( model, processor, inputs, gen_config, request_id=request_id ) - hit_max = ( - gen_config.max_new_tokens is not None - and len(generated_ids) >= gen_config.max_new_tokens - ) + hit_max = gen_config.max_new_tokens is not None and len(generated_ids) >= gen_config.max_new_tokens completion_tokens = len(generated_ids) usage = CompletionUsage( prompt_tokens=input_len, @@ -402,20 +376,14 @@ def _build_generation_config( ): """Apply Chat Completions params (``max_tokens``, ``frequency_penalty``, ``logit_bias``, ``stop``) on top of the base generation config.""" - generation_config = super()._build_generation_config( - body, model_generation_config, use_cb=use_cb - ) + generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) if body.get("max_tokens") is not None: generation_config.max_new_tokens = int(body["max_tokens"]) if body.get("frequency_penalty") is not None: - generation_config.repetition_penalty = 1.0 + float( - body["frequency_penalty"] - ) + generation_config.repetition_penalty = 1.0 + float(body["frequency_penalty"]) if body.get("logit_bias") is not None: - generation_config.sequence_bias = { - (int(k),): v for k, v in body["logit_bias"].items() - } + generation_config.sequence_bias = {(int(k),): v for k, v in body["logit_bias"].items()} if body.get("stop") is not None: generation_config.stop_strings = body["stop"] @@ -445,9 +413,7 @@ def _build_completion( Returns: `dict`: Serialized ``ChatCompletion`` ready for JSON response. """ - message = ChatCompletionMessage( - content=content, role="assistant", tool_calls=tool_calls - ) + message = ChatCompletionMessage(content=content, role="assistant", tool_calls=tool_calls) result = ChatCompletion( id=request_id, created=int(time.time()), @@ -494,9 +460,7 @@ def _build_chunk_sse( model=model, choices=[ ChoiceChunk( - delta=ChoiceDelta( - content=content, role=role, tool_calls=tool_calls - ), + delta=ChoiceDelta(content=content, role=role, tool_calls=tool_calls), index=0, finish_reason=finish_reason, ) diff --git a/src/transformers/cli/serving/completion.py b/src/transformers/cli/serving/completion.py index d37b3103324c..ed04fffb12a8 100644 --- a/src/transformers/cli/serving/completion.py +++ b/src/transformers/cli/serving/completion.py @@ -44,9 +44,7 @@ # --- FINAL ROBUST PATCH --- if "CompletionCreateParamsBase" in globals(): # If the real OpenAI class was successfully imported, use it - class TransformersTextCompletionCreateParams( - CompletionCreateParamsBase, total=False - ): + class TransformersTextCompletionCreateParams(CompletionCreateParamsBase, total=False): generation_config: str seed: int @@ -85,9 +83,7 @@ class CompletionHandler(BaseHandler): _valid_params_class = TransformersTextCompletionCreateParams _unused_fields = UNUSED_LEGACY_COMPLETION_FIELDS - async def handle_request( - self, body: dict, request_id: str - ) -> "StreamingResponse | JSONResponse": + async def handle_request(self, body: dict, request_id: str) -> "StreamingResponse | JSONResponse": """Validate the request, load the model, and dispatch to streaming or non-streaming. Args: @@ -114,9 +110,7 @@ async def handle_request( if not use_cb: inputs = inputs.to(model.device) - gen_config = self._build_generation_config( - body, model.generation_config, use_cb=use_cb - ) + gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) if use_cb: gen_manager.init_cb(model, gen_config) @@ -160,13 +154,9 @@ def _streaming( suffix: str | None = None, ) -> "StreamingResponse": """Stream tokens as SSE.""" - queue, streamer = gen_manager.generate_streaming( - model, processor, inputs, gen_config, request_id=request_id - ) + queue, streamer = gen_manager.generate_streaming(model, processor, inputs, gen_config, request_id=request_id) input_ids = inputs["input_ids"] - input_len = ( - len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] - ) + input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] async def sse_gen() -> AsyncGenerator[str, None]: try: @@ -190,17 +180,12 @@ async def sse_gen() -> AsyncGenerator[str, None]: yield "".join(sse_parts) return - sse_parts.append( - self._build_chunk_sse(request_id, model_id, text=text) - ) + sse_parts.append(self._build_chunk_sse(request_id, model_id, text=text)) if sse_parts: yield "".join(sse_parts) - hit_max = ( - gen_config.max_new_tokens is not None - and streamer.total_tokens >= gen_config.max_new_tokens - ) + hit_max = gen_config.max_new_tokens is not None and streamer.total_tokens >= gen_config.max_new_tokens finish_reason = "length" if hit_max else "stop" if suffix is not None: @@ -210,9 +195,7 @@ async def sse_gen() -> AsyncGenerator[str, None]: completion_tokens=streamer.total_tokens, total_tokens=input_len + streamer.total_tokens, ) - yield self._build_chunk_sse( - request_id, model_id, finish_reason=finish_reason, usage=usage - ) + yield self._build_chunk_sse(request_id, model_id, finish_reason=finish_reason, usage=usage) except (GeneratorExit, asyncio.CancelledError): streamer.cancel() raise @@ -241,10 +224,7 @@ async def _non_streaming( text = text + suffix completion_tokens = len(generated_ids) - hit_max = ( - gen_config.max_new_tokens is not None - and completion_tokens >= gen_config.max_new_tokens - ) + hit_max = gen_config.max_new_tokens is not None and completion_tokens >= gen_config.max_new_tokens finish_reason = "length" if hit_max else "stop" usage = CompletionUsage( @@ -269,9 +249,7 @@ async def _non_streaming( usage=usage, ) - return JSONResponse( - result.model_dump(exclude_none=True), media_type="application/json" - ) + return JSONResponse(result.model_dump(exclude_none=True), media_type="application/json") # ----- helpers ----- @@ -315,16 +293,12 @@ def _build_generation_config( use_cb: bool = False, ): """Apply legacy completion params (``max_tokens``, ``frequency_penalty``, ``stop``) on top of base config.""" - generation_config = super()._build_generation_config( - body, model_generation_config, use_cb=use_cb - ) + generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) if body.get("max_tokens") is not None: generation_config.max_new_tokens = int(body["max_tokens"]) if body.get("frequency_penalty") is not None: - generation_config.repetition_penalty = 1.0 + float( - body["frequency_penalty"] - ) + generation_config.repetition_penalty = 1.0 + float(body["frequency_penalty"]) if body.get("stop") is not None: generation_config.stop_strings = body["stop"] diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 07004ea9393d..d718b99738b1 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -87,9 +87,7 @@ def delete_model(self) -> None: def _timeout_reached(self) -> None: if self.timeout_seconds > 0: self.delete_model() - logger.info( - f"{self._name_or_path} was removed from memory after {self.timeout_seconds}s of inactivity" - ) + logger.info(f"{self._name_or_path} was removed from memory after {self.timeout_seconds}s of inactivity") class ModelManager: @@ -175,9 +173,8 @@ def _validate_args(self): "flash_attention_3", "flex_attention", } - is_kernels_community = ( - self.attn_implementation is not None - and self.attn_implementation.startswith("kernels-community/") + is_kernels_community = self.attn_implementation is not None and self.attn_implementation.startswith( + "kernels-community/" ) if ( self.attn_implementation is not None @@ -208,9 +205,7 @@ def get_quantization_config(self) -> BitsAndBytesConfig | None: return BitsAndBytesConfig(load_in_8bit=True) return None - def _load_processor( - self, model_id_and_revision: str - ) -> "ProcessorMixin | PreTrainedTokenizerFast": + def _load_processor(self, model_id_and_revision: str) -> "ProcessorMixin | PreTrainedTokenizerFast": """Load a processor for the given model. Args: @@ -219,9 +214,7 @@ def _load_processor( from transformers import AutoProcessor model_id, revision = model_id_and_revision.split("@", 1) - return AutoProcessor.from_pretrained( - model_id, revision=revision, trust_remote_code=self.trust_remote_code - ) + return AutoProcessor.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code) def _load_model( self, @@ -254,9 +247,7 @@ def _load_model( } if progress_callback is not None: - progress_callback( - {"status": "loading", "model": model_id_and_revision, "stage": "config"} - ) + progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "config"}) config = AutoConfig.from_pretrained(model_id, **model_kwargs) from transformers.models.auto.modeling_auto import MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES @@ -308,9 +299,7 @@ def load_model_and_processor( model, timeout_seconds=self.model_timeout, processor=processor, - on_unload=lambda key=model_id_and_revision: self.loaded_models.pop( - key, None - ), + on_unload=lambda key=model_id_and_revision: self.loaded_models.pop(key, None), ) if progress_callback is not None: progress_callback( @@ -486,11 +475,7 @@ def get_gen_models(cache_dir: str | None = None) -> list[dict]: for ref, revision_info in repo.refs.items(): config_path = next( - ( - f.file_path - for f in revision_info.files - if f.file_name == "config.json" - ), + (f.file_path for f in revision_info.files if f.file_name == "config.json"), None, ) if not config_path: @@ -505,11 +490,7 @@ def get_gen_models(cache_dir: str | None = None) -> list[dict]: vlms = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() multimodal = MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES.values() - if any( - arch - for arch in architectures - if arch in [*llms, *vlms, *multimodal] - ): + if any(arch for arch in architectures if arch in [*llms, *vlms, *multimodal]): author = repo.repo_id.split("/") if "/" in repo.repo_id else "" repo_handle = repo.repo_id + (f"@{ref}" if ref != "main" else "") generative_models.append( diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index a5689bffdfb4..f8e2491b5e34 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -70,9 +70,7 @@ # --- FINAL ROBUST PATCH --- if "ResponseCreateParamsStreaming" in globals(): - class TransformersResponseCreateParamsStreaming( - ResponseCreateParamsStreaming, total=False - ): + class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): generation_config: str seed: int @@ -108,9 +106,7 @@ class ResponseHandler(BaseHandler): _valid_params_class = TransformersResponseCreateParamsStreaming _unused_fields = UNUSED_RESPONSE_FIELDS - async def handle_request( - self, body: dict, request_id: str - ) -> StreamingResponse | JSONResponse: + async def handle_request(self, body: dict, request_id: str) -> StreamingResponse | JSONResponse: """Validate, load model, dispatch to streaming or non-streaming. Args: @@ -137,9 +133,7 @@ async def handle_request( has_video = any( c.get("type") == "video" for msg in processor_inputs - for c in ( - msg.get("content") if isinstance(msg.get("content"), list) else [] - ) + for c in (msg.get("content") if isinstance(msg.get("content"), list) else []) ) # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise @@ -161,15 +155,11 @@ async def handle_request( if not use_cb: inputs = inputs.to(model.device) # type: ignore[union-attr] - gen_config = self._build_generation_config( - body, model.generation_config, use_cb=use_cb - ) + gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) # TODO: remove when CB supports per-request generation config if use_cb: gen_manager.init_cb(model, gen_config) - tool_config = ( - get_tool_call_config(processor, model) if body.get("tools") else None - ) + tool_config = get_tool_call_config(processor, model) if body.get("tools") else None streaming = body.get("stream", True) if streaming: @@ -256,9 +246,7 @@ def _normalize_input(body: dict) -> list[dict]: else: messages = ResponseHandler._normalize_response_items(inp) else: - raise HTTPException( - status_code=422, detail="'input' must be a string or list" - ) + raise HTTPException(status_code=422, detail="'input' must be a string or list") # Prepend instructions as a system message if instructions: @@ -284,9 +272,7 @@ def _normalize_response_items(items: list[dict]) -> list[dict]: item_type = item.get("type") if "role" in item: - messages.append( - {"role": item["role"], "content": item.get("content", "")} - ) + messages.append({"role": item["role"], "content": item.get("content", "")}) elif item_type == "function_call": tc = { @@ -340,9 +326,7 @@ def _streaming( ) input_ids = inputs["input_ids"] # CB returns plain lists, regular path returns tensors - input_len = ( - len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] - ) + input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] seq = 0 output_index = 0 @@ -378,9 +362,7 @@ async def event_stream() -> AsyncGenerator[str, None]: ResponseInProgressEvent( type="response.in_progress", sequence_number=seq, - response=Response( - **response_base, status="in_progress", output=[] - ), + response=Response(**response_base, status="in_progress", output=[]), ) ) seq += 1 @@ -410,9 +392,7 @@ async def event_stream() -> AsyncGenerator[str, None]: sequence_number=seq, output_index=output_index, content_index=0, - part=ResponseOutputText( - type="output_text", text="", annotations=[] - ), + part=ResponseOutputText(type="output_text", text="", annotations=[]), ) ) seq += 1 @@ -438,9 +418,7 @@ async def event_stream() -> AsyncGenerator[str, None]: done = True break if isinstance(text, _StreamError): - logger.error( - f"Exception in response generation: {text.msg}" - ) + logger.error(f"Exception in response generation: {text.msg}") sse_parts.append( self.chunk_to_sse( ResponseErrorEvent( @@ -460,9 +438,7 @@ async def event_stream() -> AsyncGenerator[str, None]: **response_base, status="failed", output=[], - error=ResponseError( - code="server_error", message=text.msg - ), + error=ResponseError(code="server_error", message=text.msg), ), ) ) @@ -492,9 +468,7 @@ async def event_stream() -> AsyncGenerator[str, None]: # 5. Tool calls are parsed after generation completes (not during streaming), # because the full token sequence is needed for reliable parsing. if tool_config: - parsed = parse_tool_calls( - processor, streamer.generated_token_ids, tool_config["schema"] - ) + parsed = parse_tool_calls(processor, streamer.generated_token_ids, tool_config["schema"]) if parsed: for i, tc in enumerate(parsed): tc_id = f"{request_id}_tool_call_{i}" @@ -539,9 +513,7 @@ async def event_stream() -> AsyncGenerator[str, None]: seq += 1 # 6. Close text output - output_text_part = ResponseOutputText( - type="output_text", text=full_text, annotations=[] - ) + output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[]) yield self.chunk_to_sse( ResponseTextDoneEvent( type="response.output_text.done", @@ -633,11 +605,7 @@ async def _non_streaming( type="message", status="completed", role="assistant", - content=[ - ResponseOutputText( - type="output_text", text=full_text, annotations=[] - ) - ], + content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])], annotations=[], # type: ignore[call-arg] ) ] @@ -683,9 +651,7 @@ def _build_generation_config( use_cb: bool = False, ): """Apply Responses API params (``max_output_tokens``) on top of the base generation config.""" - generation_config = super()._build_generation_config( - body, model_generation_config, use_cb=use_cb - ) + generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) if body.get("max_output_tokens") is not None: generation_config.max_new_tokens = int(body["max_output_tokens"]) diff --git a/src/transformers/cli/serving/server.py b/src/transformers/cli/serving/server.py index 1fe47c3e7296..64e276d5bb56 100644 --- a/src/transformers/cli/serving/server.py +++ b/src/transformers/cli/serving/server.py @@ -73,9 +73,7 @@ async def lifespan(app: FastAPI): allow_methods=["*"], allow_headers=["*"], ) - logger.warning_once( - "CORS allow origin is set to `*`. Not recommended for production." - ) + logger.warning_once("CORS allow origin is set to `*`. Not recommended for production.") # ---- Middleware ---- @@ -112,9 +110,7 @@ async def load_model(body: dict): model = body.get("model") if model is None: - raise HTTPException( - status_code=422, detail="Missing `model` field in the request body." - ) + raise HTTPException(status_code=422, detail="Missing `model` field in the request body.") model_id_and_revision = model_manager.process_model_name(model) return StreamingResponse( model_manager.load_model_streaming(model_id_and_revision), diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py index e69ef617b816..fc853a1eb46b 100644 --- a/src/transformers/cli/serving/transcription.py +++ b/src/transformers/cli/serving/transcription.py @@ -43,9 +43,7 @@ # --- FINAL ROBUST PATCH --- if "TranscriptionCreateParamsBase" in globals(): - class TransformersTranscriptionCreateParams( - TranscriptionCreateParamsBase, total=False - ): + class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False): generation_config: str seed: int @@ -92,9 +90,7 @@ def __init__(self, model_manager: ModelManager, generation_state: GenerationStat def _validate_request(self, form_keys: set[str]) -> None: """Validate transcription request fields.""" - unexpected = form_keys - getattr( - TransformersTranscriptionCreateParams, "__mutable_keys__", set() - ) + unexpected = form_keys - getattr(TransformersTranscriptionCreateParams, "__mutable_keys__", set()) if unexpected: raise HTTPException( status_code=422, @@ -104,9 +100,7 @@ def _validate_request(self, form_keys: set[str]) -> None: if unused: logger.warning_once(f"Ignoring unsupported fields in the request: {unused}") - async def handle_request( - self, request: Request - ) -> JSONResponse | StreamingResponse: + async def handle_request(self, request: Request) -> JSONResponse | StreamingResponse: """Parse multipart form, run transcription, return result. Args: @@ -119,9 +113,7 @@ async def handle_request( from transformers.utils.import_utils import is_librosa_available, is_multipart_available if not is_librosa_available(): - raise ImportError( - "Missing librosa dependency for audio transcription. Install with `pip install librosa`" - ) + raise ImportError("Missing librosa dependency for audio transcription. Install with `pip install librosa`") if not is_multipart_available(): raise ImportError( "Missing python-multipart dependency for file uploads. Install with `pip install python-multipart`" @@ -131,21 +123,15 @@ async def handle_request( self._validate_request(set(form.keys())) file_field = form["file"] if isinstance(file_field, str): - raise HTTPException( - status_code=422, detail="Expected file upload, got string" - ) + raise HTTPException(status_code=422, detail="Expected file upload, got string") file_bytes = await file_field.read() model = form["model"] if not isinstance(model, str): - raise HTTPException( - status_code=422, detail="Expected model name as string" - ) + raise HTTPException(status_code=422, detail="Expected model name as string") stream = str(form.get("stream", "false")).lower() == "true" model_id_and_revision = self.model_manager.process_model_name(model) - audio_model, audio_processor = self.model_manager.load_model_and_processor( - model_id_and_revision - ) + audio_model, audio_processor = self.model_manager.load_model_and_processor(model_id_and_revision) base_manager = self.generation_state.get_manager(model_id_and_revision) if not isinstance(base_manager, GenerateManager): raise HTTPException( @@ -153,17 +139,11 @@ async def handle_request( detail="Audio transcription requires sequential generation (not CB)", ) gen_manager = base_manager - audio_inputs = self._prepare_audio_inputs( - file_bytes, audio_processor, audio_model - ) + audio_inputs = self._prepare_audio_inputs(file_bytes, audio_processor, audio_model) if stream: - return self._streaming( - gen_manager, audio_model, audio_processor, audio_inputs - ) - return await self._non_streaming( - gen_manager, audio_model, audio_processor, audio_inputs - ) + return self._streaming(gen_manager, audio_model, audio_processor, audio_inputs) + return await self._non_streaming(gen_manager, audio_model, audio_processor, audio_inputs) @staticmethod def _prepare_audio_inputs( @@ -175,15 +155,11 @@ def _prepare_audio_inputs( import librosa sampling_rate = audio_processor.feature_extractor.sampling_rate - audio_array, _ = librosa.load( - io.BytesIO(file_bytes), sr=sampling_rate, mono=True - ) - audio_inputs = audio_processor( - audio_array, sampling_rate=sampling_rate, return_tensors="pt" - ).to(audio_model.device) - audio_inputs["input_features"] = audio_inputs["input_features"].to( - audio_model.dtype + audio_array, _ = librosa.load(io.BytesIO(file_bytes), sr=sampling_rate, mono=True) + audio_inputs = audio_processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").to( + audio_model.device ) + audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype) return audio_inputs async def _non_streaming( @@ -198,9 +174,7 @@ async def _non_streaming( # generate_non_streaming() from openai.types.audio import Transcription - generated_ids = await gen_manager.async_submit( - audio_model.generate, **audio_inputs - ) + generated_ids = await gen_manager.async_submit(audio_model.generate, **audio_inputs) text = audio_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return JSONResponse(Transcription(text=text).model_dump(exclude_none=True)) @@ -215,16 +189,10 @@ def _streaming( # differ from text. import asyncio - tokenizer = ( - audio_processor.tokenizer - if hasattr(audio_processor, "tokenizer") - else audio_processor - ) + tokenizer = audio_processor.tokenizer if hasattr(audio_processor, "tokenizer") else audio_processor loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() - streamer = DirectStreamer( - tokenizer._tokenizer, loop, queue, skip_special_tokens=True - ) + streamer = DirectStreamer(tokenizer._tokenizer, loop, queue, skip_special_tokens=True) gen_kwargs = {**audio_inputs, "streamer": streamer} def _run(): diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index caf44771cd2a..50f901060af2 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -109,11 +109,7 @@ def get_tool_call_config(processor, model: "PreTrainedModel") -> dict | None: else: # Fallback: known model families without full tokenizer config fallback = next( - ( - v - for k, v in _TOOL_CALL_FALLBACKS.items() - if k in model.config.model_type - ), + (v for k, v in _TOOL_CALL_FALLBACKS.items() if k in model.config.model_type), None, ) if fallback is None: @@ -138,9 +134,7 @@ def _normalize_tool_call(tool_call: dict) -> dict: arguments = function.get("arguments", {}) return { "name": function["name"], - "arguments": ( - json.dumps(arguments) if not isinstance(arguments, str) else arguments - ), + "arguments": (json.dumps(arguments) if not isinstance(arguments, str) else arguments), } @@ -339,11 +333,7 @@ def put(self, value: "torch.Tensor") -> None: self._inside_tool_call = False text = self._decode_stream.step(self._tokenizer, token_id) - if ( - text is not None - and not self._inside_tool_call - and token_id != self._etc_id - ): + if text is not None and not self._inside_tool_call and token_id != self._etc_id: self._loop.call_soon_threadsafe(self._queue.put_nowait, text) def end(self) -> None: @@ -411,11 +401,7 @@ def put(self, output: "GenerationOutput") -> None: self._inside_tool_call = False text = self._decode_stream.step(self._tokenizer, token_id) - if ( - text is not None - and not self._inside_tool_call - and token_id != self._etc_id - ): + if text is not None and not self._inside_tool_call and token_id != self._etc_id: self._queue.put_nowait(text) def end(self) -> None: @@ -745,9 +731,7 @@ def _on_result(result): ) result = await future if result is None: - raise RuntimeError( - f"CB manager stopped before producing a result for {request_id}" - ) + raise RuntimeError(f"CB manager stopped before producing a result for {request_id}") generated_ids = result.generated_tokens text = processor.decode(generated_ids, skip_special_tokens=True) return text, input_len, generated_ids @@ -791,9 +775,7 @@ def __init__( self._cb_manager: CBGenerateManager | None = None self._cb_model_id: str | None = None - def use_continuous_batching( - self, model: "PreTrainedModel", modality: Modality - ) -> bool: + def use_continuous_batching(self, model: "PreTrainedModel", modality: Modality) -> bool: """Check if continuous batching can be used for this model and modality. Args: @@ -873,9 +855,7 @@ def _validate_request(self, body: dict) -> None: input_keys = set(body.keys()) if self._valid_params_class is not None: - unexpected = input_keys - getattr( - self._valid_params_class, "__mutable_keys__", set() - ) + unexpected = input_keys - getattr(self._valid_params_class, "__mutable_keys__", set()) if unexpected: raise HTTPException( status_code=422, @@ -892,9 +872,7 @@ def chunk_to_sse(chunk: "str | pydantic.BaseModel") -> str: return chunk if chunk.startswith("data: ") else f"data: {chunk}\n\n" return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" - def _resolve_model( - self, body: dict - ) -> tuple[str, "PreTrainedModel", "ProcessorMixin | PreTrainedTokenizerFast"]: + def _resolve_model(self, body: dict) -> tuple[str, "PreTrainedModel", "ProcessorMixin | PreTrainedTokenizerFast"]: """Apply force_model, load model + processor. Returns ``(model_id, model, processor)``. @@ -906,9 +884,7 @@ def _resolve_model( if requested is not None and requested != self.model_manager.force_model: raise HTTPException( status_code=400, - detail=( - f"Server is pinned to '{self.model_manager.force_model}'; requested '{requested}'." - ), + detail=(f"Server is pinned to '{self.model_manager.force_model}'; requested '{requested}'."), ) body["model"] = self.model_manager.force_model @@ -943,15 +919,10 @@ def _build_generation_config( from transformers import GenerationConfig if body.get("generation_config") is not None: - generation_config = GenerationConfig( - **json.loads(body["generation_config"]) - ) + generation_config = GenerationConfig(**json.loads(body["generation_config"])) else: generation_config = copy.deepcopy(model_generation_config) - if ( - generation_config.max_new_tokens is None - or generation_config.max_new_tokens < 1024 - ): + if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 1024: generation_config.max_new_tokens = 1024 if body.get("temperature") is not None: @@ -964,10 +935,7 @@ def _build_generation_config( set_torch_seed(body["seed"]) # --compile flag: use static cache + torch.compile for faster decode - if ( - self.generation_state._compile - and generation_config.cache_implementation is None - ): + if self.generation_state._compile and generation_config.cache_implementation is None: generation_config.cache_implementation = "static" # CB manages its own paged KV cache @@ -979,9 +947,7 @@ def _build_generation_config( return generation_config @staticmethod - def get_processor_inputs_from_messages( - messages: list[dict], modality: Modality - ) -> list[dict]: + def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) -> list[dict]: """Convert OpenAI-format messages to the format expected by HF processors. All modalities extract text. VLM additionally handles ``image_url`` and ``video_url``. @@ -1008,9 +974,7 @@ def get_processor_inputs_from_messages( # When tool_calls are present, ignore content — it's either empty or contains # raw tool call markup that would confuse the chat template if rendered. - raw_content = ( - [] if "tool_calls" in message else (message.get("content") or []) - ) + raw_content = [] if "tool_calls" in message else (message.get("content") or []) if isinstance(raw_content, str): raw_content = [{"type": "text", "text": raw_content}] @@ -1032,27 +996,17 @@ def get_processor_inputs_from_messages( # Audio: unlike images, load_audio doesn't accept raw base64 — wrap as a data URI elif content_type == "input_audio" and modality == Modality.MULTIMODAL: input_audio = content["input_audio"] - fmt = ( - input_audio.get("format", "wav") - if isinstance(input_audio, dict) - else "wav" - ) + fmt = input_audio.get("format", "wav") if isinstance(input_audio, dict) else "wav" audio_b64 = input_audio["data"] - parsed["content"].append( - {"type": "audio", "url": f"data:audio/{fmt};base64,{audio_b64}"} - ) + parsed["content"].append({"type": "audio", "url": f"data:audio/{fmt};base64,{audio_b64}"}) # Extensions (not part of the OpenAI API standard) elif content_type == "video_url" and modality in ( Modality.VLM, Modality.MULTIMODAL, ): - parsed["content"].append( - {"type": "video", "url": content["video_url"]["url"]} - ) + parsed["content"].append({"type": "video", "url": content["video_url"]["url"]}) elif content_type == "audio_url" and modality == Modality.MULTIMODAL: - parsed["content"].append( - {"type": "audio", "url": content["audio_url"]["url"]} - ) + parsed["content"].append({"type": "audio", "url": content["audio_url"]["url"]}) # LLMs expect plain text, not a list of content parts if modality == Modality.LLM: