diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index d9828d123b12..781dbcd332f0 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -31,6 +31,7 @@ from typing import TYPE_CHECKING from transformers.utils import logging +from transformers.utils.chat_template_utils import get_message_content if TYPE_CHECKING: @@ -194,7 +195,7 @@ def parse(text: str, tool_format: dict) -> list[dict] | None: if result is not None: tool_calls.append({"name": result[0], "arguments": result[1]}) pos = e + len(end) - return tool_calls if tool_calls else None + return tool_calls or None def _parse_block(self, block: str) -> dict | None: """Parse a buffered tool call block. Returns ``{"name": str, "arguments": str}`` or None.""" @@ -927,17 +928,23 @@ def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) parsed = {"role": message["role"], "content": []} if modality == Modality.LLM: - if isinstance(message["content"], str): - parsed["content"] = message["content"] - elif isinstance(message["content"], list): - texts = [c["text"] for c in message["content"] if c["type"] == "text"] + # Default to "" so assistant messages with tool_calls but no + # content (#45290) become an empty string for LLM context. + message_content = get_message_content(message, default="") + if isinstance(message_content, str): + parsed["content"] = message_content + elif isinstance(message_content, list): + texts = [c["text"] for c in message_content if c["type"] == "text"] parsed["content"] = " ".join(texts) elif modality == Modality.VLM: - if isinstance(message["content"], str): - parsed["content"].append({"type": "text", "text": message["content"]}) - else: - for content in message["content"]: + # Default to [] so assistant messages with tool_calls but no + # content (#45290) iterate over zero blocks for VLM context. + message_content = get_message_content(message, default=[]) + if isinstance(message_content, str): + parsed["content"].append({"type": "text", "text": message_content}) + elif isinstance(message_content, list): + for content in message_content: if content["type"] == "text": parsed["content"].append(content) elif content["type"] == "image_url": diff --git a/src/transformers/models/smolvlm/processing_smolvlm.py b/src/transformers/models/smolvlm/processing_smolvlm.py index 69eb7d6e5a61..8cae61b9324a 100644 --- a/src/transformers/models/smolvlm/processing_smolvlm.py +++ b/src/transformers/models/smolvlm/processing_smolvlm.py @@ -23,6 +23,7 @@ from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import BatchEncoding, TextInput from ...utils import auto_docstring, is_num2words_available, logging +from ...utils.chat_template_utils import get_message_content from ...video_utils import VideoInput @@ -333,7 +334,7 @@ def apply_chat_template( (isinstance(content, dict) and content["type"] == "video") for conversation in conversations for message in conversation - for content in message["content"] + for content in get_message_content(message, default=[]) ) if chat_template is None and has_video: # re-assign to the correct default template for BC, if user is not requesting their own template diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 8a73a68b0cf5..1df6fb42d3b4 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -58,7 +58,7 @@ list_repo_templates, logging, ) -from .utils.chat_template_utils import _get_template_variables, render_jinja_template +from .utils.chat_template_utils import _get_template_variables, get_message_content, render_jinja_template from .utils.type_validators import ( device_validator, image_size_validator, @@ -1804,10 +1804,13 @@ def apply_chat_template( for conversation in conversations: images, videos = [], [] for message in conversation: - visuals = [content for content in message["content"] if content["type"] in ["image", "video"]] + # Use get_message_content to safely default to [] when an + # assistant message has tool_calls but no content (#45290). + message_content = get_message_content(message, default=[]) + visuals = [content for content in message_content if content["type"] in ["image", "video"]] audio_fnames = [ content[key] - for content in message["content"] + for content in message_content for key in ["audio", "url", "path"] if key in content and content["type"] == "audio" ] @@ -1834,7 +1837,9 @@ def apply_chat_template( for fname in video_fnames: # This updates the template in-place and adds audio entry # to ensure `audio` token is added by jinja - message["content"].append({"type": "audio"}) + content_list = get_message_content(message, default=[]) + if isinstance(content_list, list): + content_list.append({"type": "audio"}) batch_audios.append(load_audio(fname, sampling_rate=sampling_rate)) # Currently all processors can accept nested list of batches, but not flat list of visuals @@ -1886,7 +1891,7 @@ def apply_chat_template( text=prompt, images=batch_images if images_exist else None, videos=batch_videos if videos_exist else None, - audio=batch_audios if batch_audios else None, + audio=batch_audios or None, **processor_kwargs, ) @@ -1913,7 +1918,7 @@ def apply_chat_template( # Ensure end_pos is also within bounds if end_pos > len(input_ids[i]): end_pos = len(input_ids[i]) - for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])): + for token_id in range(start_pos, end_pos or len(input_ids[i])): current_mask[token_id] = 1 assistant_masks.append(current_mask) out["assistant_masks"] = assistant_masks diff --git a/src/transformers/tokenization_mistral_common.py b/src/transformers/tokenization_mistral_common.py index 1f218fe40873..a67537708b7c 100644 --- a/src/transformers/tokenization_mistral_common.py +++ b/src/transformers/tokenization_mistral_common.py @@ -35,6 +35,7 @@ TruncationStrategy, ) from transformers.utils import PaddingStrategy, TensorType, add_end_docstrings, logging, to_py_obj +from transformers.utils.chat_template_utils import get_message_content from transformers.utils.import_utils import is_mistral_common_available, is_torch_available, requires @@ -1127,7 +1128,9 @@ def _maybe_adapt_message(message: dict[str, Any]) -> None: """Adapt message to `mistral-common` format and leave validation to `mistral-common`.""" if not isinstance(message, dict): return message - maybe_list_content: str | list[dict[str, str | dict[str, Any]]] | None = message.get("content") + # Use get_message_content to keep the missing-content convention consistent + # across all chat-template entry points (#45290). + maybe_list_content: str | list[dict[str, str | dict[str, Any]]] = get_message_content(message, default=[]) if not maybe_list_content or isinstance(maybe_list_content, str): return message diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 3c5ec4acf8d3..57d8a5ac1883 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -51,6 +51,43 @@ BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) + + +def get_message_content(message: dict, *, default: list | str | None = None) -> Any: + """Return the ``content`` field of a chat message, defaulting if missing. + + Per the OpenAI chat-completion spec, an assistant message MAY omit + ``content`` when ``tool_calls`` is present (the assistant is requesting + tool execution and has no textual reply yet). Code that iterates over + message content for processing — extracting visuals from a multimodal + conversation, joining text segments for a text-only model, or rendering + a Jinja template — must default to a safe empty value rather than + raising ``KeyError``. + + This helper centralizes that convention so individual processors, + tokenizers, and serving handlers don't each have to re-implement it + inline (and don't drift apart over time). + + Args: + message: + A chat message dict, e.g. + ``{"role": "assistant", "tool_calls": [...]}``. + default: + **Keyword-only.** The value to return when the ``content`` key + is missing. Pass ``[]`` for code that iterates over multimodal + content blocks, and ``""`` for code that treats content as a + single text string. If ``None`` (the default), an empty list + is returned — callers SHOULD pass an explicit default that + matches the shape their code expects. + + Returns: + ``message["content"]`` if the key is present, otherwise *default*. + """ + if default is None: + default = [] + return message.get("content", default) + + # Extracts the initial segment of the docstring, containing the function description description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) # Extracts the Args: block from the docstring @@ -539,9 +576,17 @@ def render_jinja_template( if hasattr(chat, "messages"): # Indicates it's a Conversation object chat = chat.messages + # Ensure all messages have a content field for Jinja template compatibility. + # Per the OpenAI spec, assistant messages with tool_calls MAY omit content; + # default to an empty string here (text-only template default) so the + # rendered template still has a content field to reference. See + # ``get_message_content`` above for the convention. + chat = [ + {**msg, "content": get_message_content(msg, default="")} if "content" not in msg else msg for msg in chat + ] if continue_final_message: chat = deepcopy(chat) - final_message = chat[-1]["content"] + final_message = get_message_content(chat[-1], default="") if isinstance(final_message, (list, tuple)): for content_block in reversed(final_message): if "text" in content_block: @@ -554,7 +599,7 @@ def render_jinja_template( "continue_final_message is set but we could not find any text to continue in the final message!" ) else: - chat[-1]["content"] = chat[-1]["content"] + continue_final_message_tag + chat[-1]["content"] = final_message + continue_final_message_tag if return_assistant_tokens_mask: rendered_chat, generation_indices = _render_with_assistant_indices( compiled_template=compiled_template, @@ -597,11 +642,16 @@ def render_jinja_template( def is_valid_message(message): """ - Check that input is a valid message in a chat, namely a dict with "role" and "content" keys. + Check that input is a valid message in a chat. A valid message is a dict with a ``"role"`` + key and either a ``"content"`` key or a ``"tool_calls"`` key — assistant messages that + request tool execution may omit ``"content"`` per the OpenAI chat-completion spec + (see ``get_message_content``). """ if not isinstance(message, dict): return False - if not ("role" in message and "content" in message): + if "role" not in message: + return False + if "content" not in message and "tool_calls" not in message: return False return True @@ -609,10 +659,18 @@ def is_valid_message(message): class Chat: """This class is intended to just be used internally for pipelines and not exposed to users. We convert chats to this format because the rest of the pipeline code tends to assume that lists of messages are - actually a batch of samples rather than messages in the same conversation.""" + actually a batch of samples rather than messages in the same conversation. + + Each message must be a dict with a ``"role"`` key and either a ``"content"`` key or a + ``"tool_calls"`` key. Assistant messages that request tool execution may omit ``"content"`` + per the OpenAI chat-completion spec — see ``is_valid_message`` and ``get_message_content``. + """ def __init__(self, messages: dict): for message in messages: if not is_valid_message(message): - raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") + raise ValueError( + "When passing chat dicts as input, each dict must have a 'role' key and either a " + "'content' or 'tool_calls' key." + ) self.messages = messages diff --git a/tests/cli/test_serve.py b/tests/cli/test_serve.py index 9be3dbeb99ff..06cf5e007b35 100644 --- a/tests/cli/test_serve.py +++ b/tests/cli/test_serve.py @@ -209,6 +209,55 @@ def test_vlm_multi_turn(self): self.assertIsInstance(msg["content"], list) self.assertEqual(msg["content"][0]["type"], "text") + # Regression tests for issue #45290 — assistant messages with tool_calls but no content + def test_vlm_tool_calls_without_content(self): + """VLM modality must not crash on an assistant message that has tool_calls but no content field.""" + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages + + messages = [ + {"role": "user", "content": "What's the weather in London?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_def456", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "London"}'}, + } + ], + }, + ] + + result = get_processor_inputs_from_messages(messages, Modality.VLM) + self.assertIsInstance(result, list) + self.assertEqual(result[1]["role"], "assistant") + # Missing content should default to an empty list for VLM + self.assertEqual(result[1]["content"], []) + + def test_llm_tool_calls_without_content(self): + """LLM modality must not crash on an assistant message that has tool_calls but no content field.""" + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages + + messages = [ + {"role": "user", "content": "What's the weather in London?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_def456", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "London"}'}, + } + ], + }, + ] + + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertIsInstance(result, list) + self.assertEqual(result[1]["role"], "assistant") + # Missing content should default to an empty string for LLM + self.assertEqual(result[1]["content"], "") + class TestGenerativeModelList(unittest.TestCase): def test_lists_only_generative_models(self): diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 5c632571b678..7ff6fda30981 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -2017,3 +2017,56 @@ def test_get_num_multimodal_tokens_matches_processor_call(self): num_image_tokens_from_call = inputs.mm_token_type_ids.sum(-1).tolist() num_image_tokens_from_helper = processor._get_num_multimodal_tokens(image_sizes=image_sizes) self.assertListEqual(num_image_tokens_from_call, num_image_tokens_from_helper["num_image_tokens"]) + + def test_apply_chat_template_with_tool_calls_no_content(self): + """Regression test for issue #45290: assistant messages with tool_calls but no content field should not raise KeyError.""" + processor = self.get_processor() + + # Skip if processor doesn't support apply_chat_template + if not hasattr(processor, "apply_chat_template"): + self.skipTest("Processor doesn't support apply_chat_template") + + # Skip processors without a default chat template — they raise ValueError + # before reaching the patched code path. + chat_template = getattr(processor, "chat_template", None) + if chat_template is None: + self.skipTest("Processor has no chat_template") + + # The bug only manifests for chat templates that actually iterate over + # tool_calls. For templates that don't reference tool_calls at all, the + # test is meaningless because the template would never inspect that + # field. Skip those rather than asserting against templates that were + # never designed to handle tool-calling at all. + chat_template_source = ( + chat_template + if isinstance(chat_template, str) + else (chat_template.get("default") if isinstance(chat_template, dict) else "") + ) + if not chat_template_source or "tool_calls" not in chat_template_source: + self.skipTest("Processor's chat_template does not support tool_calls") + + # Message with tool_calls but no content field (valid per OpenAI spec) + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "NYC"}'}, + } + ], + }, + ] + + # Should not raise KeyError + try: + result = processor.apply_chat_template(messages, tokenize=False) + # If we get here, the fix worked + self.assertIsInstance(result, str) + except KeyError as e: + if "content" in str(e): + self.fail(f"apply_chat_template raised KeyError for missing 'content' field: {e}") + else: + raise diff --git a/tests/test_tokenization_mistral_common.py b/tests/test_tokenization_mistral_common.py index 96d44e2cc80c..fcacb5d66552 100644 --- a/tests/test_tokenization_mistral_common.py +++ b/tests/test_tokenization_mistral_common.py @@ -2274,3 +2274,35 @@ def test_prepare_for_model(self): # unsupported kwargs should raise ValueError with self.assertRaises(ValueError): self.tokenizer.prepare_for_model(token_ids, add_special_tokens=False, unsupported_arg="") + + @require_mistral_common + def test_apply_chat_template_with_tool_calls_no_content(self): + """Regression test for issue #45290: Mistral tokenizer should handle assistant messages with tool_calls but no content field.""" + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + + # Message with tool_calls but no content field (valid per OpenAI spec) + messages = [ + {"role": "user", "content": "What's the weather in Paris?"}, + { + "role": "assistant", + "tool_calls": [ + { + # Mistral tokenizer requires tool_call ids to be exactly 9 alphanumeric chars. + "id": "abc123xyz", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "Paris"}'}, + } + ], + }, + ] + + # Should not raise KeyError when processing messages with tool_calls but no content + try: + result = tokenizer.apply_chat_template(messages, tokenize=False) + # If we get here, the fix worked + self.assertIsInstance(result, str) + except KeyError as e: + if "content" in str(e): + self.fail(f"Mistral tokenizer raised KeyError for missing 'content' field: {e}") + else: + raise diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index ec94e5e35c57..b67d96dee582 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -16,6 +16,7 @@ from typing import Literal from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema +from transformers.utils.chat_template_utils import get_message_content, is_valid_message class JsonSchemaGeneratorTest(unittest.TestCase): @@ -611,3 +612,72 @@ def fn(x: str, y: list[str | int] | None, z: tuple[str | int, str] = (42, "hello }, } self.assertEqual(schema["function"], expected_schema) + + +class GetMessageContentTest(unittest.TestCase): + """Tests for get_message_content (issue #45290).""" + + def test_returns_content_when_present(self): + self.assertEqual(get_message_content({"role": "user", "content": "hi"}), "hi") + self.assertEqual( + get_message_content({"role": "user", "content": [{"type": "text", "text": "x"}]}), + [{"type": "text", "text": "x"}], + ) + + def test_returns_default_list_when_content_missing(self): + # Default is [] — multimodal-iterating callers expect a list. + self.assertEqual(get_message_content({"role": "assistant", "tool_calls": []}), []) + + def test_returns_explicit_default_string(self): + # Text-only callers pass default="". + self.assertEqual( + get_message_content({"role": "assistant", "tool_calls": []}, default=""), + "", + ) + + def test_returns_explicit_default_list(self): + self.assertEqual( + get_message_content({"role": "assistant", "tool_calls": []}, default=[]), + [], + ) + + def test_default_is_keyword_only(self): + # Passing default positionally must raise TypeError to keep call sites self-documenting. + with self.assertRaises(TypeError): + get_message_content({"role": "user"}, []) # type: ignore[misc] + + +class IsValidMessageTest(unittest.TestCase): + """Tests for is_valid_message (issue #45290).""" + + def test_accepts_role_and_content(self): + self.assertTrue(is_valid_message({"role": "user", "content": "hi"})) + self.assertTrue(is_valid_message({"role": "assistant", "content": [{"type": "text", "text": "ok"}]})) + + def test_accepts_role_and_tool_calls_without_content(self): + # Per OpenAI chat-completion spec, assistant messages may omit content + # when tool_calls is present. + self.assertTrue( + is_valid_message( + { + "role": "assistant", + "tool_calls": [ + {"id": "abc12345x", "type": "function", "function": {"name": "f", "arguments": "{}"}} + ], + } + ) + ) + + def test_rejects_missing_role(self): + self.assertFalse(is_valid_message({"content": "hi"})) + self.assertFalse(is_valid_message({"tool_calls": []})) + + def test_rejects_message_missing_both_content_and_tool_calls(self): + # Negative case explicitly requested by review of #45290 patch. + self.assertFalse(is_valid_message({"role": "assistant"})) + self.assertFalse(is_valid_message({"role": "user"})) + + def test_rejects_non_dict(self): + self.assertFalse(is_valid_message("not a dict")) + self.assertFalse(is_valid_message(["role", "user"])) + self.assertFalse(is_valid_message(None))