diff --git a/src/transformers/utils/chat_parsing_utils.py b/src/transformers/utils/chat_parsing_utils.py index d1ab2791c3f7..1035d65b094a 100644 --- a/src/transformers/utils/chat_parsing_utils.py +++ b/src/transformers/utils/chat_parsing_utils.py @@ -70,14 +70,22 @@ def recursive_parse( # If not, we have to do a little parsing. First, set some vars and do basic validation node_type = node_schema.get("type") - has_regex = "x-regex" in node_schema or "x-regex-iterator" in node_schema or "x-regex-key-value" in node_schema + has_regex = ( + "x-regex" in node_schema + or "x-regex-iterator" in node_schema + or "x-regex-key-value" in node_schema + or "x-regex-substitutions" in node_schema + ) if has_regex and not isinstance(node_content, str): raise TypeError( - "Schema node got a non-string input, but has a regex for parsing.\n" + "Schema node got a non-string input, but has a regex for parsing or substitution.\n" f"Input: {node_content}\n" f"Schema: {node_schema}" ) + node_subs = node_schema.get("x-regex-substitutions", []) + for node_sub in node_subs: + node_content = re.sub(node_sub[0], node_sub[1], node_content, flags=re.DOTALL) node_regex = node_schema.get("x-regex") node_regex_iterator = node_schema.get("x-regex-iterator") node_regex_to_dict = node_schema.get("x-regex-key-value") diff --git a/tests/utils/test_chat_parsing_utils.py b/tests/utils/test_chat_parsing_utils.py index 6f91f9a2d97b..60bc181d1c88 100644 --- a/tests/utils/test_chat_parsing_utils.py +++ b/tests/utils/test_chat_parsing_utils.py @@ -180,6 +180,42 @@ }, } +re_sub_schema = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "thinking": {"type": "string"}, + "content": {"type": "string"}, + "tool_calls": { + "x-regex-iterator": r"<\|tool_call>(.*?)", + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "x-regex": r"call\:(?P\w+)(?P\{.*\})", + "properties": { + "name": { + "type": "string", + }, + "arguments": { + "type": "object", + "x-regex-key-value": r'(?P\w+):(?P<\|"\|>.*?<\|"\|>|[^,}]+)', + "additionalProperties": { + "x-regex-substitutions": [[r'^<\|"\|>|<\|"\|>$', ""]], + }, + }, + }, + }, + }, + }, + }, + }, + "x-regex": r"(\<\|channel\>thought\n(?P.*?)\)?(?P(?:(?!\<\|tool_call\>).)+)?(?P\<\|tool_call\>.*\)?", +} + prefix_items_schema = { # Not intended to be "realistic", just checks that prefixItems can handle a heterogeneous array "x-regex-iterator": r"(.*?)<\/block>", @@ -389,6 +425,27 @@ def test_qwen3_tool_calls(self): }, ) + def test_re_sub_schema(self): + """Test that a schema doing re substitutions to enable JSON parsing works.""" + model_out = '<|channel>thought\nThe user is asking for the current temperature in Paris. I should check the available tools to see if there\'s a function that can provide this information.<|tool_call>call:get_current_temperature{detail_level:0,location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>}<|tool_response>' + parsed = recursive_parse(model_out, re_sub_schema) + self.assertEqual( + parsed, + { + "role": "assistant", + "thinking": "The user is asking for the current temperature in Paris. I should check the available tools to see if there's a function that can provide this information.", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_current_temperature", + "arguments": {"detail_level": "0", "location": "Paris, France", "unit": "celsius"}, + }, + } + ], + }, + ) + def test_required_fields_present(self): """Test that required fields pass validation when present in the output.""" schema = {