Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions src/transformers/cli/serving/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import TYPE_CHECKING

from transformers.utils import logging
from transformers.utils.chat_template_utils import get_message_content


if TYPE_CHECKING:
Expand Down Expand Up @@ -194,7 +195,7 @@ def parse(text: str, tool_format: dict) -> list[dict] | None:
if result is not None:
tool_calls.append({"name": result[0], "arguments": result[1]})
pos = e + len(end)
return tool_calls if tool_calls else None
return tool_calls or None

def _parse_block(self, block: str) -> dict | None:
"""Parse a buffered tool call block. Returns ``{"name": str, "arguments": str}`` or None."""
Expand Down Expand Up @@ -927,17 +928,23 @@ def get_processor_inputs_from_messages(messages: list[dict], modality: Modality)
parsed = {"role": message["role"], "content": []}

if modality == Modality.LLM:
if isinstance(message["content"], str):
parsed["content"] = message["content"]
elif isinstance(message["content"], list):
texts = [c["text"] for c in message["content"] if c["type"] == "text"]
# Default to "" so assistant messages with tool_calls but no
# content (#45290) become an empty string for LLM context.
message_content = get_message_content(message, default="")
if isinstance(message_content, str):
parsed["content"] = message_content
elif isinstance(message_content, list):
texts = [c["text"] for c in message_content if c["type"] == "text"]
parsed["content"] = " ".join(texts)

elif modality == Modality.VLM:
if isinstance(message["content"], str):
parsed["content"].append({"type": "text", "text": message["content"]})
else:
for content in message["content"]:
# Default to [] so assistant messages with tool_calls but no
# content (#45290) iterate over zero blocks for VLM context.
message_content = get_message_content(message, default=[])
if isinstance(message_content, str):
parsed["content"].append({"type": "text", "text": message_content})
elif isinstance(message_content, list):
for content in message_content:
if content["type"] == "text":
parsed["content"].append(content)
elif content["type"] == "image_url":
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/smolvlm/processing_smolvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import BatchEncoding, TextInput
from ...utils import auto_docstring, is_num2words_available, logging
from ...utils.chat_template_utils import get_message_content
from ...video_utils import VideoInput


Expand Down Expand Up @@ -333,7 +334,7 @@ def apply_chat_template(
(isinstance(content, dict) and content["type"] == "video")
for conversation in conversations
for message in conversation
for content in message["content"]
for content in get_message_content(message, default=[])
)
if chat_template is None and has_video:
# re-assign to the correct default template for BC, if user is not requesting their own template
Expand Down
17 changes: 11 additions & 6 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
list_repo_templates,
logging,
)
from .utils.chat_template_utils import _get_template_variables, render_jinja_template
from .utils.chat_template_utils import _get_template_variables, get_message_content, render_jinja_template
from .utils.type_validators import (
device_validator,
image_size_validator,
Expand Down Expand Up @@ -1804,10 +1804,13 @@ def apply_chat_template(
for conversation in conversations:
images, videos = [], []
for message in conversation:
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
# Use get_message_content to safely default to [] when an
# assistant message has tool_calls but no content (#45290).
message_content = get_message_content(message, default=[])
visuals = [content for content in message_content if content["type"] in ["image", "video"]]
audio_fnames = [
content[key]
for content in message["content"]
for content in message_content
for key in ["audio", "url", "path"]
if key in content and content["type"] == "audio"
]
Expand All @@ -1834,7 +1837,9 @@ def apply_chat_template(
for fname in video_fnames:
# This updates the template in-place and adds audio entry
# to ensure `audio` token is added by jinja
message["content"].append({"type": "audio"})
content_list = get_message_content(message, default=[])
if isinstance(content_list, list):
content_list.append({"type": "audio"})
batch_audios.append(load_audio(fname, sampling_rate=sampling_rate))

# Currently all processors can accept nested list of batches, but not flat list of visuals
Expand Down Expand Up @@ -1886,7 +1891,7 @@ def apply_chat_template(
text=prompt,
images=batch_images if images_exist else None,
videos=batch_videos if videos_exist else None,
audio=batch_audios if batch_audios else None,
audio=batch_audios or None,
**processor_kwargs,
)

Expand All @@ -1913,7 +1918,7 @@ def apply_chat_template(
# Ensure end_pos is also within bounds
if end_pos > len(input_ids[i]):
end_pos = len(input_ids[i])
for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])):
for token_id in range(start_pos, end_pos or len(input_ids[i])):
current_mask[token_id] = 1
assistant_masks.append(current_mask)
out["assistant_masks"] = assistant_masks
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/tokenization_mistral_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TruncationStrategy,
)
from transformers.utils import PaddingStrategy, TensorType, add_end_docstrings, logging, to_py_obj
from transformers.utils.chat_template_utils import get_message_content
from transformers.utils.import_utils import is_mistral_common_available, is_torch_available, requires


