Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions nemo_automodel/components/datasets/llm/chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,16 @@ def _normalize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]:
raise ValueError(f"assistant message `tool_calls[{idx}]` must be a dict")

tool_call_id = tool_call.get("id")
if not isinstance(tool_call_id, str) or not tool_call_id:
raise ValueError(f"assistant message `tool_calls[{idx}].id` must be a non-empty string")
if tool_call_id is None or tool_call_id == "":
tool_call_id = f"call_{idx}"
elif not isinstance(tool_call_id, str):
raise ValueError(f"assistant message `tool_calls[{idx}].id` must be a string when provided")

tool_call_type = tool_call.get("type")
if not isinstance(tool_call_type, str) or not tool_call_type:
raise ValueError(f"assistant message `tool_calls[{idx}].type` must be a non-empty string")
if tool_call_type is None or tool_call_type == "":
tool_call_type = "function"
elif not isinstance(tool_call_type, str):
raise ValueError(f"assistant message `tool_calls[{idx}].type` must be a string when provided")

function = tool_call.get("function")
if not isinstance(function, dict):
Expand All @@ -232,6 +236,8 @@ def _normalize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]:
normalized_function["arguments"] = json.dumps(function_arguments)

normalized_tool_call = dict(tool_call)
normalized_tool_call["id"] = tool_call_id
normalized_tool_call["type"] = tool_call_type
normalized_tool_call["function"] = normalized_function
normalized_tool_calls.append(normalized_tool_call)

Expand Down
63 changes: 46 additions & 17 deletions tests/unit_tests/datasets/llm/test_chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,26 @@ def test_normalize_messages_supports_reasoning_and_tool_call_fields():
assert none_reasoning[0]["reasoning_content"] == ""


def test_normalize_tool_calls_autofills_missing_id_and_type():
msgs = [
{
"role": "assistant",
"content": "",
"tool_calls": [
{"function": {"name": "fn_a", "arguments": {"x": 1}}},
{"id": "", "type": "", "function": {"name": "fn_b", "arguments": "{}"}},
],
}
]
norm = tcd._normalize_messages(msgs)
calls = norm[0]["tool_calls"]
assert calls[0]["id"] == "call_0"
assert calls[0]["type"] == "function"
assert calls[0]["function"]["arguments"] == '{"x": 1}'
assert calls[1]["id"] == "call_1"
assert calls[1]["type"] == "function"


@pytest.mark.parametrize(
("message", "error_pattern"),
[
Expand All @@ -90,7 +110,7 @@ def test_normalize_messages_supports_reasoning_and_tool_call_fields():
{
"role": "assistant",
"content": "",
"tool_calls": [{"type": "function", "function": {"name": "fn", "arguments": "{}"}}],
"tool_calls": [{"id": 123, "type": "function", "function": {"name": "fn", "arguments": "{}"}}],
}
],
"tool_calls\\[0\\]\\.id",
Expand All @@ -100,7 +120,7 @@ def test_normalize_messages_supports_reasoning_and_tool_call_fields():
{
"role": "assistant",
"content": "",
"tool_calls": [{"id": "call_1", "function": {"name": "fn", "arguments": "{}"}}],
"tool_calls": [{"id": "call_1", "type": 1, "function": {"name": "fn", "arguments": "{}"}}],
}
],
"tool_calls\\[0\\]\\.type",
Expand Down Expand Up @@ -140,19 +160,23 @@ def test_load_openai_messages_local_and_errors(tmp_path, monkeypatch):
# Create local files: JSONL and JSON
jsonl = tmp_path / "data.jsonl"
jsonl.write_text(
"\n".join([
json.dumps({"messages": [{"role": "user", "content": "u1"}]}),
json.dumps({"messages": [{"role": "assistant", "content": "a1"}]}),
]),
"\n".join(
[
json.dumps({"messages": [{"role": "user", "content": "u1"}]}),
json.dumps({"messages": [{"role": "assistant", "content": "a1"}]}),
]
),
encoding="utf-8",
)

json_file = tmp_path / "data.json"
json_file.write_text(
json.dumps([
{"messages": [{"role": "user", "content": "u2"}]},
{"messages": [{"role": "assistant", "content": "a2"}]},
]),
json.dumps(
[
{"messages": [{"role": "user", "content": "u2"}]},
{"messages": [{"role": "assistant", "content": "a2"}]},
]
),
encoding="utf-8",
)

Expand Down Expand Up @@ -260,12 +284,14 @@ class Tok:
calls = []

def fake_format(tokenizer, normalized, eos_id, pad_id, **kwargs):
calls.append({
"normalized": normalized,
"eos": eos_id,
"pad": pad_id,
"kwargs": kwargs,
})
calls.append(
{
"normalized": normalized,
"eos": eos_id,
"pad": pad_id,
"kwargs": kwargs,
}
)
return {"input_ids": [1, 2], "labels": [0, 1], "attention_mask": [1, 1]}

monkeypatch.setattr(tcd, "format_chat_template", fake_format)
Expand Down Expand Up @@ -323,14 +349,17 @@ def fake_format(tokenizer, normalized, eos_id, pad_id, **kwargs):

def test_chat_dataset_skip_invalid_samples_does_not_filter_structured_bad_rows(monkeypatch):
"""skip_invalid_samples only affects JSONL parse errors, not invalid message rows after load."""

class Tok:
eos_token_id = 1
chat_template = "{{ default }}"

tok = Tok()
monkeypatch.setattr(tcd, "_has_chat_template", lambda _tok: True)
monkeypatch.setattr(tcd, "_add_pad_token", lambda _tok: 3)
monkeypatch.setattr(tcd, "format_chat_template", lambda *a, **k: {"input_ids": [1], "labels": [1], "attention_mask": [1]})
monkeypatch.setattr(
tcd, "format_chat_template", lambda *a, **k: {"input_ids": [1], "labels": [1], "attention_mask": [1]}
)

dataset_rows = [
{"messages": [{"role": "user", "content": "ok"}]},
Expand Down
Loading