From 13cee5ba6dda40ef01218bf3477771ab150d37b7 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 28 Apr 2025 21:39:46 -0700 Subject: [PATCH 1/6] make dtype of empty messages consistent Signed-off-by: ashors1 --- nemo_rl/data/llm_message_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nemo_rl/data/llm_message_utils.py b/nemo_rl/data/llm_message_utils.py index ec7ccdbed5..55b72b7163 100644 --- a/nemo_rl/data/llm_message_utils.py +++ b/nemo_rl/data/llm_message_utils.py @@ -420,6 +420,12 @@ def get_formatted_message_log( new_message["token_ids"] = tokenizer( message_chunk, return_tensors="pt", add_special_tokens=False )["input_ids"][0] + if len(new_message["token_ids"]) == 0: + # if there is an empty message, the empty `token_ids` tensor ends up being in fp32, + # which causes `_validate_tensor_consistency` to fail. To fix this, we convert the + # empty tensor to int64. + new_message["token_ids"] = new_message["token_ids"].to(torch.int64) + new_message["content"] = message_chunk new_message_log.append(new_message) From 18ce2fcfaaa8c0f3375e5b0215e74044d756e45f Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 29 Apr 2025 14:14:12 -0700 Subject: [PATCH 2/6] add a unit test Signed-off-by: ashors1 --- tests/unit/data/test_llm_message_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/unit/data/test_llm_message_utils.py b/tests/unit/data/test_llm_message_utils.py index 1a7ddc568c..f182cf4eb9 100644 --- a/tests/unit/data/test_llm_message_utils.py +++ b/tests/unit/data/test_llm_message_utils.py @@ -18,6 +18,7 @@ from transformers import AutoTokenizer from nemo_rl.data.llm_message_utils import ( + _validate_tensor_consistency, message_log_to_flat_messages, get_keys_from_message_log, batched_message_log_to_flat_message, @@ -405,6 +406,21 @@ def test_get_formatted_message_log_qwen( assert actual_text == expected_text +def test_formatted_message_log_empty_message(): + message_log = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": ""}, + ] + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + task_data_spec = TaskDataSpec(task_name="test") + result = get_formatted_message_log(message_log, tokenizer, task_data_spec) + flat_result = message_log_to_flat_messages(result) + for k in flat_result.keys(): + if isinstance(flat_result[k], torch.Tensor): + # make sure validate_tensor_consistency does not raise an error when one of the messages is empty + _validate_tensor_consistency([flat_result[k]]) + + def test_add_loss_mask_to_chat_message_log( tokenized_chat_message_log: LLMMessageLogType, ): From 90f3143ee6c64be48ccea27c523cd0ba89d833f3 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 29 Apr 2025 14:25:46 -0700 Subject: [PATCH 3/6] improve test to cover actual failing case Signed-off-by: ashors1 --- tests/unit/data/test_llm_message_utils.py | 35 +++++++++++++++++------ 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/tests/unit/data/test_llm_message_utils.py b/tests/unit/data/test_llm_message_utils.py index f182cf4eb9..522c2c14a0 100644 --- a/tests/unit/data/test_llm_message_utils.py +++ b/tests/unit/data/test_llm_message_utils.py @@ -17,6 +17,7 @@ from typing import Dict, List from transformers import AutoTokenizer +from nemo_rl.data.hf_datasets import COMMON_CHAT_TEMPLATES from nemo_rl.data.llm_message_utils import ( _validate_tensor_consistency, message_log_to_flat_messages, @@ -407,18 +408,36 @@ def test_get_formatted_message_log_qwen( def test_formatted_message_log_empty_message(): - message_log = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": ""}, + message_logs = [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": ""}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], ] tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + tokenizer.chat_template = COMMON_CHAT_TEMPLATES.passthrough_prompt_response task_data_spec = TaskDataSpec(task_name="test") - result = get_formatted_message_log(message_log, tokenizer, task_data_spec) - flat_result = message_log_to_flat_messages(result) - for k in flat_result.keys(): - if isinstance(flat_result[k], torch.Tensor): + result = [ + get_formatted_message_log( + message_log, + tokenizer, + task_data_spec, + add_bos_token=False, + add_eos_token=False, + ) + for message_log in message_logs + ] + flat_result = [message_log_to_flat_messages(m) for m in result] + for k in flat_result[0].keys(): + if isinstance(flat_result[0][k], torch.Tensor): # make sure validate_tensor_consistency does not raise an error when one of the messages is empty - _validate_tensor_consistency([flat_result[k]]) + _validate_tensor_consistency( + [flat_result[i][k] for i in range(len(flat_result))] + ) def test_add_loss_mask_to_chat_message_log( From 425a113beccf710f8f487fd3be67ec8613c63c6e Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 29 Apr 2025 15:15:30 -0700 Subject: [PATCH 4/6] linting Signed-off-by: ashors1 --- nemo_rl/data/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/data/datasets.py b/nemo_rl/data/datasets.py index 70b1dd9786..b52cbc0a84 100644 --- a/nemo_rl/data/datasets.py +++ b/nemo_rl/data/datasets.py @@ -14,8 +14,8 @@ from typing import Any, Dict, List, Tuple, Union import torch -from datasets import Dataset +from datasets import Dataset from nemo_rl.data.interfaces import ( DatumSpec, TaskDataProcessFnCallable, From 55e7f2b732a40b88d8aa7f9f0ddb8f3a962f4fdf Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 29 Apr 2025 15:33:01 -0700 Subject: [PATCH 5/6] linting Signed-off-by: ashors1 --- nemo_rl/data/llm_message_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/data/llm_message_utils.py b/nemo_rl/data/llm_message_utils.py index 539b4e3dd8..51cd5a279d 100644 --- a/nemo_rl/data/llm_message_utils.py +++ b/nemo_rl/data/llm_message_utils.py @@ -15,8 +15,8 @@ from typing import Dict, List import torch - from datasets import Dataset + from nemo_rl.data.interfaces import ( FlatMessagesType, LLMMessageLogType, From 974a1bcd3e59727040edf790858bbe0d8db13cf1 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 29 Apr 2025 15:34:53 -0700 Subject: [PATCH 6/6] linting Signed-off-by: ashors1 --- nemo_rl/data/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/data/datasets.py b/nemo_rl/data/datasets.py index b52cbc0a84..70b1dd9786 100644 --- a/nemo_rl/data/datasets.py +++ b/nemo_rl/data/datasets.py @@ -14,8 +14,8 @@ from typing import Any, Dict, List, Tuple, Union import torch - from datasets import Dataset + from nemo_rl.data.interfaces import ( DatumSpec, TaskDataProcessFnCallable,