Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions nemo_rl/data/llm_message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,12 @@ def get_formatted_message_log(
new_message["token_ids"] = tokenizer(
message_chunk, return_tensors="pt", add_special_tokens=False
)["input_ids"][0]
if len(new_message["token_ids"]) == 0:
# if there is an empty message, the empty `token_ids` tensor ends up being in fp32,
# which causes `_validate_tensor_consistency` to fail. To fix this, we convert the
# empty tensor to int64.
new_message["token_ids"] = new_message["token_ids"].to(torch.int64)

new_message["content"] = message_chunk
new_message_log.append(new_message)

Expand Down
35 changes: 35 additions & 0 deletions tests/unit/data/test_llm_message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import torch
from transformers import AutoTokenizer

from nemo_rl.data.hf_datasets import COMMON_CHAT_TEMPLATES
from nemo_rl.data.interfaces import LLMMessageLogType, TaskDataSpec
from nemo_rl.data.llm_message_utils import (
_validate_tensor_consistency,
add_loss_mask_to_message_log,
batched_message_log_to_flat_message,
get_first_index_that_differs,
Expand Down Expand Up @@ -406,6 +408,39 @@ def test_get_formatted_message_log_qwen(
assert actual_text == expected_text


def test_formatted_message_log_empty_message():
message_logs = [
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": ""},
],
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
]
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
tokenizer.chat_template = COMMON_CHAT_TEMPLATES.passthrough_prompt_response
task_data_spec = TaskDataSpec(task_name="test")
result = [
get_formatted_message_log(
message_log,
tokenizer,
task_data_spec,
add_bos_token=False,
add_eos_token=False,
)
for message_log in message_logs
]
flat_result = [message_log_to_flat_messages(m) for m in result]
for k in flat_result[0].keys():
if isinstance(flat_result[0][k], torch.Tensor):
# make sure validate_tensor_consistency does not raise an error when one of the messages is empty
_validate_tensor_consistency(
[flat_result[i][k] for i in range(len(flat_result))]
)


def test_add_loss_mask_to_chat_message_log(
tokenized_chat_message_log: LLMMessageLogType,
):
Expand Down