diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 161a25a02f41..c25ba58f7e52 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -23,10 +23,11 @@ from typing import TYPE_CHECKING from ...utils import logging -from ...utils.import_utils import is_serve_available +from .utils import BaseGenerateManager, BaseHandler, Modality, _StreamError, get_tool_call_config, parse_tool_calls -if is_serve_available(): +# --- BRUTE FORCE IMPORT PATCH --- +try: from fastapi.responses import JSONResponse, StreamingResponse from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import Choice @@ -35,26 +36,62 @@ from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming from openai.types.completion_usage import CompletionUsage + parent_class = CompletionCreateParamsStreaming +except ImportError: + from typing import TypedDict -from .utils import ( - BaseGenerateManager, - BaseHandler, - Modality, - _StreamError, - get_tool_call_config, - parse_tool_calls, -) + class _DummyDict(dict): + def __getattr__(self, name): + return None + def __setattr__(self, name, value): + self[name] = value -if TYPE_CHECKING: - from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin + class ChatCompletion(_DummyDict): + pass + + class ChatCompletionMessage(_DummyDict): + pass + + class ChatCompletionMessageToolCall(_DummyDict): + pass + + class Choice(_DummyDict): + pass + + class ChatCompletionChunk(_DummyDict): + pass + + class ChoiceDelta(_DummyDict): + pass + + class ChoiceDeltaToolCall(_DummyDict): + pass + + class ChoiceChunk(_DummyDict): + pass + + class CompletionCreateParamsStreaming(_DummyDict): + pass + class CompletionUsage(_DummyDict): + pass -class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): + parent_class = TypedDict + + +class TransformersCompletionCreateParamsStreaming(parent_class, total=False): # type: ignore generation_config: str seed: int +# --- END PATCH --- + + +if TYPE_CHECKING: + from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin + + # Fields accepted by the OpenAI schema but not yet supported. # Receiving these raises an error to avoid silent misbehaviour. # NOTE: "stop" is NOT in this set — we map it to stop_strings. @@ -133,7 +170,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse **chat_template_kwargs, ) if not use_cb: - inputs = inputs.to(model.device) # type: ignore[union-attr] + inputs = inputs.to(model.device) # type: ignore gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb) # TODO: remove when CB supports per-request generation config @@ -237,7 +274,10 @@ async def sse_gen() -> AsyncGenerator[str, None]: index=i, type="function", id=f"{request_id}_tool_call_{i}", - function={"name": tc["name"], "arguments": tc["arguments"]}, + function={ + "name": tc["name"], + "arguments": tc["arguments"], + }, ) ], ) @@ -328,7 +368,12 @@ async def _non_streaming( # ----- helpers ----- - def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): + def _build_generation_config( + self, + body: dict, + model_generation_config: "GenerationConfig", + use_cb: bool = False, + ): """Apply Chat Completions params (``max_tokens``, ``frequency_penalty``, ``logit_bias``, ``stop``) on top of the base generation config.""" generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) diff --git a/src/transformers/cli/serving/completion.py b/src/transformers/cli/serving/completion.py index 52c1f1b8471d..ed04fffb12a8 100644 --- a/src/transformers/cli/serving/completion.py +++ b/src/transformers/cli/serving/completion.py @@ -22,7 +22,7 @@ import asyncio import time from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypedDict from ...utils import logging from ...utils.import_utils import is_serve_available @@ -34,7 +34,6 @@ from openai.types import Completion, CompletionChoice, CompletionUsage from openai.types.completion_create_params import CompletionCreateParamsBase - from .utils import BaseGenerateManager, BaseHandler, _StreamError @@ -42,11 +41,21 @@ from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin -class TransformersTextCompletionCreateParams(CompletionCreateParamsBase, total=False): - generation_config: str - seed: int - stream: bool +# --- FINAL ROBUST PATCH --- +if "CompletionCreateParamsBase" in globals(): + # If the real OpenAI class was successfully imported, use it + class TransformersTextCompletionCreateParams(CompletionCreateParamsBase, total=False): + generation_config: str + seed: int + +else: + # Fallback to standard TypedDict if OpenAI types are missing + class TransformersTextCompletionCreateParams(TypedDict, total=False): + generation_config: str + seed: int + +# --- END PATCH --- # Fields accepted by the OpenAI schema but not yet supported. UNUSED_LEGACY_COMPLETION_FIELDS = { @@ -109,10 +118,26 @@ async def handle_request(self, body: dict, request_id: str) -> "StreamingRespons streaming = body.get("stream") if streaming: - return self._streaming(request_id, model, processor, model_id, inputs, gen_config, gen_manager, suffix) + return self._streaming( + request_id, + model, + processor, + model_id, + inputs, + gen_config, + gen_manager, + suffix, + ) else: return await self._non_streaming( - request_id, model, processor, model_id, inputs, gen_config, gen_manager, suffix + request_id, + model, + processor, + model_id, + inputs, + gen_config, + gen_manager, + suffix, ) # ----- streaming ----- @@ -261,7 +286,12 @@ def _build_chunk_sse( # ----- generation config ----- - def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): + def _build_generation_config( + self, + body: dict, + model_generation_config: "GenerationConfig", + use_cb: bool = False, + ): """Apply legacy completion params (``max_tokens``, ``frequency_penalty``, ``stop``) on top of base config.""" generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) diff --git a/src/transformers/cli/serving/model_manager.py b/src/transformers/cli/serving/model_manager.py index 826199ee4b01..d718b99738b1 100644 --- a/src/transformers/cli/serving/model_manager.py +++ b/src/transformers/cli/serving/model_manager.py @@ -159,11 +159,20 @@ def _resolve_dtype(dtype: str | None): return resolved def _validate_args(self): - if self.quantization is not None and self.quantization not in ("bnb-4bit", "bnb-8bit"): + if self.quantization is not None and self.quantization not in ( + "bnb-4bit", + "bnb-8bit", + ): raise ValueError( f"Unsupported quantization method: '{self.quantization}'. Must be 'bnb-4bit' or 'bnb-8bit'." ) - VALID_ATTN_IMPLEMENTATIONS = {"eager", "sdpa", "flash_attention_2", "flash_attention_3", "flex_attention"} + VALID_ATTN_IMPLEMENTATIONS = { + "eager", + "sdpa", + "flash_attention_2", + "flash_attention_3", + "flex_attention", + } is_kernels_community = self.attn_implementation is not None and self.attn_implementation.startswith( "kernels-community/" ) @@ -208,7 +217,10 @@ def _load_processor(self, model_id_and_revision: str) -> "ProcessorMixin | PreTr return AutoProcessor.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code) def _load_model( - self, model_id_and_revision: str, tqdm_class: type | None = None, progress_callback: Callable | None = None + self, + model_id_and_revision: str, + tqdm_class: type | None = None, + progress_callback: Callable | None = None, ) -> "PreTrainedModel": """Load a model. @@ -270,10 +282,18 @@ def load_model_and_processor( if model_id_and_revision not in self.loaded_models: logger.warning(f"Loading {model_id_and_revision}") if progress_callback is not None: - progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "processor"}) + progress_callback( + { + "status": "loading", + "model": model_id_and_revision, + "stage": "processor", + } + ) processor = self._load_processor(model_id_and_revision) model = self._load_model( - model_id_and_revision, tqdm_class=tqdm_class, progress_callback=progress_callback + model_id_and_revision, + tqdm_class=tqdm_class, + progress_callback=progress_callback, ) self.loaded_models[model_id_and_revision] = TimedModel( model, @@ -282,13 +302,25 @@ def load_model_and_processor( on_unload=lambda key=model_id_and_revision: self.loaded_models.pop(key, None), ) if progress_callback is not None: - progress_callback({"status": "ready", "model": model_id_and_revision, "cached": False}) + progress_callback( + { + "status": "ready", + "model": model_id_and_revision, + "cached": False, + } + ) else: self.loaded_models[model_id_and_revision].reset_timer() model = self.loaded_models[model_id_and_revision].model processor = self.loaded_models[model_id_and_revision].processor if progress_callback is not None: - progress_callback({"status": "ready", "model": model_id_and_revision, "cached": True}) + progress_callback( + { + "status": "ready", + "model": model_id_and_revision, + "cached": True, + } + ) return model, processor async def load_model_streaming(self, model_id_and_revision: str): @@ -384,7 +416,8 @@ def shutdown(self) -> None: @staticmethod def get_model_modality( - model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None + model: "PreTrainedModel", + processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None, ) -> Modality: """Detect whether a model is an LLM or VLM based on its architecture. @@ -441,7 +474,10 @@ def get_gen_models(cache_dir: str | None = None) -> list[dict]: continue for ref, revision_info in repo.refs.items(): - config_path = next((f.file_path for f in revision_info.files if f.file_name == "config.json"), None) + config_path = next( + (f.file_path for f in revision_info.files if f.file_name == "config.json"), + None, + ) if not config_path: continue diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 4d29dfd1d6a2..f8e2491b5e34 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -20,7 +20,7 @@ import asyncio import time from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypedDict from ...utils import logging from ...utils.import_utils import is_serve_available @@ -48,18 +48,16 @@ ResponseTextDeltaEvent, ResponseTextDoneEvent, ) - from openai.types.responses.response_create_params import ResponseCreateParamsStreaming - from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage - + from openai.types.responses.response_create_params import ( + ResponseCreateParamsStreaming, + ) + from openai.types.responses.response_usage import ( + InputTokensDetails, + OutputTokensDetails, + ResponseUsage, + ) -from .utils import ( - BaseGenerateManager, - BaseHandler, - Modality, - _StreamError, - get_tool_call_config, - parse_tool_calls, -) +from .utils import BaseGenerateManager, BaseHandler, Modality, _StreamError, get_tool_call_config, parse_tool_calls if TYPE_CHECKING: @@ -69,10 +67,21 @@ logger = logging.get_logger(__name__) -class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): - generation_config: str - seed: int +# --- FINAL ROBUST PATCH --- +if "ResponseCreateParamsStreaming" in globals(): + + class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): + generation_config: str + seed: int +else: + + class TransformersResponseCreateParamsStreaming(TypedDict, total=False): + generation_config: str + seed: int + + +# --- END PATCH --- UNUSED_RESPONSE_FIELDS = { "background", @@ -192,7 +201,14 @@ def _normalize_tools(tools: list[dict] | None) -> list[dict] | None: if not tools: return tools return [ - {"type": "function", "function": {k: v for k, v in t.items() if k != "type"}} if "function" not in t else t + ( + { + "type": "function", + "function": {k: v for k, v in t.items() if k != "type"}, + } + if "function" not in t + else t + ) for t in tools ] @@ -278,7 +294,10 @@ def _normalize_response_items(items: list[dict]) -> list[dict]: ) else: - raise HTTPException(status_code=422, detail=f"Unsupported input item type: {item_type!r}") + raise HTTPException( + status_code=422, + detail=f"Unsupported input item type: {item_type!r}", + ) return messages @@ -402,7 +421,11 @@ async def event_stream() -> AsyncGenerator[str, None]: logger.error(f"Exception in response generation: {text.msg}") sse_parts.append( self.chunk_to_sse( - ResponseErrorEvent(type="error", sequence_number=seq, message=text.msg) + ResponseErrorEvent( + type="error", + sequence_number=seq, + message=text.msg, + ) ) ) seq += 1 @@ -540,7 +563,12 @@ async def event_stream() -> AsyncGenerator[str, None]: ResponseCompletedEvent( type="response.completed", sequence_number=seq, - response=Response(**response_base, status="completed", output=all_output, usage=usage), + response=Response( + **response_base, + status="completed", + output=all_output, + usage=usage, + ), ) ) seq += 1 @@ -616,7 +644,12 @@ async def _non_streaming( # ----- helpers ----- - def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False): + def _build_generation_config( + self, + body: dict, + model_generation_config: "GenerationConfig", + use_cb: bool = False, + ): """Apply Responses API params (``max_output_tokens``) on top of the base generation config.""" generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb) diff --git a/src/transformers/cli/serving/server.py b/src/transformers/cli/serving/server.py index 13a9565db590..64e276d5bb56 100644 --- a/src/transformers/cli/serving/server.py +++ b/src/transformers/cli/serving/server.py @@ -113,7 +113,8 @@ async def load_model(body: dict): raise HTTPException(status_code=422, detail="Missing `model` field in the request body.") model_id_and_revision = model_manager.process_model_name(model) return StreamingResponse( - model_manager.load_model_streaming(model_id_and_revision), media_type="text/event-stream" + model_manager.load_model_streaming(model_id_and_revision), + media_type="text/event-stream", ) @app.post("/reset") diff --git a/src/transformers/cli/serving/transcription.py b/src/transformers/cli/serving/transcription.py index 5865dc77029f..fc853a1eb46b 100644 --- a/src/transformers/cli/serving/transcription.py +++ b/src/transformers/cli/serving/transcription.py @@ -16,7 +16,7 @@ """ import io -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypedDict from ...utils import logging from ...utils.import_utils import is_serve_available @@ -25,7 +25,9 @@ if is_serve_available(): from fastapi import HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse - from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase + from openai.types.audio.transcription_create_params import ( + TranscriptionCreateParamsBase, + ) from .model_manager import ModelManager from .utils import DirectStreamer, GenerateManager, GenerationState, _StreamError @@ -38,8 +40,21 @@ logger = logging.get_logger(__name__) -class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False): - stream: bool +# --- FINAL ROBUST PATCH --- +if "TranscriptionCreateParamsBase" in globals(): + + class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False): + generation_config: str + seed: int + +else: + + class TransformersTranscriptionCreateParams(TypedDict, total=False): + generation_config: str + seed: int + + +# --- END PATCH --- UNUSED_TRANSCRIPTION_FIELDS = { @@ -77,7 +92,10 @@ def _validate_request(self, form_keys: set[str]) -> None: """Validate transcription request fields.""" unexpected = form_keys - getattr(TransformersTranscriptionCreateParams, "__mutable_keys__", set()) if unexpected: - raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}") + raise HTTPException( + status_code=422, + detail=f"Unexpected fields in the request: {unexpected}", + ) unused = form_keys & UNUSED_TRANSCRIPTION_FIELDS if unused: logger.warning_once(f"Ignoring unsupported fields in the request: {unused}") @@ -116,7 +134,10 @@ async def handle_request(self, request: Request) -> JSONResponse | StreamingResp audio_model, audio_processor = self.model_manager.load_model_and_processor(model_id_and_revision) base_manager = self.generation_state.get_manager(model_id_and_revision) if not isinstance(base_manager, GenerateManager): - raise HTTPException(status_code=400, detail="Audio transcription requires sequential generation (not CB)") + raise HTTPException( + status_code=400, + detail="Audio transcription requires sequential generation (not CB)", + ) gen_manager = base_manager audio_inputs = self._prepare_audio_inputs(file_bytes, audio_processor, audio_model) @@ -126,7 +147,9 @@ async def handle_request(self, request: Request) -> JSONResponse | StreamingResp @staticmethod def _prepare_audio_inputs( - file_bytes: bytes, audio_processor: "ProcessorMixin", audio_model: "PreTrainedModel" + file_bytes: bytes, + audio_processor: "ProcessorMixin", + audio_model: "PreTrainedModel", ) -> dict: """Load audio bytes and convert to model inputs.""" import librosa diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index d786a828fc28..50f901060af2 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -108,7 +108,10 @@ def get_tool_call_config(processor, model: "PreTrainedModel") -> dict | None: schema = response_schema["properties"]["tool_calls"] else: # Fallback: known model families without full tokenizer config - fallback = next((v for k, v in _TOOL_CALL_FALLBACKS.items() if k in model.config.model_type), None) + fallback = next( + (v for k, v in _TOOL_CALL_FALLBACKS.items() if k in model.config.model_type), + None, + ) if fallback is None: return None stc, etc, schema = fallback["stc"], fallback["etc"], fallback["schema"] @@ -131,7 +134,7 @@ def _normalize_tool_call(tool_call: dict) -> dict: arguments = function.get("arguments", {}) return { "name": function["name"], - "arguments": json.dumps(arguments) if not isinstance(arguments, str) else arguments, + "arguments": (json.dumps(arguments) if not isinstance(arguments, str) else arguments), } @@ -153,7 +156,7 @@ def parse_tool_calls(processor, generated_ids, schema: dict) -> list[dict] | Non if not isinstance(parsed, list): parsed = [parsed] tool_calls = [_normalize_tool_call(tool_call) for tool_call in parsed] - return tool_calls if tool_calls else None + return tool_calls or None class DownloadAggregator: @@ -552,7 +555,12 @@ def generate_streaming( # ProcessorMixin exposes the fast tokenizer as .tokenizer; PreTrainedTokenizerFast is already one. rust_tokenizer = getattr(processor, "tokenizer", processor)._tokenizer # type: ignore[union-attr] streamer = DirectStreamer(rust_tokenizer, loop, queue, tool_config=tool_config) - gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor} + gen_kwargs = { + **inputs, + "streamer": streamer, + "generation_config": gen_config, + "tokenizer": processor, + } if hasattr(model, "has_talker"): gen_kwargs["generation_mode"] = "text" @@ -578,7 +586,11 @@ async def generate_non_streaming( """Run generation to completion via ``model.generate()`` on the inference thread.""" # Multimodal models (e.g. Qwen2.5-Omni) may generate audio alongside text by default; # force text-only output since the serve layer only handles text - generate_kwargs = {**inputs, "generation_config": gen_config, "tokenizer": processor} + generate_kwargs = { + **inputs, + "generation_config": gen_config, + "tokenizer": processor, + } if hasattr(model, "has_talker"): generate_kwargs["generation_mode"] = "text" sequences = await self.async_submit(model.generate, **generate_kwargs) @@ -662,7 +674,14 @@ def generate_streaming( ) # ProcessorMixin exposes the fast tokenizer as .tokenizer; PreTrainedTokenizerFast is already one. rust_tokenizer = getattr(processor, "tokenizer", processor)._tokenizer # type: ignore[union-attr] - streamer = CBStreamer(self._cb, request_id, rust_tokenizer, loop, text_queue, tool_config=tool_config) + streamer = CBStreamer( + self._cb, + request_id, + rust_tokenizer, + loop, + text_queue, + tool_config=tool_config, + ) # Register a direct callback: the dispatcher calls this on the event loop with each GenerationOutput. # This decodes tokens and pushes text straight to the SSE text_queue @@ -838,7 +857,10 @@ def _validate_request(self, body: dict) -> None: if self._valid_params_class is not None: unexpected = input_keys - getattr(self._valid_params_class, "__mutable_keys__", set()) if unexpected: - raise HTTPException(status_code=422, detail=f"Unexpected fields in the request: {unexpected}") + raise HTTPException( + status_code=422, + detail=f"Unexpected fields in the request: {unexpected}", + ) unused = input_keys & self._unused_fields if unused: logger.warning_once(f"Ignoring unsupported fields in the request: {unused}") @@ -872,7 +894,10 @@ def _resolve_model(self, body: dict) -> tuple[str, "PreTrainedModel", "Processor return model_id, model, processor def _build_generation_config( - self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False + self, + body: dict, + model_generation_config: "GenerationConfig", + use_cb: bool = False, ) -> "GenerationConfig": """Build a GenerationConfig from shared params (temperature, top_p, seed, generation_config JSON). @@ -959,7 +984,10 @@ def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) if content_type in ("text", "input_text", "output_text"): parsed["content"].append({"type": "text", "text": content["text"]}) # Image: chat completions ("image_url") and Responses API ("input_image") - elif content_type in ("image_url", "input_image") and modality in (Modality.VLM, Modality.MULTIMODAL): + elif content_type in ("image_url", "input_image") and modality in ( + Modality.VLM, + Modality.MULTIMODAL, + ): # chat completions: {"image_url": {"url": "..."}}, Responses API: {"image_url": "..."} url = content["image_url"] if isinstance(url, dict): @@ -972,7 +1000,10 @@ def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) audio_b64 = input_audio["data"] parsed["content"].append({"type": "audio", "url": f"data:audio/{fmt};base64,{audio_b64}"}) # Extensions (not part of the OpenAI API standard) - elif content_type == "video_url" and modality in (Modality.VLM, Modality.MULTIMODAL): + elif content_type == "video_url" and modality in ( + Modality.VLM, + Modality.MULTIMODAL, + ): parsed["content"].append({"type": "video", "url": content["video_url"]["url"]}) elif content_type == "audio_url" and modality == Modality.MULTIMODAL: parsed["content"].append({"type": "audio", "url": content["audio_url"]["url"]})