diff --git a/setup.py b/setup.py
index 0fa835d5fb4a..42c865b1b9ba 100644
--- a/setup.py
+++ b/setup.py
@@ -124,7 +124,9 @@
"rjieba",
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff==0.14.10",
- "transformers-mlinter @ git+https://github.com/huggingface/transformers-mlinter@b9d319ce264c106f97a959d926ef42bc3c0ea4d1",
+ # When bumping `transformers-mlinter`, sync repo-local rule overrides from
+ # `utils/rules.toml` back into the released package.
+ "transformers-mlinter==0.1.1",
"ty==0.0.20",
# `sacrebleu` not used in `transformers`. However, it is needed in several tests, when a test calls
# `evaluate.load("sacrebleu")`. This metric is used in the examples that we use to test the `Trainer` with, in the
@@ -295,7 +297,7 @@ def finalize_options(self):
pass
def run(self):
- if SUPPORTED_PYTHON_VERSIONS[0] >= PYTHON_MINOR_VERSION:
+ if SUPPORTED_PYTHON_VERSIONS[0] > PYTHON_MINOR_VERSION:
print(
f"Table updated only when running 3.{SUPPORTED_PYTHON_VERSIONS[0]}.x, detected version is {sys.version}."
)
@@ -328,7 +330,7 @@ def run(self):
setup(
name="transformers",
- version="5.6.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="5.7.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="Transformers: the model-definition framework for state-of-the-art machine learning models in text, vision, audio, and multimodal models, for both inference and training.",
diff --git a/src/transformers/cli/serve.py b/src/transformers/cli/serve.py
index 4d9e6c712e7c..14f8901339c1 100644
--- a/src/transformers/cli/serve.py
+++ b/src/transformers/cli/serve.py
@@ -16,6 +16,8 @@
"""
import asyncio
+import enum
+import json
import threading
from typing import Annotated
@@ -30,6 +32,12 @@
logger = logging.get_logger(__name__)
+class ReasoningMode(str, enum.Enum):
+ ON = "on"
+ OFF = "off"
+ AUTO = "auto"
+
+
class Serve:
def __init__(
self,
@@ -39,6 +47,32 @@ def __init__(
bool,
typer.Option(help="Enable continuous batching with paged attention. Configure with --cb-* flags."),
] = False,
+ attn_implementation: Annotated[
+ str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).")
+ ] = None,
+ compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False,
+ quantization: Annotated[
+ str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.")
+ ] = None,
+ reasoning: Annotated[
+ ReasoningMode, typer.Option(help="Reasoning mode. 'auto' uses the chat template default.")
+ ] = ReasoningMode.AUTO,
+ chat_template_kwargs: Annotated[
+ str | None,
+ typer.Option(
+ help=(
+ "Default JSON kwargs forwarded to apply_chat_template "
+ "(e.g. '{\"enable_thinking\": true}'); per-request chat_template_kwargs override these."
+ )
+ ),
+ ] = None,
+ device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto",
+ dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto",
+ trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False,
+ model_timeout: Annotated[
+ int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.")
+ ] = 300,
+ # Continuous batching tuning
cb_block_size: Annotated[
int | None, typer.Option(help="KV cache block size in tokens for continuous batching.")
] = None,
@@ -54,19 +88,6 @@ def __init__(
cb_use_cuda_graph: Annotated[
bool | None, typer.Option(help="Enable CUDA graphs for continuous batching.")
] = None,
- attn_implementation: Annotated[
- str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).")
- ] = None,
- compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False,
- quantization: Annotated[
- str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.")
- ] = None,
- device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto",
- dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto",
- trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False,
- model_timeout: Annotated[
- int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.")
- ] = 300,
# Server options
host: Annotated[str, typer.Option(help="Server listen address.")] = "localhost",
port: Annotated[int, typer.Option(help="Server listen port.")] = 8000,
@@ -126,14 +147,28 @@ def __init__(
cb_config=cb_config,
)
+ if chat_template_kwargs:
+ chat_template_kwargs = json.loads(chat_template_kwargs)
+ if not isinstance(chat_template_kwargs, dict):
+ raise typer.BadParameter("--chat-template-kwargs must be a JSON object")
+ else:
+ chat_template_kwargs = {}
+
+ if reasoning == ReasoningMode.ON:
+ chat_template_kwargs["enable_thinking"] = True
+ elif reasoning == ReasoningMode.OFF:
+ chat_template_kwargs["enable_thinking"] = False
+
self._chat_handler = ChatCompletionHandler(
model_manager=self._model_manager,
generation_state=self._generation_state,
+ chat_template_kwargs=chat_template_kwargs,
)
self._response_handler = ResponseHandler(
model_manager=self._model_manager,
generation_state=self._generation_state,
+ chat_template_kwargs=chat_template_kwargs,
)
self._transcription_handler = TranscriptionHandler(self._model_manager, self._generation_state)
diff --git a/src/transformers/cli/serving/chat_completion.py b/src/transformers/cli/serving/chat_completion.py
index 161a25a02f41..0e6a3bb1fad6 100644
--- a/src/transformers/cli/serving/chat_completion.py
+++ b/src/transformers/cli/serving/chat_completion.py
@@ -40,8 +40,11 @@
BaseGenerateManager,
BaseHandler,
Modality,
+ ReasoningText,
_StreamError,
+ get_reasoning_config,
get_tool_call_config,
+ parse_reasoning,
parse_tool_calls,
)
@@ -53,6 +56,7 @@
class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False):
generation_config: str
seed: int
+ chat_template_kwargs: dict
# Fields accepted by the OpenAI schema but not yet supported.
@@ -118,10 +122,13 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
for msg in processor_inputs
for c in (msg.get("content") if isinstance(msg.get("content"), list) else [])
)
- # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise
- chat_template_kwargs = {}
+ # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise.
+ # Merge order (later wins): custom default -> server default → request-level kwargs.
+ chat_template_kwargs: dict = {}
if has_video:
chat_template_kwargs["num_frames"] = 32
+ chat_template_kwargs.update(self.chat_template_kwargs)
+ chat_template_kwargs.update(body.get("chat_template_kwargs", {}))
inputs = processor.apply_chat_template(
processor_inputs,
add_generation_prompt=True,
@@ -141,6 +148,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
gen_manager.init_cb(model, gen_config)
tool_config = get_tool_call_config(processor, model) if body.get("tools") else None
+ reasoning_config = get_reasoning_config(processor, model, inputs["input_ids"])
streaming = body.get("stream")
if streaming:
@@ -153,6 +161,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
gen_config,
gen_manager=gen_manager,
tool_config=tool_config,
+ reasoning_config=reasoning_config,
)
else:
return await self._non_streaming(
@@ -164,6 +173,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
gen_config,
gen_manager=gen_manager,
tool_config=tool_config,
+ reasoning_config=reasoning_config,
)
# ----- streaming -----
@@ -178,6 +188,7 @@ def _streaming(
gen_config: "GenerationConfig",
gen_manager: BaseGenerateManager,
tool_config: dict | None = None,
+ reasoning_config: dict | None = None,
) -> StreamingResponse:
"""Stream tokens as SSE via DirectStreamer."""
queue, streamer = gen_manager.generate_streaming(
@@ -187,6 +198,7 @@ def _streaming(
gen_config,
request_id=request_id,
tool_config=tool_config,
+ reasoning_config=reasoning_config,
)
input_ids = inputs["input_ids"]
# CB returns plain lists, regular path returns tensors
@@ -216,7 +228,10 @@ async def sse_gen() -> AsyncGenerator[str, None]:
yield "".join(sse_parts)
return
- sse_parts.append(self._build_chunk_sse(request_id, model=model_id, content=text))
+ if isinstance(text, ReasoningText):
+ sse_parts.append(self._build_chunk_sse(request_id, model=model_id, reasoning_content=text))
+ else:
+ sse_parts.append(self._build_chunk_sse(request_id, model=model_id, content=text))
if sse_parts:
yield "".join(sse_parts)
@@ -280,6 +295,7 @@ async def _non_streaming(
gen_config: "GenerationConfig",
gen_manager: BaseGenerateManager,
tool_config: dict | None = None,
+ reasoning_config: dict | None = None,
) -> JSONResponse:
"""Run generation and return a JSONResponse."""
content, input_len, generated_ids = await gen_manager.generate_non_streaming(
@@ -307,6 +323,10 @@ async def _non_streaming(
for i, tc in enumerate(parsed)
]
+ reasoning_content = None
+ if reasoning_config is not None:
+ content, reasoning_content = parse_reasoning(processor, generated_ids, content, reasoning_config)
+
if tool_calls is not None:
finish_reason = "tool_calls"
elif hit_max:
@@ -322,6 +342,7 @@ async def _non_streaming(
finish_reason=finish_reason,
usage=usage,
tool_calls=tool_calls,
+ reasoning_content=reasoning_content,
),
media_type="application/json",
)
@@ -354,6 +375,7 @@ def _build_completion(
finish_reason: str,
usage: CompletionUsage | None = None,
tool_calls: list[dict] | None = None,
+ reasoning_content: str | None = None,
) -> dict:
"""Build a non-streaming ChatCompletion response dict.
@@ -364,11 +386,14 @@ def _build_completion(
finish_reason (`str`): Why generation stopped (``"stop"``, ``"length"``, ``"tool_calls"``).
usage (`CompletionUsage`, *optional*): Token usage statistics.
tool_calls (`list[dict]`, *optional*): Parsed tool calls, if any.
+ reasoning_content (`str`, *optional*): Chain-of-thought content extracted from the response.
Returns:
`dict`: Serialized ``ChatCompletion`` ready for JSON response.
"""
- message = ChatCompletionMessage(content=content, role="assistant", tool_calls=tool_calls)
+ message = ChatCompletionMessage(
+ content=content, role="assistant", tool_calls=tool_calls, reasoning_content=reasoning_content
+ )
result = ChatCompletion(
id=request_id,
created=int(time.time()),
@@ -394,6 +419,7 @@ def _build_chunk_sse(
finish_reason: str | None = None,
tool_calls: list | None = None,
usage: CompletionUsage | None = None,
+ reasoning_content: str | None = None,
) -> str:
"""Build a streaming ``ChatCompletionChunk`` and format it as an SSE ``data:`` line.
@@ -405,6 +431,7 @@ def _build_chunk_sse(
finish_reason (`str`, *optional*): Set on the final chunk.
tool_calls (`list`, *optional*): Tool call deltas.
usage (`CompletionUsage`, *optional*): Token usage (sent with the final chunk).
+ reasoning_content (`str`, *optional*): Reasoning/thinking delta (OpenAI-compatible extension).
Returns:
`str`: A formatted SSE event string.
@@ -415,7 +442,9 @@ def _build_chunk_sse(
model=model,
choices=[
ChoiceChunk(
- delta=ChoiceDelta(content=content, role=role, tool_calls=tool_calls),
+ delta=ChoiceDelta(
+ content=content, role=role, tool_calls=tool_calls, reasoning_content=reasoning_content
+ ),
index=0,
finish_reason=finish_reason,
)
diff --git a/src/transformers/cli/serving/response.py b/src/transformers/cli/serving/response.py
index 4d29dfd1d6a2..002391c2479a 100644
--- a/src/transformers/cli/serving/response.py
+++ b/src/transformers/cli/serving/response.py
@@ -45,6 +45,9 @@
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
+ ResponseReasoningItem,
+ ResponseReasoningTextDeltaEvent,
+ ResponseReasoningTextDoneEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
)
@@ -56,8 +59,11 @@
BaseGenerateManager,
BaseHandler,
Modality,
+ ReasoningText,
_StreamError,
+ get_reasoning_config,
get_tool_call_config,
+ parse_reasoning,
parse_tool_calls,
)
@@ -80,7 +86,6 @@ class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, t
"max_tool_calls",
"previous_response_id",
"prompt",
- "reasoning",
"service_tier",
"store",
"text",
@@ -127,10 +132,13 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
for c in (msg.get("content") if isinstance(msg.get("content"), list) else [])
)
- # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise
- chat_template_kwargs = {}
+ # Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise.
+ # Merge order (later wins): custom default -> server default → request-level kwargs.
+ chat_template_kwargs: dict = {}
if has_video:
chat_template_kwargs["num_frames"] = 32
+ chat_template_kwargs.update(self.chat_template_kwargs)
+ chat_template_kwargs.update(body.get("chat_template_kwargs") or {})
# updates the flat tool structure to the one expected by the `apply_chat_template` method.
tools = self._normalize_tools(body.get("tools"))
inputs = processor.apply_chat_template(
@@ -151,6 +159,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
if use_cb:
gen_manager.init_cb(model, gen_config)
tool_config = get_tool_call_config(processor, model) if body.get("tools") else None
+ reasoning_config = get_reasoning_config(processor, model, inputs["input_ids"])
streaming = body.get("stream", True)
if streaming:
@@ -164,6 +173,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
gen_config,
gen_manager=gen_manager,
tool_config=tool_config,
+ reasoning_config=reasoning_config,
)
else:
return await self._non_streaming(
@@ -176,6 +186,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
gen_config,
gen_manager=gen_manager,
tool_config=tool_config,
+ reasoning_config=reasoning_config,
)
# ----- input conversion -----
@@ -247,16 +258,26 @@ def _normalize_response_items(items: list[dict]) -> list[dict]:
Input items may be a mix of:
- Messages (``EasyInputMessageParam`` with ``role``, or ``type: "message"``).
+ - ``reasoning`` — buffered and attached as ``reasoning_content`` to the next assistant message.
- ``function_call`` — merged as ``tool_calls`` onto the preceding assistant message.
- ``function_call_output`` — converted to ``role: "tool"`` messages.
"""
messages = []
+ pending_reasoning: str | None = None
for item in items:
item_type = item.get("type")
+ if item_type == "reasoning":
+ pending_reasoning = "".join(c["text"] for c in item.get("content") or [])
+ continue
+
if "role" in item:
- messages.append({"role": item["role"], "content": item.get("content", "")})
+ msg = {"role": item["role"], "content": item.get("content", "")}
+ if pending_reasoning is not None and item["role"] == "assistant":
+ msg["reasoning_content"] = pending_reasoning
+ pending_reasoning = None
+ messages.append(msg)
elif item_type == "function_call":
tc = {
@@ -295,6 +316,7 @@ def _streaming(
gen_config: "GenerationConfig",
gen_manager: BaseGenerateManager,
tool_config: dict | None = None,
+ reasoning_config: dict | None = None,
) -> StreamingResponse:
"""Generate a streaming Responses API reply (SSE) using DirectStreamer."""
queue, streamer = gen_manager.generate_streaming(
@@ -304,16 +326,17 @@ def _streaming(
gen_config,
request_id=request_id,
tool_config=tool_config,
+ reasoning_config=reasoning_config,
)
input_ids = inputs["input_ids"]
# CB returns plain lists, regular path returns tensors
input_len = len(input_ids) if isinstance(input_ids, list) else input_ids.shape[-1]
seq = 0
- output_index = 0
created_at = time.time()
resp_id = f"resp_{request_id}"
msg_id = f"msg_{request_id}"
+ reasoning_id = f"rs_{request_id}"
response_base = {
"id": resp_id,
@@ -327,7 +350,7 @@ def _streaming(
}
async def event_stream() -> AsyncGenerator[str, None]:
- nonlocal seq, output_index
+ nonlocal seq
try:
# 1. Created + In progress
@@ -348,44 +371,163 @@ async def event_stream() -> AsyncGenerator[str, None]:
)
seq += 1
- # 2. Output item added (message)
- yield self.chunk_to_sse(
- ResponseOutputItemAddedEvent(
- type="response.output_item.added",
- sequence_number=seq,
- output_index=output_index,
- item=ResponseOutputMessage(
- id=msg_id,
- type="message",
- status="in_progress",
- role="assistant",
- content=[],
- ),
- )
- )
- seq += 1
-
- # 3. Content part added
- yield self.chunk_to_sse(
- ResponseContentPartAddedEvent(
- type="response.content_part.added",
- item_id=msg_id,
- sequence_number=seq,
- output_index=output_index,
- content_index=0,
- part=ResponseOutputText(type="output_text", text="", annotations=[]),
- )
- )
- seq += 1
-
- # 4. Stream tokens — drain queue to batch HTTP writes
+ # 2. Stream tokens — items are opened lazily so reasoning (if any)
+ # appears as a separate output item before the message item.
full_text = ""
+ full_reasoning = ""
tool_calls = []
+ output_index = 0
+ reasoning_open = False
+ message_open = False
+ reasoning_item = None
+ message_item = None
done = False
+ def open_reasoning() -> str:
+ """Emit ``output_item.added`` for an in-progress reasoning item."""
+ nonlocal seq, reasoning_open
+ reasoning_open = True
+ sse = self.chunk_to_sse(
+ ResponseOutputItemAddedEvent(
+ type="response.output_item.added",
+ sequence_number=seq,
+ output_index=output_index,
+ item=ResponseReasoningItem(
+ id=reasoning_id, type="reasoning", summary=[], content=[], status="in_progress"
+ ),
+ )
+ )
+ seq += 1
+ return sse
+
+ def close_reasoning() -> str:
+ """Emit ``reasoning_text.done`` + ``output_item.done`` for the completed reasoning item."""
+ nonlocal seq, reasoning_open, reasoning_item
+ reasoning_item = ResponseReasoningItem(
+ id=reasoning_id,
+ type="reasoning",
+ summary=[],
+ content=[{"type": "reasoning_text", "text": full_reasoning}],
+ status="completed",
+ )
+ parts = [
+ self.chunk_to_sse(
+ ResponseReasoningTextDoneEvent(
+ type="response.reasoning_text.done",
+ item_id=reasoning_id,
+ sequence_number=seq,
+ output_index=output_index,
+ content_index=0,
+ text=full_reasoning,
+ )
+ )
+ ]
+ seq += 1
+ parts.append(
+ self.chunk_to_sse(
+ ResponseOutputItemDoneEvent(
+ type="response.output_item.done",
+ sequence_number=seq,
+ output_index=output_index,
+ item=reasoning_item,
+ )
+ )
+ )
+ seq += 1
+ reasoning_open = False
+ return "".join(parts)
+
+ def open_message() -> str:
+ """Emit ``output_item.added`` + ``content_part.added`` for an in-progress message."""
+ nonlocal seq, message_open
+ message_open = True
+ parts = [
+ self.chunk_to_sse(
+ ResponseOutputItemAddedEvent(
+ type="response.output_item.added",
+ sequence_number=seq,
+ output_index=output_index,
+ item=ResponseOutputMessage(
+ id=msg_id,
+ type="message",
+ status="in_progress",
+ role="assistant",
+ content=[],
+ ),
+ )
+ )
+ ]
+ seq += 1
+ parts.append(
+ self.chunk_to_sse(
+ ResponseContentPartAddedEvent(
+ type="response.content_part.added",
+ item_id=msg_id,
+ sequence_number=seq,
+ output_index=output_index,
+ content_index=0,
+ part=ResponseOutputText(type="output_text", text="", annotations=[]),
+ )
+ )
+ )
+ seq += 1
+ return "".join(parts)
+
+ def close_message() -> str:
+ """Emit ``output_text.done`` + ``content_part.done`` + ``output_item.done`` for the message."""
+ nonlocal seq, message_open, message_item
+ output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[])
+ message_item = ResponseOutputMessage(
+ id=msg_id,
+ type="message",
+ status="completed",
+ role="assistant",
+ content=[output_text_part],
+ annotations=[], # type: ignore[call-arg]
+ )
+ parts = [
+ self.chunk_to_sse(
+ ResponseTextDoneEvent(
+ type="response.output_text.done",
+ item_id=msg_id,
+ sequence_number=seq,
+ output_index=output_index,
+ content_index=0,
+ text=full_text,
+ logprobs=[],
+ )
+ )
+ ]
+ seq += 1
+ parts.append(
+ self.chunk_to_sse(
+ ResponseContentPartDoneEvent(
+ type="response.content_part.done",
+ item_id=msg_id,
+ sequence_number=seq,
+ output_index=output_index,
+ content_index=0,
+ part=output_text_part,
+ )
+ )
+ )
+ seq += 1
+ parts.append(
+ self.chunk_to_sse(
+ ResponseOutputItemDoneEvent(
+ type="response.output_item.done",
+ sequence_number=seq,
+ output_index=output_index,
+ item=message_item,
+ )
+ )
+ )
+ seq += 1
+ message_open = False
+ return "".join(parts)
+
while not done:
text = await queue.get()
- # Drain all available tokens for one batched HTTP write
batch = [text]
try:
while True:
@@ -423,26 +565,59 @@ async def event_stream() -> AsyncGenerator[str, None]:
yield "".join(sse_parts)
return
- full_text += text
- sse_parts.append(
- self.chunk_to_sse(
- ResponseTextDeltaEvent(
- type="response.output_text.delta",
- item_id=msg_id,
- sequence_number=seq,
- output_index=0,
- content_index=0,
- delta=text,
- logprobs=[],
+ if isinstance(text, ReasoningText):
+ if not reasoning_open:
+ sse_parts.append(open_reasoning())
+ full_reasoning += text
+ sse_parts.append(
+ self.chunk_to_sse(
+ ResponseReasoningTextDeltaEvent(
+ type="response.reasoning_text.delta",
+ item_id=reasoning_id,
+ sequence_number=seq,
+ output_index=output_index,
+ content_index=0,
+ delta=text,
+ )
)
)
- )
- seq += 1
+ seq += 1
+ else:
+ if reasoning_open:
+ sse_parts.append(close_reasoning())
+ output_index += 1
+ if not message_open:
+ sse_parts.append(open_message())
+ full_text += text
+ sse_parts.append(
+ self.chunk_to_sse(
+ ResponseTextDeltaEvent(
+ type="response.output_text.delta",
+ item_id=msg_id,
+ sequence_number=seq,
+ output_index=output_index,
+ content_index=0,
+ delta=text,
+ logprobs=[],
+ )
+ )
+ )
+ seq += 1
if sse_parts:
yield "".join(sse_parts)
- # 5. Tool calls are parsed after generation completes (not during streaming),
+ # Close any open reasoning section that didn't transition to content.
+ if reasoning_open:
+ yield close_reasoning()
+ output_index += 1
+
+ # Close message section (open it first if no content was emitted).
+ if not message_open:
+ yield open_message()
+ yield close_message()
+
+ # 3. Tool calls are parsed after generation completes (not during streaming),
# because the full token sequence is needed for reliable parsing.
if tool_config:
parsed = parse_tool_calls(processor, streamer.generated_token_ids, tool_config["schema"])
@@ -489,52 +664,12 @@ async def event_stream() -> AsyncGenerator[str, None]:
)
seq += 1
- # 6. Close text output
- output_text_part = ResponseOutputText(type="output_text", text=full_text, annotations=[])
- yield self.chunk_to_sse(
- ResponseTextDoneEvent(
- type="response.output_text.done",
- item_id=msg_id,
- sequence_number=seq,
- output_index=0,
- content_index=0,
- text=full_text,
- logprobs=[],
- )
- )
- seq += 1
- yield self.chunk_to_sse(
- ResponseContentPartDoneEvent(
- type="response.content_part.done",
- item_id=msg_id,
- sequence_number=seq,
- output_index=0,
- content_index=0,
- part=output_text_part,
- )
- )
- seq += 1
-
- msg_item = ResponseOutputMessage(
- id=msg_id,
- type="message",
- status="completed",
- role="assistant",
- content=[output_text_part],
- annotations=[], # type: ignore[call-arg]
- )
- yield self.chunk_to_sse(
- ResponseOutputItemDoneEvent(
- type="response.output_item.done",
- sequence_number=seq,
- output_index=0,
- item=msg_item,
- )
- )
- seq += 1
-
- # 7. Completed
- all_output = [msg_item] + list(tool_calls)
+ # 4. Completed
+ all_output = []
+ if reasoning_item is not None:
+ all_output.append(reasoning_item)
+ all_output.append(message_item)
+ all_output.extend(tool_calls)
usage = compute_usage(input_len, streamer.total_tokens)
yield self.chunk_to_sse(
ResponseCompletedEvent(
@@ -565,13 +700,28 @@ async def _non_streaming(
gen_config: "GenerationConfig",
gen_manager: BaseGenerateManager,
tool_config: dict | None = None,
+ reasoning_config: dict | None = None,
) -> JSONResponse:
"""Generate a non-streaming Responses API reply (single JSON)."""
full_text, input_len, generated_ids = await gen_manager.generate_non_streaming(
model, processor, inputs, gen_config, request_id=request_id
)
- output_items = [
+ output_items = []
+ if reasoning_config is not None:
+ full_text, reasoning_content = parse_reasoning(processor, generated_ids, full_text, reasoning_config)
+ if reasoning_content is not None:
+ output_items.append(
+ ResponseReasoningItem(
+ id=f"rs_{request_id}",
+ type="reasoning",
+ summary=[],
+ content=[{"type": "reasoning_text", "text": reasoning_content}],
+ status="completed",
+ )
+ )
+
+ output_items.append(
ResponseOutputMessage(
id=f"msg_{request_id}",
type="message",
@@ -580,7 +730,7 @@ async def _non_streaming(
content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])],
annotations=[], # type: ignore[call-arg]
)
- ]
+ )
if tool_config is not None:
parsed = parse_tool_calls(processor, generated_ids, tool_config["schema"])
diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py
index d786a828fc28..33f927fffc9f 100644
--- a/src/transformers/cli/serving/utils.py
+++ b/src/transformers/cli/serving/utils.py
@@ -73,6 +73,14 @@ class _GenerationCancelled(Exception):
"""Raised inside ``DirectStreamer.put()`` to abort ``model.generate()``."""
+class ReasoningText(str):
+ """Tagged str subclass: text chunk belonging to a thinking/reasoning block.
+
+ Streamers wrap reasoning text with this so handlers can route it to
+ ``reasoning_content`` deltas instead of ``content``.
+ """
+
+
# Fallback tool call configs for models that don't declare stc_token/etc_token/response_schema
# on their tokenizer.
# Keys are matched via substring against model_type (e.g. "qwen" matches "qwen2", "qwen3_vl", etc.).
@@ -156,6 +164,110 @@ def parse_tool_calls(processor, generated_ids, schema: dict) -> list[dict] | Non
return tool_calls if tool_calls else None
+# Default start/end tokens + schema. The opening token is optional so prefilled
+# ```` prompts still match.
+_DEFAULT_THINKING_TOKENS = {
+ "start": [""],
+ "end": "",
+ "schema": {
+ "type": "object",
+ "properties": {
+ "thinking": {"type": "string"},
+ "content": {"type": "string"},
+ },
+ "x-regex": r"\s*(?:)?(?P.*?)\s*(?P.*)",
+ },
+}
+# Streaming-side token IDs for families whose ``response_schema`` uses non-default
+# start/end tokens. Post-hoc parsing uses the schema; this only feeds the
+# streamer's token-level detector.
+_THINKING_TOKENS = {
+ "gemma4": {"start": ["<|channel>", "thought"], "end": ""},
+}
+
+
+def get_reasoning_config(processor, model: "PreTrainedModel", input_ids=None) -> dict | None:
+ """Return reasoning config for the model, or ``None`` if not supported.
+
+ The config drives both streaming detection (token IDs) and post-hoc parsing
+ (response schema). Returns a dict with:
+ - ``start_ids`` (`list[int]`): Token ID sequence that opens a thinking block.
+ - ``end_id`` (`int`): Token ID that closes the block.
+ - ``schema`` (`dict`): Response schema with ``thinking`` / ``content``
+ properties for :func:`parse_reasoning`.
+ - ``start_in_thinking`` (`bool`, only when ``input_ids`` is given): Whether
+ the rendered prompt already opened an unclosed thinking block (prefilled
+ by the template), so the model's output begins inside the block.
+ """
+ tokenizer = getattr(processor, "tokenizer", processor)
+ model_type = model.config.model_type.lower()
+ thinking_tokens = next(
+ (v for k, v in _THINKING_TOKENS.items() if k in model_type),
+ _DEFAULT_THINKING_TOKENS,
+ )
+ start_ids = [tokenizer.convert_tokens_to_ids(t) for t in thinking_tokens["start"]]
+ end_id = tokenizer.convert_tokens_to_ids(thinking_tokens["end"])
+ if any(tid in (None, tokenizer.unk_token_id) for tid in start_ids) or end_id in (None, tokenizer.unk_token_id):
+ return None
+ # Custom-token families (e.g. Gemma 4) provide their schema via the tokenizer;
+ # default ```` falls back to the schema baked into ``_DEFAULT_THINKING_TOKENS``.
+ schema = getattr(tokenizer, "response_schema", None)
+ if not (schema and "thinking" in schema["properties"]):
+ schema = _DEFAULT_THINKING_TOKENS["schema"]
+ config: dict = {"start_ids": start_ids, "end_id": end_id, "schema": schema}
+ if input_ids is not None:
+ config["start_in_thinking"] = _starts_in_thinking(input_ids, start_ids)
+ return config
+
+
+def parse_reasoning(processor, generated_ids, content: str, reasoning_config: dict) -> tuple[str, str | None]:
+ """Split generated output into ``(content, reasoning_content)`` via ``parse_response``.
+
+ If the schema's regex matches (closing marker present), use it. For prompts
+ that prefill the opener (QwQ-32B, DeepSeek-R1) the entire output is reasoning
+ until ```` arrives — when that's truncated, fall back to treating
+ all decoded text as reasoning. Returns ``(content, None)`` otherwise.
+ """
+ parsed = processor.parse_response(generated_ids, reasoning_config["schema"])
+ if parsed:
+ reasoning = parsed.get("thinking", "").strip()
+ if reasoning:
+ return parsed.get("content", ""), reasoning
+ # Prefilled opener (QwQ-32B, DeepSeek-R1) truncated before ```` —
+ # no anchor for the schema regex; treat all output as reasoning.
+ if reasoning_config.get("start_in_thinking"):
+ return "", content.strip()
+ return content, None
+
+
+def _starts_in_thinking(input_ids, start_ids: list[int]) -> bool:
+ """True if the rendered prompt ends with an unclosed thinking block.
+
+ Some reasoning-model chat templates prefill the thinking opener as the final
+ prompt tokens (e.g. DeepSeek-R1, QwQ-32B emit ``\\n`` at the end when
+ ``add_generation_prompt=True``). In those cases the model resumes *inside*
+ the block, so its output contains only ``...reasoninganswer`` with
+ no opening tag — the streamer must start with ``_inside_thinking=True``.
+
+ The prefill always lands at the tail of the prompt (optionally followed by a
+ single whitespace token like ``\\n``), so we only inspect the last few tokens.
+ """
+ if hasattr(input_ids, "tolist"):
+ input_ids = input_ids.tolist()
+ if input_ids and isinstance(input_ids[0], list):
+ if len(input_ids) != 1:
+ return False
+ input_ids = input_ids[0]
+ n = len(start_ids)
+ # Match start_ids at the tail, allowing up to one trailing token (e.g. "\n").
+ for trailing in (0, 1):
+ if len(input_ids) >= n + trailing:
+ end = len(input_ids) - trailing
+ if input_ids[end - n : end] == start_ids:
+ return True
+ return False
+
+
class DownloadAggregator:
"""Aggregates byte-progress across multiple concurrent download tqdm bars.
@@ -286,6 +398,7 @@ def __init__(
queue: asyncio.Queue,
skip_special_tokens: bool = True,
tool_config: dict | None = None,
+ reasoning_config: dict | None = None,
):
"""
Args:
@@ -297,6 +410,9 @@ def __init__(
tool_config (`dict`, *optional*): Tool call config from ``get_tool_call_config``.
When set, tokens between stc/etc delimiters (inclusive) are suppressed
from the queue so tool call markup is never streamed to the client.
+ reasoning_config (`dict`, *optional*): Thinking config from ``get_reasoning_config``.
+ When set, tokens between start/end delimiters are wrapped as
+ :class:`ReasoningText` so handlers route them to ``reasoning_content``.
"""
from tokenizers.decoders import DecodeStream
@@ -307,6 +423,10 @@ def __init__(
self._stc_id = tool_config["stc_id"] if tool_config else None
self._etc_id = tool_config["etc_id"] if tool_config else None
self._inside_tool_call = False
+ self._thinking_start_ids = reasoning_config["start_ids"] if reasoning_config else None
+ self._thinking_end_id = reasoning_config["end_id"] if reasoning_config else None
+ self._inside_thinking = bool(reasoning_config and reasoning_config.get("start_in_thinking"))
+ self._thinking_prefix: list[int] = []
self._first = True
self._cancelled = threading.Event()
self.total_tokens = 0
@@ -329,9 +449,33 @@ def put(self, value: "torch.Tensor") -> None:
elif token_id == self._etc_id:
self._inside_tool_call = False
+ is_start_or_end_token = self._advance_thinking_state(token_id)
+
text = self._decode_stream.step(self._tokenizer, token_id)
- if text is not None and not self._inside_tool_call and token_id != self._etc_id:
- self._loop.call_soon_threadsafe(self._queue.put_nowait, text)
+ if text is None or self._inside_tool_call or token_id == self._etc_id or is_start_or_end_token:
+ continue
+ if self._inside_thinking:
+ text = ReasoningText(text)
+ self._loop.call_soon_threadsafe(self._queue.put_nowait, text)
+
+ def _advance_thinking_state(self, token_id: int) -> bool:
+ """Mutate thinking state; return ``True`` if ``token_id`` is a start or end token — suppress from output."""
+ if self._thinking_start_ids is None:
+ return False
+ if self._inside_thinking:
+ if token_id == self._thinking_end_id:
+ self._inside_thinking = False
+ return True
+ return False
+ expected = self._thinking_start_ids[len(self._thinking_prefix)]
+ if token_id != expected:
+ self._thinking_prefix = []
+ return False
+ self._thinking_prefix.append(token_id)
+ if len(self._thinking_prefix) == len(self._thinking_start_ids):
+ self._inside_thinking = True
+ self._thinking_prefix = []
+ return True
def end(self) -> None:
"""Called by ``model.generate()`` when generation is complete."""
@@ -359,6 +503,7 @@ def __init__(
loop: asyncio.AbstractEventLoop,
queue: asyncio.Queue,
tool_config: dict | None = None,
+ reasoning_config: dict | None = None,
):
"""
Args:
@@ -368,6 +513,7 @@ def __init__(
loop (`asyncio.AbstractEventLoop`): The event loop to push decoded text to.
queue (`asyncio.Queue`): The queue that receives decoded text chunks.
tool_config (`dict`, *optional*): Tool call config (see ``DirectStreamer``).
+ reasoning_config (`dict`, *optional*): Thinking config (see ``DirectStreamer``).
"""
from tokenizers.decoders import DecodeStream
@@ -380,6 +526,10 @@ def __init__(
self._stc_id = tool_config["stc_id"] if tool_config else None
self._etc_id = tool_config["etc_id"] if tool_config else None
self._inside_tool_call = False
+ self._thinking_start_ids = reasoning_config["start_ids"] if reasoning_config else None
+ self._thinking_end_id = reasoning_config["end_id"] if reasoning_config else None
+ self._inside_thinking = bool(reasoning_config and reasoning_config.get("start_in_thinking"))
+ self._thinking_prefix: list[int] = []
self._prev_len = 0
self.total_tokens = 0
self.generated_token_ids: list[int] = []
@@ -397,9 +547,33 @@ def put(self, output: "GenerationOutput") -> None:
elif token_id == self._etc_id:
self._inside_tool_call = False
+ is_start_or_end_token = self._advance_thinking_state(token_id)
+
text = self._decode_stream.step(self._tokenizer, token_id)
- if text is not None and not self._inside_tool_call and token_id != self._etc_id:
- self._queue.put_nowait(text)
+ if text is None or self._inside_tool_call or token_id == self._etc_id or is_start_or_end_token:
+ continue
+ if self._inside_thinking:
+ text = ReasoningText(text)
+ self._queue.put_nowait(text)
+
+ def _advance_thinking_state(self, token_id: int) -> bool:
+ """Mutate thinking state; return ``True`` if ``token_id`` is a start or end token — suppress from output."""
+ if self._thinking_start_ids is None:
+ return False
+ if self._inside_thinking:
+ if token_id == self._thinking_end_id:
+ self._inside_thinking = False
+ return True
+ return False
+ expected = self._thinking_start_ids[len(self._thinking_prefix)]
+ if token_id != expected:
+ self._thinking_prefix = []
+ return False
+ self._thinking_prefix.append(token_id)
+ if len(self._thinking_prefix) == len(self._thinking_start_ids):
+ self._inside_thinking = True
+ self._thinking_prefix = []
+ return True
def end(self) -> None:
"""Signal end of stream."""
@@ -486,6 +660,7 @@ def generate_streaming(
gen_config: "GenerationConfig",
request_id: str,
tool_config: dict | None = None,
+ reasoning_config: dict | None = None,
) -> tuple[asyncio.Queue, "DirectStreamer | CBStreamer"]:
"""Start streaming generation.
@@ -497,6 +672,8 @@ def generate_streaming(
request_id (`str`): Unique request identifier.
tool_config (`dict`, *optional*): Tool call config from ``get_tool_call_config``.
When set, tool call tokens (between stc/etc) are suppressed from output.
+ reasoning_config (`dict`, *optional*): Thinking config from ``get_reasoning_config``.
+ When set, thinking tokens are wrapped as :class:`ReasoningText`.
Returns:
`tuple[asyncio.Queue, DirectStreamer | CBStreamer]`: A ``(queue, streamer)`` pair
@@ -545,13 +722,16 @@ def generate_streaming(
gen_config: "GenerationConfig",
request_id: str,
tool_config: dict | None = None,
+ reasoning_config: dict | None = None,
) -> tuple[asyncio.Queue, DirectStreamer]:
"""Start streaming generation via ``model.generate()`` on the inference thread."""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
# ProcessorMixin exposes the fast tokenizer as .tokenizer; PreTrainedTokenizerFast is already one.
rust_tokenizer = getattr(processor, "tokenizer", processor)._tokenizer # type: ignore[union-attr]
- streamer = DirectStreamer(rust_tokenizer, loop, queue, tool_config=tool_config)
+ streamer = DirectStreamer(
+ rust_tokenizer, loop, queue, tool_config=tool_config, reasoning_config=reasoning_config
+ )
gen_kwargs = {**inputs, "streamer": streamer, "generation_config": gen_config, "tokenizer": processor}
if hasattr(model, "has_talker"):
gen_kwargs["generation_mode"] = "text"
@@ -643,6 +823,7 @@ def generate_streaming(
gen_config: "GenerationConfig",
request_id: str,
tool_config: dict | None = None,
+ reasoning_config: dict | None = None,
) -> tuple[asyncio.Queue, CBStreamer]:
"""Start streaming CB generation. Registers a per-request output handler."""
cb = self._cb
@@ -662,7 +843,15 @@ def generate_streaming(
)
# ProcessorMixin exposes the fast tokenizer as .tokenizer; PreTrainedTokenizerFast is already one.
rust_tokenizer = getattr(processor, "tokenizer", processor)._tokenizer # type: ignore[union-attr]
- streamer = CBStreamer(self._cb, request_id, rust_tokenizer, loop, text_queue, tool_config=tool_config)
+ streamer = CBStreamer(
+ self._cb,
+ request_id,
+ rust_tokenizer,
+ loop,
+ text_queue,
+ tool_config=tool_config,
+ reasoning_config=reasoning_config,
+ )
# Register a direct callback: the dispatcher calls this on the event loop with each GenerationOutput.
# This decodes tokens and pushes text straight to the SSE text_queue
@@ -826,9 +1015,11 @@ def __init__(
self,
model_manager: "ModelManager",
generation_state: GenerationState,
+ chat_template_kwargs: dict | None = None,
):
self.model_manager = model_manager
self.generation_state = generation_state
+ self.chat_template_kwargs = chat_template_kwargs or {}
def _validate_request(self, body: dict) -> None:
"""Validate request fields against the handler's params class and unused fields."""
diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py
index b08aa558d795..1a721ca2a82a 100644
--- a/src/transformers/dependency_versions_table.py
+++ b/src/transformers/dependency_versions_table.py
@@ -56,6 +56,7 @@
"rjieba": "rjieba",
"rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff": "ruff==0.14.10",
+ "transformers-mlinter": "transformers-mlinter==0.1.1",
"ty": "ty==0.0.20",
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
"sacremoses": "sacremoses",
diff --git a/src/transformers/models/esm/configuration_esm.py b/src/transformers/models/esm/configuration_esm.py
index a00dcf8b39e3..7875d88ecee8 100644
--- a/src/transformers/models/esm/configuration_esm.py
+++ b/src/transformers/models/esm/configuration_esm.py
@@ -159,12 +159,12 @@ class EsmConfig(PreTrainedConfig):
mask_token_id (`int`, *optional*):
The index of the mask token in the vocabulary. This must be included in the config because of the
"mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens.
+ rope_theta (`float`, defaults to 10000.0):
+ The base period of the RoPE embeddings. Only used when `position_embedding_type` is set to `"rotary"`.
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
Type of position embedding. Choose either `"absolute"` or "rotary"`.
emb_layer_norm_before (`bool`, *optional*):
Whether to apply layer normalization after embeddings but before the main stem of the network.
- rope_theta (`float`, defaults to 10000.0):
- The base period of the RoPE embeddings. Only used when `position_embedding_type` is set to `"rotary"`.
token_dropout (`bool`, defaults to `False`):
When this is enabled, masked tokens are treated as if they had been dropped out by input dropout.
is_folding_model (`bool`, defaults to `False`):
diff --git a/tests/cli/test_serve.py b/tests/cli/test_serve.py
index c54de16b32bd..926a3ed46046 100644
--- a/tests/cli/test_serve.py
+++ b/tests/cli/test_serve.py
@@ -1958,6 +1958,193 @@ class TestToolCallGemma(_TestToolCallBase, unittest.TestCase):
MODEL = "google/gemma-4-E2B-it"
+class _TestReasoningBase:
+ """Base class for reasoning integration tests. Subclasses set MODEL.
+
+ A single server is shared across all tests in a subclass via setUpClass.
+ """
+
+ MODEL: str
+ USER_PROMPT = "What is 17 * 23? Think briefly, then answer in one sentence."
+ EXPECTED_ANSWER = "391"
+ MAX_TOKENS = 512
+
+ @classmethod
+ def setUpClass(cls):
+ cls.serve, port = _start_serve()
+ cls.base_url = f"http://localhost:{port}"
+ cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.serve.kill_server()
+
+ @staticmethod
+ def _reasoning_field(obj):
+ """Return ``reasoning_content`` from a chat message or delta (handles model_extra)."""
+ return getattr(obj, "reasoning_content", None) or (obj.model_extra or {}).get("reasoning_content")
+
+ # ----- chat completions -----
+
+ def test_chat_non_streaming(self):
+ """Chat completions: non-streaming surfaces ``reasoning_content`` and strips delimiters."""
+ msg = (
+ self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": self.USER_PROMPT}],
+ stream=False,
+ max_tokens=self.MAX_TOKENS,
+ temperature=0.0,
+ )
+ .choices[0]
+ .message
+ )
+ reasoning = self._reasoning_field(msg)
+ self.assertIn(self.EXPECTED_ANSWER, reasoning or "", f"answer missing from reasoning: {reasoning!r}")
+ self.assertIn(self.EXPECTED_ANSWER, msg.content or "", f"answer missing from content: {msg.content!r}")
+ self.assertNotIn("", msg.content or "")
+ self.assertNotIn("<|channel>", msg.content or "")
+ self.assertNotIn(reasoning.strip()[:30], msg.content or "")
+
+ def test_chat_streaming(self):
+ """Chat completions: streaming emits ``reasoning_content`` deltas; content stays clean."""
+ chunks = list(
+ self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": self.USER_PROMPT}],
+ stream=True,
+ max_tokens=self.MAX_TOKENS,
+ temperature=0.0,
+ )
+ )
+ reasoning_text = "".join(self._reasoning_field(c.choices[0].delta) or "" for c in chunks)
+ self.assertIn(self.EXPECTED_ANSWER, reasoning_text, f"answer missing from reasoning: {reasoning_text!r}")
+ content = "".join(c.choices[0].delta.content or "" for c in chunks)
+ self.assertIn(self.EXPECTED_ANSWER, content, f"answer missing from content: {content!r}")
+ self.assertNotIn("", content)
+ self.assertNotIn("<|channel>", content)
+
+ def test_chat_multi_turn_round_trips_reasoning(self):
+ """Chat completions: reasoning_content from a prior turn round-trips through input."""
+ first = (
+ self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[{"role": "user", "content": self.USER_PROMPT}],
+ stream=False,
+ max_tokens=self.MAX_TOKENS,
+ temperature=0.0,
+ )
+ .choices[0]
+ .message
+ )
+ reasoning = self._reasoning_field(first)
+ self.assertTrue(reasoning)
+ second = self.client.chat.completions.create(
+ model=self.MODEL,
+ messages=[
+ {"role": "user", "content": self.USER_PROMPT},
+ {"role": "assistant", "content": first.content or "", "reasoning_content": reasoning},
+ {"role": "user", "content": "Now multiply that result by 2."},
+ ],
+ stream=False,
+ max_tokens=self.MAX_TOKENS,
+ temperature=0.0,
+ )
+ self.assertIsNotNone(second.choices[0].message.content)
+
+ # ----- responses -----
+
+ def test_response_non_streaming(self):
+ """Responses API: non-streaming includes a reasoning item before the message item."""
+ resp = self.client.responses.create(
+ model=self.MODEL,
+ input=self.USER_PROMPT,
+ stream=False,
+ max_output_tokens=self.MAX_TOKENS,
+ temperature=0.0,
+ )
+ types = [item.type for item in resp.output]
+ self.assertIn("reasoning", types, f"expected reasoning item, got types: {types}")
+ self.assertIn("message", types)
+ self.assertLess(types.index("reasoning"), types.index("message"))
+ reasoning_text = next(item for item in resp.output if item.type == "reasoning").content[0].text
+ self.assertIn(self.EXPECTED_ANSWER, reasoning_text, f"answer missing from reasoning: {reasoning_text!r}")
+ message_text = next(item for item in resp.output if item.type == "message").content[0].text
+ self.assertIn(self.EXPECTED_ANSWER, message_text, f"answer missing from message: {message_text!r}")
+ self.assertNotIn("", message_text)
+ self.assertNotIn("<|channel>", message_text)
+
+ def test_response_streaming(self):
+ """Responses API: streaming emits reasoning_text events and a separate reasoning item."""
+ events = list(
+ self.client.responses.create(
+ model=self.MODEL,
+ input=self.USER_PROMPT,
+ stream=True,
+ max_output_tokens=self.MAX_TOKENS,
+ temperature=0.0,
+ )
+ )
+ added = [e for e in events if e.type == "response.output_item.added"]
+ self.assertGreaterEqual(len(added), 2)
+ self.assertEqual(added[0].item.type, "reasoning")
+ self.assertEqual(added[1].item.type, "message")
+ # Coherence: concat of reasoning_text.delta events == reasoning_text.done.text, and contains the answer.
+ reasoning_text = "".join(e.delta for e in events if e.type == "response.reasoning_text.delta")
+ done = next(e for e in events if e.type == "response.reasoning_text.done")
+ self.assertEqual(reasoning_text, done.text)
+ self.assertIn(self.EXPECTED_ANSWER, reasoning_text, f"answer missing from reasoning: {reasoning_text!r}")
+ content = "".join(e.delta for e in events if e.type == "response.output_text.delta")
+ self.assertIn(self.EXPECTED_ANSWER, content, f"answer missing from content: {content!r}")
+ self.assertNotIn("", content)
+ self.assertNotIn("<|channel>", content)
+
+ def test_response_multi_turn_round_trips_reasoning(self):
+ """Responses API: ``reasoning`` items echoed back as input are accepted."""
+ first = self.client.responses.create(
+ model=self.MODEL,
+ input=self.USER_PROMPT,
+ stream=False,
+ max_output_tokens=self.MAX_TOKENS,
+ temperature=0.0,
+ )
+ reasoning_item = next((i for i in first.output if i.type == "reasoning"), None)
+ message_item = next((i for i in first.output if i.type == "message"), None)
+ self.assertIsNotNone(reasoning_item)
+ self.assertIsNotNone(message_item)
+ second = self.client.responses.create(
+ model=self.MODEL,
+ input=[
+ {"role": "user", "content": self.USER_PROMPT},
+ reasoning_item.model_dump(exclude_none=True),
+ {"role": "assistant", "content": message_item.content[0].text},
+ {"role": "user", "content": "Now multiply that result by 2."},
+ ],
+ stream=False,
+ max_output_tokens=self.MAX_TOKENS,
+ temperature=0.0,
+ )
+ self.assertEqual(second.status, "completed")
+
+
+@slow
+@require_serve
+@require_torch_accelerator
+class TestReasoningQwen(_TestReasoningBase, unittest.TestCase):
+ """Reasoning tests with Qwen3 (inline ... tags)."""
+
+ MODEL = "Qwen/Qwen3-1.7B"
+
+
+@slow
+@require_serve
+@require_torch_accelerator
+class TestReasoningGemma(_TestReasoningBase, unittest.TestCase):
+ """Reasoning tests with Gemma 4 (response_schema-based thinking channel)."""
+
+ MODEL = "google/gemma-4-E2B-it"
+
+
@slow
@require_librosa
@require_multipart