diff --git a/nemo_automodel/components/datasets/llm/chat_dataset.py b/nemo_automodel/components/datasets/llm/chat_dataset.py index 0030923019..96c3276176 100644 --- a/nemo_automodel/components/datasets/llm/chat_dataset.py +++ b/nemo_automodel/components/datasets/llm/chat_dataset.py @@ -208,12 +208,16 @@ def _normalize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]: raise ValueError(f"assistant message `tool_calls[{idx}]` must be a dict") tool_call_id = tool_call.get("id") - if not isinstance(tool_call_id, str) or not tool_call_id: - raise ValueError(f"assistant message `tool_calls[{idx}].id` must be a non-empty string") + if tool_call_id is None or tool_call_id == "": + tool_call_id = f"call_{idx}" + elif not isinstance(tool_call_id, str): + raise ValueError(f"assistant message `tool_calls[{idx}].id` must be a string when provided") tool_call_type = tool_call.get("type") - if not isinstance(tool_call_type, str) or not tool_call_type: - raise ValueError(f"assistant message `tool_calls[{idx}].type` must be a non-empty string") + if tool_call_type is None or tool_call_type == "": + tool_call_type = "function" + elif not isinstance(tool_call_type, str): + raise ValueError(f"assistant message `tool_calls[{idx}].type` must be a string when provided") function = tool_call.get("function") if not isinstance(function, dict): @@ -232,6 +236,8 @@ def _normalize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]: normalized_function["arguments"] = json.dumps(function_arguments) normalized_tool_call = dict(tool_call) + normalized_tool_call["id"] = tool_call_id + normalized_tool_call["type"] = tool_call_type normalized_tool_call["function"] = normalized_function normalized_tool_calls.append(normalized_tool_call) diff --git a/tests/unit_tests/datasets/llm/test_chat_dataset.py b/tests/unit_tests/datasets/llm/test_chat_dataset.py index 6335eb9323..d2ca81ca09 100644 --- a/tests/unit_tests/datasets/llm/test_chat_dataset.py +++ b/tests/unit_tests/datasets/llm/test_chat_dataset.py @@ -74,6 +74,26 @@ def test_normalize_messages_supports_reasoning_and_tool_call_fields(): assert none_reasoning[0]["reasoning_content"] == "" +def test_normalize_tool_calls_autofills_missing_id_and_type(): + msgs = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + {"function": {"name": "fn_a", "arguments": {"x": 1}}}, + {"id": "", "type": "", "function": {"name": "fn_b", "arguments": "{}"}}, + ], + } + ] + norm = tcd._normalize_messages(msgs) + calls = norm[0]["tool_calls"] + assert calls[0]["id"] == "call_0" + assert calls[0]["type"] == "function" + assert calls[0]["function"]["arguments"] == '{"x": 1}' + assert calls[1]["id"] == "call_1" + assert calls[1]["type"] == "function" + + @pytest.mark.parametrize( ("message", "error_pattern"), [ @@ -90,7 +110,7 @@ def test_normalize_messages_supports_reasoning_and_tool_call_fields(): { "role": "assistant", "content": "", - "tool_calls": [{"type": "function", "function": {"name": "fn", "arguments": "{}"}}], + "tool_calls": [{"id": 123, "type": "function", "function": {"name": "fn", "arguments": "{}"}}], } ], "tool_calls\\[0\\]\\.id", @@ -100,7 +120,7 @@ def test_normalize_messages_supports_reasoning_and_tool_call_fields(): { "role": "assistant", "content": "", - "tool_calls": [{"id": "call_1", "function": {"name": "fn", "arguments": "{}"}}], + "tool_calls": [{"id": "call_1", "type": 1, "function": {"name": "fn", "arguments": "{}"}}], } ], "tool_calls\\[0\\]\\.type", @@ -140,19 +160,23 @@ def test_load_openai_messages_local_and_errors(tmp_path, monkeypatch): # Create local files: JSONL and JSON jsonl = tmp_path / "data.jsonl" jsonl.write_text( - "\n".join([ - json.dumps({"messages": [{"role": "user", "content": "u1"}]}), - json.dumps({"messages": [{"role": "assistant", "content": "a1"}]}), - ]), + "\n".join( + [ + json.dumps({"messages": [{"role": "user", "content": "u1"}]}), + json.dumps({"messages": [{"role": "assistant", "content": "a1"}]}), + ] + ), encoding="utf-8", ) json_file = tmp_path / "data.json" json_file.write_text( - json.dumps([ - {"messages": [{"role": "user", "content": "u2"}]}, - {"messages": [{"role": "assistant", "content": "a2"}]}, - ]), + json.dumps( + [ + {"messages": [{"role": "user", "content": "u2"}]}, + {"messages": [{"role": "assistant", "content": "a2"}]}, + ] + ), encoding="utf-8", ) @@ -260,12 +284,14 @@ class Tok: calls = [] def fake_format(tokenizer, normalized, eos_id, pad_id, **kwargs): - calls.append({ - "normalized": normalized, - "eos": eos_id, - "pad": pad_id, - "kwargs": kwargs, - }) + calls.append( + { + "normalized": normalized, + "eos": eos_id, + "pad": pad_id, + "kwargs": kwargs, + } + ) return {"input_ids": [1, 2], "labels": [0, 1], "attention_mask": [1, 1]} monkeypatch.setattr(tcd, "format_chat_template", fake_format) @@ -323,6 +349,7 @@ def fake_format(tokenizer, normalized, eos_id, pad_id, **kwargs): def test_chat_dataset_skip_invalid_samples_does_not_filter_structured_bad_rows(monkeypatch): """skip_invalid_samples only affects JSONL parse errors, not invalid message rows after load.""" + class Tok: eos_token_id = 1 chat_template = "{{ default }}" @@ -330,7 +357,9 @@ class Tok: tok = Tok() monkeypatch.setattr(tcd, "_has_chat_template", lambda _tok: True) monkeypatch.setattr(tcd, "_add_pad_token", lambda _tok: 3) - monkeypatch.setattr(tcd, "format_chat_template", lambda *a, **k: {"input_ids": [1], "labels": [1], "attention_mask": [1]}) + monkeypatch.setattr( + tcd, "format_chat_template", lambda *a, **k: {"input_ids": [1], "labels": [1], "attention_mask": [1]} + ) dataset_rows = [ {"messages": [{"role": "user", "content": "ok"}]},