Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 61 additions & 16 deletions src/transformers/cli/serving/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
from typing import TYPE_CHECKING

from ...utils import logging
from ...utils.import_utils import is_serve_available
from .utils import BaseGenerateManager, BaseHandler, Modality, _StreamError, get_tool_call_config, parse_tool_calls


if is_serve_available():
# --- BRUTE FORCE IMPORT PATCH ---
try:
from fastapi.responses import JSONResponse, StreamingResponse
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import Choice
Expand All @@ -35,26 +36,62 @@
from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming
from openai.types.completion_usage import CompletionUsage

parent_class = CompletionCreateParamsStreaming
except ImportError:
from typing import TypedDict

from .utils import (
BaseGenerateManager,
BaseHandler,
Modality,
_StreamError,
get_tool_call_config,
parse_tool_calls,
)
class _DummyDict(dict):
def __getattr__(self, name):
return None

def __setattr__(self, name, value):
self[name] = value

if TYPE_CHECKING:
from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin
class ChatCompletion(_DummyDict):
pass

class ChatCompletionMessage(_DummyDict):
pass

class ChatCompletionMessageToolCall(_DummyDict):
pass

class Choice(_DummyDict):
pass

class ChatCompletionChunk(_DummyDict):
pass

class ChoiceDelta(_DummyDict):
pass

class ChoiceDeltaToolCall(_DummyDict):
pass

class ChoiceChunk(_DummyDict):
pass

class CompletionCreateParamsStreaming(_DummyDict):
pass

class CompletionUsage(_DummyDict):
pass

class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False):
parent_class = TypedDict


class TransformersCompletionCreateParamsStreaming(parent_class, total=False): # type: ignore
generation_config: str
seed: int


# --- END PATCH ---


if TYPE_CHECKING:
from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin


# Fields accepted by the OpenAI schema but not yet supported.
# Receiving these raises an error to avoid silent misbehaviour.
# NOTE: "stop" is NOT in this set — we map it to stop_strings.
Expand Down Expand Up @@ -133,7 +170,7 @@ async def handle_request(self, body: dict, request_id: str) -> StreamingResponse
**chat_template_kwargs,
)
if not use_cb:
inputs = inputs.to(model.device) # type: ignore[union-attr]
inputs = inputs.to(model.device) # type: ignore

gen_config = self._build_generation_config(body, model.generation_config, use_cb=use_cb)
# TODO: remove when CB supports per-request generation config
Expand Down Expand Up @@ -237,7 +274,10 @@ async def sse_gen() -> AsyncGenerator[str, None]:
index=i,
type="function",
id=f"{request_id}_tool_call_{i}",
function={"name": tc["name"], "arguments": tc["arguments"]},
function={
"name": tc["name"],
"arguments": tc["arguments"],
},
)
],
)
Expand Down Expand Up @@ -328,7 +368,12 @@ async def _non_streaming(

# ----- helpers -----

def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False):
def _build_generation_config(
self,
body: dict,
model_generation_config: "GenerationConfig",
use_cb: bool = False,
):
"""Apply Chat Completions params (``max_tokens``, ``frequency_penalty``, ``logit_bias``,
``stop``) on top of the base generation config."""
generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb)
Expand Down
48 changes: 39 additions & 9 deletions src/transformers/cli/serving/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import asyncio
import time
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypedDict

from ...utils import logging
from ...utils.import_utils import is_serve_available
Expand All @@ -34,19 +34,28 @@
from openai.types import Completion, CompletionChoice, CompletionUsage
from openai.types.completion_create_params import CompletionCreateParamsBase


from .utils import BaseGenerateManager, BaseHandler, _StreamError


if TYPE_CHECKING:
from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin


class TransformersTextCompletionCreateParams(CompletionCreateParamsBase, total=False):
generation_config: str
seed: int
stream: bool
# --- FINAL ROBUST PATCH ---
if "CompletionCreateParamsBase" in globals():
# If the real OpenAI class was successfully imported, use it
class TransformersTextCompletionCreateParams(CompletionCreateParamsBase, total=False):
generation_config: str
seed: int

else:
# Fallback to standard TypedDict if OpenAI types are missing
class TransformersTextCompletionCreateParams(TypedDict, total=False):
generation_config: str
seed: int


# --- END PATCH ---