Expand Down Expand Up @@ -1127,7 +1128,9 @@ def _maybe_adapt_message(message: dict[str, Any]) -> None:
"""Adapt message to `mistral-common` format and leave validation to `mistral-common`."""
if not isinstance(message, dict):
return message
maybe_list_content: str | list[dict[str, str | dict[str, Any]]] | None = message.get("content")
# Use get_message_content to keep the missing-content convention consistent
# across all chat-template entry points (#45290).
maybe_list_content: str | list[dict[str, str | dict[str, Any]]] = get_message_content(message, default=[])
if not maybe_list_content or isinstance(maybe_list_content, str):
return message

Expand Down
70 changes: 64 additions & 6 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,43 @@


BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)


def get_message_content(message: dict, *, default: list | str | None = None) -> Any:
"""Return the ``content`` field of a chat message, defaulting if missing.

Per the OpenAI chat-completion spec, an assistant message MAY omit
``content`` when ``tool_calls`` is present (the assistant is requesting
tool execution and has no textual reply yet). Code that iterates over
message content for processing — extracting visuals from a multimodal
conversation, joining text segments for a text-only model, or rendering
a Jinja template — must default to a safe empty value rather than
raising ``KeyError``.

This helper centralizes that convention so individual processors,
tokenizers, and serving handlers don't each have to re-implement it
inline (and don't drift apart over time).

Args:
message:
A chat message dict, e.g.
``{"role": "assistant", "tool_calls": [...]}``.
default:
**Keyword-only.** The value to return when the ``content`` key
is missing. Pass ``[]`` for code that iterates over multimodal
content blocks, and ``""`` for code that treats content as a
single text string. If ``None`` (the default), an empty list
is returned — callers SHOULD pass an explicit default that
matches the shape their code expects.

Returns:
``message["content"]`` if the key is present, otherwise *default*.
"""
if default is None:
default = []
return message.get("content", default)


