From 750e7a4ab747ee4aa830bf63883a203fb41ea944 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 28 Apr 2026 17:00:11 +0000 Subject: [PATCH 1/2] reasoning --- src/transformers/cli/serve.py | 61 ++- .../cli/serving/chat_completion.py | 39 +- src/transformers/cli/serving/response.py | 350 +++++++++++++----- src/transformers/cli/serving/utils.py | 203 +++++++++- tests/cli/test_serve.py | 179 +++++++++ 5 files changed, 708 insertions(+), 124 deletions(-) diff --git a/src/transformers/cli/serve.py b/src/transformers/cli/serve.py index 4d9e6c712e7c..14f8901339c1 100644 --- a/src/transformers/cli/serve.py +++ b/src/transformers/cli/serve.py @@ -16,6 +16,8 @@ """ import asyncio +import enum +import json import threading from typing import Annotated @@ -30,6 +32,12 @@ logger = logging.get_logger(__name__) +class ReasoningMode(str, enum.Enum): + ON = "on" + OFF = "off" + AUTO = "auto" + + class Serve: def __init__( self, @@ -39,6 +47,32 @@ def __init__( bool, typer.Option(help="Enable continuous batching with paged attention. Configure with --cb-* flags."), ] = False, + attn_implementation: Annotated[ + str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).") + ] = None, + compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False, + quantization: Annotated[ + str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.") + ] = None, + reasoning: Annotated[ + ReasoningMode, typer.Option(help="Reasoning mode. 'auto' uses the chat template default.") + ] = ReasoningMode.AUTO, + chat_template_kwargs: Annotated[ + str | None, + typer.Option( + help=( + "Default JSON kwargs forwarded to apply_chat_template " + "(e.g. '{\"enable_thinking\": true}'); per-request chat_template_kwargs override these." + ) + ), + ] = None, + device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto", + dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto", + trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False, + model_timeout: Annotated[ + int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.") + ] = 300, + # Continuous batching tuning cb_block_size: Annotated[ int | None, typer.Option(help="KV cache block size in tokens for continuous batching.") ] = None, @@ -54,19 +88,6 @@ def __init__( cb_use_cuda_graph: Annotated[ bool | None, typer.Option(help="Enable CUDA graphs for continuous batching.") ] = None, - attn_implementation: Annotated[ - str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).") - ] = None, - compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False, - quantization: Annotated[ - str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.") - ] = None, - device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto", - dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto", - trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False, - model_timeout: Annotated[ - int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.") - ] = 300, # Server options host: Annotated[str, typer.Option(help="Server listen address.")] = "localhost", port: Annotated[int, typer.Option(help="Server listen port.")] = 8000, @@ -126,14 +147,28 @@ def __init__( cb_config=cb_config, ) + if chat_template_kwargs: + chat_template_kwargs = json.loads(chat_template_kwargs) + if not isinstance(chat_template_kwargs, dict): + raise typer.BadParameter("--chat-template-kwargs must be a JSON object") + else: + chat_template_kwargs = {} + + if reasoning == ReasoningMode.ON: + chat_template_kwargs["enable_thinking"] = True + elif reasoning == ReasoningMode.OFF: + chat_template_kwargs["enable_thinking"] = False + self._chat_handler = ChatCompletionHandler( model_manager=self._model_manager, generation_state=self._generation_state, + chat_template_kwargs=chat_template_kwargs, ) self._response_handler = ResponseHandler( model_manager=self._model_manager, generation_state=self._generation_state, + chat_template_kwargs=chat_template_kwargs, ) self._transcription_handler = TranscriptionHandler(self._model_manager, self._generation_state) diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 161a25a02f41..7444f28c1a5c 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -40,8 +40,11 @@ BaseGenerateManager, BaseHandler, Modality, + ReasoningText, _StreamError, + get_reasoning_config, get_tool_call_config, + parse_reasoning, parse_tool_calls, ) @@ -53,6 +56,7 @@ class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): generation_config: str seed: int + chat_template_kwargs: dict # Fields accepted by the OpenAI schema but not yet supported. @@ -118,10 +122,13 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse for msg in processor_inputs for c in (msg.get("content") if isinstance(msg.get("content"), list) else []) ) - # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise - chat_template_kwargs = {} + # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise. + # Merge order (later wins): custom default -> server default → request-level kwargs. + chat_template_kwargs: dict = {} if has_video: chat_template_kwargs["num_frames"] = 32 + chat_template_kwargs.update(self.chat_template_kwargs) + chat_template_kwargs.update(body.get("chat_template_kwargs", {})) inputs = processor.apply_chat_template( processor_inputs, add_generation_prompt=True, @@ -141,6 +148,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse gen_manager.init_cb(model, gen_config) tool_config = get_tool_call_config(processor, model) if body.get("tools") else None + reasoning_config = get_reasoning_config(processor, model, inputs["input_ids"]) streaming = body.get("stream") if streaming: @@ -153,6 +161,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse gen_config, gen_manager=gen_manager, tool_config=tool_config, + reasoning_config=reasoning_config, ) else: return await self._non_streaming( @@ -164,6 +173,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse gen_config, gen_manager=gen_manager, tool_config=tool_config, + reasoning_config=reasoning_config, ) # ----- streaming ----- @@ -178,6 +188,7 @@ def _streaming( gen_config: "GenerationConfig", gen_manager: BaseGenerateManager, tool_config: dict | None = None, + reasoning_config: dict | None = None, ) -> StreamingResponse: """Stream tokens as SSE via DirectStreamer.""" queue, streamer = gen_manager.generate_streaming( @@ -187,6 +198,7 @@ def _streaming( gen_config, request_id=request_id, tool_config=tool_config, + reasoning_config=reasoning_config, ) input_ids = inputs["input_ids"] # CB returns plain lists, regular path returns tensors @@ -216,7 +228,10 @@ async def sse_gen() -> AsyncGenerator[str, None]: yield "".join(sse_parts) return - sse_parts.append(self._build_chunk_sse(request_id, model=model_id, content=text)) + if isinstance(text, ReasoningText): + sse_parts.append(self._build_chunk_sse(request_id, model=model_id, reasoning_content=text)) + else: + sse_parts.append(self._build_chunk_sse(request_id, model=model_id, content=text)) if sse_parts: yield "".join(sse_parts) @@ -280,6 +295,7 @@ async def _non_streaming( gen_config: "GenerationConfig", gen_manager: BaseGenerateManager, tool_config: dict | None = None, + reasoning_config: dict | None = None, ) -> JSONResponse: """Run generation and return a JSONResponse.""" content, input_len, generated_ids = await gen_manager.generate_non_streaming( @@ -307,6 +323,10 @@ async def _non_streaming( for i, tc in enumerate(parsed) ] + reasoning_content = None + if reasoning_config is not None: + content, reasoning_content = parse_reasoning(processor, generated_ids, content, reasoning_config) + if tool_calls is not None: finish_reason = "tool_calls" elif hit_max: @@ -322,6 +342,7 @@ async def _non_streaming( finish_reason=finish_reason, usage=usage, tool_calls=tool_calls, + reasoning_content=reasoning_content, ), media_type="application/json", ) @@ -354,6 +375,7 @@ def _build_completion( finish_reason: str, usage: CompletionUsage | None = None, tool_calls: list[dict] | None = None, + reasoning_content: str | None = None, ) -> dict: """Build a non-streaming ChatCompletion response dict. @@ -364,11 +386,14 @@ def _build_completion( finish_reason (`str`): Why generation stopped (``"stop"``, ``"length"``, ``"tool_calls"``). usage (`CompletionUsage`, *optional*): Token usage statistics. tool_calls (`list[dict]`, *optional*): Parsed tool calls, if any. + reasoning_content (`str`, *optional*): Chain-of-thought content extracted from the response. Returns: `dict`: Serialized ``ChatCompletion`` ready for JSON response. """ - message = ChatCompletionMessage(content=content, role="assistant", tool_calls=tool_calls) + message = ChatCompletionMessage( + content=content, role="assistant", tool_calls=tool_calls, reasoning_content=reasoning_content + ) result = ChatCompletion( id=request_id, created=int(time.time()), @@ -394,6 +419,7 @@ def _build_chunk_sse( finish_reason: str | None = None, tool_calls: list | None = None, usage: CompletionUsage | None = None, + reasoning_content: str | None = None, ) -> str: """Build a streaming ``ChatCompletionChunk`` and format it as an SSE ``data:`` line. @@ -405,6 +431,7 @@ def _build_chunk_sse( finish_reason (`str`, *optional*): Set on the final chunk. tool_calls (`list`, *optional*): Tool call deltas. usage (`CompletionUsage`, *optional*): Token usage (sent with the final chunk). + reasoning_content (`str`, *optional*): Reasoning/thinking delta (OpenAI-compatible extension). Returns: `str`: A formatted SSE event string. @@ -415,7 +442,9 @@ def _build_chunk_sse( model=model, choices=[ ChoiceChunk( - delta=ChoiceDelta(content=content, role=role, tool_calls=tool_calls), + delta=ChoiceDelta( + content=content, role=role, tool_calls=tool_calls, reasoning_content=reasoning_content + ), index=0, finish_reason=finish_reason, ) diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py index 4d29dfd1d6a2..002391c2479a 100644 --- a/src/transformers/cli/serving/response.py +++ b/src/transformers/cli/serving/response.py @@ -45,6 +45,9 @@ ResponseOutputItemDoneEvent, ResponseOutputMessage, ResponseOutputText, + ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, ResponseTextDeltaEvent, ResponseTextDoneEvent, ) @@ -56,8 +59,11 @@ BaseGenerateManager, BaseHandler, Modality, + ReasoningText, _StreamError, + get_reasoning_config, get_tool_call_config, + parse_reasoning, parse_tool_calls, ) @@ -80,7 +86,6 @@ class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, t "max_tool_calls", "previous_response_id", "prompt", - "reasoning", "service_tier", "store", "text", @@ -127,10 +132,13 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse for c in (msg.get("content") if isinstance(msg.get("content"), list) else []) ) - # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise - chat_template_kwargs = {} + # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise. + # Merge order (later wins): custom default -> server default → request-level kwargs. + chat_template_kwargs: dict = {} if has_video: chat_template_kwargs["num_frames"] = 32 + chat_template_kwargs.update(self.chat_template_kwargs) + chat_template_kwargs.update(body.get("chat_template_kwargs") or {}) # updates the flat tool structure to the one expected by the `apply_chat_template` method. tools = self._normalize_tools(body.get("tools")) inputs = processor.apply_chat_template( @@ -151,6 +159,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse if use_cb: gen_manager.init_cb(model, gen_config) tool_config = get_tool_call_config(processor, model) if body.get("tools") else None + reasoning_config = get_reasoning_config(processor, model, inputs["input_ids"]) streaming = body.get("stream", True) if streaming: @@ -164,6 +173,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse gen_config, gen_manager=gen_manager, tool_config=tool_config, + reasoning_config=reasoning_config, ) else: return await self._non_streaming( @@ -176,6 +186,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse gen_config, gen_manager=gen_manager, tool_config=tool_config, + reasoning_config=reasoning_config, ) # ----- input conversion ----- @@ -247,16 +258,26 @@ def _normalize_response_items(items: list[dict]) -> list[dict]: Input items may be a mix of: - Messages (``EasyInputMessageParam`` with ``role``, or ``type: "message"``). + - ``reasoning`` — buffered and attached as ``reasoning_content`` to the next assistant message. - ``function_call`` — merged as ``tool_calls`` onto the preceding assistant message. - ``function_call_output`` — converted to ``role: "tool"`` messages. """ messages = [] + pending_reasoning: str | None = None for item in items: item_type = item.get("type") + if item_type == "reasoning": + pending_reasoning = "".join(c["text"] for c in item.get("content") or []) + continue + if "role" in item: - messages.append({"role": item["role"], "content": item.get("content", "")}) + msg = {"role": item["role"], "content": item.get("content", "")} + if pending_reasoning is not None and item["role"] == "assistant": + msg["reasoning_content"] = pending_reasoning + pending_reasoning = None + messages.append(msg) elif item_type == "function_call": tc = { @@ -295,6 +316,7 @@ def _streaming( gen_config: "GenerationConfig", gen_manager: BaseGenerateManager, tool_config: dict | None = None, + reasoning_config: dict | None = None, ) -> StreamingResponse: """Generate a streaming Responses API reply (SSE) using DirectStreamer.""" queue, streamer = gen_manager.generate_streaming( @@ -304,16 +326,17 @@ def _streaming( gen_config, request_id=request_id, tool_config=tool_config, + reasoning_config=reasoning_config, ) input_ids = inputs["input_ids"] # CB returns plain lists, regular path returns tensors input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1] seq = 0 - output_index = 0 created_at = time.time() resp_id = f"resp_{request_id}" msg_id = f"msg_{request_id}" + reasoning_id = f"rs_{request_id}" response_base = { "id": resp_id, @@ -327,7 +350,7 @@ def _streaming( } async def event_stream() -> AsyncGenerator[str, None]: - nonlocal seq, output_index + nonlocal seq try: # 1. Created + In progress @@ -348,44 +371,163 @@ async def event_stream() -> AsyncGenerator[str, None]: ) seq += 1 - # 2. Output item added (message) - yield self.chunk_to_sse( - ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=seq, - output_index=output_index, - item=ResponseOutputMessage( - id=msg_id, - type="message", - status="in_progress", - role="assistant", - content=[], - ), - ) - ) - seq += 1 - - # 3. Content part added - yield self.chunk_to_sse( - ResponseContentPartAddedEvent( - type="response.content_part.added", - item_id=msg_id, - sequence_number=seq, - output_index=output_index, - content_index=0, - part=ResponseOutputText(type="output_text", text="", annotations=[]), - ) - ) - seq += 1 - - # 4. Stream tokens — drain queue to batch HTTP writes + # 2. Stream tokens — items are opened lazily so reasoning (if any) + # appears as a separate output item before the message item. full_text = "" + full_reasoning = "" tool_calls = [] + output_index = 0 + reasoning_open = False + message_open = False + reasoning_item = None + message_item = None done = False + def open_reasoning() -> str: + """Emit ``output_item.added`` for an in-progress reasoning item.""" + nonlocal seq, reasoning_open + reasoning_open = True + sse = self.chunk_to_sse( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=seq, + output_index=output_index, + item=ResponseReasoningItem( + id=reasoning_id, type="reasoning", summary=[], content=[], status="in_progress" + ), + ) + ) + seq += 1 + return sse + + def close_reasoning() -> str: + """Emit ``reasoning_text.done`` + ``output_item.done`` for the completed reasoning item.""" + nonlocal seq, reasoning_open, reasoning_item + reasoning_item = ResponseReasoningItem( + id=reasoning_id, + type="reasoning", + summary=[], + content=[{"type": "reasoning_text", "text": full_reasoning}], + status="completed", + ) + parts = [ + self.chunk_to_sse( + ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + item_id=reasoning_id, + sequence_number=seq, + output_index=output_index, + content_index=0, + text=full_reasoning, + ) + ) + ] + seq += 1 + parts.append( + self.chunk_to_sse( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=seq, + output_index=output_index, + item=reasoning_item, + ) + ) + ) + seq += 1 + reasoning_open = False + return "".join(parts) + + def open_message() -> str: + """Emit ``output_item.added`` + ``content_part.added`` for an in-progress message.""" + nonlocal seq, message_open + message_open = True + parts = [ + self.chunk_to_sse( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=seq, + output_index=output_index, + item=ResponseOutputMessage( + id=msg_id, + type="message", + status="in_progress", + role="assistant", + content=[], + ), + ) + ) + ] + seq += 1 + parts.append( + self.chunk_to_sse( + ResponseContentPartAddedEvent( + type="response.content_part.added", + item_id=msg_id, + sequence_number=seq, + output_index=output_index, + content_index=0, + part=ResponseOutputText(type="output_text", text="", annotations=[]), + ) + ) + ) + seq += 1 + return "".join(parts) + + def close_message() -> str: + """Emit ``output_text.done`` + ``content_part.done`` + ``output_item.done`` for the message.""" + nonlocal seq, message_open, message_item + output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[]) + message_item = ResponseOutputMessage( + id=msg_id, + type="message", + status="completed", + role="assistant", + content=[output_text_part], + annotations=[], # type: ignore[call-arg] + ) + parts = [ + self.chunk_to_sse( + ResponseTextDoneEvent( + type="response.output_text.done", + item_id=msg_id, + sequence_number=seq, + output_index=output_index, + content_index=0, + text=full_text, + logprobs=[], + ) + ) + ] + seq += 1 + parts.append( + self.chunk_to_sse( + ResponseContentPartDoneEvent( + type="response.content_part.done", + item_id=msg_id, + sequence_number=seq, + output_index=output_index, + content_index=0, + part=output_text_part, + ) + ) + ) + seq += 1 + parts.append( + self.chunk_to_sse( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=seq, + output_index=output_index, + item=message_item, + ) + ) + ) + seq += 1 + message_open = False + return "".join(parts) + while not done: text = await queue.get() - # Drain all available tokens for one batched HTTP write batch = [text] try: while True: @@ -423,26 +565,59 @@ async def event_stream() -> AsyncGenerator[str, None]: yield "".join(sse_parts) return - full_text += text - sse_parts.append( - self.chunk_to_sse( - ResponseTextDeltaEvent( - type="response.output_text.delta", - item_id=msg_id, - sequence_number=seq, - output_index=0, - content_index=0, - delta=text, - logprobs=[], + if isinstance(text, ReasoningText): + if not reasoning_open: + sse_parts.append(open_reasoning()) + full_reasoning += text + sse_parts.append( + self.chunk_to_sse( + ResponseReasoningTextDeltaEvent( + type="response.reasoning_text.delta", + item_id=reasoning_id, + sequence_number=seq, + output_index=output_index, + content_index=0, + delta=text, + ) ) ) - ) - seq += 1 + seq += 1 + else: + if reasoning_open: + sse_parts.append(close_reasoning()) + output_index += 1 + if not message_open: + sse_parts.append(open_message()) + full_text += text + sse_parts.append( + self.chunk_to_sse( + ResponseTextDeltaEvent( + type="response.output_text.delta", + item_id=msg_id, + sequence_number=seq, + output_index=output_index, + content_index=0, + delta=text, + logprobs=[], + ) + ) + ) + seq += 1 if sse_parts: yield "".join(sse_parts) - # 5. Tool calls are parsed after generation completes (not during streaming), + # Close any open reasoning section that didn't transition to content. + if reasoning_open: + yield close_reasoning() + output_index += 1 + + # Close message section (open it first if no content was emitted). + if not message_open: + yield open_message() + yield close_message() + + # 3. Tool calls are parsed after generation completes (not during streaming), # because the full token sequence is needed for reliable parsing. if tool_config: parsed = parse_tool_calls(processor, streamer.generated_token_ids, tool_config["schema"]) @@ -489,52 +664,12 @@ async def event_stream() -> AsyncGenerator[str, None]: ) seq += 1 - # 6. Close text output - output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[]) - yield self.chunk_to_sse( - ResponseTextDoneEvent( - type="response.output_text.done", - item_id=msg_id, - sequence_number=seq, - output_index=0, - content_index=0, - text=full_text, - logprobs=[], - ) - ) - seq += 1 - yield self.chunk_to_sse( - ResponseContentPartDoneEvent( - type="response.content_part.done", - item_id=msg_id, - sequence_number=seq, - output_index=0, - content_index=0, - part=output_text_part, - ) - ) - seq += 1 - - msg_item = ResponseOutputMessage( - id=msg_id, - type="message", - status="completed", - role="assistant", - content=[output_text_part], - annotations=[], # type: ignore[call-arg] - ) - yield self.chunk_to_sse( - ResponseOutputItemDoneEvent( - type="response.output_item.done", - sequence_number=seq, - output_index=0, - item=msg_item, - ) - ) - seq += 1 - - # 7. Completed - all_output = [msg_item] + list(tool_calls) + # 4. Completed + all_output = [] + if reasoning_item is not None: + all_output.append(reasoning_item) + all_output.append(message_item) + all_output.extend(tool_calls) usage = compute_usage(input_len, streamer.total_tokens) yield self.chunk_to_sse( ResponseCompletedEvent( @@ -565,13 +700,28 @@ async def _non_streaming( gen_config: "GenerationConfig", gen_manager: BaseGenerateManager, tool_config: dict | None = None, + reasoning_config: dict | None = None, ) -> JSONResponse: """Generate a non-streaming Responses API reply (single JSON).""" full_text, input_len, generated_ids = await gen_manager.generate_non_streaming( model, processor, inputs, gen_config, request_id=request_id ) - output_items = [ + output_items = [] + if reasoning_config is not None: + full_text, reasoning_content = parse_reasoning(processor, generated_ids, full_text, reasoning_config) + if reasoning_content is not None: + output_items.append( + ResponseReasoningItem( + id=f"rs_{request_id}", + type="reasoning", + summary=[], + content=[{"type": "reasoning_text", "text": reasoning_content}], + status="completed", + ) + ) + + output_items.append( ResponseOutputMessage( id=f"msg_{request_id}", type="message", @@ -580,7 +730,7 @@ async def _non_streaming( content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])], annotations=[], # type: ignore[call-arg] ) - ] + ) if tool_config is not None: parsed = parse_tool_calls(processor, generated_ids, tool_config["schema"]) diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index d786a828fc28..d8d1178e3a8c 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -19,6 +19,7 @@ import copy import enum import json +import re import threading from abc import ABC, abstractmethod from collections.abc import Callable @@ -73,6 +74,14 @@ class _GenerationCancelled(Exception): """Raised inside ``DirectStreamer.put()`` to abort ``model.generate()``.""" +class ReasoningText(str): + """Tagged str subclass: text chunk belonging to a thinking/reasoning block. + + Streamers wrap reasoning text with this so handlers can route it to + ``reasoning_content`` deltas instead of ``content``. + """ + + # Fallback tool call configs for models that don't declare stc_token/etc_token/response_schema # on their tokenizer. # Keys are matched via substring against model_type (e.g. "qwen" matches "qwen2", "qwen3_vl", etc.). @@ -156,6 +165,109 @@ def parse_tool_calls(processor, generated_ids, schema: dict) -> list[dict] | Non return tool_calls if tool_calls else None +# Default start/end tokens + schema. The opening token is optional so prefilled +# ```` prompts still match. +_DEFAULT_THINKING_TOKENS = { + "start": [""], + "end": "", + "schema": { + "type": "object", + "properties": { + "thinking": {"type": "string"}, + "content": {"type": "string"}, + }, + "x-regex": r"\s*(?:)?(?P.*?)\s*(?P.*)", + }, +} +# Streaming-side token IDs for families whose ``response_schema`` uses non-default +# start/end tokens. Post-hoc parsing uses the schema; this only feeds the +# streamer's token-level detector. +_THINKING_TOKENS = { + "gemma4": {"start": ["<|channel>", "thought"], "end": ""}, +} + + +def get_reasoning_config(processor, model: "PreTrainedModel", input_ids=None) -> dict | None: + """Return reasoning config for the model, or ``None`` if not supported. + + The config drives both streaming detection (token IDs) and post-hoc parsing + (response schema). Returns a dict with: + - ``start_ids`` (`list[int]`): Token ID sequence that opens a thinking block. + - ``end_id`` (`int`): Token ID that closes the block. + - ``schema`` (`dict`): Response schema with ``thinking`` / ``content`` + properties for :func:`parse_reasoning`. + - ``start_in_thinking`` (`bool`, only when ``input_ids`` is given): Whether + the rendered prompt already opened an unclosed thinking block (prefilled + by the template), so the model's output begins inside the block. + """ + tokenizer = getattr(processor, "tokenizer", processor) + model_type = model.config.model_type.lower() + thinking_tokens = next( + (v for k, v in _THINKING_TOKENS.items() if k in model_type), + _DEFAULT_THINKING_TOKENS, + ) + start_ids = [tokenizer.convert_tokens_to_ids(t) for t in thinking_tokens["start"]] + end_id = tokenizer.convert_tokens_to_ids(thinking_tokens["end"]) + if any(tid in (None, tokenizer.unk_token_id) for tid in start_ids) or end_id in (None, tokenizer.unk_token_id): + return None + # Custom-token families (e.g. Gemma 4) provide their schema via the tokenizer; + # default ```` falls back to the schema baked into ``_DEFAULT_THINKING_TOKENS``. + schema = getattr(tokenizer, "response_schema", None) + if not (schema and "thinking" in schema["properties"]): + schema = _DEFAULT_THINKING_TOKENS["schema"] + config: dict = {"start_ids": start_ids, "end_id": end_id, "schema": schema} + if input_ids is not None: + config["start_in_thinking"] = _starts_in_thinking(input_ids, start_ids) + return config + + +def parse_reasoning(processor, generated_ids, content: str, reasoning_config: dict) -> tuple[str, str | None]: + """Split generated output into ``(content, reasoning_content)`` via ``parse_response``. + + If the schema's regex matches (closing marker present), use it. For prompts + that prefill the opener (QwQ-32B, DeepSeek-R1) the entire output is reasoning + until ```` arrives — when that's truncated, fall back to treating + all decoded text as reasoning. Returns ``(content, None)`` otherwise. + """ + parsed = processor.parse_response(generated_ids, reasoning_config["schema"]) + if parsed: + reasoning = parsed.get("thinking", "").strip() + if reasoning: + return parsed.get("content", ""), reasoning + # Prefilled opener (QwQ-32B, DeepSeek-R1) truncated before ```` — + # no anchor for the schema regex; treat all output as reasoning. + if reasoning_config.get("start_in_thinking"): + return "", content.strip() + return content, None + + +def _starts_in_thinking(input_ids, start_ids: list[int]) -> bool: + """True if the rendered prompt ends with an unclosed thinking block. + + Some reasoning-model chat templates prefill the thinking opener as the final + prompt tokens (e.g. DeepSeek-R1, QwQ-32B emit ``\\n`` at the end when + ``add_generation_prompt=True``). In those cases the model resumes *inside* + the block, so its output contains only ``...reasoninganswer`` with + no opening tag — the streamer must start with ``_inside_thinking=True``. + + The prefill always lands at the tail of the prompt (optionally followed by a + single whitespace token like ``\\n``), so we only inspect the last few tokens. + """ + if hasattr(input_ids, "tolist"): + input_ids = input_ids.tolist() + if input_ids and isinstance(input_ids[0], list): + if len(input_ids) != 1: + return False + input_ids = input_ids[0] + n = len(start_ids) + # Match start_ids at the tail, allowing up to one trailing token (e.g. "\n"). + for trailing in (0, 1): + if len(input_ids) >= n + trailing: + end = len(input_ids) - trailing + if input_ids[end - n : end] == start_ids: + return True + return False + class DownloadAggregator: """Aggregates byte-progress across multiple concurrent download tqdm bars. @@ -286,6 +398,7 @@ def __init__( queue: asyncio.Queue, skip_special_tokens: bool = True, tool_config: dict | None = None, + reasoning_config: dict | None = None, ): """ Args: @@ -297,6 +410,9 @@ def __init__( tool_config (`dict`, *optional*): Tool call config from ``get_tool_call_config``. When set, tokens between stc/etc delimiters (inclusive) are suppressed from the queue so tool call markup is never streamed to the client. + reasoning_config (`dict`, *optional*): Thinking config from ``get_reasoning_config``. + When set, tokens between start/end delimiters are wrapped as + :class:`ReasoningText` so handlers route them to ``reasoning_content``. """ from tokenizers.decoders import DecodeStream @@ -307,6 +423,10 @@ def __init__( self._stc_id = tool_config["stc_id"] if tool_config else None self._etc_id = tool_config["etc_id"] if tool_config else None self._inside_tool_call = False + self._thinking_start_ids = reasoning_config["start_ids"] if reasoning_config else None + self._thinking_end_id = reasoning_config["end_id"] if reasoning_config else None + self._inside_thinking = bool(reasoning_config and reasoning_config.get("start_in_thinking")) + self._thinking_prefix: list[int] = [] self._first = True self._cancelled = threading.Event() self.total_tokens = 0 @@ -329,9 +449,33 @@ def put(self, value: "torch.Tensor") -> None: elif token_id == self._etc_id: self._inside_tool_call = False + is_start_or_end_token = self._advance_thinking_state(token_id) + text = self._decode_stream.step(self._tokenizer, token_id) - if text is not None and not self._inside_tool_call and token_id != self._etc_id: - self._loop.call_soon_threadsafe(self._queue.put_nowait, text) + if text is None or self._inside_tool_call or token_id == self._etc_id or is_start_or_end_token: + continue + if self._inside_thinking: + text = ReasoningText(text) + self._loop.call_soon_threadsafe(self._queue.put_nowait, text) + + def _advance_thinking_state(self, token_id: int) -> bool: + """Mutate thinking state; return ``True`` if ``token_id`` is a start or end token — suppress from output.""" + if self._thinking_start_ids is None: + return False + if self._inside_thinking: + if token_id == self._thinking_end_id: + self._inside_thinking = False + return True + return False + expected = self._thinking_start_ids[len(self._thinking_prefix)] + if token_id != expected: + self._thinking_prefix = [] + return False + self._thinking_prefix.append(token_id) + if len(self._thinking_prefix) == len(self._thinking_start_ids): + self._inside_thinking = True + self._thinking_prefix = [] + return True def end(self) -> None: """Called by ``model.generate()`` when generation is complete.""" @@ -359,6 +503,7 @@ def __init__( loop: asyncio.AbstractEventLoop, queue: asyncio.Queue, tool_config: dict | None = None, + reasoning_config: dict | None = None, ): """ Args: @@ -368,6 +513,7 @@ def __init__( loop (`asyncio.AbstractEventLoop`): The event loop to push decoded text to. queue (`asyncio.Queue`): The queue that receives decoded text chunks. tool_config (`dict`, *optional*): Tool call config (see ``DirectStreamer``). + reasoning_config (`dict`, *optional*): Thinking config (see ``DirectStreamer``). """ from tokenizers.decoders import DecodeStream @@ -380,6 +526,10 @@ def __init__( self._stc_id = tool_config["stc_id"] if tool_config else None self._etc_id = tool_config["etc_id"] if tool_config else None self._inside_tool_call = False + self._thinking_start_ids = reasoning_config["start_ids"] if reasoning_config else None + self._thinking_end_id = reasoning_config["end_id"] if reasoning_config else None + self._inside_thinking = bool(reasoning_config and reasoning_config.get("start_in_thinking")) + self._thinking_prefix: list[int] = [] self._prev_len = 0 self.total_tokens = 0 self.generated_token_ids: list[int] = [] @@ -397,9 +547,33 @@ def put(self, output: "GenerationOutput") -> None: elif token_id == self._etc_id: self._inside_tool_call = False + is_start_or_end_token = self._advance_thinking_state(token_id) + text = self._decode_stream.step(self._tokenizer, token_id) - if text is not None and not self._inside_tool_call and token_id != self._etc_id: - self._queue.put_nowait(text) + if text is None or self._inside_tool_call or token_id == self._etc_id or is_start_or_end_token: + continue + if self._inside_thinking: + text = ReasoningText(text) + self._queue.put_nowait(text) + + def _advance_thinking_state(self, token_id: int) -> bool: + """Mutate thinking state; return ``True`` if ``token_id`` is a start or end token — suppress from output.""" + if self._thinking_start_ids is None: + return False + if self._inside_thinking: + if token_id == self._thinking_end_id: + self._inside_thinking = False + return True + return False + expected = self._thinking_start_ids[len(self._thinking_prefix)] + if token_id != expected: + self._thinking_prefix = [] + return False + self._thinking_prefix.append(token_id) + if len(self._thinking_prefix) == len(self._thinking_start_ids): + self._inside_thinking = True + self._thinking_prefix = [] + return True def end(self) -> None: """Signal end of stream.""" @@ -486,6 +660,7 @@ def generate_streaming( gen_config: "GenerationConfig", request_id: str, tool_config: dict | None = None, + reasoning_config: dict | None = None, ) -> tuple[asyncio.Queue, "DirectStreamer | CBStreamer"]: """Start streaming generation. @@ -497,6 +672,8 @@ def generate_streaming( request_id (`str`): Unique request identifier. tool_config (`dict`, *optional*): Tool call config from ``get_tool_call_config``. When set, tool call tokens (between stc/etc) are suppressed from output. + reasoning_config (`dict`, *optional*): Thinking config from ``get_reasoning_config``. + When set, thinking tokens are wrapped as :class:`ReasoningText`. Returns: `tuple[asyncio.Queue, DirectStreamer | CBStreamer]`: A ``(queue, streamer)`` pair @@ -545,13 +722,16 @@ def generate_streaming( gen_config: "GenerationConfig", request_id: str, tool_config: dict | None = None, + reasoning_config: dict | None = None, ) -> tuple[asyncio.Queue, DirectStreamer]: """Start streaming generation via ``model.generate()`` on the inference thread.""" loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() # ProcessorMixin exposes the fast tokenizer as .tokenizer; PreTrainedTokenizerFast is already one. rust_tokenizer = getattr(processor, "tokenizer", processor)._tokenizer # type: ignore[union-attr] - streamer = DirectStreamer(rust_tokenizer, loop, queue, tool_config=tool_config) + streamer = DirectStreamer( + rust_tokenizer, loop, queue, tool_config=tool_config, reasoning_config=reasoning_config + ) gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor} if hasattr(model, "has_talker"): gen_kwargs["generation_mode"] = "text" @@ -643,6 +823,7 @@ def generate_streaming( gen_config: "GenerationConfig", request_id: str, tool_config: dict | None = None, + reasoning_config: dict | None = None, ) -> tuple[asyncio.Queue, CBStreamer]: """Start streaming CB generation. Registers a per-request output handler.""" cb = self._cb @@ -662,7 +843,15 @@ def generate_streaming( ) # ProcessorMixin exposes the fast tokenizer as .tokenizer; PreTrainedTokenizerFast is already one. rust_tokenizer = getattr(processor, "tokenizer", processor)._tokenizer # type: ignore[union-attr] - streamer = CBStreamer(self._cb, request_id, rust_tokenizer, loop, text_queue, tool_config=tool_config) + streamer = CBStreamer( + self._cb, + request_id, + rust_tokenizer, + loop, + text_queue, + tool_config=tool_config, + reasoning_config=reasoning_config, + ) # Register a direct callback: the dispatcher calls this on the event loop with each GenerationOutput. # This decodes tokens and pushes text straight to the SSE text_queue @@ -826,9 +1015,11 @@ def __init__( self, model_manager: "ModelManager", generation_state: GenerationState, + chat_template_kwargs: dict | None = None, ): self.model_manager = model_manager self.generation_state = generation_state + self.chat_template_kwargs = chat_template_kwargs or {} def _validate_request(self, body: dict) -> None: """Validate request fields against the handler's params class and unused fields.""" diff --git a/tests/cli/test_serve.py b/tests/cli/test_serve.py index c54de16b32bd..e567295957d2 100644 --- a/tests/cli/test_serve.py +++ b/tests/cli/test_serve.py @@ -1958,6 +1958,185 @@ class TestToolCallGemma(_TestToolCallBase, unittest.TestCase): MODEL = "google/gemma-4-E2B-it" +class _TestReasoningBase: + """Base class for reasoning integration tests. Subclasses set MODEL. + + A single server is shared across all tests in a subclass via setUpClass. + """ + + MODEL: str + USER_PROMPT = "What is 17 * 23? Think briefly, then answer in one sentence." + EXPECTED_ANSWER = "391" + MAX_TOKENS = 512 + + @classmethod + def setUpClass(cls): + cls.serve, port = _start_serve() + cls.base_url = f"http://localhost:{port}" + cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused") + + @classmethod + def tearDownClass(cls): + cls.serve.kill_server() + + @staticmethod + def _reasoning_field(obj): + """Return ``reasoning_content`` from a chat message or delta (handles model_extra).""" + return getattr(obj, "reasoning_content", None) or (obj.model_extra or {}).get("reasoning_content") + + # ----- chat completions ----- + + def test_chat_non_streaming(self): + """Chat completions: non-streaming surfaces ``reasoning_content`` and strips delimiters.""" + msg = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": self.USER_PROMPT}], + stream=False, + max_tokens=self.MAX_TOKENS, + temperature=0.0, + ).choices[0].message + reasoning = self._reasoning_field(msg) + self.assertIn(self.EXPECTED_ANSWER, reasoning or "", f"answer missing from reasoning: {reasoning!r}") + self.assertIn(self.EXPECTED_ANSWER, msg.content or "", f"answer missing from content: {msg.content!r}") + self.assertNotIn("", msg.content or "") + self.assertNotIn("<|channel>", msg.content or "") + self.assertNotIn(reasoning.strip()[:30], msg.content or "") + + def test_chat_streaming(self): + """Chat completions: streaming emits ``reasoning_content`` deltas; content stays clean.""" + chunks = list( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": self.USER_PROMPT}], + stream=True, + max_tokens=self.MAX_TOKENS, + temperature=0.0, + ) + ) + reasoning_text = "".join(self._reasoning_field(c.choices[0].delta) or "" for c in chunks) + self.assertIn(self.EXPECTED_ANSWER, reasoning_text, f"answer missing from reasoning: {reasoning_text!r}") + content = "".join(c.choices[0].delta.content or "" for c in chunks) + self.assertIn(self.EXPECTED_ANSWER, content, f"answer missing from content: {content!r}") + self.assertNotIn("", content) + self.assertNotIn("<|channel>", content) + + def test_chat_multi_turn_round_trips_reasoning(self): + """Chat completions: reasoning_content from a prior turn round-trips through input.""" + first = self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": self.USER_PROMPT}], + stream=False, + max_tokens=self.MAX_TOKENS, + temperature=0.0, + ).choices[0].message + reasoning = self._reasoning_field(first) + self.assertTrue(reasoning) + second = self.client.chat.completions.create( + model=self.MODEL, + messages=[ + {"role": "user", "content": self.USER_PROMPT}, + {"role": "assistant", "content": first.content or "", "reasoning_content": reasoning}, + {"role": "user", "content": "Now multiply that result by 2."}, + ], + stream=False, + max_tokens=self.MAX_TOKENS, + temperature=0.0, + ) + self.assertIsNotNone(second.choices[0].message.content) + + # ----- responses ----- + + def test_response_non_streaming(self): + """Responses API: non-streaming includes a reasoning item before the message item.""" + resp = self.client.responses.create( + model=self.MODEL, + input=self.USER_PROMPT, + stream=False, + max_output_tokens=self.MAX_TOKENS, + temperature=0.0, + ) + types = [item.type for item in resp.output] + self.assertIn("reasoning", types, f"expected reasoning item, got types: {types}") + self.assertIn("message", types) + self.assertLess(types.index("reasoning"), types.index("message")) + reasoning_text = next(item for item in resp.output if item.type == "reasoning").content[0].text + self.assertIn(self.EXPECTED_ANSWER, reasoning_text, f"answer missing from reasoning: {reasoning_text!r}") + message_text = next(item for item in resp.output if item.type == "message").content[0].text + self.assertIn(self.EXPECTED_ANSWER, message_text, f"answer missing from message: {message_text!r}") + self.assertNotIn("", message_text) + self.assertNotIn("<|channel>", message_text) + + def test_response_streaming(self): + """Responses API: streaming emits reasoning_text events and a separate reasoning item.""" + events = list( + self.client.responses.create( + model=self.MODEL, + input=self.USER_PROMPT, + stream=True, + max_output_tokens=self.MAX_TOKENS, + temperature=0.0, + ) + ) + added = [e for e in events if e.type == "response.output_item.added"] + self.assertGreaterEqual(len(added), 2) + self.assertEqual(added[0].item.type, "reasoning") + self.assertEqual(added[1].item.type, "message") + # Coherence: concat of reasoning_text.delta events == reasoning_text.done.text, and contains the answer. + reasoning_text = "".join(e.delta for e in events if e.type == "response.reasoning_text.delta") + done = next(e for e in events if e.type == "response.reasoning_text.done") + self.assertEqual(reasoning_text, done.text) + self.assertIn(self.EXPECTED_ANSWER, reasoning_text, f"answer missing from reasoning: {reasoning_text!r}") + content = "".join(e.delta for e in events if e.type == "response.output_text.delta") + self.assertIn(self.EXPECTED_ANSWER, content, f"answer missing from content: {content!r}") + self.assertNotIn("", content) + self.assertNotIn("<|channel>", content) + + def test_response_multi_turn_round_trips_reasoning(self): + """Responses API: ``reasoning`` items echoed back as input are accepted.""" + first = self.client.responses.create( + model=self.MODEL, + input=self.USER_PROMPT, + stream=False, + max_output_tokens=self.MAX_TOKENS, + temperature=0.0, + ) + reasoning_item = next((i for i in first.output if i.type == "reasoning"), None) + message_item = next((i for i in first.output if i.type == "message"), None) + self.assertIsNotNone(reasoning_item) + self.assertIsNotNone(message_item) + second = self.client.responses.create( + model=self.MODEL, + input=[ + {"role": "user", "content": self.USER_PROMPT}, + reasoning_item.model_dump(exclude_none=True), + {"role": "assistant", "content": message_item.content[0].text}, + {"role": "user", "content": "Now multiply that result by 2."}, + ], + stream=False, + max_output_tokens=self.MAX_TOKENS, + temperature=0.0, + ) + self.assertEqual(second.status, "completed") + + +@slow +@require_serve +@require_torch_accelerator +class TestReasoningQwen(_TestReasoningBase, unittest.TestCase): + """Reasoning tests with Qwen3 (inline ... tags).""" + + MODEL = "Qwen/Qwen3-1.7B" + + +@slow +@require_serve +@require_torch_accelerator +class TestReasoningGemma(_TestReasoningBase, unittest.TestCase): + """Reasoning tests with Gemma 4 (response_schema-based thinking channel).""" + + MODEL = "google/gemma-4-E2B-it" + + @slow @require_librosa @require_multipart From 737a35881afea374bd9311760734857f5e0abeda Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 28 Apr 2026 17:07:53 +0000 Subject: [PATCH 2/2] Apply repo consistency fixes --- setup.py | 8 +++-- .../cli/serving/chat_completion.py | 2 +- src/transformers/cli/serving/utils.py | 2 +- src/transformers/dependency_versions_table.py | 1 + .../models/esm/configuration_esm.py | 4 +-- tests/cli/test_serve.py | 36 +++++++++++-------- 6 files changed, 32 insertions(+), 21 deletions(-) diff --git a/setup.py b/setup.py index 0fa835d5fb4a..42c865b1b9ba 100644 --- a/setup.py +++ b/setup.py @@ -124,7 +124,9 @@ "rjieba", "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "ruff==0.14.10", - "transformers-mlinter @ git+https://github.com/huggingface/transformers-mlinter@b9d319ce264c106f97a959d926ef42bc3c0ea4d1", + # When bumping `transformers-mlinter`, sync repo-local rule overrides from + # `utils/rules.toml` back into the released package. + "transformers-mlinter==0.1.1", "ty==0.0.20", # `sacrebleu` not used in `transformers`. However, it is needed in several tests, when a test calls # `evaluate.load("sacrebleu")`. This metric is used in the examples that we use to test the `Trainer` with, in the @@ -295,7 +297,7 @@ def finalize_options(self): pass def run(self): - if SUPPORTED_PYTHON_VERSIONS[0] >= PYTHON_MINOR_VERSION: + if SUPPORTED_PYTHON_VERSIONS[0] > PYTHON_MINOR_VERSION: print( f"Table updated only when running 3.{SUPPORTED_PYTHON_VERSIONS[0]}.x, detected version is {sys.version}." ) @@ -328,7 +330,7 @@ def run(self): setup( name="transformers", - version="5.6.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="5.7.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)", author_email="transformers@huggingface.co", description="Transformers: the model-definition framework for state-of-the-art machine learning models in text, vision, audio, and multimodal models, for both inference and training.", diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py index 7444f28c1a5c..0e6a3bb1fad6 100644 --- a/src/transformers/cli/serving/chat_completion.py +++ b/src/transformers/cli/serving/chat_completion.py @@ -323,7 +323,7 @@ async def _non_streaming( for i, tc in enumerate(parsed) ] - reasoning_content = None + reasoning_content = None if reasoning_config is not None: content, reasoning_content = parse_reasoning(processor, generated_ids, content, reasoning_config) diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index d8d1178e3a8c..33f927fffc9f 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -19,7 +19,6 @@ import copy import enum import json -import re import threading from abc import ABC, abstractmethod from collections.abc import Callable @@ -268,6 +267,7 @@ def _starts_in_thinking(input_ids, start_ids: list[int]) -> bool: return True return False + class DownloadAggregator: """Aggregates byte-progress across multiple concurrent download tqdm bars. diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index b08aa558d795..1a721ca2a82a 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -56,6 +56,7 @@ "rjieba": "rjieba", "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "ruff": "ruff==0.14.10", + "transformers-mlinter": "transformers-mlinter==0.1.1", "ty": "ty==0.0.20", "sacrebleu": "sacrebleu>=1.4.12,<2.0.0", "sacremoses": "sacremoses", diff --git a/src/transformers/models/esm/configuration_esm.py b/src/transformers/models/esm/configuration_esm.py index a00dcf8b39e3..7875d88ecee8 100644 --- a/src/transformers/models/esm/configuration_esm.py +++ b/src/transformers/models/esm/configuration_esm.py @@ -159,12 +159,12 @@ class EsmConfig(PreTrainedConfig): mask_token_id (`int`, *optional*): The index of the mask token in the vocabulary. This must be included in the config because of the "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens. + rope_theta (`float`, defaults to 10000.0): + The base period of the RoPE embeddings. Only used when `position_embedding_type` is set to `"rotary"`. position_embedding_type (`str`, *optional*, defaults to `"absolute"`): Type of position embedding. Choose either `"absolute"` or "rotary"`. emb_layer_norm_before (`bool`, *optional*): Whether to apply layer normalization after embeddings but before the main stem of the network. - rope_theta (`float`, defaults to 10000.0): - The base period of the RoPE embeddings. Only used when `position_embedding_type` is set to `"rotary"`. token_dropout (`bool`, defaults to `False`): When this is enabled, masked tokens are treated as if they had been dropped out by input dropout. is_folding_model (`bool`, defaults to `False`): diff --git a/tests/cli/test_serve.py b/tests/cli/test_serve.py index e567295957d2..926a3ed46046 100644 --- a/tests/cli/test_serve.py +++ b/tests/cli/test_serve.py @@ -1988,13 +1988,17 @@ def _reasoning_field(obj): def test_chat_non_streaming(self): """Chat completions: non-streaming surfaces ``reasoning_content`` and strips delimiters.""" - msg = self.client.chat.completions.create( - model=self.MODEL, - messages=[{"role": "user", "content": self.USER_PROMPT}], - stream=False, - max_tokens=self.MAX_TOKENS, - temperature=0.0, - ).choices[0].message + msg = ( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": self.USER_PROMPT}], + stream=False, + max_tokens=self.MAX_TOKENS, + temperature=0.0, + ) + .choices[0] + .message + ) reasoning = self._reasoning_field(msg) self.assertIn(self.EXPECTED_ANSWER, reasoning or "", f"answer missing from reasoning: {reasoning!r}") self.assertIn(self.EXPECTED_ANSWER, msg.content or "", f"answer missing from content: {msg.content!r}") @@ -2022,13 +2026,17 @@ def test_chat_streaming(self): def test_chat_multi_turn_round_trips_reasoning(self): """Chat completions: reasoning_content from a prior turn round-trips through input.""" - first = self.client.chat.completions.create( - model=self.MODEL, - messages=[{"role": "user", "content": self.USER_PROMPT}], - stream=False, - max_tokens=self.MAX_TOKENS, - temperature=0.0, - ).choices[0].message + first = ( + self.client.chat.completions.create( + model=self.MODEL, + messages=[{"role": "user", "content": self.USER_PROMPT}], + stream=False, + max_tokens=self.MAX_TOKENS, + temperature=0.0, + ) + .choices[0] + .message + ) reasoning = self._reasoning_field(first) self.assertTrue(reasoning) second = self.client.chat.completions.create(