# Fields accepted by the OpenAI schema but not yet supported.
UNUSED_LEGACY_COMPLETION_FIELDS = {
Expand Down Expand Up @@ -109,10 +118,26 @@ async def handle_request(self, body: dict, request_id: str) -> "StreamingRespons
streaming = body.get("stream")

if streaming:
return self._streaming(request_id, model, processor, model_id, inputs, gen_config, gen_manager, suffix)
return self._streaming(
request_id,
model,
processor,
model_id,
inputs,
gen_config,
gen_manager,
suffix,
)
else:
return await self._non_streaming(
request_id, model, processor, model_id, inputs, gen_config, gen_manager, suffix
request_id,
model,
processor,
model_id,
inputs,
gen_config,
gen_manager,
suffix,
)

# ----- streaming -----
Expand Down Expand Up @@ -261,7 +286,12 @@ def _build_chunk_sse(

# ----- generation config -----

def _build_generation_config(self, body: dict, model_generation_config: "GenerationConfig", use_cb: bool = False):
def _build_generation_config(
self,
body: dict,
model_generation_config: "GenerationConfig",
use_cb: bool = False,
):
"""Apply legacy completion params (``max_tokens``, ``frequency_penalty``, ``stop``) on top of base config."""
generation_config = super()._build_generation_config(body, model_generation_config, use_cb=use_cb)

Expand Down
54 changes: 45 additions & 9 deletions src/transformers/cli/serving/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,20 @@ def _resolve_dtype(dtype: str | None):
return resolved

def _validate_args(self):
if self.quantization is not None and self.quantization not in ("bnb-4bit", "bnb-8bit"):
if self.quantization is not None and self.quantization not in (
"bnb-4bit",
"bnb-8bit",
):
raise ValueError(
f"Unsupported quantization method: '{self.quantization}'. Must be 'bnb-4bit' or 'bnb-8bit'."
)
VALID_ATTN_IMPLEMENTATIONS = {"eager", "sdpa", "flash_attention_2", "flash_attention_3", "flex_attention"}
VALID_ATTN_IMPLEMENTATIONS = {
"eager",
"sdpa",
"flash_attention_2",
"flash_attention_3",
"flex_attention",
}
is_kernels_community = self.attn_implementation is not None and self.attn_implementation.startswith(
"kernels-community/"
)
Expand Down Expand Up @@ -208,7 +217,10 @@ def _load_processor(self, model_id_and_revision: str) -> "ProcessorMixin | PreTr
return AutoProcessor.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code)

def _load_model(
self, model_id_and_revision: str, tqdm_class: type | None = None, progress_callback: Callable | None = None
self,
model_id_and_revision: str,
tqdm_class: type | None = None,
progress_callback: Callable | None = None,
) -> "PreTrainedModel":
"""Load a model.

Expand Down Expand Up @@ -270,10 +282,18 @@ def load_model_and_processor(
if model_id_and_revision not in self.loaded_models:
logger.warning(f"Loading {model_id_and_revision}")
if progress_callback is not None:
progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "processor"})
progress_callback(
{
"status": "loading",
"model": model_id_and_revision,
"stage": "processor",
}
)
processor = self._load_processor(model_id_and_revision)
model = self._load_model(
model_id_and_revision, tqdm_class=tqdm_class, progress_callback=progress_callback
model_id_and_revision,
tqdm_class=tqdm_class,
progress_callback=progress_callback,
)
self.loaded_models[model_id_and_revision] = TimedModel(
model,
Expand All @@ -282,13 +302,25 @@ def load_model_and_processor(
on_unload=lambda key=model_id_and_revision: self.loaded_models.pop(key, None),
)
if progress_callback is not None:
progress_callback({"status": "ready", "model": model_id_and_revision, "cached": False})
progress_callback(
{
"status": "ready",
"model": model_id_and_revision,
"cached": False,
}
)
else:
self.loaded_models[model_id_and_revision].reset_timer()
model = self.loaded_models[model_id_and_revision].model
processor = self.loaded_models[model_id_and_revision].processor
if progress_callback is not None:
progress_callback({"status": "ready", "model": model_id_and_revision, "cached": True})
progress_callback(
{
"status": "ready",
"model": model_id_and_revision,
"cached": True,
}
)
return model, processor

async def load_model_streaming(self, model_id_and_revision: str):
Expand Down Expand Up @@ -384,7 +416,8 @@ def shutdown(self) -> None:

@staticmethod
def get_model_modality(
model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None
model: "PreTrainedModel",
processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None,
) -> Modality:
"""Detect whether a model is an LLM or VLM based on its architecture.

Expand Down Expand Up @@ -441,7 +474,10 @@ def get_gen_models(cache_dir: str | None = None) -> list[dict]:
continue

for ref, revision_info in repo.refs.items():
config_path = next((f.file_path for f in revision_info.files if f.file_name == "config.json"), None)
config_path = next(
(f.file_path for f in revision_info.files if f.file_name == "config.json"),
None,
)
if not config_path:
continue

Expand Down
Loading
Loading