# Extracts the initial segment of the docstring, containing the function description
description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
# Extracts the Args: block from the docstring
Expand Down Expand Up @@ -539,9 +576,17 @@ def render_jinja_template(
if hasattr(chat, "messages"):
# Indicates it's a Conversation object
chat = chat.messages
# Ensure all messages have a content field for Jinja template compatibility.
# Per the OpenAI spec, assistant messages with tool_calls MAY omit content;
# default to an empty string here (text-only template default) so the
# rendered template still has a content field to reference. See
# ``get_message_content`` above for the convention.
chat = [
{**msg, "content": get_message_content(msg, default="")} if "content" not in msg else msg for msg in chat
]
if continue_final_message:
chat = deepcopy(chat)
final_message = chat[-1]["content"]
final_message = get_message_content(chat[-1], default="")
if isinstance(final_message, (list, tuple)):
for content_block in reversed(final_message):
if "text" in content_block:
Expand All @@ -554,7 +599,7 @@ def render_jinja_template(
"continue_final_message is set but we could not find any text to continue in the final message!"
)
else:
chat[-1]["content"] = chat[-1]["content"] + continue_final_message_tag
chat[-1]["content"] = final_message + continue_final_message_tag
if return_assistant_tokens_mask:
rendered_chat, generation_indices = _render_with_assistant_indices(
compiled_template=compiled_template,
Expand Down Expand Up @@ -597,22 +642,35 @@ def render_jinja_template(

def is_valid_message(message):
"""
Check that input is a valid message in a chat, namely a dict with "role" and "content" keys.
Check that input is a valid message in a chat. A valid message is a dict with a ``"role"``
key and either a ``"content"`` key or a ``"tool_calls"`` key — assistant messages that
request tool execution may omit ``"content"`` per the OpenAI chat-completion spec
(see ``get_message_content``).
"""
if not isinstance(message, dict):
return False
if not ("role" in message and "content" in message):
if "role" not in message:
return False
if "content" not in message and "tool_calls" not in message:
return False
return True


class Chat:
"""This class is intended to just be used internally for pipelines and not exposed to users. We convert chats
to this format because the rest of the pipeline code tends to assume that lists of messages are
actually a batch of samples rather than messages in the same conversation."""
actually a batch of samples rather than messages in the same conversation.

Each message must be a dict with a ``"role"`` key and either a ``"content"`` key or a
``"tool_calls"`` key. Assistant messages that request tool execution may omit ``"content"``
per the OpenAI chat-completion spec — see ``is_valid_message`` and ``get_message_content``.
"""

def __init__(self, messages: dict):
for message in messages:
if not is_valid_message(message):
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
raise ValueError(
"When passing chat dicts as input, each dict must have a 'role' key and either a "
"'content' or 'tool_calls' key."
)
self.messages = messages
49 changes: 49 additions & 0 deletions tests/cli/test_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,55 @@ def test_vlm_multi_turn(self):
self.assertIsInstance(msg["content"], list)
self.assertEqual(msg["content"][0]["type"], "text")

# Regression tests for issue #45290 — assistant messages with tool_calls but no content
def test_vlm_tool_calls_without_content(self):
"""VLM modality must not crash on an assistant message that has tool_calls but no content field."""
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages

messages = [
{"role": "user", "content": "What's the weather in London?"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_def456",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"location": "London"}'},
}
],
},
]

result = get_processor_inputs_from_messages(messages, Modality.VLM)
self.assertIsInstance(result, list)
self.assertEqual(result[1]["role"], "assistant")
# Missing content should default to an empty list for VLM
self.assertEqual(result[1]["content"], [])

def test_llm_tool_calls_without_content(self):
"""LLM modality must not crash on an assistant message that has tool_calls but no content field."""
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages

messages = [
{"role": "user", "content": "What's the weather in London?"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_def456",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"location": "London"}'},
}
],
},
]

result = get_processor_inputs_from_messages(messages, Modality.LLM)
self.assertIsInstance(result, list)
self.assertEqual(result[1]["role"], "assistant")
# Missing content should default to an empty string for LLM
self.assertEqual(result[1]["content"], "")


class TestGenerativeModelList(unittest.TestCase):
def test_lists_only_generative_models(self):
Expand Down
53 changes: 53 additions & 0 deletions tests/test_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2017,3 +2017,56 @@ def test_get_num_multimodal_tokens_matches_processor_call(self):
num_image_tokens_from_call = inputs.mm_token_type_ids.sum(-1).tolist()
num_image_tokens_from_helper = processor._get_num_multimodal_tokens(image_sizes=image_sizes)
self.assertListEqual(num_image_tokens_from_call, num_image_tokens_from_helper["num_image_tokens"])

def test_apply_chat_template_with_tool_calls_no_content(self):
"""Regression test for issue #45290: assistant messages with tool_calls but no content field should not raise KeyError."""
processor = self.get_processor()

# Skip if processor doesn't support apply_chat_template
if not hasattr(processor, "apply_chat_template"):
self.skipTest("Processor doesn't support apply_chat_template")

# Skip processors without a default chat template — they raise ValueError
# before reaching the patched code path.
chat_template = getattr(processor, "chat_template", None)
if chat_template is None:
self.skipTest("Processor has no chat_template")

# The bug only manifests for chat templates that actually iterate over
# tool_calls. For templates that don't reference tool_calls at all, the
# test is meaningless because the template would never inspect that
# field. Skip those rather than asserting against templates that were
# never designed to handle tool-calling at all.
chat_template_source = (
chat_template
if isinstance(chat_template, str)
else (chat_template.get("default") if isinstance(chat_template, dict) else "")
)
if not chat_template_source or "tool_calls" not in chat_template_source:
self.skipTest("Processor's chat_template does not support tool_calls")

# Message with tool_calls but no content field (valid per OpenAI spec)
messages = [
{"role": "user", "content": "What's the weather?"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"location": "NYC"}'},
}
],
},
]

# Should not raise KeyError
try:
result = processor.apply_chat_template(messages, tokenize=False)
# If we get here, the fix worked
self.assertIsInstance(result, str)
except KeyError as e:
if "content" in str(e):
self.fail(f"apply_chat_template raised KeyError for missing 'content' field: {e}")
else:
raise
32 changes: 32 additions & 0 deletions tests/test_tokenization_mistral_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2274,3 +2274,35 @@ def test_prepare_for_model(self):
# unsupported kwargs should raise ValueError
with self.assertRaises(ValueError):
self.tokenizer.prepare_for_model(token_ids, add_special_tokens=False, unsupported_arg="")

@require_mistral_common
def test_apply_chat_template_with_tool_calls_no_content(self):
"""Regression test for issue #45290: Mistral tokenizer should handle assistant messages with tool_calls but no content field."""
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")

# Message with tool_calls but no content field (valid per OpenAI spec)
messages = [
{"role": "user", "content": "What's the weather in Paris?"},
{
"role": "assistant",
"tool_calls": [
{
# Mistral tokenizer requires tool_call ids to be exactly 9 alphanumeric chars.
"id": "abc123xyz",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"location": "Paris"}'},
}
],
},
]

# Should not raise KeyError when processing messages with tool_calls but no content
try:
result = tokenizer.apply_chat_template(messages, tokenize=False)
# If we get here, the fix worked
self.assertIsInstance(result, str)
except KeyError as e:
if "content" in str(e):
self.fail(f"Mistral tokenizer raised KeyError for missing 'content' field: {e}")
else:
raise
Loading
Loading