From 236bbc76a158eb8e561476f8cbe6d22d2f020e62 Mon Sep 17 00:00:00 2001 From: Atsunori Fujita Date: Thu, 5 Jun 2025 07:02:38 +0000 Subject: [PATCH 1/3] Add OpenAI format dataset for SFT Signed-off-by: Atsunori Fujita --- examples/run_sft.py | 20 ++++--- nemo_rl/data/hf_datasets/__init__.py | 2 + .../data/hf_datasets/oai_format_dataset.py | 58 +++++++++++++++++++ 3 files changed, 71 insertions(+), 9 deletions(-) create mode 100644 nemo_rl/data/hf_datasets/oai_format_dataset.py diff --git a/examples/run_sft.py b/examples/run_sft.py index 8eb93b5adc..65ab32fe7f 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -35,9 +35,7 @@ def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Run SFT training with configuration") - parser.add_argument( - "--config", type=str, default=None, help="Path to YAML config file" - ) + parser.add_argument("--config", type=str, default=None, help="Path to YAML config file") # Parse known args for the script args, overrides = parser.parse_known_args() @@ -74,9 +72,7 @@ def sft_preprocessor( if length > max_seq_length: # make smaller and mask out for message in message_log: - message["token_ids"] = message["token_ids"][ - : min(4, max_seq_length // len(message_log)) - ] + message["token_ids"] = message["token_ids"][: min(4, max_seq_length // len(message_log))] loss_multiplier = 0.0 output = { @@ -109,6 +105,14 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): output_key=data_config["output_key"], prompt_file=data_config["prompt_file"], ) + elif data_cls == "openai_format": + data = hf_datasets.OpenAIFormatDataset( + data_config["train_data_path"], + data_config["val_data_path"], + data_config["chat_key"], + data_config["system_key"], + data_config["system_prompt"], + ) else: raise ValueError(f"Unknown dataset class: {data_cls}") print( @@ -173,9 +177,7 @@ def main(): config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) print(f"📊 Using log directory: {config['logger']['log_dir']}") if config["checkpointing"]["enabled"]: - print( - f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" - ) + print(f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}") init_ray() diff --git a/nemo_rl/data/hf_datasets/__init__.py b/nemo_rl/data/hf_datasets/__init__.py index 54d4fd9c34..72bd1b6591 100644 --- a/nemo_rl/data/hf_datasets/__init__.py +++ b/nemo_rl/data/hf_datasets/__init__.py @@ -16,6 +16,7 @@ from nemo_rl.data.hf_datasets.dpo import DPODataset from nemo_rl.data.hf_datasets.helpsteer3 import HelpSteer3Dataset from nemo_rl.data.hf_datasets.oasst import OasstDataset +from nemo_rl.data.hf_datasets.oai_format_dataset import OpenAIFormatDataset from nemo_rl.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset from nemo_rl.data.hf_datasets.prompt_response_dataset import ( PromptResponseDataset, @@ -26,6 +27,7 @@ "DPODataset", "HelpSteer3Dataset", "OasstDataset", + "OpenAIFormatDataset", "OpenMathInstruct2Dataset", "PromptResponseDataset", "SquadDataset", diff --git a/nemo_rl/data/hf_datasets/oai_format_dataset.py b/nemo_rl/data/hf_datasets/oai_format_dataset.py new file mode 100644 index 0000000000..5aea8ef584 --- /dev/null +++ b/nemo_rl/data/hf_datasets/oai_format_dataset.py @@ -0,0 +1,58 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from datasets import load_dataset + +from nemo_rl.data.interfaces import TaskDataSpec + + +class OpenAIFormatDataset: + def __init__( + self, + train_ds_path: str, + val_ds_path: str, + chat_key: str = "messages", + system_key: str = None, + system_prompt: str = None, + ): + self.chat_key = chat_key + self.system_key = system_key + self.system_prompt = system_prompt + train_original_dataset = load_dataset("json", data_files=train_ds_path)["train"] + val_original_dataset = load_dataset("json", data_files=val_ds_path)["train"] + + formatted_train_dataset = train_original_dataset.map(self.add_messages_key) + formatted_val_dataset = val_original_dataset.map(self.add_messages_key) + + self.formatted_ds = { + "train": formatted_train_dataset, + "validation": formatted_val_dataset, + } + + self.task_spec = TaskDataSpec( + "json_dataset", + ) + + def add_messages_key( + self, example: dict[str, Any], + ) -> dict[str, list[dict[str, Any]]]: + messages = [message for message in example[self.chat_key]] + if self.system_key in example: + messages = [{"role": "system", "content": example[self.system_key]}] + messages + elif self.system_prompt: + messages = [{"role": "system", "content": self.system_prompt}] + messages + assert messages[-1]["role"] == "assistant" + return {"messages": messages} From b5343be3d529e5928d4a79b54d5017cf8ad7b0f1 Mon Sep 17 00:00:00 2001 From: Atsunori Fujita Date: Wed, 18 Jun 2025 11:01:49 +0000 Subject: [PATCH 2/3] Add class docstrings and unit test Signed-off-by: Atsunori Fujita --- .../data/hf_datasets/oai_format_dataset.py | 18 ++- .../hf_datasets/test_oai_format_dataset.py | 108 ++++++++++++++++++ 2 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 tests/unit/data/hf_datasets/test_oai_format_dataset.py diff --git a/nemo_rl/data/hf_datasets/oai_format_dataset.py b/nemo_rl/data/hf_datasets/oai_format_dataset.py index 5aea8ef584..d466bde638 100644 --- a/nemo_rl/data/hf_datasets/oai_format_dataset.py +++ b/nemo_rl/data/hf_datasets/oai_format_dataset.py @@ -20,6 +20,21 @@ class OpenAIFormatDataset: + """ + This class is used to load an SFT dataset in the OpenAI format. + The dataset should be in the following format: + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."} + ] + } + system_key or system_prompt is optional. If provided, it will be added to the beginning of the dataset. + chat_key should be the key of the messages list. Multi-turn conversations are supported. + However, the last message should be the assistant's response. + """ + def __init__( self, train_ds_path: str, @@ -47,7 +62,8 @@ def __init__( ) def add_messages_key( - self, example: dict[str, Any], + self, + example: dict[str, Any], ) -> dict[str, list[dict[str, Any]]]: messages = [message for message in example[self.chat_key]] if self.system_key in example: diff --git a/tests/unit/data/hf_datasets/test_oai_format_dataset.py b/tests/unit/data/hf_datasets/test_oai_format_dataset.py new file mode 100644 index 0000000000..26a540e782 --- /dev/null +++ b/tests/unit/data/hf_datasets/test_oai_format_dataset.py @@ -0,0 +1,108 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile + +import pytest +from transformers import AutoTokenizer + +from nemo_rl.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES +from nemo_rl.data.hf_datasets.oai_format_dataset import ( + OpenAIFormatDataset, +) + + +@pytest.fixture +def sample_data(request): + chat_key = request.param[0] + system_key = request.param[1] + + train_data = { + chat_key: [ + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + ], + } + val_data = { + chat_key: [ + {"role": "user", "content": "What is the capital of Germany?"}, + {"role": "assistant", "content": "The capital of Germany is Berlin."}, + ], + } + + if system_key is not None: + train_data[system_key] = "You are a helpful assistant." + if system_key is not None: + val_data[system_key] = "You are a helpful assistant." + + # Create temporary files for train and validation data + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as train_file: + json.dump(train_data, train_file) + train_path = train_file.name + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as val_file: + json.dump(val_data, val_file) + val_path = val_file.name + + return train_path, val_path + + +@pytest.mark.parametrize("sample_data", [("messages", None)], indirect=True) +def test_dataset_initialization(sample_data): + train_path, val_path = sample_data + dataset = OpenAIFormatDataset(train_path, val_path) + + assert dataset.chat_key == "messages" + assert "train" in dataset.formatted_ds + assert "validation" in dataset.formatted_ds + + +@pytest.mark.parametrize("sample_data", [("conversations", None)], indirect=True) +def test_custom_keys(sample_data): + train_path, val_path = sample_data + dataset = OpenAIFormatDataset( + train_path, val_path, chat_key="conversations", system_prompt="You are a helpful assistant." + ) + + assert dataset.chat_key == "conversations" + assert dataset.system_prompt == "You are a helpful assistant." + + +@pytest.mark.parametrize("sample_data", [("messages", "system_key")], indirect=True) +def test_message_formatting(sample_data): + train_path, val_path = sample_data + dataset = OpenAIFormatDataset(train_path, val_path, chat_key="messages", system_key="system_key") + + first_example = dataset.formatted_ds["train"][0] + + assert first_example["messages"][0]["role"] == "system" + assert first_example["messages"][0]["content"] == "You are a helpful assistant." + assert first_example["messages"][1]["role"] == "user" + assert first_example["messages"][1]["content"] == "What is the capital of France?" + assert first_example["messages"][2]["role"] == "assistant" + assert first_example["messages"][2]["content"] == "The capital of France is Paris." + + chat_template = COMMON_CHAT_TEMPLATES.passthrough_prompt_response + tokenizer = AutoTokenizer.from_pretrained("Meta-Llama/Meta-Llama-3-8B-Instruct") + + combined_message = tokenizer.apply_chat_template( + first_example["messages"], + chat_template=chat_template, + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + + assert combined_message == "".join(message["content"] for message in first_example["messages"]) From 8a966bc1d9015571be7590412c7131083db9fe05 Mon Sep 17 00:00:00 2001 From: Atsunori Fujita Date: Wed, 18 Jun 2025 23:51:46 +0000 Subject: [PATCH 3/3] Apply pre-commit Signed-off-by: Atsunori Fujita --- examples/run_sft.py | 12 ++++++++--- nemo_rl/data/hf_datasets/__init__.py | 2 +- .../data/hf_datasets/oai_format_dataset.py | 16 ++++++++------ .../hf_datasets/test_oai_format_dataset.py | 21 ++++++++++++++----- 4 files changed, 36 insertions(+), 15 deletions(-) diff --git a/examples/run_sft.py b/examples/run_sft.py index 65ab32fe7f..ce5b258b0c 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -35,7 +35,9 @@ def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Run SFT training with configuration") - parser.add_argument("--config", type=str, default=None, help="Path to YAML config file") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) # Parse known args for the script args, overrides = parser.parse_known_args() @@ -72,7 +74,9 @@ def sft_preprocessor( if length > max_seq_length: # make smaller and mask out for message in message_log: - message["token_ids"] = message["token_ids"][: min(4, max_seq_length // len(message_log))] + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] loss_multiplier = 0.0 output = { @@ -177,7 +181,9 @@ def main(): config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) print(f"📊 Using log directory: {config['logger']['log_dir']}") if config["checkpointing"]["enabled"]: - print(f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}") + print( + f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) init_ray() diff --git a/nemo_rl/data/hf_datasets/__init__.py b/nemo_rl/data/hf_datasets/__init__.py index 72bd1b6591..aa5596397c 100644 --- a/nemo_rl/data/hf_datasets/__init__.py +++ b/nemo_rl/data/hf_datasets/__init__.py @@ -15,8 +15,8 @@ from nemo_rl.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES from nemo_rl.data.hf_datasets.dpo import DPODataset from nemo_rl.data.hf_datasets.helpsteer3 import HelpSteer3Dataset -from nemo_rl.data.hf_datasets.oasst import OasstDataset from nemo_rl.data.hf_datasets.oai_format_dataset import OpenAIFormatDataset +from nemo_rl.data.hf_datasets.oasst import OasstDataset from nemo_rl.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset from nemo_rl.data.hf_datasets.prompt_response_dataset import ( PromptResponseDataset, diff --git a/nemo_rl/data/hf_datasets/oai_format_dataset.py b/nemo_rl/data/hf_datasets/oai_format_dataset.py index d466bde638..22d01346bc 100644 --- a/nemo_rl/data/hf_datasets/oai_format_dataset.py +++ b/nemo_rl/data/hf_datasets/oai_format_dataset.py @@ -20,8 +20,8 @@ class OpenAIFormatDataset: - """ - This class is used to load an SFT dataset in the OpenAI format. + """This class is used to load an SFT dataset in the OpenAI format. + The dataset should be in the following format: { "messages": [ @@ -30,9 +30,11 @@ class OpenAIFormatDataset: {"role": "assistant", "content": "The capital of France is Paris."} ] } - system_key or system_prompt is optional. If provided, it will be added to the beginning of the dataset. - chat_key should be the key of the messages list. Multi-turn conversations are supported. - However, the last message should be the assistant's response. + system_key and system_prompt are optional. If provided, it will be added to the + beginning of the dataset. + chat_key should be the key of the messages list. Multi-turn conversations are + supported. + The last message in the conversation must be from the assistant. """ def __init__( @@ -67,7 +69,9 @@ def add_messages_key( ) -> dict[str, list[dict[str, Any]]]: messages = [message for message in example[self.chat_key]] if self.system_key in example: - messages = [{"role": "system", "content": example[self.system_key]}] + messages + messages = [ + {"role": "system", "content": example[self.system_key]} + ] + messages elif self.system_prompt: messages = [{"role": "system", "content": self.system_prompt}] + messages assert messages[-1]["role"] == "assistant" diff --git a/tests/unit/data/hf_datasets/test_oai_format_dataset.py b/tests/unit/data/hf_datasets/test_oai_format_dataset.py index 26a540e782..4ba75a6a1d 100644 --- a/tests/unit/data/hf_datasets/test_oai_format_dataset.py +++ b/tests/unit/data/hf_datasets/test_oai_format_dataset.py @@ -48,11 +48,15 @@ def sample_data(request): val_data[system_key] = "You are a helpful assistant." # Create temporary files for train and validation data - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as train_file: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as train_file: json.dump(train_data, train_file) train_path = train_file.name - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as val_file: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as val_file: json.dump(val_data, val_file) val_path = val_file.name @@ -73,7 +77,10 @@ def test_dataset_initialization(sample_data): def test_custom_keys(sample_data): train_path, val_path = sample_data dataset = OpenAIFormatDataset( - train_path, val_path, chat_key="conversations", system_prompt="You are a helpful assistant." + train_path, + val_path, + chat_key="conversations", + system_prompt="You are a helpful assistant.", ) assert dataset.chat_key == "conversations" @@ -83,7 +90,9 @@ def test_custom_keys(sample_data): @pytest.mark.parametrize("sample_data", [("messages", "system_key")], indirect=True) def test_message_formatting(sample_data): train_path, val_path = sample_data - dataset = OpenAIFormatDataset(train_path, val_path, chat_key="messages", system_key="system_key") + dataset = OpenAIFormatDataset( + train_path, val_path, chat_key="messages", system_key="system_key" + ) first_example = dataset.formatted_ds["train"][0] @@ -105,4 +114,6 @@ def test_message_formatting(sample_data): add_special_tokens=False, ) - assert combined_message == "".join(message["content"] for message in first_example["messages"]) + assert combined_message == "".join( + message["content"] for message in first_example["messages"] + )