diff --git a/nemo_rl/data/llm_message_utils.py b/nemo_rl/data/llm_message_utils.py index f2d24fc421..51cd5a279d 100644 --- a/nemo_rl/data/llm_message_utils.py +++ b/nemo_rl/data/llm_message_utils.py @@ -421,6 +421,12 @@ def get_formatted_message_log( new_message["token_ids"] = tokenizer( message_chunk, return_tensors="pt", add_special_tokens=False )["input_ids"][0] + if len(new_message["token_ids"]) == 0: + # if there is an empty message, the empty `token_ids` tensor ends up being in fp32, + # which causes `_validate_tensor_consistency` to fail. To fix this, we convert the + # empty tensor to int64. + new_message["token_ids"] = new_message["token_ids"].to(torch.int64) + new_message["content"] = message_chunk new_message_log.append(new_message) diff --git a/tests/unit/data/test_llm_message_utils.py b/tests/unit/data/test_llm_message_utils.py index 0a5cb3ef4b..fc4c6c6b8d 100644 --- a/tests/unit/data/test_llm_message_utils.py +++ b/tests/unit/data/test_llm_message_utils.py @@ -18,8 +18,10 @@ import torch from transformers import AutoTokenizer +from nemo_rl.data.hf_datasets import COMMON_CHAT_TEMPLATES from nemo_rl.data.interfaces import LLMMessageLogType, TaskDataSpec from nemo_rl.data.llm_message_utils import ( + _validate_tensor_consistency, add_loss_mask_to_message_log, batched_message_log_to_flat_message, get_first_index_that_differs, @@ -406,6 +408,39 @@ def test_get_formatted_message_log_qwen( assert actual_text == expected_text +def test_formatted_message_log_empty_message(): + message_logs = [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": ""}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + ] + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + tokenizer.chat_template = COMMON_CHAT_TEMPLATES.passthrough_prompt_response + task_data_spec = TaskDataSpec(task_name="test") + result = [ + get_formatted_message_log( + message_log, + tokenizer, + task_data_spec, + add_bos_token=False, + add_eos_token=False, + ) + for message_log in message_logs + ] + flat_result = [message_log_to_flat_messages(m) for m in result] + for k in flat_result[0].keys(): + if isinstance(flat_result[0][k], torch.Tensor): + # make sure validate_tensor_consistency does not raise an error when one of the messages is empty + _validate_tensor_consistency( + [flat_result[i][k] for i in range(len(flat_result))] + ) + + def test_add_loss_mask_to_chat_message_log( tokenized_chat_message_log: LLMMessageLogType, ):