Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@
"rjieba",
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff==0.14.10",
"transformers-mlinter @ git+https://github.com/huggingface/transformers-mlinter@b9d319ce264c106f97a959d926ef42bc3c0ea4d1",
# When bumping `transformers-mlinter`, sync repo-local rule overrides from
# `utils/rules.toml` back into the released package.
"transformers-mlinter==0.1.1",
"ty==0.0.20",
# `sacrebleu` not used in `transformers`. However, it is needed in several tests, when a test calls
# `evaluate.load("sacrebleu")`. This metric is used in the examples that we use to test the `Trainer` with, in the
Expand Down Expand Up @@ -295,7 +297,7 @@ def finalize_options(self):
pass

def run(self):
if SUPPORTED_PYTHON_VERSIONS[0] >= PYTHON_MINOR_VERSION:
if SUPPORTED_PYTHON_VERSIONS[0] > PYTHON_MINOR_VERSION:
print(
f"Table updated only when running 3.{SUPPORTED_PYTHON_VERSIONS[0]}.x, detected version is {sys.version}."
)
Expand Down Expand Up @@ -328,7 +330,7 @@ def run(self):

setup(
name="transformers",
version="5.6.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="5.7.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="Transformers: the model-definition framework for state-of-the-art machine learning models in text, vision, audio, and multimodal models, for both inference and training.",
Expand Down
61 changes: 48 additions & 13 deletions src/transformers/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"""

import asyncio
import enum
import json
import threading
from typing import Annotated

Expand All @@ -30,6 +32,12 @@
logger = logging.get_logger(__name__)


class ReasoningMode(str, enum.Enum):
ON = "on"
OFF = "off"
AUTO = "auto"


class Serve:
def __init__(
self,
Expand All @@ -39,6 +47,32 @@ def __init__(
bool,
typer.Option(help="Enable continuous batching with paged attention. Configure with --cb-* flags."),
] = False,
attn_implementation: Annotated[
str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).")
] = None,
compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False,
quantization: Annotated[
str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.")
] = None,
reasoning: Annotated[
ReasoningMode, typer.Option(help="Reasoning mode. 'auto' uses the chat template default.")
] = ReasoningMode.AUTO,
chat_template_kwargs: Annotated[
str | None,
typer.Option(
help=(
"Default JSON kwargs forwarded to apply_chat_template "
"(e.g. '{\"enable_thinking\": true}'); per-request chat_template_kwargs override these."
)
),
] = None,
device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto",
dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto",
trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False,
model_timeout: Annotated[
int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.")
] = 300,
# Continuous batching tuning
cb_block_size: Annotated[
int | None, typer.Option(help="KV cache block size in tokens for continuous batching.")
] = None,
Expand All @@ -54,19 +88,6 @@ def __init__(
cb_use_cuda_graph: Annotated[
bool | None, typer.Option(help="Enable CUDA graphs for continuous batching.")
] = None,
attn_implementation: Annotated[
str | None, typer.Option(help="Attention implementation (e.g. flash_attention_2).")
] = None,
compile: Annotated[bool, typer.Option(help="Enable torch.compile for faster inference.")] = False,
quantization: Annotated[
str | None, typer.Option(help="Quantization method: 'bnb-4bit' or 'bnb-8bit'.")
] = None,
device: Annotated[str, typer.Option(help="Device for inference (e.g. 'auto', 'cuda:0', 'cpu').")] = "auto",
dtype: Annotated[str | None, typer.Option(help="Override model dtype. 'auto' derives from weights.")] = "auto",
trust_remote_code: Annotated[bool, typer.Option(help="Trust remote code when loading.")] = False,
model_timeout: Annotated[
int, typer.Option(help="Seconds before idle model is unloaded. Ignored when force_model is set.")
] = 300,
# Server options
host: Annotated[str, typer.Option(help="Server listen address.")] = "localhost",
port: Annotated[int, typer.Option(help="Server listen port.")] = 8000,
Expand Down Expand Up @@ -126,14 +147,28 @@ def __init__(
cb_config=cb_config,
)

if chat_template_kwargs:
chat_template_kwargs = json.loads(chat_template_kwargs)
if not isinstance(chat_template_kwargs, dict):
raise typer.BadParameter("--chat-template-kwargs must be a JSON object")
else:
chat_template_kwargs = {}

if reasoning == ReasoningMode.ON:
chat_template_kwargs["enable_thinking"] = True
elif reasoning == ReasoningMode.OFF:
chat_template_kwargs["enable_thinking"] = False

self._chat_handler = ChatCompletionHandler(
model_manager=self._model_manager,
generation_state=self._generation_state,
chat_template_kwargs=chat_template_kwargs,
)

self._response_handler = ResponseHandler(
model_manager=self._model_manager,
generation_state=self._generation_state,
chat_template_kwargs=chat_template_kwargs,
)

self._transcription_handler = TranscriptionHandler(self._model_manager, self._generation_state)
Expand Down
39 changes: 34 additions & 5 deletions src/transformers/cli/serving/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@
BaseGenerateManager,
BaseHandler,
Modality,
ReasoningText,
_StreamError,
get_reasoning_config,
get_tool_call_config,
parse_reasoning,
parse_tool_calls,
)

Expand All @@ -53,6 +56,7 @@
class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False):
generation_config: str
seed: int
chat_template_kwargs: dict


# Fields accepted by the OpenAI schema but not yet supported.
Expand Down Expand Up @@ -118,10 +122,13 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
for msg in processor_inputs
for c in (msg.get("content") if isinstance(msg.get("content"), list) else [])
)
# Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise
chat_template_kwargs = {}
# Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise.
# Merge order (later wins): custom default -> server default → request-level kwargs.
chat_template_kwargs: dict = {}
if has_video:
chat_template_kwargs["num_frames"] = 32
chat_template_kwargs.update(self.chat_template_kwargs)
chat_template_kwargs.update(body.get("chat_template_kwargs", {}))
inputs = processor.apply_chat_template(
processor_inputs,
add_generation_prompt=True,
Expand All @@ -141,6 +148,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
gen_manager.init_cb(model, gen_config)

tool_config = get_tool_call_config(processor, model) if body.get("tools") else None
reasoning_config = get_reasoning_config(processor, model, inputs["input_ids"])

streaming = body.get("stream")
if streaming:
Expand All @@ -153,6 +161,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
gen_config,
gen_manager=gen_manager,
tool_config=tool_config,
reasoning_config=reasoning_config,
)
else:
return await self._non_streaming(
Expand All @@ -164,6 +173,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
gen_config,
gen_manager=gen_manager,
tool_config=tool_config,
reasoning_config=reasoning_config,
)

# ----- streaming -----
Expand All @@ -178,6 +188,7 @@ def _streaming(
gen_config: "GenerationConfig",
gen_manager: BaseGenerateManager,
tool_config: dict | None = None,
reasoning_config: dict | None = None,
) -> StreamingResponse:
"""Stream tokens as SSE via DirectStreamer."""
queue, streamer = gen_manager.generate_streaming(
Expand All @@ -187,6 +198,7 @@ def _streaming(
gen_config,
request_id=request_id,
tool_config=tool_config,
reasoning_config=reasoning_config,
)
input_ids = inputs["input_ids"]
# CB returns plain lists, regular path returns tensors
Expand Down Expand Up @@ -216,7 +228,10 @@ async def sse_gen() -> AsyncGenerator[str, None]:
yield "".join(sse_parts)
return

sse_parts.append(self._build_chunk_sse(request_id, model=model_id, content=text))
if isinstance(text, ReasoningText):
sse_parts.append(self._build_chunk_sse(request_id, model=model_id, reasoning_content=text))
else:
sse_parts.append(self._build_chunk_sse(request_id, model=model_id, content=text))

if sse_parts:
yield "".join(sse_parts)
Expand Down Expand Up @@ -280,6 +295,7 @@ async def _non_streaming(
gen_config: "GenerationConfig",
gen_manager: BaseGenerateManager,
tool_config: dict | None = None,
reasoning_config: dict | None = None,
) -> JSONResponse:
"""Run generation and return a JSONResponse."""
content, input_len, generated_ids = await gen_manager.generate_non_streaming(
Expand Down Expand Up @@ -307,6 +323,10 @@ async def _non_streaming(
for i, tc in enumerate(parsed)
]

reasoning_content = None
if reasoning_config is not None:
content, reasoning_content = parse_reasoning(processor, generated_ids, content, reasoning_config)

if tool_calls is not None:
finish_reason = "tool_calls"
elif hit_max:
Expand All @@ -322,6 +342,7 @@ async def _non_streaming(
finish_reason=finish_reason,
usage=usage,
tool_calls=tool_calls,
reasoning_content=reasoning_content,
),
media_type="application/json",
)
Expand Down Expand Up @@ -354,6 +375,7 @@ def _build_completion(
finish_reason: str,
usage: CompletionUsage | None = None,
tool_calls: list[dict] | None = None,
reasoning_content: str | None = None,
) -> dict:
"""Build a non-streaming ChatCompletion response dict.

Expand All @@ -364,11 +386,14 @@ def _build_completion(
finish_reason (`str`): Why generation stopped (``"stop"``, ``"length"``, ``"tool_calls"``).
usage (`CompletionUsage`, *optional*): Token usage statistics.
tool_calls (`list[dict]`, *optional*): Parsed tool calls, if any.
reasoning_content (`str`, *optional*): Chain-of-thought content extracted from the response.

Returns:
`dict`: Serialized ``ChatCompletion`` ready for JSON response.
"""
message = ChatCompletionMessage(content=content, role="assistant", tool_calls=tool_calls)
message = ChatCompletionMessage(
content=content, role="assistant", tool_calls=tool_calls, reasoning_content=reasoning_content
)
result = ChatCompletion(
id=request_id,
created=int(time.time()),
Expand All @@ -394,6 +419,7 @@ def _build_chunk_sse(
finish_reason: str | None = None,
tool_calls: list | None = None,
usage: CompletionUsage | None = None,
reasoning_content: str | None = None,
) -> str:
"""Build a streaming ``ChatCompletionChunk`` and format it as an SSE ``data:`` line.

Expand All @@ -405,6 +431,7 @@ def _build_chunk_sse(
finish_reason (`str`, *optional*): Set on the final chunk.
tool_calls (`list`, *optional*): Tool call deltas.
usage (`CompletionUsage`, *optional*): Token usage (sent with the final chunk).
reasoning_content (`str`, *optional*): Reasoning/thinking delta (OpenAI-compatible extension).

Returns:
`str`: A formatted SSE event string.
Expand All @@ -415,7 +442,9 @@ def _build_chunk_sse(
model=model,
choices=[
ChoiceChunk(
delta=ChoiceDelta(content=content, role=role, tool_calls=tool_calls),
delta=ChoiceDelta(
content=content, role=role, tool_calls=tool_calls, reasoning_content=reasoning_content
),
index=0,
finish_reason=finish_reason,
)
Expand Down
Loading
Loading