From 2e7bc5905ab6dec448fefb7b43260f400be25ecb Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 4 Apr 2025 13:33:54 -0700 Subject: [PATCH 01/57] initial dpo implementation Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/dpo.py | 406 ++++++++++++++++++ nemo_reinforcer/algorithms/loss_functions.py | 104 ++++- nemo_reinforcer/data/datasets.py | 49 +++ nemo_reinforcer/data/hf_datasets/dpo.py | 27 ++ .../data/hf_datasets/helpsteer3.py | 26 ++ nemo_reinforcer/data/llm_message_utils.py | 13 + 6 files changed, 617 insertions(+), 8 deletions(-) create mode 100755 nemo_reinforcer/algorithms/dpo.py create mode 100644 nemo_reinforcer/data/hf_datasets/dpo.py create mode 100644 nemo_reinforcer/data/hf_datasets/helpsteer3.py diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py new file mode 100755 index 0000000000..c732b84360 --- /dev/null +++ b/nemo_reinforcer/algorithms/dpo.py @@ -0,0 +1,406 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path +from typing import Optional, Tuple, TypedDict + +import numpy as np +import torch +from torchdata.stateful_dataloader import StatefulDataLoader +from nemo_reinforcer.algorithms.loss_functions import ( + DPOLossFn, +) +from nemo_reinforcer.algorithms.utils import set_seed +from nemo_reinforcer.data import DataConfig +from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, dpo_collate_fn +from nemo_reinforcer.data.interfaces import TaskDataSpec +from nemo_reinforcer.data.llm_message_utils import ( + add_dpo_loss_mask_to_message_log, + batched_message_log_to_flat_message, +) +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_reinforcer.models.interfaces import PolicyInterface +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.utils.checkpoint import CheckpointManager, CheckpointingConfig +from nemo_reinforcer.utils.logger import Logger, LoggerConfig +from nemo_reinforcer.utils.timer import Timer + + +class DPOSaveState(TypedDict): + step: int + val_loss: float + consumed_samples: int + + +def _default_dpo_save_state() -> DPOSaveState: + return { + "step": 0, + "consumed_samples": 0, + } + + +class DPOConfig(TypedDict): + max_num_steps: int + val_period: int + val_batches: int + val_global_batch_size: int + val_micro_batch_size: int + val_at_start: bool + seed: int + + reference_policy_kl_penalty: float + ## TODO: support below + # preference_average_log_probs: False + # sft_average_log_probs: ${.preference_average_log_probs} + # gt_reward_scale: 1. + # preference_loss: dpo + preference_loss_weight: float + sft_loss_weight: float + + +class MasterConfig(TypedDict): + policy: PolicyConfig + data: DataConfig + dpo: DPOConfig + logger: LoggerConfig + cluster: ClusterConfig + checkpointing: CheckpointingConfig + + +# ======================================================= +# Setup & Initialization +# ======================================================= +def setup( + master_config: MasterConfig, + train_dataset: AllTaskProcessedDataset, ## TODO: figure out dataset stuff for DPO + val_dataset: AllTaskProcessedDataset, +) -> Tuple[ + HfPolicy, + RayVirtualCluster, + StatefulDataLoader, + StatefulDataLoader, + DPOLossFn, + MasterConfig, + Logger, + TaskDataSpec, + DPOSaveState, +]: + """Main entry point for running DPO algorithm. + + Returns: + Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger + """ + set_seed(master_config["dpo"]["seed"]) + + # Extract individual configs for easier access + policy_config = master_config["policy"] + data_config = master_config["data"] + logger_config = master_config["logger"] + cluster_config = master_config["cluster"] + dpo_config = master_config["dpo"] + + # ========================== + # Logger + # ========================== + logger = Logger(logger_config) + logger.log_hyperparams(master_config) + + # ========================== + # Checkpointing + # ========================== + checkpointer = CheckpointManager(master_config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + dpo_save_state: Optional[DPOSaveState] = checkpointer.load_training_info( + last_checkpoint_path + ) + # config validation checks + if master_config["checkpointing"]["enabled"]: + assert master_config["checkpointing"]["save_period"] > 0 + assert ( + master_config["checkpointing"]["save_period"] + % master_config["dpo"]["val_period"] + == 0 + ), ( + f"Checkpointing save period {master_config['checkpointing']['save_period']} " + f"must be a multiple of validation period {master_config['dpo']['val_period']}" + f", or we won't know what metric to save!" + ) + + # ========================== + # Data + # ========================== + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=policy_config["train_global_batch_size"], + shuffle=True, + collate_fn=dpo_collate_fn, + ) + + if last_checkpoint_path is not None: + dataloader_state_dict = torch.load( + os.path.join(last_checkpoint_path, "train_dataloader.pt") + ) + train_dataloader.load_state_dict(dataloader_state_dict) + + val_dataloader = StatefulDataLoader( + val_dataset, + batch_size=dpo_config["val_global_batch_size"], + shuffle=False, + collate_fn=dpo_collate_fn, + drop_last=True, + ) + + # ========================== + # Cluster + # ========================== + print("\nā–¶ Setting up compute cluster...") + cluster = RayVirtualCluster( + name="dpo_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=1, + ) + print(f" āœ“ Ray cluster initialized with {cluster_config['num_nodes']} nodes") + + # ========================== + # Training + # ========================== + print("\nā–¶ Setting up model...") + policy = HfPolicy( + cluster=cluster, + config=policy_config, + weights_path=Path(last_checkpoint_path) / "policy.pt" + if last_checkpoint_path + else None, + optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt" + if last_checkpoint_path + else None, + init_optimizer=True, + init_reference_model=True, + ) + loss_fn = DPOLossFn(master_config["dpo"]) + print(f" āœ“ Model initialized") + + print("\n" + "=" * 60) + print(" " * 18 + "SETUP COMPLETE") + print("=" * 60 + "\n") + + return ( + policy, + cluster, + train_dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + dpo_save_state, + master_config, + ) + + +# ======================================================= +# Training & Validation +# ======================================================= +def validate( + policy: PolicyInterface, + val_dataloader: StatefulDataLoader, + tokenizer, + loss_fn, + step: int, + master_config: MasterConfig, + dpo_task_spec: TaskDataSpec, + val_batches: int, + val_batch_size: int, + val_mbs: int, +): + """Run validation on the validation dataset.""" + ### TODO ### + + +def dpo_train( + policy, + train_dataloader, + val_dataloader, + tokenizer, + loss_fn, + master_config, + logger, + dpo_task_spec, ### TODO: flesh out + checkpointer, + dpo_save_state, +): + # Run dpo training + timer = Timer() + + if dpo_save_state is None: + dpo_save_state = _default_dpo_save_state() + step = 0 + else: + step = dpo_save_state["step"] + + dpo_config = master_config["dpo"] + # Validation configuration + val_period = dpo_config["val_period"] + val_at_start = dpo_config["val_at_start"] + + # Run validation at the start if configured + """if val_at_start and step == 0: + print("\nšŸ” Running initial validation...") + + ## TODO + val_metrics, validation_timings = validate( + policy, + val_dataloader, + tokenizer, + loss_fn, + step=0, + master_config=master_config, + dpo_task_spec=dpo_task_spec, + val_batches=dpo_config["val_batches"], + val_batch_size=dpo_config["val_global_batch_size"], + val_mbs=dpo_config["val_micro_batch_size"], + ) + + logger.log_metrics(val_metrics, step, prefix="validation") + logger.log_metrics(validation_timings, step, prefix="timing/validation")""" + + policy.prepare_for_training() + + for batch in train_dataloader: + print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") + + with timer.time("total_step_time"): + ## add loss mask based on role to every message + add_dpo_loss_mask_to_message_log( + batch["message_log"], + ) + + cat_and_padded, input_lengths = batched_message_log_to_flat_message( + batch["message_log"], + ## TODO: update pad value + pad_value_dict={"token_ids": tokenizer.eos_token_id}, + ) + + train_data: BatchedDataDict = BatchedDataDict( + { + "input_ids": cat_and_padded["token_ids"], + "input_lengths": input_lengths, + "token_mask": cat_and_padded["token_loss_mask"], + "sample_mask": batch["loss_multiplier"], + } + ) + + # Prepare batch and generate responses + print("ā–¶ Preparing batch...") + with timer.time("get_ref_policy_logprobs"): + ## append ref policy logprobs to batch + batch = policy.get_reference_policy_logprobs(train_data) + + train_data["reference_policy_logprobs"] = batch["reference_logprobs"] + + ## train_data.to("cpu") + print("ā–¶ Taking a training step...") + train_results = policy.train( + train_data, + loss_fn, + eval_mode=False, + gbs=master_config["policy"]["train_global_batch_size"] * 2, + mbs=master_config["policy"]["train_micro_batch_size"] * 2, + ) + + """# Run validation if it's a validation step + if val_period > 0 and (step + 1) % val_period == 0: + val_metrics, validation_timings = validate( + policy, + val_dataloader, + tokenizer, + loss_fn, + step=step + 1, + master_config=master_config, + dpo_task_spec=dpo_task_spec, + val_batches=dpo_config["val_batches"], + val_batch_size=dpo_config["val_global_batch_size"], + val_mbs=dpo_config["val_micro_batch_size"], + ) + logger.log_metrics( + validation_timings, step + 1, prefix="timing/validation" + ) + logger.log_metrics(val_metrics, step + 1, prefix="validation") + + ## Checkpointing + dpo_save_state["consumed_samples"] += master_config["policy"][ + "train_global_batch_size" + ] + if ( + master_config["checkpointing"]["enabled"] + and (step + 1) % master_config["checkpointing"]["save_period"] == 0 + ): # +1 because step is 0-indexed + dpo_save_state["step"] = step + 1 + dpo_save_state["val_loss"] = val_metrics["val_loss"] + with timer.time("checkpointing"): + print(f"Saving checkpoint for step {step + 1}...") + checkpoint_path = checkpointer.init_tmp_checkpoint( + step + 1, dpo_save_state, master_config + ) + ## TODO: move checkpointing logic elsewhere? + policy.save_checkpoint( + os.path.join(checkpoint_path, "policy.pt"), + os.path.join(checkpoint_path, "policy_optimizer.pt"), + ## NOTE: below is a workaround to avoid a bug with checkpointing + ## this should be removed once the bug is fixed + offload_to_cpu=False, + ) + torch.save( + train_dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path)""" + + ## TODO: add more DPO metrics + ## accuracy, sft loss, preference loss, etc. + losses = train_results["loss"] + metrics = { + "loss": train_results["loss"].numpy(), + } + metrics.update(train_results["all_mb_metrics"]) + metrics = {k: np.mean(v).item() for k, v in metrics.items()} + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + + print("\nšŸ“Š Training Results:") + print(f" • Loss: {float(metrics['loss']):.4f}") + print("\nā±ļø Timing:") + # Display total time first, separately + total_time = timing_metrics.get("total_step_time", 0) + print(f" • Total step time: {total_time:.2f}s") + + # Display all other timing metrics (if any) + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + + logger.log_metrics(metrics, step + 1, prefix="train") + logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") + + timer.reset() + step += 1 + + if step >= master_config["dpo"]["max_num_steps"]: + break diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 158c9824eb..6172e9043e 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -149,7 +149,10 @@ def __call__( class NLLLoss(LossFunction): def __call__( - self, next_token_logits: torch.Tensor, data: BatchedDataDict + self, + next_token_logits: torch.Tensor, + data: BatchedDataDict, + reduce_across_batch: bool = True, ) -> Tuple[torch.Tensor, dict]: # logits shape: [batch_size, seq_len, vocab_size] # Get the next token logits for each position @@ -168,14 +171,99 @@ def __call__( # Only compute loss on generated tokens (not input tokens) # by applying the token_loss_mask (shifted by 1 since we're predicting next tokens) - num_unmasked_tokens = torch.sum(mask) - if num_unmasked_tokens == 0: - # prevent division by zero - num_unmasked_tokens = torch.tensor(1) - loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens + num_unmasked_tokens = torch.sum(mask, -1) + num_unmasked_tokens[num_unmasked_tokens == 0] = 1 + + if reduce_across_batch: + num_unmasked_tokens = num_unmasked_tokens.sum().item() + loss = (-torch.sum(token_logprobs * mask) / num_unmasked_tokens).item() + else: + loss = -torch.sum(token_logprobs * mask, dim=-1) / num_unmasked_tokens return loss, { - "loss": loss.item(), - "num_unmasked_tokens": num_unmasked_tokens.item(), + "loss": loss, + "num_unmasked_tokens": num_unmasked_tokens, "total_tokens": mask.numel(), } + + +class DPOLossConfig(TypedDict): + reference_policy_kl_penalty: float + preference_loss_weight: float = 1.0 + sft_loss_weight: float = 0.0 + + +class DPOLossDataDict(TypedDict): + """Required keys for the Clipped Policy Gradient loss function.""" + + input_ids: torch.Tensor + reference_policy_logprobs: torch.Tensor + token_mask: torch.Tensor + sample_mask: torch.Tensor + + +class DPOLossFn(LossFunction): + def __init__(self, cfg: DPOLossConfig): + self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] + self.preference_loss_weight = cfg["preference_loss_weight"] + self.sft_loss_weight = cfg["sft_loss_weight"] + self.sft_loss = NLLLoss() + + def split_output_tensor(self, tensor: torch.Tensor): + return torch.split(tensor, tensor.shape[0] // 2, dim=0) + + def preference_loss( + self, next_token_logits: torch.Tensor, data: BatchedDataDict[DPOLossDataDict] + ) -> torch.Tensor: + ## TODO: make sure this token mask only includes the chosen / rejected responses + ## and not prior assistant tokens + ## TODO: there's some duplicate code here with the NLLLoss function. We should refactor + token_mask = data["token_mask"][:, 1:] + sample_mask = data["sample_mask"] + mask = token_mask * sample_mask.unsqueeze(-1) + + next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token + next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) + logprobs = next_token_logprobs[:, :-1] # Remove last position's logits + + # Gather the logprobs for the actual next tokens + token_logprobs = logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + + ref_logprobs = data["reference_policy_logprobs"][:, :-1] + + diff = token_logprobs - ref_logprobs + + ## TODO: provide the option to average over tokens + rewards = (diff * mask).sum(-1) + + rewards_chosen, rewards_rejected = self.split_output_tensor(rewards) + rewards_delta = rewards_chosen - rewards_rejected + return -torch.nn.functional.logsigmoid( + self.reference_policy_kl_penalty * rewards_delta + ).mean(0) + + def __call__( + self, next_token_logits: torch.Tensor, data: BatchedDataDict[DPOLossDataDict] + ) -> Tuple[torch.Tensor, dict]: + sft_loss, _ = self.sft_loss(next_token_logits, data, reduce_across_batch=False) + sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) + + ## average over the batch dimension + sft_loss_chosen = sft_loss_chosen.mean(0) + + preference_loss = self.preference_loss(next_token_logits, data) + + dpo_loss = ( + self.sft_loss_weight * sft_loss_chosen + + self.preference_loss_weight * preference_loss + ) + + ## TODO: fix initial preference loss -- should be exactly 0.69315 + print(f"{preference_loss.item()=}") + return dpo_loss, { + "loss": dpo_loss.item(), + "sft_loss": sft_loss_chosen.item(), + "preference_loss": preference_loss.item(), + } diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 8a81c85fb2..42c4f14dd1 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -181,3 +181,52 @@ def eval_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: idx=idx, ) return output + + +def dpo_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: + """Collate function for DPO training. + + This function separates the chosen and rejected responses to create + two examples per prompt. The chosen and rejected examples are concatenated + along the batch dimension, resulting in a batch size of 2 * len(data_batch). + """ + message_log_chosen = [datum_spec["message_log_chosen"] for datum_spec in data_batch] + message_log_rejected = [ + datum_spec["message_log_rejected"] for datum_spec in data_batch + ] + length_chosen = torch.tensor( + [datum_spec["length_chosen"] for datum_spec in data_batch] + ) + length_rejected = torch.tensor( + [datum_spec["length_rejected"] for datum_spec in data_batch] + ) + loss_multiplier = torch.tensor( + [datum_spec["loss_multiplier"] for datum_spec in data_batch] + ) + + ## TODO + # extra_env_info = [datum_spec["extra_env_info"] for datum_spec in data_batch] + + ## conctenate chosen and rejected examples + message_log = message_log_chosen + message_log_rejected + length = torch.cat([length_chosen, length_rejected]) + loss_multiplier = torch.cat([loss_multiplier] * 2) + + task_names = [] + for datum_spec in data_batch: + task_names.append(datum_spec.get("task_name", None)) + task_names = task_names * 2 + + idx = [datum_spec["idx"] for datum_spec in data_batch] * 2 + batch_max_length = torch.ones_like(length) * length.max() + + output = BatchedDataDict( + message_log=message_log, + length=length, + loss_multiplier=loss_multiplier, + # extra_env_info=extra_env_info, + task_name=task_names, + idx=idx, + batch_max_length=batch_max_length, + ) + return output diff --git a/nemo_reinforcer/data/hf_datasets/dpo.py b/nemo_reinforcer/data/hf_datasets/dpo.py new file mode 100644 index 0000000000..483351af60 --- /dev/null +++ b/nemo_reinforcer/data/hf_datasets/dpo.py @@ -0,0 +1,27 @@ +from datasets import load_dataset + +from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset + + +## assumptions about DPO dataset: +## json files should have the following keys: +## "prompt" +## "chosen_response" +## "rejected_response" +class DPODataset(HfDataset): + def __init__(self, train_data_path: str, val_data_path: str): + ## TODO: assuming for now that data has been split into train and val + ## as an offline preprocessing step + + ## TODO: update the keys to match with what's expected from apply_chat_template + ## we need to do this outisde of the data class because we want to keep + ## chosen and rejected responses for a given prompt together when shuffling + self.formatted_ds = { + "train": load_dataset("json", data_files=train_data_path), + "validation": load_dataset("json", data_files=val_data_path), + } + super().__init__( + dataset_name="dpo", + ## no custom template. Assume we use tokenizer's template + # custom_template=COMMON_CHAT_TEMPLATES.simple_role_header, + ) diff --git a/nemo_reinforcer/data/hf_datasets/helpsteer3.py b/nemo_reinforcer/data/hf_datasets/helpsteer3.py new file mode 100644 index 0000000000..991a82f4e1 --- /dev/null +++ b/nemo_reinforcer/data/hf_datasets/helpsteer3.py @@ -0,0 +1,26 @@ +from datasets import load_dataset +from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset + + +def format_helpsteer3(data): + context = data["context"] + response_1 = data["response1"] + response_2 = data["response2"] + overall_preference = data["overall_preference"] + + return { + "prompt": data["context"], + "chosen_response": response_1 if overall_preference < 0 else response_2, + "rejected_response": response_2 if overall_preference < 0 else response_1, + } + + +class HelpSteer3Dataset(HfDataset): + def __init__(self): + ds = load_dataset("nvidia/HelpSteer3", "preference") + self.formatted_ds = ds.map(format_helpsteer3) + + super().__init__( + dataset_name="HelpSteer3", + custom_template=None, ## use tokenizer's template + ) diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py index f6cb3c8079..40474498e1 100644 --- a/nemo_reinforcer/data/llm_message_utils.py +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -127,6 +127,19 @@ def add_loss_mask_to_message_log( sentence["token_loss_mask"] = torch.zeros_like(sentence["token_ids"]) +## TODO: VERIFY +def add_dpo_loss_mask_to_message_log( + message_log: LLMMessageLogType, +) -> None: + """Only unmask the final assistant message in the log.""" + for message in message_log: + for i, sentence in enumerate(message): + if i == len(message) - 1: + sentence["token_loss_mask"] = torch.ones_like(sentence["token_ids"]) + else: + sentence["token_loss_mask"] = torch.zeros_like(sentence["token_ids"]) + + def _pad_tensor( tensor: torch.Tensor, max_len: int, From f44a028186ac73a4330e30787dfee8e1fec256c1 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 4 Apr 2025 16:37:14 -0700 Subject: [PATCH 02/57] bug fixes Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/dpo.py | 14 ++++++- nemo_reinforcer/algorithms/loss_functions.py | 21 +++++----- nemo_reinforcer/data/datasets.py | 40 ++++++++------------ nemo_reinforcer/models/policy/hf_policy.py | 24 ++++++++---- 4 files changed, 56 insertions(+), 43 deletions(-) diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index c732b84360..1f4bc233db 100755 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -309,9 +309,19 @@ def dpo_train( print("ā–¶ Preparing batch...") with timer.time("get_ref_policy_logprobs"): ## append ref policy logprobs to batch - batch = policy.get_reference_policy_logprobs(train_data) + batch = policy.get_reference_policy_logprobs( + train_data, + ## TODO: make more robust + micro_batch_size=master_config["policy"]["train_micro_batch_size"] + * 2, + ) - train_data["reference_policy_logprobs"] = batch["reference_logprobs"] + ## roll the reference logprobs by one to the left + ## this ensures that the logprobs correspond to the next token + ## in the sequence + train_data["reference_policy_logprobs"] = torch.roll( + batch["reference_logprobs"], -1, dims=-1 + ) ## train_data.to("cpu") print("ā–¶ Taking a training step...") diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 6172e9043e..fac8bb4e73 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -210,7 +210,7 @@ def __init__(self, cfg: DPOLossConfig): self.sft_loss = NLLLoss() def split_output_tensor(self, tensor: torch.Tensor): - return torch.split(tensor, tensor.shape[0] // 2, dim=0) + return tensor[::2], tensor[1::2] def preference_loss( self, next_token_logits: torch.Tensor, data: BatchedDataDict[DPOLossDataDict] @@ -233,10 +233,9 @@ def preference_loss( ref_logprobs = data["reference_policy_logprobs"][:, :-1] - diff = token_logprobs - ref_logprobs - + diff = (token_logprobs - ref_logprobs) * token_mask ## TODO: provide the option to average over tokens - rewards = (diff * mask).sum(-1) + rewards = diff.sum(-1) rewards_chosen, rewards_rejected = self.split_output_tensor(rewards) rewards_delta = rewards_chosen - rewards_rejected @@ -247,11 +246,15 @@ def preference_loss( def __call__( self, next_token_logits: torch.Tensor, data: BatchedDataDict[DPOLossDataDict] ) -> Tuple[torch.Tensor, dict]: - sft_loss, _ = self.sft_loss(next_token_logits, data, reduce_across_batch=False) - sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) + sft_loss_chosen = torch.tensor(0.0) + if self.sft_loss_weight > 0: + sft_loss, _ = self.sft_loss( + next_token_logits, data, reduce_across_batch=False + ) + sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) - ## average over the batch dimension - sft_loss_chosen = sft_loss_chosen.mean(0) + ## average over the batch dimension + sft_loss_chosen = sft_loss_chosen.mean(0) preference_loss = self.preference_loss(next_token_logits, data) @@ -260,8 +263,6 @@ def __call__( + self.preference_loss_weight * preference_loss ) - ## TODO: fix initial preference loss -- should be exactly 0.69315 - print(f"{preference_loss.item()=}") return dpo_loss, { "loss": dpo_loss.item(), "sft_loss": sft_loss_chosen.item(), diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 42c4f14dd1..4bf175f486 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -190,34 +190,26 @@ def dpo_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: two examples per prompt. The chosen and rejected examples are concatenated along the batch dimension, resulting in a batch size of 2 * len(data_batch). """ - message_log_chosen = [datum_spec["message_log_chosen"] for datum_spec in data_batch] - message_log_rejected = [ - datum_spec["message_log_rejected"] for datum_spec in data_batch - ] - length_chosen = torch.tensor( - [datum_spec["length_chosen"] for datum_spec in data_batch] - ) - length_rejected = torch.tensor( - [datum_spec["length_rejected"] for datum_spec in data_batch] - ) - loss_multiplier = torch.tensor( - [datum_spec["loss_multiplier"] for datum_spec in data_batch] - ) + message_log = [] + length = [] + loss_multiplier = [] + idx = [] + task_names = [] + for datum_spec in data_batch: + ## interleave chosen and rejected examples + message_log.append(datum_spec["message_log_chosen"]) + message_log.append(datum_spec["message_log_rejected"]) + length.append(datum_spec["length_chosen"]) + length.append(datum_spec["length_rejected"]) + loss_multiplier.extend([datum_spec["loss_multiplier"]] * 2) + idx.extend([datum_spec["idx"]] * 2) + task_names.extend([datum_spec.get("task_name", None)] * 2) + length = torch.tensor(length) + loss_multiplier = torch.tensor(loss_multiplier) ## TODO # extra_env_info = [datum_spec["extra_env_info"] for datum_spec in data_batch] - ## conctenate chosen and rejected examples - message_log = message_log_chosen + message_log_rejected - length = torch.cat([length_chosen, length_rejected]) - loss_multiplier = torch.cat([loss_multiplier] * 2) - - task_names = [] - for datum_spec in data_batch: - task_names.append(datum_spec.get("task_name", None)) - task_names = task_names * 2 - - idx = [datum_spec["idx"] for datum_spec in data_batch] * 2 batch_max_length = torch.ones_like(length) * length.max() output = BatchedDataDict( diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index ebc9e879f2..7284d96901 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -347,10 +347,12 @@ def train( return metrics - def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: + def get_logprobs( + self, data: BatchedDataDict, micro_batch_size: int = None + ) -> BatchedDataDict: """Get the logprobs of the model for a batch of data. - Uses the configured logprob_batch_size to do microbatching. + If no micro-batch size is provided, uses the configured logprob_batch_size to do microbatching. Input data is assumed to be right-padded. The method internally converts to left-padded format for computation, and returns outputs in right-padded format. @@ -360,7 +362,11 @@ def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor. """ - logprob_batch_size = self.cfg["logprob_batch_size"] + logprob_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg["logprob_batch_size"] + ) all_log_probs = [] self.model.eval() @@ -454,7 +460,9 @@ def use_reference_model(self): gc.collect() torch.cuda.empty_cache() - def get_reference_policy_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: + def get_reference_policy_logprobs( + self, data: BatchedDataDict, micro_batch_size: int = None + ) -> BatchedDataDict: """Get the logprobs from the reference policy for a batch of data. Returns: @@ -463,7 +471,7 @@ def get_reference_policy_logprobs(self, data: BatchedDataDict) -> BatchedDataDic The logprob of input token i is specified at position i in the output logprobs tensor. """ with self.use_reference_model(): - reference_logprobs = self.get_logprobs(data) + reference_logprobs = self.get_logprobs(data, micro_batch_size) return_data = BatchedDataDict() return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu() @@ -941,7 +949,7 @@ def get_logprobs( return logprobs def get_reference_policy_logprobs( - self, data: BatchedDataDict[GenerationDatumSpec] + self, data: BatchedDataDict[GenerationDatumSpec], micro_batch_size: int = None ) -> BatchedDataDict: """Get the logprobs of the reference policy for a data dict. @@ -949,7 +957,9 @@ def get_reference_policy_logprobs( """ sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=None) futures = self.worker_group.run_all_workers_multiple_data( - "get_reference_policy_logprobs", sharded_data + "get_reference_policy_logprobs", + sharded_data, + common_kwargs={"micro_batch_size": micro_batch_size}, ) logprobs = BatchedDataDict.from_batches( self.worker_group.get_all_worker_results(futures) From 33e8535516052a2266e98db4fa631d3a69627cce Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 7 Apr 2025 15:16:42 -0700 Subject: [PATCH 03/57] small perf gains Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/dpo.py | 61 +++++++++------------- nemo_reinforcer/data/datasets.py | 33 ++++++++++-- nemo_reinforcer/models/policy/hf_policy.py | 9 ++-- 3 files changed, 60 insertions(+), 43 deletions(-) diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index 1f4bc233db..15a7b2f3d6 100755 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from functools import partial from pathlib import Path from typing import Optional, Tuple, TypedDict +from tqdm import tqdm import numpy as np import torch +from transformers import AutoTokenizer from torchdata.stateful_dataloader import StatefulDataLoader from nemo_reinforcer.algorithms.loss_functions import ( DPOLossFn, @@ -142,11 +145,13 @@ def setup( # ========================== # Data # ========================== + ## TODO: clean up + tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) train_dataloader = StatefulDataLoader( train_dataset, batch_size=policy_config["train_global_batch_size"], shuffle=True, - collate_fn=dpo_collate_fn, + collate_fn=partial(dpo_collate_fn, tokenizer=tokenizer), ) if last_checkpoint_path is not None: @@ -281,52 +286,34 @@ def dpo_train( policy.prepare_for_training() - for batch in train_dataloader: - print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") - - with timer.time("total_step_time"): - ## add loss mask based on role to every message - add_dpo_loss_mask_to_message_log( - batch["message_log"], - ) - - cat_and_padded, input_lengths = batched_message_log_to_flat_message( - batch["message_log"], - ## TODO: update pad value - pad_value_dict={"token_ids": tokenizer.eos_token_id}, - ) + def augment_dataloader(dataloader): + dataloader_iter = iter(dataloader) + while True: + try: + batch = next(dataloader_iter) - train_data: BatchedDataDict = BatchedDataDict( - { - "input_ids": cat_and_padded["token_ids"], - "input_lengths": input_lengths, - "token_mask": cat_and_padded["token_loss_mask"], - "sample_mask": batch["loss_multiplier"], - } - ) - - # Prepare batch and generate responses - print("ā–¶ Preparing batch...") - with timer.time("get_ref_policy_logprobs"): ## append ref policy logprobs to batch - batch = policy.get_reference_policy_logprobs( - train_data, + logprobs = policy.get_reference_policy_logprobs( + batch, ## TODO: make more robust micro_batch_size=master_config["policy"]["train_micro_batch_size"] * 2, - ) + )["reference_logprobs"] + batch["reference_policy_logprobs"] = torch.roll(logprobs, -1, dims=-1) - ## roll the reference logprobs by one to the left - ## this ensures that the logprobs correspond to the next token - ## in the sequence - train_data["reference_policy_logprobs"] = torch.roll( - batch["reference_logprobs"], -1, dims=-1 - ) + yield batch + + except StopIteration: + break + for batch in augment_dataloader(train_dataloader): + print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") + + with timer.time("total_step_time"): ## train_data.to("cpu") print("ā–¶ Taking a training step...") train_results = policy.train( - train_data, + batch, loss_fn, eval_mode=False, gbs=master_config["policy"]["train_global_batch_size"] * 2, diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 4bf175f486..53cad2433a 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -183,7 +183,13 @@ def eval_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: return output -def dpo_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: +from nemo_reinforcer.data.llm_message_utils import ( + add_dpo_loss_mask_to_message_log, + batched_message_log_to_flat_message, +) + + +def dpo_collate_fn(data_batch: List[DatumSpec], tokenizer) -> BatchedDataDict: """Collate function for DPO training. This function separates the chosen and rejected responses to create @@ -212,7 +218,7 @@ def dpo_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: batch_max_length = torch.ones_like(length) * length.max() - output = BatchedDataDict( + batch = BatchedDataDict( message_log=message_log, length=length, loss_multiplier=loss_multiplier, @@ -221,4 +227,25 @@ def dpo_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: idx=idx, batch_max_length=batch_max_length, ) - return output + + ## add loss mask based on role to every message + add_dpo_loss_mask_to_message_log( + batch["message_log"], + ) + + cat_and_padded, input_lengths = batched_message_log_to_flat_message( + batch["message_log"], + ## TODO: update pad value + pad_value_dict={"token_ids": tokenizer.eos_token_id}, + ) + + train_data: BatchedDataDict = BatchedDataDict( + { + "input_ids": cat_and_padded["token_ids"], + "input_lengths": input_lengths, + "token_mask": cat_and_padded["token_loss_mask"], + "sample_mask": batch["loss_multiplier"], + } + ) + + return train_data diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 7284d96901..7ed605db2c 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -436,6 +436,7 @@ def use_reference_model(self): On exit: Restores original references and re-flips cuda/cpu """ + # yield try: # Save original references original_model = self.model @@ -470,6 +471,7 @@ def get_reference_policy_logprobs( We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor. """ + ## TODO: investigate this. This is super slow with self.use_reference_model(): reference_logprobs = self.get_logprobs(data, micro_batch_size) @@ -801,15 +803,16 @@ def offload_after_refit(self): ) def move_to_cpu(self, model): + ## this is the slowest part for param in model.parameters(): - param.data = param.data.to("cpu") + param.data = param.data.to("cpu", non_blocking=True, copy=True) for buffer in model.buffers(): - buffer.data = buffer.data.to("cpu") + buffer.data = buffer.data.to("cpu", non_blocking=True, copy=True) + ## commenting this out improves perf by 3x if hasattr(model, "_fsdp_wrapped_module"): model._fsdp_wrapped_module.to("cpu") - return model def save_checkpoint( From b430180c74d33eebfd188f92929beb99f2b4b3e8 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 9 Apr 2025 09:24:29 -0700 Subject: [PATCH 04/57] make dpo work with jsonl Signed-off-by: ashors1 --- examples/configs/dpo.yaml | 67 +++++++ examples/run_dpo.py | 230 ++++++++++++++++++++++++ nemo_reinforcer/data/hf_datasets/dpo.py | 8 +- 3 files changed, 301 insertions(+), 4 deletions(-) create mode 100644 examples/configs/dpo.yaml create mode 100644 examples/run_dpo.py diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml new file mode 100644 index 0000000000..22f6dac102 --- /dev/null +++ b/examples/configs/dpo.yaml @@ -0,0 +1,67 @@ +# DPO Algorithm Configuration +dpo: + max_num_steps: 60 + val_period: 10 + val_batches: 8 + val_global_batch_size: 32 + val_micro_batch_size: 2 + val_at_start: true + seed: 42 + + reference_policy_kl_penalty: 0.2 + ## TODO: support below + #preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss + #sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss + #gt_reward_scale: 1. # the scale of the rewards in RPO + #preference_loss: dpo # the preference loss, we support dpo, ipo, rpo_sq, rpo_bwd_kl, rpo_fwd_kl + preference_loss_weight: 1 # the coefficient of the preference loss + sft_loss_weight: 0 # the coefficient of the SFT loss + +checkpointing: + enabled: true + checkpoint_dir: "results/dpo" + metric_name: "val_loss" + higher_is_better: false + keep_top_k: 3 + save_period: 10 + +## TODO: OOM with mbs 2?? +policy: + model_name: "meta-llama/Llama-3.2-1B-Instruct" + tokenizer_name: "meta-llama/Llama-3.2-1B-Instruct" + train_global_batch_size: 2 + train_micro_batch_size: 1 + logprob_batch_size: ${policy.train_micro_batch_size} + max_total_sequence_length: 1024 + precision: "float32" + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + +data: + max_input_seq_length: ${policy.max_total_sequence_length} + train_data_path: "/path/to/train.jsonl" + val_data_path: "/path/to/val.jsonl" + +logger: + log_dir: "logs" # Base directory for all logs + wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running + tensorboard_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "dpo-dev" + name: "dpo" + tensorboard: + log_dir: "tb_logs-dpo-dev" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 1 + num_nodes: 1 diff --git a/examples/run_dpo.py b/examples/run_dpo.py new file mode 100644 index 0000000000..d3247734a3 --- /dev/null +++ b/examples/run_dpo.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint +from typing import Dict, Any + +from omegaconf import OmegaConf + +from nemo_reinforcer.algorithms.dpo import MasterConfig, dpo_train, setup +from nemo_reinforcer.distributed.virtual_cluster import init_ray +from nemo_reinforcer.utils.config import load_config +from nemo_reinforcer.utils.logger import get_next_experiment_dir +from nemo_reinforcer.data import DataConfig, hf_datasets +from nemo_reinforcer.data.datasets import AllTaskProcessedDataset +from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec +from nemo_reinforcer.data.llm_message_utils import get_formatted_message_log +from transformers import AutoTokenizer +from nemo_reinforcer.models.policy import PolicyConfig + +# from nemo_reinforcer.data.hf_datasets.helpsteer3 import HelpSteer3Dataset +from nemo_reinforcer.data.hf_datasets.dpo import DPODataset + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run DPO training with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, remaining = parser.parse_known_args() + + # Convert remaining args to OmegaConf format + overrides = OmegaConf.from_dotlist(remaining) + + return args, overrides + + +# ======================================================= +# Data Processing +# ======================================================= +def dpo_preprocessor( + datum_dict: Dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary for DPO training.""" + print(f"{datum_dict=}") + if isinstance(datum_dict["prompt"], list): + messages_chosen = datum_dict["prompt"] + messages_rejected = datum_dict["prompt"] + else: + messages_chosen = [ + { + "role": "user", + "content": datum_dict["prompt"], + }, + ] + messages_rejected = [ + { + "role": "user", + "content": datum_dict["prompt"], + }, + ] + + ## TODO: sometimes the context above includes assistant, but we don't want to train + ## on that. Only want to train on the chosen and rejected responses... right? How do we ensure this? + messages_chosen.append( + { + "role": "assistant", + "content": datum_dict["chosen_response"], + }, + ) + + messages_rejected.append( + { + "role": "assistant", + "content": datum_dict["rejected_response"], + }, + ) + + ## TODO: DO NOT APPLY CHAT TEMPLATE! + message_log_chosen = get_formatted_message_log( + messages_chosen, tokenizer, task_data_spec + ) + message_log_rejected = get_formatted_message_log( + messages_rejected, tokenizer, task_data_spec + ) + + length_chosen = sum(len(m["token_ids"]) for m in message_log_chosen) + length_rejected = sum(len(m["token_ids"]) for m in message_log_rejected) + + loss_multiplier = 1.0 + if max(length_chosen, length_rejected) > max_seq_length: + # make smaller and mask out + for message in message_log_chosen: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log_chosen)) + ] + for message in message_log_rejected: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log_rejected)) + ] + loss_multiplier = 0.0 + + output = { + "message_log_chosen": message_log_chosen, + "length_chosen": length_chosen, + "message_log_rejected": message_log_rejected, + "length_rejected": length_rejected, + "extra_env_info": None, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + return output + + +def setup_data(data_config: DataConfig, policy_config: PolicyConfig): + print("\nā–¶ Setting up data...") + # data = HelpSteer3Dataset() + # train_dataset = data.formatted_ds["train"] + # val_dataset = data.formatted_ds["validation"] + + data = DPODataset( + train_data_path=data_config["train_data_path"], + val_data_path=data_config["val_data_path"], + ) + train_dataset = data.formatted_ds["train"] + val_dataset = data.formatted_ds["validation"] + + dpo_task_spec = data.task_spec + + tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) + + train_dataset = AllTaskProcessedDataset( + train_dataset, + tokenizer, + dpo_task_spec, + dpo_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + + val_dataset = AllTaskProcessedDataset( + val_dataset, + tokenizer, + dpo_task_spec, + dpo_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + + return train_dataset, val_dataset, tokenizer, dpo_task_spec + + +def main(): + """Main entry point.""" + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join(os.path.dirname(__file__), "configs", "dpo.yaml") + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = OmegaConf.merge(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"šŸ“Š Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"šŸ“Š Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) + + init_ray() + + # setup data + dataset, val_dataset, tokenizer, dpo_task_spec = setup_data( + config["data"], config["policy"] + ) + ( + policy, + cluster, + train_dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + dpo_save_state, + master_config, + ) = setup(config, dataset, val_dataset) + dpo_train( + policy, + train_dataloader, + val_dataloader, + tokenizer, + loss_fn, + master_config, + logger, + dpo_task_spec, + checkpointer, + dpo_save_state, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo_reinforcer/data/hf_datasets/dpo.py b/nemo_reinforcer/data/hf_datasets/dpo.py index 483351af60..0131f6cd5a 100644 --- a/nemo_reinforcer/data/hf_datasets/dpo.py +++ b/nemo_reinforcer/data/hf_datasets/dpo.py @@ -17,11 +17,11 @@ def __init__(self, train_data_path: str, val_data_path: str): ## we need to do this outisde of the data class because we want to keep ## chosen and rejected responses for a given prompt together when shuffling self.formatted_ds = { - "train": load_dataset("json", data_files=train_data_path), - "validation": load_dataset("json", data_files=val_data_path), + "train": load_dataset("json", data_files=train_data_path, split="train"), + "validation": load_dataset("json", data_files=val_data_path, split="train"), } super().__init__( dataset_name="dpo", - ## no custom template. Assume we use tokenizer's template - # custom_template=COMMON_CHAT_TEMPLATES.simple_role_header, + ## passthrough template + custom_template="{% for message in messages %}{{ message['content'] }}{% endfor %}", ) From 4cc5a7f91b7ac1177009f564b9ad403b21e200fa Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 9 Apr 2025 14:22:45 -0700 Subject: [PATCH 05/57] add validation and checkpointing Signed-off-by: ashors1 --- examples/run_dpo.py | 1 - nemo_reinforcer/algorithms/dpo.py | 121 ++++++++++++++----- nemo_reinforcer/algorithms/loss_functions.py | 23 +++- nemo_reinforcer/data/datasets.py | 2 +- nemo_reinforcer/models/policy/hf_policy.py | 18 ++- 5 files changed, 123 insertions(+), 42 deletions(-) diff --git a/examples/run_dpo.py b/examples/run_dpo.py index d3247734a3..638bb77b12 100644 --- a/examples/run_dpo.py +++ b/examples/run_dpo.py @@ -61,7 +61,6 @@ def dpo_preprocessor( idx: int, ) -> DatumSpec: """Process a datum dictionary for DPO training.""" - print(f"{datum_dict=}") if isinstance(datum_dict["prompt"], list): messages_chosen = datum_dict["prompt"] messages_rejected = datum_dict["prompt"] diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index 15a7b2f3d6..9c1d68e242 100755 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -164,7 +164,7 @@ def setup( val_dataset, batch_size=dpo_config["val_global_batch_size"], shuffle=False, - collate_fn=dpo_collate_fn, + collate_fn=partial(dpo_collate_fn, tokenizer=tokenizer), drop_last=True, ) @@ -218,6 +218,26 @@ def setup( ) +def augment_dataloader(dataloader, policy, master_config): + dataloader_iter = iter(dataloader) + while True: + try: + batch = next(dataloader_iter) + + ## append ref policy logprobs to batch + logprobs = policy.get_reference_policy_logprobs( + batch, + ## TODO: make more robust + micro_batch_size=master_config["policy"]["train_micro_batch_size"] * 2, + )["reference_logprobs"].to("cpu") + batch["reference_policy_logprobs"] = torch.roll(logprobs, -1, dims=-1) + + yield batch + + except StopIteration: + break + + # ======================================================= # Training & Validation # ======================================================= @@ -234,7 +254,60 @@ def validate( val_mbs: int, ): """Run validation on the validation dataset.""" - ### TODO ### + if val_dataloader is None: + print(" āš ļø No validation dataloader provided, skipping validation") + return + + timer = Timer() + + with timer.time("total_validation_time"): + print(f"ā–¶ Starting validation at step {step}...") + + # Show a progress indicator for validation + # val_total = len(val_dataloader) + + for batch_idx, val_batch in enumerate( + augment_dataloader(val_dataloader, policy, master_config) + ): + ## just run model fwd + val_results = policy.train( + val_batch, + loss_fn, + eval_mode=True, + gbs=val_batch_size, + mbs=val_mbs, + ) + + ## TODO: this should already be averaged across microbatches.. why isn't it? + val_metrics = { + "loss": val_results["loss"].numpy(), + } + val_metrics.update(val_results["all_mb_metrics"]) + val_metrics = {k: np.mean(v).item() for k, v in val_metrics.items()} + if val_batches > 0 and batch_idx >= val_batches: + break + + # Calculate validation metrics + policy.prepare_for_training() + + # Get timing metrics + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + validation_time = timing_metrics.get("total_validation_time", 0) + + # Print summary of validation results + print("\nšŸ“Š Validation Results:") + print(f" • Validation loss: {float(val_metrics['loss']):.4f}") + print(f" • Validation accuracy: {float(val_metrics['accuracy']):.4f}") + + # Print timing information + print("\n ā±ļø Validation Timing:") + validation_time = timing_metrics.get("total_validation_time", 0) + print(f" • Total validation time: {validation_time:.2f}s") + + # Make sure to reset the timer after validation + timer.reset() + + return val_metrics, timing_metrics def dpo_train( @@ -264,7 +337,7 @@ def dpo_train( val_at_start = dpo_config["val_at_start"] # Run validation at the start if configured - """if val_at_start and step == 0: + if val_at_start and step == 0: print("\nšŸ” Running initial validation...") ## TODO @@ -282,35 +355,14 @@ def dpo_train( ) logger.log_metrics(val_metrics, step, prefix="validation") - logger.log_metrics(validation_timings, step, prefix="timing/validation")""" + logger.log_metrics(validation_timings, step, prefix="timing/validation") policy.prepare_for_training() - def augment_dataloader(dataloader): - dataloader_iter = iter(dataloader) - while True: - try: - batch = next(dataloader_iter) - - ## append ref policy logprobs to batch - logprobs = policy.get_reference_policy_logprobs( - batch, - ## TODO: make more robust - micro_batch_size=master_config["policy"]["train_micro_batch_size"] - * 2, - )["reference_logprobs"] - batch["reference_policy_logprobs"] = torch.roll(logprobs, -1, dims=-1) - - yield batch - - except StopIteration: - break - - for batch in augment_dataloader(train_dataloader): + for batch in augment_dataloader(train_dataloader, policy, master_config): print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") with timer.time("total_step_time"): - ## train_data.to("cpu") print("ā–¶ Taking a training step...") train_results = policy.train( batch, @@ -320,7 +372,7 @@ def augment_dataloader(dataloader): mbs=master_config["policy"]["train_micro_batch_size"] * 2, ) - """# Run validation if it's a validation step + # Run validation if it's a validation step if val_period > 0 and (step + 1) % val_period == 0: val_metrics, validation_timings = validate( policy, @@ -348,9 +400,16 @@ def augment_dataloader(dataloader): and (step + 1) % master_config["checkpointing"]["save_period"] == 0 ): # +1 because step is 0-indexed dpo_save_state["step"] = step + 1 - dpo_save_state["val_loss"] = val_metrics["val_loss"] + dpo_save_state["val_loss"] = val_metrics["loss"] with timer.time("checkpointing"): print(f"Saving checkpoint for step {step + 1}...") + is_last_checkpoint = ( + min( + len(train_dataloader), master_config["dpo"]["max_num_steps"] + ) + - (step + 1) + < master_config["checkpointing"]["save_period"] + ) checkpoint_path = checkpointer.init_tmp_checkpoint( step + 1, dpo_save_state, master_config ) @@ -358,15 +417,13 @@ def augment_dataloader(dataloader): policy.save_checkpoint( os.path.join(checkpoint_path, "policy.pt"), os.path.join(checkpoint_path, "policy_optimizer.pt"), - ## NOTE: below is a workaround to avoid a bug with checkpointing - ## this should be removed once the bug is fixed - offload_to_cpu=False, + save_hf=is_last_checkpoint, ) torch.save( train_dataloader.state_dict(), os.path.join(checkpoint_path, "train_dataloader.pt"), ) - checkpointer.finalize_checkpoint(checkpoint_path)""" + checkpointer.finalize_checkpoint(checkpoint_path) ## TODO: add more DPO metrics ## accuracy, sft loss, preference loss, etc. diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index fac8bb4e73..9b5688641a 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -237,11 +237,24 @@ def preference_loss( ## TODO: provide the option to average over tokens rewards = diff.sum(-1) + ## TODO: make configurable + if True: # average_log_probs: + # need to guard against divide by zero in case labels are all -100 + num_tokens_for_loss = token_mask.sum(-1) + rewards = rewards / num_tokens_for_loss.clamp(min=1) + + ## ignore the batches whose sample_mask is 0 + rewards = rewards[data["sample_mask"] == 1] + + if len(rewards) == 0: + return torch.tensor(0.0), torch.tensor(0.0) + rewards_chosen, rewards_rejected = self.split_output_tensor(rewards) rewards_delta = rewards_chosen - rewards_rejected + return -torch.nn.functional.logsigmoid( self.reference_policy_kl_penalty * rewards_delta - ).mean(0) + ).mean(0), (rewards_chosen > rewards_rejected).float().mean(0) def __call__( self, next_token_logits: torch.Tensor, data: BatchedDataDict[DPOLossDataDict] @@ -251,12 +264,17 @@ def __call__( sft_loss, _ = self.sft_loss( next_token_logits, data, reduce_across_batch=False ) + ## ignore the batches whose sample_mask is 0 + sft_loss = sft_loss[data["sample_mask"] == 1] sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) + if len(sft_loss_chosen) == 0: + sft_loss_chosen = torch.tensor(0.0) + ## average over the batch dimension sft_loss_chosen = sft_loss_chosen.mean(0) - preference_loss = self.preference_loss(next_token_logits, data) + preference_loss, accuracy = self.preference_loss(next_token_logits, data) dpo_loss = ( self.sft_loss_weight * sft_loss_chosen @@ -267,4 +285,5 @@ def __call__( "loss": dpo_loss.item(), "sft_loss": sft_loss_chosen.item(), "preference_loss": preference_loss.item(), + "accuracy": accuracy.item(), } diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 53cad2433a..6ef060164f 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -244,7 +244,7 @@ def dpo_collate_fn(data_batch: List[DatumSpec], tokenizer) -> BatchedDataDict: "input_ids": cat_and_padded["token_ids"], "input_lengths": input_lengths, "token_mask": cat_and_padded["token_loss_mask"], - "sample_mask": batch["loss_multiplier"], + "sample_mask": loss_multiplier, } ) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 720acbc6ff..aedfbfc633 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -314,11 +314,16 @@ def train( # Backward pass # Loss is accumulated across microbatches, so we need to scale by the number of microbatches - loss = loss / num_microbatches - if not eval_mode: - loss.backward() - mb_losses.append(loss.item()) - all_mb_metrics.append(loss_metrics) + # loss = loss / num_microbatches + + ## TODO: improve this + ## loss = 0 indicates that there are no valid examples in the microbatch + ## we should probably use a reserved value here + if loss != 0: + if not eval_mode: + loss.backward() + mb_losses.append(loss.item()) + all_mb_metrics.append(loss_metrics) # Clip gradients if not eval_mode: @@ -327,7 +332,8 @@ def train( # Update parameters self.optimizer.step() self.scheduler.step() - losses.append(torch.tensor(mb_losses).sum().item()) + + losses.append(torch.tensor(mb_losses).mean().item()) # Compute global loss across all ranks with torch.no_grad(): From c136aa7ee0d97fcabf770bd7079ac5f6b575cb2e Mon Sep 17 00:00:00 2001 From: Anna Shors Date: Fri, 11 Apr 2025 09:16:09 -0700 Subject: [PATCH 06/57] revert handling of too-long lines Signed-off-by: Anna Shors --- nemo_reinforcer/algorithms/loss_functions.py | 12 ++++++------ nemo_reinforcer/models/policy/hf_policy.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 9b5688641a..7a08909150 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -244,10 +244,10 @@ def preference_loss( rewards = rewards / num_tokens_for_loss.clamp(min=1) ## ignore the batches whose sample_mask is 0 - rewards = rewards[data["sample_mask"] == 1] + #rewards = rewards[data["sample_mask"] == 1] - if len(rewards) == 0: - return torch.tensor(0.0), torch.tensor(0.0) + #if len(rewards) == 0: + # return torch.tensor(0.0), torch.tensor(0.0) rewards_chosen, rewards_rejected = self.split_output_tensor(rewards) rewards_delta = rewards_chosen - rewards_rejected @@ -265,11 +265,11 @@ def __call__( next_token_logits, data, reduce_across_batch=False ) ## ignore the batches whose sample_mask is 0 - sft_loss = sft_loss[data["sample_mask"] == 1] + #sft_loss = sft_loss[data["sample_mask"] == 1] sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) - if len(sft_loss_chosen) == 0: - sft_loss_chosen = torch.tensor(0.0) + #if len(sft_loss_chosen) == 0: + # sft_loss_chosen = torch.tensor(0.0) ## average over the batch dimension sft_loss_chosen = sft_loss_chosen.mean(0) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index aedfbfc633..507c4a42c9 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -319,11 +319,11 @@ def train( ## TODO: improve this ## loss = 0 indicates that there are no valid examples in the microbatch ## we should probably use a reserved value here - if loss != 0: - if not eval_mode: - loss.backward() - mb_losses.append(loss.item()) - all_mb_metrics.append(loss_metrics) + #if loss != 0: + if not eval_mode: + loss.backward() + mb_losses.append(loss.item()) + all_mb_metrics.append(loss_metrics) # Clip gradients if not eval_mode: From e2307ae5039ac10ffc047e54602a72dcf42f2a4c Mon Sep 17 00:00:00 2001 From: Anna Shors Date: Fri, 11 Apr 2025 09:30:57 -0700 Subject: [PATCH 07/57] fix running validation with different batch size than training Signed-off-by: Anna Shors --- nemo_reinforcer/models/policy/hf_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 507c4a42c9..c0f67ee62d 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -974,7 +974,7 @@ def train( # Shard and replicate the batch shards = self.dp_size sharded_data = data.shard_by_batch_size( - shards, batch_size=self.cfg["train_global_batch_size"] + shards, batch_size=gbs or self.cfg["train_global_batch_size"] ) # Train each shard in parallel From dbe01d01703e2f5cdc1bb572cb58b6f9cc4b75d0 Mon Sep 17 00:00:00 2001 From: Anna Shors Date: Sat, 12 Apr 2025 13:14:58 -0700 Subject: [PATCH 08/57] changes for convergence testing Signed-off-by: Anna Shors --- examples/configs/dpo.yaml | 23 ++++++++++---------- nemo_reinforcer/algorithms/dpo.py | 4 ++-- nemo_reinforcer/algorithms/loss_functions.py | 11 +++++----- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 22f6dac102..763143272b 100644 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -1,14 +1,14 @@ # DPO Algorithm Configuration dpo: - max_num_steps: 60 - val_period: 10 + max_num_steps: 500 + val_period: 5 val_batches: 8 - val_global_batch_size: 32 - val_micro_batch_size: 2 + val_global_batch_size: 8 + val_micro_batch_size: 1 val_at_start: true seed: 42 - reference_policy_kl_penalty: 0.2 + reference_policy_kl_penalty: 0.1 ## TODO: support below #preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss #sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss @@ -23,17 +23,18 @@ checkpointing: metric_name: "val_loss" higher_is_better: false keep_top_k: 3 - save_period: 10 + save_period: 10000 -## TODO: OOM with mbs 2?? policy: - model_name: "meta-llama/Llama-3.2-1B-Instruct" - tokenizer_name: "meta-llama/Llama-3.2-1B-Instruct" - train_global_batch_size: 2 + model_name: "meta-llama/Meta-Llama-3.1-8B" + tokenizer_name: "meta-llama/Meta-Llama-3.1-8B" + train_global_batch_size: 256 train_micro_batch_size: 1 logprob_batch_size: ${policy.train_micro_batch_size} max_total_sequence_length: 1024 precision: "float32" + fsdp_offload_enabled: false + activation_checkpointing_enabled: false optimizer: name: "torch.optim.AdamW" @@ -63,5 +64,5 @@ logger: flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) cluster: - gpus_per_node: 1 + gpus_per_node: 8 num_nodes: 1 diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index 9c1d68e242..8783a8a2e2 100755 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -274,8 +274,8 @@ def validate( val_batch, loss_fn, eval_mode=True, - gbs=val_batch_size, - mbs=val_mbs, + gbs=val_batch_size * 2, + mbs=val_mbs * 2, ) ## TODO: this should already be averaged across microbatches.. why isn't it? diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 7a08909150..2498e8c41f 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -237,11 +237,12 @@ def preference_loss( ## TODO: provide the option to average over tokens rewards = diff.sum(-1) - ## TODO: make configurable - if True: # average_log_probs: - # need to guard against divide by zero in case labels are all -100 - num_tokens_for_loss = token_mask.sum(-1) - rewards = rewards / num_tokens_for_loss.clamp(min=1) + ## TODO: make configurable. For now, do not average across sequences + ## --> matches Aligner's default + #if average_log_probs: + # # need to guard against divide by zero in case labels are all -100 + # num_tokens_for_loss = token_mask.sum(-1) + # rewards = rewards / num_tokens_for_loss.clamp(min=1) ## ignore the batches whose sample_mask is 0 #rewards = rewards[data["sample_mask"] == 1] From 3ff1a7e9761c386f6b4378f64b0d6f653f7c241b Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Sat, 12 Apr 2025 21:01:04 -0700 Subject: [PATCH 09/57] no_grad and model.eval when in eval_mode Signed-off-by: Yi-Fu Wu --- nemo_reinforcer/models/policy/hf_policy.py | 181 +++++++++++---------- 1 file changed, 94 insertions(+), 87 deletions(-) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index c0f67ee62d..039d433389 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import gc import warnings import os @@ -265,97 +266,103 @@ def train( local_gbs = gbs // torch.distributed.get_world_size() dataset_size = data.get("input_ids").shape[0] - # Ensure model is in training mode - self.model.train() - - # Get data from batch and move to device - data.to("cuda") - - losses = [] - all_mb_metrics = [] - for gb_start in range(0, dataset_size, local_gbs): - self.optimizer.zero_grad() - mb_losses = [] - - # Calculate number of microbatches to process - # make_microbatch_iterator assumes that the batch size is a multiple of the microbatch size - # so its safe to not check for the case where the last data slice is smaller than mbs - num_microbatches = min(local_gbs, dataset_size - gb_start) // mbs - - for mb in data.slice( - gb_start, gb_start + local_gbs - ).make_microbatch_iterator(mbs): - input_ids = mb.get("input_ids") + if eval_mode: + ctx = torch.no_grad() + self.model.eval() + else: + ctx = contextlib.nullcontext() + # Ensure model is in training mode + self.model.train() - input_lengths = mb.get("input_lengths") - batch_size, seq_len = input_ids.shape - attention_mask = torch.ones( - (batch_size, seq_len), dtype=torch.long, device=input_ids.device - ) - for i, length in enumerate(input_lengths): - # For right-padded sequence, set 1s at the beginning of the sequence - attention_mask[i, :length] = 1 + with ctx: + # Get data from batch and move to device + data.to("cuda") - with torch.autocast(device_type="cuda", dtype=self.dtype): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - use_cache=False, + losses = [] + all_mb_metrics = [] + for gb_start in range(0, dataset_size, local_gbs): + self.optimizer.zero_grad() + mb_losses = [] + + # Calculate number of microbatches to process + # make_microbatch_iterator assumes that the batch size is a multiple of the microbatch size + # so its safe to not check for the case where the last data slice is smaller than mbs + num_microbatches = min(local_gbs, dataset_size - gb_start) // mbs + + for mb in data.slice( + gb_start, gb_start + local_gbs + ).make_microbatch_iterator(mbs): + input_ids = mb.get("input_ids") + + input_lengths = mb.get("input_lengths") + batch_size, seq_len = input_ids.shape + attention_mask = torch.ones( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device ) - # Get logprobs - if not hasattr(outputs, "logits"): - logits = self.model.lm_head(outputs.last_hidden_state) - else: - logits = outputs.logits - - loss, loss_metrics = loss_fn(logits, mb) - loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] - - # Backward pass - - # Loss is accumulated across microbatches, so we need to scale by the number of microbatches - # loss = loss / num_microbatches - - ## TODO: improve this - ## loss = 0 indicates that there are no valid examples in the microbatch - ## we should probably use a reserved value here - #if loss != 0: + for i, length in enumerate(input_lengths): + # For right-padded sequence, set 1s at the beginning of the sequence + attention_mask[i, :length] = 1 + + with torch.autocast(device_type="cuda", dtype=self.dtype): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False, + ) + # Get logprobs + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + + loss, loss_metrics = loss_fn(logits, mb) + loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] + + # Backward pass + + # Loss is accumulated across microbatches, so we need to scale by the number of microbatches + # loss = loss / num_microbatches + + ## TODO: improve this + ## loss = 0 indicates that there are no valid examples in the microbatch + ## we should probably use a reserved value here + #if loss != 0: + if not eval_mode: + loss.backward() + mb_losses.append(loss.item()) + all_mb_metrics.append(loss_metrics) + + # Clip gradients if not eval_mode: - loss.backward() - mb_losses.append(loss.item()) - all_mb_metrics.append(loss_metrics) - - # Clip gradients - if not eval_mode: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - - # Update parameters - self.optimizer.step() - self.scheduler.step() - - losses.append(torch.tensor(mb_losses).mean().item()) - - # Compute global loss across all ranks - with torch.no_grad(): - local_loss = torch.tensor(losses, device="cuda") - global_loss = torch.zeros_like(local_loss) - torch.distributed.all_reduce(local_loss) - global_loss = local_loss / torch.distributed.get_world_size() - - # Aggregate metrics across all microbatches - mb_metrics = defaultdict(list) - for m in all_mb_metrics: - for k, v in m.items(): - mb_metrics[k].append(v) - - metrics = { - "global_loss": global_loss.cpu(), - "local_loss": local_loss.cpu(), - "rank": torch.distributed.get_rank(), - "all_mb_metrics": dict(mb_metrics), - } - - return metrics + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + # Update parameters + self.optimizer.step() + self.scheduler.step() + + losses.append(torch.tensor(mb_losses).mean().item()) + + # Compute global loss across all ranks + with torch.no_grad(): + local_loss = torch.tensor(losses, device="cuda") + global_loss = torch.zeros_like(local_loss) + torch.distributed.all_reduce(local_loss) + global_loss = local_loss / torch.distributed.get_world_size() + + # Aggregate metrics across all microbatches + mb_metrics = defaultdict(list) + for m in all_mb_metrics: + for k, v in m.items(): + mb_metrics[k].append(v) + + metrics = { + "global_loss": global_loss.cpu(), + "local_loss": local_loss.cpu(), + "rank": torch.distributed.get_rank(), + "all_mb_metrics": dict(mb_metrics), + } + + return metrics def get_logprobs( self, data: BatchedDataDict, micro_batch_size: int = None From b2b3c66d7d10a220d2ef4234bb196bae0b0d58c7 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Sun, 13 Apr 2025 17:07:01 -0700 Subject: [PATCH 10/57] clean up, add support for average_log_probs Signed-off-by: ashors1 --- examples/configs/dpo.yaml | 10 +- nemo_reinforcer/algorithms/dpo.py | 200 ++++++++++--------- nemo_reinforcer/algorithms/loss_functions.py | 36 +--- 3 files changed, 119 insertions(+), 127 deletions(-) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 763143272b..7a3d636031 100644 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -1,5 +1,6 @@ # DPO Algorithm Configuration dpo: + max_num_epochs: 1 max_num_steps: 500 val_period: 5 val_batches: 8 @@ -9,11 +10,12 @@ dpo: seed: 42 reference_policy_kl_penalty: 0.1 - ## TODO: support below - #preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss - #sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss - #gt_reward_scale: 1. # the scale of the rewards in RPO + preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss + sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss + + ## TODO(@ashors) support other loss functions #preference_loss: dpo # the preference loss, we support dpo, ipo, rpo_sq, rpo_bwd_kl, rpo_fwd_kl + #gt_reward_scale: 1. # the scale of the rewards in RPO preference_loss_weight: 1 # the coefficient of the preference loss sft_loss_weight: 0 # the coefficient of the SFT loss diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index 8783a8a2e2..7a8d13ee25 100755 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -56,6 +56,7 @@ def _default_dpo_save_state() -> DPOSaveState: class DPOConfig(TypedDict): + max_num_epochs: int max_num_steps: int val_period: int val_batches: int @@ -65,11 +66,11 @@ class DPOConfig(TypedDict): seed: int reference_policy_kl_penalty: float - ## TODO: support below - # preference_average_log_probs: False - # sft_average_log_probs: ${.preference_average_log_probs} - # gt_reward_scale: 1. - # preference_loss: dpo + preference_average_log_probs: bool + sft_average_log_probs: bool + ## TODO(@ashors) support other loss functions + # preference_loss: str + # gt_reward_scale: float preference_loss_weight: float sft_loss_weight: float @@ -359,102 +360,105 @@ def dpo_train( policy.prepare_for_training() - for batch in augment_dataloader(train_dataloader, policy, master_config): - print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") + for epoch in range(master_config["dpo"]["max_num_epochs"]): + for batch in augment_dataloader(train_dataloader, policy, master_config): + print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") - with timer.time("total_step_time"): - print("ā–¶ Taking a training step...") - train_results = policy.train( - batch, - loss_fn, - eval_mode=False, - gbs=master_config["policy"]["train_global_batch_size"] * 2, - mbs=master_config["policy"]["train_micro_batch_size"] * 2, - ) - - # Run validation if it's a validation step - if val_period > 0 and (step + 1) % val_period == 0: - val_metrics, validation_timings = validate( - policy, - val_dataloader, - tokenizer, + with timer.time("total_step_time"): + print("ā–¶ Taking a training step...") + train_results = policy.train( + batch, loss_fn, - step=step + 1, - master_config=master_config, - dpo_task_spec=dpo_task_spec, - val_batches=dpo_config["val_batches"], - val_batch_size=dpo_config["val_global_batch_size"], - val_mbs=dpo_config["val_micro_batch_size"], - ) - logger.log_metrics( - validation_timings, step + 1, prefix="timing/validation" + eval_mode=False, + gbs=master_config["policy"]["train_global_batch_size"] * 2, + mbs=master_config["policy"]["train_micro_batch_size"] * 2, ) - logger.log_metrics(val_metrics, step + 1, prefix="validation") - - ## Checkpointing - dpo_save_state["consumed_samples"] += master_config["policy"][ - "train_global_batch_size" - ] - if ( - master_config["checkpointing"]["enabled"] - and (step + 1) % master_config["checkpointing"]["save_period"] == 0 - ): # +1 because step is 0-indexed - dpo_save_state["step"] = step + 1 - dpo_save_state["val_loss"] = val_metrics["loss"] - with timer.time("checkpointing"): - print(f"Saving checkpoint for step {step + 1}...") - is_last_checkpoint = ( - min( - len(train_dataloader), master_config["dpo"]["max_num_steps"] - ) - - (step + 1) - < master_config["checkpointing"]["save_period"] - ) - checkpoint_path = checkpointer.init_tmp_checkpoint( - step + 1, dpo_save_state, master_config - ) - ## TODO: move checkpointing logic elsewhere? - policy.save_checkpoint( - os.path.join(checkpoint_path, "policy.pt"), - os.path.join(checkpoint_path, "policy_optimizer.pt"), - save_hf=is_last_checkpoint, + + # Run validation if it's a validation step + if val_period > 0 and (step + 1) % val_period == 0: + val_metrics, validation_timings = validate( + policy, + val_dataloader, + tokenizer, + loss_fn, + step=step + 1, + master_config=master_config, + dpo_task_spec=dpo_task_spec, + val_batches=dpo_config["val_batches"], + val_batch_size=dpo_config["val_global_batch_size"], + val_mbs=dpo_config["val_micro_batch_size"], ) - torch.save( - train_dataloader.state_dict(), - os.path.join(checkpoint_path, "train_dataloader.pt"), + logger.log_metrics( + validation_timings, step + 1, prefix="timing/validation" ) - checkpointer.finalize_checkpoint(checkpoint_path) - - ## TODO: add more DPO metrics - ## accuracy, sft loss, preference loss, etc. - losses = train_results["loss"] - metrics = { - "loss": train_results["loss"].numpy(), - } - metrics.update(train_results["all_mb_metrics"]) - metrics = {k: np.mean(v).item() for k, v in metrics.items()} - timing_metrics = timer.get_timing_metrics(reduction_op="sum") - - print("\nšŸ“Š Training Results:") - print(f" • Loss: {float(metrics['loss']):.4f}") - print("\nā±ļø Timing:") - # Display total time first, separately - total_time = timing_metrics.get("total_step_time", 0) - print(f" • Total step time: {total_time:.2f}s") - - # Display all other timing metrics (if any) - for k, v in sorted( - timing_metrics.items(), key=lambda item: item[1], reverse=True - ): - if k != "total_step_time": - percent = (v / total_time * 100) if total_time > 0 else 0 - print(f" • {k}: {v:.2f}s ({percent:.1f}%)") - - logger.log_metrics(metrics, step + 1, prefix="train") - logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") - - timer.reset() - step += 1 + logger.log_metrics(val_metrics, step + 1, prefix="validation") + + ## Checkpointing + dpo_save_state["consumed_samples"] += master_config["policy"][ + "train_global_batch_size" + ] + if ( + master_config["checkpointing"]["enabled"] + and (step + 1) % master_config["checkpointing"]["save_period"] == 0 + ): # +1 because step is 0-indexed + dpo_save_state["step"] = step + 1 + dpo_save_state["val_loss"] = val_metrics["loss"] + with timer.time("checkpointing"): + print(f"Saving checkpoint for step {step + 1}...") + is_last_checkpoint = ( + min( + len(train_dataloader) + * master_config["dpo"]["max_num_epochs"], + master_config["dpo"]["max_num_steps"], + ) + - (step + 1) + < master_config["checkpointing"]["save_period"] + ) + checkpoint_path = checkpointer.init_tmp_checkpoint( + step + 1, dpo_save_state, master_config + ) + ## TODO: move checkpointing logic elsewhere? + policy.save_checkpoint( + os.path.join(checkpoint_path, "policy.pt"), + os.path.join(checkpoint_path, "policy_optimizer.pt"), + save_hf=is_last_checkpoint, + ) + torch.save( + train_dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) - if step >= master_config["dpo"]["max_num_steps"]: - break + ## TODO: add more DPO metrics + ## accuracy, sft loss, preference loss, etc. + losses = train_results["loss"] + metrics = { + "loss": train_results["loss"].numpy(), + } + metrics.update(train_results["all_mb_metrics"]) + metrics = {k: np.mean(v).item() for k, v in metrics.items()} + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + + print("\nšŸ“Š Training Results:") + print(f" • Loss: {float(metrics['loss']):.4f}") + print("\nā±ļø Timing:") + # Display total time first, separately + total_time = timing_metrics.get("total_step_time", 0) + print(f" • Total step time: {total_time:.2f}s") + + # Display all other timing metrics (if any) + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + + logger.log_metrics(metrics, step + 1, prefix="train") + logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") + + timer.reset() + step += 1 + + if step >= master_config["dpo"]["max_num_steps"]: + break diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 2498e8c41f..655c6326d1 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -152,7 +152,7 @@ def __call__( self, next_token_logits: torch.Tensor, data: BatchedDataDict, - reduce_across_batch: bool = True, + average_log_probs: bool = False, ) -> Tuple[torch.Tensor, dict]: # logits shape: [batch_size, seq_len, vocab_size] # Get the next token logits for each position @@ -174,11 +174,11 @@ def __call__( num_unmasked_tokens = torch.sum(mask, -1) num_unmasked_tokens[num_unmasked_tokens == 0] = 1 - if reduce_across_batch: + if average_log_probs: num_unmasked_tokens = num_unmasked_tokens.sum().item() loss = (-torch.sum(token_logprobs * mask) / num_unmasked_tokens).item() else: - loss = -torch.sum(token_logprobs * mask, dim=-1) / num_unmasked_tokens + loss = -torch.sum(token_logprobs * mask, dim=-1) return loss, { "loss": loss, @@ -191,6 +191,8 @@ class DPOLossConfig(TypedDict): reference_policy_kl_penalty: float preference_loss_weight: float = 1.0 sft_loss_weight: float = 0.0 + preference_average_log_probs: bool = False + sft_average_log_probs: bool = False class DPOLossDataDict(TypedDict): @@ -207,6 +209,8 @@ def __init__(self, cfg: DPOLossConfig): self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.preference_loss_weight = cfg["preference_loss_weight"] self.sft_loss_weight = cfg["sft_loss_weight"] + self.preference_average_log_probs = cfg["preference_average_log_probs"] + self.sft_average_log_probs = cfg["sft_average_log_probs"] self.sft_loss = NLLLoss() def split_output_tensor(self, tensor: torch.Tensor): @@ -234,21 +238,10 @@ def preference_loss( ref_logprobs = data["reference_policy_logprobs"][:, :-1] diff = (token_logprobs - ref_logprobs) * token_mask - ## TODO: provide the option to average over tokens - rewards = diff.sum(-1) - - ## TODO: make configurable. For now, do not average across sequences - ## --> matches Aligner's default - #if average_log_probs: - # # need to guard against divide by zero in case labels are all -100 - # num_tokens_for_loss = token_mask.sum(-1) - # rewards = rewards / num_tokens_for_loss.clamp(min=1) - - ## ignore the batches whose sample_mask is 0 - #rewards = rewards[data["sample_mask"] == 1] - #if len(rewards) == 0: - # return torch.tensor(0.0), torch.tensor(0.0) + rewards = diff.sum(-1) + if self.preference_average_log_probs: + rewards = rewards / mask.sum(-1).clamp(min=1) rewards_chosen, rewards_rejected = self.split_output_tensor(rewards) rewards_delta = rewards_chosen - rewards_rejected @@ -263,16 +256,9 @@ def __call__( sft_loss_chosen = torch.tensor(0.0) if self.sft_loss_weight > 0: sft_loss, _ = self.sft_loss( - next_token_logits, data, reduce_across_batch=False + next_token_logits, data, average_log_probs=self.sft_average_log_probs ) - ## ignore the batches whose sample_mask is 0 - #sft_loss = sft_loss[data["sample_mask"] == 1] sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) - - #if len(sft_loss_chosen) == 0: - # sft_loss_chosen = torch.tensor(0.0) - - ## average over the batch dimension sft_loss_chosen = sft_loss_chosen.mean(0) preference_loss, accuracy = self.preference_loss(next_token_logits, data) From 8145e4cab4ce22f779e5bd5d37bc5d14f53855fa Mon Sep 17 00:00:00 2001 From: ashors1 Date: Sun, 13 Apr 2025 17:08:54 -0700 Subject: [PATCH 11/57] drop_last during training Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/dpo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index 7a8d13ee25..715fc04334 100755 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -153,6 +153,7 @@ def setup( batch_size=policy_config["train_global_batch_size"], shuffle=True, collate_fn=partial(dpo_collate_fn, tokenizer=tokenizer), + drop_last=True, ) if last_checkpoint_path is not None: From 78ae2eb9133e53d525d2797ff003e932a892ec8b Mon Sep 17 00:00:00 2001 From: ashors1 Date: Sun, 13 Apr 2025 17:36:07 -0700 Subject: [PATCH 12/57] small fixes for checkpointing and multi-epoch Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/dpo.py | 97 +++++++++++++++++++------------ 1 file changed, 59 insertions(+), 38 deletions(-) mode change 100755 => 100644 nemo_reinforcer/algorithms/dpo.py diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py old mode 100755 new mode 100644 index 715fc04334..b19d3d612e --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -43,14 +43,18 @@ class DPOSaveState(TypedDict): - step: int + epoch: int # Track current epoch + step: int # Track step within current epoch + total_steps: int # Track total number of steps across all epochs val_loss: float consumed_samples: int def _default_dpo_save_state() -> DPOSaveState: return { + "epoch": 0, "step": 0, + "total_steps": 0, "consumed_samples": 0, } @@ -320,7 +324,7 @@ def dpo_train( loss_fn, master_config, logger, - dpo_task_spec, ### TODO: flesh out + dpo_task_spec, ## TODO: do we need? checkpointer, dpo_save_state, ): @@ -329,20 +333,23 @@ def dpo_train( if dpo_save_state is None: dpo_save_state = _default_dpo_save_state() - step = 0 + current_epoch = 0 + current_step = 0 + total_steps = 0 else: - step = dpo_save_state["step"] + current_epoch = dpo_save_state["epoch"] + current_step = dpo_save_state["step"] + total_steps = dpo_save_state["total_steps"] dpo_config = master_config["dpo"] # Validation configuration val_period = dpo_config["val_period"] val_at_start = dpo_config["val_at_start"] + max_num_epochs = dpo_config["max_num_epochs"] # Run validation at the start if configured - if val_at_start and step == 0: + if val_at_start and total_steps == 0: print("\nšŸ” Running initial validation...") - - ## TODO val_metrics, validation_timings = validate( policy, val_dataloader, @@ -356,14 +363,18 @@ def dpo_train( val_mbs=dpo_config["val_micro_batch_size"], ) - logger.log_metrics(val_metrics, step, prefix="validation") - logger.log_metrics(validation_timings, step, prefix="timing/validation") + logger.log_metrics(val_metrics, total_steps, prefix="validation") + logger.log_metrics(validation_timings, total_steps, prefix="timing/validation") policy.prepare_for_training() - for epoch in range(master_config["dpo"]["max_num_epochs"]): + while current_epoch < max_num_epochs: + print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") + for batch in augment_dataloader(train_dataloader, policy, master_config): - print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") + print( + f"\n{'=' * 25} Step {current_step + 1}/{min(len(train_dataloader), master_config['dpo']['max_num_steps'])} {'=' * 25}" + ) with timer.time("total_step_time"): print("ā–¶ Taking a training step...") @@ -376,13 +387,13 @@ def dpo_train( ) # Run validation if it's a validation step - if val_period > 0 and (step + 1) % val_period == 0: + if val_period > 0 and (total_steps + 1) % val_period == 0: val_metrics, validation_timings = validate( policy, val_dataloader, tokenizer, loss_fn, - step=step + 1, + step=total_steps + 1, master_config=master_config, dpo_task_spec=dpo_task_spec, val_batches=dpo_config["val_batches"], @@ -390,9 +401,11 @@ def dpo_train( val_mbs=dpo_config["val_micro_batch_size"], ) logger.log_metrics( - validation_timings, step + 1, prefix="timing/validation" + validation_timings, total_steps + 1, prefix="timing/validation" + ) + logger.log_metrics( + val_metrics, total_steps + 1, prefix="validation" ) - logger.log_metrics(val_metrics, step + 1, prefix="validation") ## Checkpointing dpo_save_state["consumed_samples"] += master_config["policy"][ @@ -400,28 +413,34 @@ def dpo_train( ] if ( master_config["checkpointing"]["enabled"] - and (step + 1) % master_config["checkpointing"]["save_period"] == 0 + and (total_steps + 1) + % master_config["checkpointing"]["save_period"] + == 0 ): # +1 because step is 0-indexed - dpo_save_state["step"] = step + 1 + is_last_checkpoint = ( + min( + len(train_dataloader) * max_num_epochs, + master_config["dpo"]["max_num_steps"], + ) + - (total_steps + 1) + < master_config["checkpointing"]["save_period"] + ) + dpo_save_state["step"] = (current_step + 1) % len(train_dataloader) + dpo_save_state["total_steps"] = total_steps + 1 + dpo_save_state["epoch"] = current_epoch dpo_save_state["val_loss"] = val_metrics["loss"] with timer.time("checkpointing"): - print(f"Saving checkpoint for step {step + 1}...") - is_last_checkpoint = ( - min( - len(train_dataloader) - * master_config["dpo"]["max_num_epochs"], - master_config["dpo"]["max_num_steps"], - ) - - (step + 1) - < master_config["checkpointing"]["save_period"] - ) + print(f"Saving checkpoint for step {total_steps + 1}...") checkpoint_path = checkpointer.init_tmp_checkpoint( - step + 1, dpo_save_state, master_config + total_steps + 1, dpo_save_state, master_config ) - ## TODO: move checkpointing logic elsewhere? policy.save_checkpoint( - os.path.join(checkpoint_path, "policy.pt"), - os.path.join(checkpoint_path, "policy_optimizer.pt"), + weights_path=os.path.join( + checkpoint_path, "policy", "weights" + ), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), save_hf=is_last_checkpoint, ) torch.save( @@ -430,8 +449,6 @@ def dpo_train( ) checkpointer.finalize_checkpoint(checkpoint_path) - ## TODO: add more DPO metrics - ## accuracy, sft loss, preference loss, etc. losses = train_results["loss"] metrics = { "loss": train_results["loss"].numpy(), @@ -455,11 +472,15 @@ def dpo_train( percent = (v / total_time * 100) if total_time > 0 else 0 print(f" • {k}: {v:.2f}s ({percent:.1f}%)") - logger.log_metrics(metrics, step + 1, prefix="train") - logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") + logger.log_metrics(metrics, total_steps + 1, prefix="train") + logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") timer.reset() - step += 1 + current_step += 1 + total_steps += 1 - if step >= master_config["dpo"]["max_num_steps"]: - break + if total_steps >= master_config["dpo"]["max_num_steps"]: + return + + current_epoch += 1 + current_step = 0 # Reset step counter for new epoch From c039b7d914093a44350d1a4bf36b1e24021d392b Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 09:42:06 -0700 Subject: [PATCH 13/57] add copyright Signed-off-by: ashors1 --- nemo_reinforcer/data/hf_datasets/dpo.py | 13 +++++++++++++ nemo_reinforcer/data/hf_datasets/helpsteer3.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/nemo_reinforcer/data/hf_datasets/dpo.py b/nemo_reinforcer/data/hf_datasets/dpo.py index 0131f6cd5a..bbf0af3e65 100644 --- a/nemo_reinforcer/data/hf_datasets/dpo.py +++ b/nemo_reinforcer/data/hf_datasets/dpo.py @@ -1,3 +1,16 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from datasets import load_dataset from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset diff --git a/nemo_reinforcer/data/hf_datasets/helpsteer3.py b/nemo_reinforcer/data/hf_datasets/helpsteer3.py index 991a82f4e1..3e82362530 100644 --- a/nemo_reinforcer/data/hf_datasets/helpsteer3.py +++ b/nemo_reinforcer/data/hf_datasets/helpsteer3.py @@ -1,3 +1,16 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from datasets import load_dataset from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset From 0adec9be9f3f41ebe52572e70d4d78f328797f26 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 09:52:39 -0700 Subject: [PATCH 14/57] clean up loss Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/loss_functions.py | 32 +++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 655c6326d1..d998be6eaf 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -152,7 +152,8 @@ def __call__( self, next_token_logits: torch.Tensor, data: BatchedDataDict, - average_log_probs: bool = False, + dpo_loss: bool = False, + dpo_average_log_probs: bool = False, ) -> Tuple[torch.Tensor, dict]: # logits shape: [batch_size, seq_len, vocab_size] # Get the next token logits for each position @@ -169,16 +170,22 @@ def __call__( dim=-1, index=next_tokens.unsqueeze(-1) ).squeeze(-1) - # Only compute loss on generated tokens (not input tokens) - # by applying the token_loss_mask (shifted by 1 since we're predicting next tokens) - num_unmasked_tokens = torch.sum(mask, -1) - num_unmasked_tokens[num_unmasked_tokens == 0] = 1 - - if average_log_probs: - num_unmasked_tokens = num_unmasked_tokens.sum().item() - loss = (-torch.sum(token_logprobs * mask) / num_unmasked_tokens).item() - else: + if dpo_loss: + ## shape: [batch_size] + num_unmasked_tokens = torch.sum(mask, -1) loss = -torch.sum(token_logprobs * mask, dim=-1) + if dpo_average_log_probs: + loss = loss / num_unmasked_tokens.clamp(min=1) + else: + ## single scalar loss + # Only compute loss on generated tokens (not input tokens) + # by applying the token_loss_mask + num_unmasked_tokens = torch.sum(mask) + if num_unmasked_tokens == 0: + # prevent division by zero + num_unmasked_tokens = torch.tensor(1) + loss = (-torch.sum(token_logprobs * mask) / num_unmasked_tokens).item() + num_unmasked_tokens = num_unmasked_tokens.item() return loss, { "loss": loss, @@ -256,7 +263,10 @@ def __call__( sft_loss_chosen = torch.tensor(0.0) if self.sft_loss_weight > 0: sft_loss, _ = self.sft_loss( - next_token_logits, data, average_log_probs=self.sft_average_log_probs + next_token_logits, + data, + dpo_loss=True, + dpo_average_log_probs=self.sft_average_log_probs, ) sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) sft_loss_chosen = sft_loss_chosen.mean(0) From 63bec20f03ca364fcc624849943fa0d55d499d17 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 11:06:26 -0700 Subject: [PATCH 15/57] fixes for helpsteer dataset Signed-off-by: ashors1 --- examples/run_dpo.py | 24 +++++++------------ nemo_reinforcer/algorithms/dpo.py | 6 ----- nemo_reinforcer/data/hf_datasets/__init__.py | 4 +++- .../data/hf_datasets/helpsteer3.py | 1 - nemo_reinforcer/data/llm_message_utils.py | 1 - 5 files changed, 12 insertions(+), 24 deletions(-) diff --git a/examples/run_dpo.py b/examples/run_dpo.py index 638bb77b12..018430b015 100644 --- a/examples/run_dpo.py +++ b/examples/run_dpo.py @@ -30,9 +30,6 @@ from transformers import AutoTokenizer from nemo_reinforcer.models.policy import PolicyConfig -# from nemo_reinforcer.data.hf_datasets.helpsteer3 import HelpSteer3Dataset -from nemo_reinforcer.data.hf_datasets.dpo import DPODataset - def parse_args(): """Parse command line arguments.""" @@ -62,8 +59,8 @@ def dpo_preprocessor( ) -> DatumSpec: """Process a datum dictionary for DPO training.""" if isinstance(datum_dict["prompt"], list): - messages_chosen = datum_dict["prompt"] - messages_rejected = datum_dict["prompt"] + messages_chosen = datum_dict["prompt"].copy() + messages_rejected = datum_dict["prompt"].copy() else: messages_chosen = [ { @@ -78,8 +75,6 @@ def dpo_preprocessor( }, ] - ## TODO: sometimes the context above includes assistant, but we don't want to train - ## on that. Only want to train on the chosen and rejected responses... right? How do we ensure this? messages_chosen.append( { "role": "assistant", @@ -94,7 +89,6 @@ def dpo_preprocessor( }, ) - ## TODO: DO NOT APPLY CHAT TEMPLATE! message_log_chosen = get_formatted_message_log( messages_chosen, tokenizer, task_data_spec ) @@ -132,14 +126,14 @@ def dpo_preprocessor( def setup_data(data_config: DataConfig, policy_config: PolicyConfig): print("\nā–¶ Setting up data...") - # data = HelpSteer3Dataset() - # train_dataset = data.formatted_ds["train"] - # val_dataset = data.formatted_ds["validation"] - data = DPODataset( - train_data_path=data_config["train_data_path"], - val_data_path=data_config["val_data_path"], - ) + if data_config["dataset_name"] == "HelpSteer3": + data = hf_datasets.HelpSteer3Dataset() + else: + data = hf_datasets.DPODataset( + train_data_path=data_config["train_data_path"], + val_data_path=data_config["val_data_path"], + ) train_dataset = data.formatted_ds["train"] val_dataset = data.formatted_ds["validation"] diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index b19d3d612e..bdd592c97f 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from functools import partial -from pathlib import Path from typing import Optional, Tuple, TypedDict from tqdm import tqdm @@ -28,10 +26,6 @@ from nemo_reinforcer.data import DataConfig from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, dpo_collate_fn from nemo_reinforcer.data.interfaces import TaskDataSpec -from nemo_reinforcer.data.llm_message_utils import ( - add_dpo_loss_mask_to_message_log, - batched_message_log_to_flat_message, -) from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster from nemo_reinforcer.models.interfaces import PolicyInterface diff --git a/nemo_reinforcer/data/hf_datasets/__init__.py b/nemo_reinforcer/data/hf_datasets/__init__.py index df1227140d..6dc864e0d1 100644 --- a/nemo_reinforcer/data/hf_datasets/__init__.py +++ b/nemo_reinforcer/data/hf_datasets/__init__.py @@ -14,5 +14,7 @@ from nemo_reinforcer.data.hf_datasets.oasst import OasstDataset from nemo_reinforcer.data.hf_datasets.squad import SquadDataset +from nemo_reinforcer.data.hf_datasets.dpo import DPODataset +from nemo_reinforcer.data.hf_datasets.helpsteer3 import HelpSteer3Dataset -__all__ = ["OasstDataset", "SquadDataset"] +__all__ = ["OasstDataset", "SquadDataset", "DPODataset", "HelpSteer3Dataset"] diff --git a/nemo_reinforcer/data/hf_datasets/helpsteer3.py b/nemo_reinforcer/data/hf_datasets/helpsteer3.py index 3e82362530..fc6b9fc4c4 100644 --- a/nemo_reinforcer/data/hf_datasets/helpsteer3.py +++ b/nemo_reinforcer/data/hf_datasets/helpsteer3.py @@ -16,7 +16,6 @@ def format_helpsteer3(data): - context = data["context"] response_1 = data["response1"] response_2 = data["response2"] overall_preference = data["overall_preference"] diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py index f211632817..f4faf6d14b 100644 --- a/nemo_reinforcer/data/llm_message_utils.py +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -127,7 +127,6 @@ def add_loss_mask_to_message_log( sentence["token_loss_mask"] = torch.zeros_like(sentence["token_ids"]) -## TODO: VERIFY def add_dpo_loss_mask_to_message_log( message_log: LLMMessageLogType, ) -> None: From 444d5db4e59d43f8260f655ac0699305153a4d28 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 12:11:46 -0700 Subject: [PATCH 16/57] add loss function unit test Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/loss_functions.py | 8 +- tests/functional/dpo.sh | 38 ++ tests/unit/algorithms/test_loss_functions.py | 413 ++++++++++++++++++- 3 files changed, 450 insertions(+), 9 deletions(-) create mode 100755 tests/functional/dpo.sh diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index d998be6eaf..5421d6c727 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -184,11 +184,11 @@ def __call__( if num_unmasked_tokens == 0: # prevent division by zero num_unmasked_tokens = torch.tensor(1) - loss = (-torch.sum(token_logprobs * mask) / num_unmasked_tokens).item() + loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens num_unmasked_tokens = num_unmasked_tokens.item() return loss, { - "loss": loss, + "loss": loss.item() if loss.ndim == 0 else loss, "num_unmasked_tokens": num_unmasked_tokens, "total_tokens": mask.numel(), } @@ -226,8 +226,6 @@ def split_output_tensor(self, tensor: torch.Tensor): def preference_loss( self, next_token_logits: torch.Tensor, data: BatchedDataDict[DPOLossDataDict] ) -> torch.Tensor: - ## TODO: make sure this token mask only includes the chosen / rejected responses - ## and not prior assistant tokens ## TODO: there's some duplicate code here with the NLLLoss function. We should refactor token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] @@ -268,7 +266,9 @@ def __call__( dpo_loss=True, dpo_average_log_probs=self.sft_average_log_probs, ) + print(f"{sft_loss.shape=}") sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) + print(f"{sft_loss_chosen.shape=}") sft_loss_chosen = sft_loss_chosen.mean(0) preference_loss, accuracy = self.preference_loss(next_token_logits, data) diff --git a/tests/functional/dpo.sh b/tests/functional/dpo.sh new file mode 100755 index 0000000000..9c6dca6763 --- /dev/null +++ b/tests/functional/dpo.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +LOG_DIR=$SCRIPT_DIR/$(basename $0 .sh)-logs +JSON_METRICS=$LOG_DIR/$(basename $0 .sh).json +RUN_LOG=$LOG_DIR/$(basename $0 .sh).log +export RAY_DEDUP_LOGS=0 +export UV_CACHE_DIR=${UV_CACHE_DIR:-$PROJECT_ROOT/uv_cache} +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $LOG_DIR +mkdir -p $LOG_DIR + +cd $PROJECT_ROOT +python -u $PROJECT_ROOT/examples/run_dpo.py \ + cluster.gpus_per_node=2 \ + dpo.max_num_steps=10 \ + dpo.val_batches=1 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +cd $SCRIPT_DIR +python json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +## TODO: add loss check once config is finalized +#python check_metrics.py $JSON_METRICS \ +# 'data["train/loss"]["9"] < 1500' \ + diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index fe874ecc26..d87dc34240 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -13,7 +13,49 @@ # limitations under the License. import pytest import torch -from nemo_reinforcer.algorithms.loss_functions import NLLLoss +import numpy as np + +from nemo_reinforcer.algorithms.loss_functions import ( + NLLLoss, + ClippedPGLossFn, + DPOLossFn, +) +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.algorithms.utils import ( + calculate_kl_penalty_joschu2020, + masked_mean, +) + + +def setup_dpo_loss_test_data(vocab_size=16, batch_size=1): + data = { + "input_ids": torch.arange(vocab_size / 2) + .reshape(2 * batch_size, 4) + .to(torch.int64) + .to("cuda"), + "token_mask": torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]]).to("cuda"), + "sample_mask": torch.tensor([1, 1]).to("cuda"), + "reference_policy_logprobs": torch.zeros((2 * batch_size, 4)).to("cuda"), + } + + next_token_logits = torch.tensor( + [ + [ + [0.0] * vocab_size, + [0.0] * vocab_size, + [0.0] * vocab_size, + [0.0] * vocab_size, + ], + [ + [0.0] * vocab_size, + [0.0] * vocab_size, + [0.0] * vocab_size, + [0.0] * vocab_size, + ], + ] + ).to("cuda") + + return data, next_token_logits def test_nll_loss(): @@ -46,7 +88,7 @@ def test_nll_loss(): .to("cuda") ) loss, metrics_dict = loss_fn(next_token_logits, data) - torch.testing.assert_allclose(loss.cpu(), torch.tensor(0.0)) + torch.testing.assert_close(loss.cpu(), torch.tensor(0.0)) # Check the metrics dictionary contains the expected values assert metrics_dict["num_unmasked_tokens"] == 2 assert metrics_dict["total_tokens"] == 3 @@ -66,8 +108,369 @@ def test_nll_loss(): ) loss, metrics_dict = loss_fn(next_token_logits, data) ## loss per token is 999, and we have two unmasked tokens - ## with the updated loss function, we now average the loss over unmasked tokens - torch.testing.assert_allclose(loss.cpu(), torch.tensor(999.0)) - # Check the metrics dictionary contains the expected values + ## NLLLoss averages the loss over unmasked tokens + torch.testing.assert_close(loss.cpu(), torch.tensor(999.0)) assert metrics_dict["num_unmasked_tokens"] == 2 assert metrics_dict["total_tokens"] == 3 + + +def test_dpo_loss(): + vocab_size = 16 + batch_size = 1 + num_unmasked_tokens = 2 + data, next_token_logits = setup_dpo_loss_test_data( + vocab_size=vocab_size, + batch_size=batch_size, + ) + print(f"dpo {next_token_logits=}") + loss_fn = DPOLossFn( + cfg={ + "reference_policy_kl_penalty": 0.0, + "preference_loss_weight": 1.0, + "sft_loss_weight": 0.0, + "preference_average_log_probs": False, + "sft_average_log_probs": False, + } + ) + + loss, metrics_dict = loss_fn( + next_token_logits, + data, + ) + + ## chosen and rejected errors are the same, so difference between them is 0 + assert torch.isclose(loss.cpu(), -torch.nn.functional.logsigmoid(torch.tensor(0.0))) + + loss_fn_with_sft = DPOLossFn( + cfg={ + "reference_policy_kl_penalty": 0.0, + "preference_loss_weight": 1.0, + "sft_loss_weight": 0.5, + "preference_average_log_probs": False, + "sft_average_log_probs": False, + } + ) + + expected_sft_loss = ( + -( + torch.nn.functional.log_softmax(torch.tensor([[0.0] * vocab_size]), dim=-1)[ + :, 0 + ].sum() + ) + * num_unmasked_tokens + * batch_size + ) + expected_preference_loss = -torch.nn.functional.logsigmoid(torch.tensor(0.0)) + assert torch.isclose( + loss_fn_with_sft(next_token_logits, data)[0].cpu(), + 0.5 * expected_sft_loss + expected_preference_loss, + ) + + +def _setup_clipped_pg_test_data(batch_size=1, seq_len=4, vocab_size=8, device="cuda"): + """Sets up basic mock data structure. Tests should fill values.""" + input_ids = torch.randint( # Input IDs only needed if original loss fn used + 0, vocab_size, (batch_size, seq_len), dtype=torch.int64, device=device + ) + # Default mask: Mask first token [[0, 1, 1, 1]] + token_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) + token_mask[:, 0] = 0 + # sample_mask needs shape [B] + sample_mask = torch.ones(batch_size, dtype=torch.int64, device=device) + + # Simple default values, tests overwrite these + advantages = torch.zeros((batch_size, seq_len), device=device) + prev_logprobs = torch.zeros((batch_size, seq_len), device=device) + reference_policy_logprobs = torch.zeros((batch_size, seq_len), device=device) + generation_logprobs = torch.zeros((batch_size, seq_len), device=device) + + data = BatchedDataDict( + { + "input_ids": input_ids, # Include for completeness + "token_mask": token_mask, + "sample_mask": sample_mask, + "advantages": advantages, + "prev_logprobs": prev_logprobs, + "reference_policy_logprobs": reference_policy_logprobs, + "generation_logprobs": generation_logprobs, + } + ) + # Return seq_len and vocab_size needed by tests + return data, seq_len, vocab_size + + +# Helper to create logits that yield specific target log probs after log_softmax +def _create_exact_logits(target_curr_lp_masked, input_ids, seq_len, vocab_size, device): + """Constructs logits such that log_softmax results in target_curr_lp_masked.""" + dummy_logits = torch.full( + (1, seq_len, vocab_size), -100.0, device=device + ) # Start very low + + # Loss fn uses logits[:, :-1] and gathers based on next_tokens = input_ids[:, 1:] + # We need to set logits for indices i=0..S-2 of the sliced logits tensor. + # These correspond to target logprobs at indices 0..S-2 of target_curr_lp_masked. + num_effective_pos = target_curr_lp_masked.shape[1] + for i in range(num_effective_pos): + logit_idx = i # Index in the sliced logits tensor (dummy_logits[:, 0:S-1, :]) + data_idx = i + 1 # Index in the original input_ids to find the target token + + target_token_id = input_ids[0, data_idx].item() + # Keep target_lp as a 0-dim tensor for torch ops + target_lp = target_curr_lp_masked[0, i] + + # Handle target_lp = 0 case separately + if torch.isclose(target_lp, torch.tensor(0.0, device=device)): + dummy_logits[0, logit_idx, target_token_id] = 100.0 # Large positive logit + elif target_lp < 0: + # Set target token logit to 0 + dummy_logits[0, logit_idx, target_token_id] = 0.0 + # Set one distractor token logit using the formula + distractor_token_id = (target_token_id + 1) % vocab_size + # Ensure distractor isn't same as target if vocab_size=1 (edge case) + if distractor_token_id == target_token_id: + distractor_token_id = (target_token_id + 2) % vocab_size + distractor_logit = torch.log(torch.exp(-target_lp) - 1.0) + dummy_logits[0, logit_idx, distractor_token_id] = distractor_logit + else: # target_lp > 0 is not supported by this method + raise ValueError( + "Target log probability must be negative or zero for this construction" + ) + return dummy_logits + + +# Simplified PPO Clipping Test using original Loss +def test_clipped_pg_loss_ppo_clipping(): + """Tests PPO clipping calculations directly.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + ratio_eps = 0.2 + cfg = { + "ratio_eps_min": ratio_eps, + "ratio_eps_max": ratio_eps, + "reference_policy_kl_penalty": 0.0, # Disable KL + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + # Use non-zero prev_lp to allow ratios > 1 with valid curr_lp <= 0 + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + # Target Curr logprobs (masked pos 1, 2, 3) - design for clipping + # Target ratios: 0.5 (<0.8), 1.0 (in [0.8, 1.2]), 1.5 (>1.2) + # Curr = log(Ratio) + Prev + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + # Fill full tensors (only need first dim for B=1) + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + + # --- Hand Calculation --- + ratios = torch.exp(curr_lp_masked - prev_lp_masked) # approx [0.5, 1.0, 1.5] + ratios_clamped = torch.clamp( + ratios, 1.0 - ratio_eps, 1.0 + ratio_eps + ) # [0.8, 1.0, 1.2] + loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0] + loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4] + max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4] + expected_loss = torch.mean( + max_loss + ) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_loss) + + +# Simplified REINFORCE Test using original Loss +def test_clipped_pg_loss_reinforce_mode(): + """Tests REINFORCE mode calculations directly.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + cfg = { + "disable_ppo_ratio": True, + "reference_policy_kl_penalty": 0.0, + "ratio_eps_min": 0.0, # Placeholder, ignored + "ratio_eps_max": 0.0, # Placeholder, ignored + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + curr_lp_masked = torch.tensor([[-0.5, -1.0, -1.5]], device=device) + + data["advantages"][0, 1:] = adv_masked + data["_test_curr_logprobs"] = curr_lp_masked + data["prev_logprobs"][0, 1:] = torch.zeros_like(curr_lp_masked) + + # --- Hand Calculation --- + expected_loss_per_token = -adv_masked * curr_lp_masked # [0.5, -1.0, 3.0] + expected_loss = torch.mean(expected_loss_per_token) # 2.5 / 3 = 0.8333 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_loss) + + +# Simplified KL Penalty Test using original Loss +def test_clipped_pg_loss_kl_penalty(): + """Tests KL penalty calculations directly.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + # --- Test Setup --- + kl_beta = 0.1 + cfg = { + "reference_policy_kl_penalty": kl_beta, + "ratio_eps_min": 0.2, + "ratio_eps_max": 0.2, + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[0.0, 0.0, 0.0]], device=device) + curr_lp_masked = torch.tensor([[0.0, -1.0, -2.0]], device=device) + ref_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + prev_lp_masked = torch.tensor([[0.0, 0.0, 0.0]], device=device) + + data["advantages"][0, 1:] = adv_masked + data["reference_policy_logprobs"][0, 1:] = ref_lp_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + data["_test_curr_logprobs"] = curr_lp_masked + + # --- Hand Calculation --- + # Actor loss is 0. Total loss = kl_beta * mean(kl_term) + # kl_term = exp(ref - curr) - (ref - curr) - 1 + r = ref_lp_masked - curr_lp_masked # [-1.0, 0.0, 1.0] + kl_term_per_token = torch.exp(r) - r - 1 # [0.368, 0.0, 0.718] + expected_kl_mean = torch.mean(kl_term_per_token) # 0.362 + expected_loss = kl_beta * expected_kl_mean # 0.0362 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_loss) + + +# Masking tests - Should work with original Loss Fn if needed, but less critical +def test_clipped_pg_loss_masking(): + """Tests the effect of token_mask and sample_mask.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + batch_size = 2 + seq_len = 4 + device = "cuda" + # Use original loss function for masking tests, as it involves interactions + # that the Testable class might obscure slightly. + data, seq_len, vocab_size = _setup_clipped_pg_test_data( + batch_size=batch_size, seq_len=seq_len, device=device + ) + # Need some realistic-ish logits and logprobs for masking test + dummy_logits = torch.randn(batch_size, seq_len, vocab_size, device=device) + # Ensure logprobs used by the loss fn make sense relative to advantages + data["prev_logprobs"] = torch.randn_like(data["prev_logprobs"]) * 0.1 + data["reference_policy_logprobs"] = ( + torch.randn_like(data["reference_policy_logprobs"]) * 0.1 + ) + # Make advantages non-zero + data["advantages"] = torch.randn_like(data["advantages"]) + 1.0 + + cfg = { + "ratio_eps_min": 0.2, + "ratio_eps_max": 0.2, + "reference_policy_kl_penalty": 0.1, + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + + # --- Test 1: Token Mask --- + # Default mask: [[0, 1, 1, 1], [0, 1, 1, 1]] -> 3 tokens per sample + loss_default, _ = loss_fn(dummy_logits, data) + + # Modify token_mask for batch item 0 to mask one more token (pos 1) + data_mod_token = data.copy() + data_mod_token["token_mask"] = data["token_mask"].clone() + data_mod_token["token_mask"][0, 1] = ( + 0 # New mask: [[0, 0, 1, 1], [0, 1, 1, 1]] -> 2 tokens sample 0, 3 tokens sample 1 + ) + + loss_token_masked, _ = loss_fn(dummy_logits, data_mod_token) + # Loss should change if a potentially contributing token is masked + assert not torch.isclose(loss_default, loss_token_masked, atol=1e-4), ( + "Token mask did not change loss as expected" + ) + + # --- Test 2: Sample Mask --- + data_mod_sample = data.copy() + data_mod_sample["sample_mask"] = torch.tensor( + [1, 0], dtype=torch.int64, device=device + ) # Ignore item 1 + + loss_sample_masked, _ = loss_fn(dummy_logits, data_mod_sample) + + # Manually create data dict for only batch 0 + data_only_b0_dict = {} + for key, value in data.items(): + if isinstance(value, torch.Tensor): + if key == "sample_mask": + data_only_b0_dict[key] = value[0:1] + else: + data_only_b0_dict[key] = value[0:1] + else: + data_only_b0_dict[key] = value + data_only_b0 = BatchedDataDict(data_only_b0_dict) + + logits_only_b0 = dummy_logits[0:1] + loss_only_b0, _ = loss_fn(logits_only_b0, data_only_b0) + + torch.testing.assert_close(loss_sample_masked, loss_only_b0) + + +def test_clipped_pg_loss_zero_mask(): + """Tests the case where the combined mask sum is zero.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + # Need dummy logits + dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) + + cfg = { + "ratio_eps_min": 0.2, + "ratio_eps_max": 0.2, + "reference_policy_kl_penalty": 0.1, + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + + # Set token mask to all zeros + data["token_mask"] = torch.zeros_like(data["token_mask"]) + + loss, _ = loss_fn(dummy_logits, data) + + # Loss should be exactly zero + torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) From e9dae56044335cb0a2814b06987a324a539a1c10 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 12:22:34 -0700 Subject: [PATCH 17/57] cleanup Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/loss_functions.py | 2 -- tests/unit/algorithms/test_loss_functions.py | 24 ++++---------------- 2 files changed, 5 insertions(+), 21 deletions(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 5421d6c727..c683a76703 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -266,9 +266,7 @@ def __call__( dpo_loss=True, dpo_average_log_probs=self.sft_average_log_probs, ) - print(f"{sft_loss.shape=}") sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) - print(f"{sft_loss_chosen.shape=}") sft_loss_chosen = sft_loss_chosen.mean(0) preference_loss, accuracy = self.preference_loss(next_token_logits, data) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index d87dc34240..c8e989cd88 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -28,6 +28,7 @@ def setup_dpo_loss_test_data(vocab_size=16, batch_size=1): + seq_len = 4 data = { "input_ids": torch.arange(vocab_size / 2) .reshape(2 * batch_size, 4) @@ -35,26 +36,10 @@ def setup_dpo_loss_test_data(vocab_size=16, batch_size=1): .to("cuda"), "token_mask": torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]]).to("cuda"), "sample_mask": torch.tensor([1, 1]).to("cuda"), - "reference_policy_logprobs": torch.zeros((2 * batch_size, 4)).to("cuda"), + "reference_policy_logprobs": torch.zeros((2 * batch_size, seq_len)).to("cuda"), } - next_token_logits = torch.tensor( - [ - [ - [0.0] * vocab_size, - [0.0] * vocab_size, - [0.0] * vocab_size, - [0.0] * vocab_size, - ], - [ - [0.0] * vocab_size, - [0.0] * vocab_size, - [0.0] * vocab_size, - [0.0] * vocab_size, - ], - ] - ).to("cuda") - + next_token_logits = torch.zeros((2 * batch_size, seq_len, vocab_size)).to("cuda") return data, next_token_logits @@ -122,7 +107,6 @@ def test_dpo_loss(): vocab_size=vocab_size, batch_size=batch_size, ) - print(f"dpo {next_token_logits=}") loss_fn = DPOLossFn( cfg={ "reference_policy_kl_penalty": 0.0, @@ -166,6 +150,8 @@ def test_dpo_loss(): 0.5 * expected_sft_loss + expected_preference_loss, ) + ## TODO: test with a batch of varying sequence lengths + def _setup_clipped_pg_test_data(batch_size=1, seq_len=4, vocab_size=8, device="cuda"): """Sets up basic mock data structure. Tests should fill values.""" From 8045ab9c1ba88e2e975cde4640705c6be604b628 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 12:35:45 -0700 Subject: [PATCH 18/57] add test for augment_dataloader Signed-off-by: ashors1 --- tests/unit/algorithms/test_dpo.py | 75 +++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 tests/unit/algorithms/test_dpo.py diff --git a/tests/unit/algorithms/test_dpo.py b/tests/unit/algorithms/test_dpo.py new file mode 100644 index 0000000000..8c39e7e4d2 --- /dev/null +++ b/tests/unit/algorithms/test_dpo.py @@ -0,0 +1,75 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from unittest.mock import MagicMock, patch + +from nemo_reinforcer.algorithms.dpo import augment_dataloader + + +class MockPolicy: + def __init__(self, logprobs): + self.logprobs = logprobs + + def get_reference_policy_logprobs(self, batch, micro_batch_size): + return {"reference_logprobs": self.logprobs} + + +def test_augment_dataloader(): + """Test that augment_dataloader correctly adds reference policy logprobs to batches.""" + # Create mock data + batch_size = 2 + seq_len = 4 + vocab_size = 16 + + # Create a mock batch + mock_batch = { + "input_ids": torch.randint(0, vocab_size, (batch_size, seq_len)), + "attention_mask": torch.ones(batch_size, seq_len), + } + + # Create mock logprobs that will be returned by the policy + mock_logprobs = torch.randn(batch_size, seq_len) + + # Create a mock dataloader that yields our mock batch + mock_dataloader = MagicMock() + mock_dataloader.__iter__.return_value = iter([mock_batch]) + + # Create a mock policy that returns our mock logprobs + mock_policy = MockPolicy(mock_logprobs) + + # Create a mock master config + mock_master_config = {"policy": {"train_micro_batch_size": 1}} + + # Get the augmented batches + augmented_batches = list( + augment_dataloader(mock_dataloader, mock_policy, mock_master_config) + ) + + # Verify we got exactly one batch + assert len(augmented_batches) == 1 + augmented_batch = augmented_batches[0] + + # Verify the original batch data is preserved + assert torch.equal(augmented_batch["input_ids"], mock_batch["input_ids"]) + assert torch.equal(augmented_batch["attention_mask"], mock_batch["attention_mask"]) + + # Verify the reference policy logprobs were added correctly + assert "reference_policy_logprobs" in augmented_batch + assert augmented_batch["reference_policy_logprobs"].shape == (batch_size, seq_len) + + # Verify the logprobs were rolled by -1 as expected + expected_logprobs = torch.roll(mock_logprobs, -1, dims=-1) + assert torch.equal(augmented_batch["reference_policy_logprobs"], expected_logprobs) From db03bcc11cccff58e6579d72533164210b46de10 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 12:49:55 -0700 Subject: [PATCH 19/57] add dpo collate test Signed-off-by: ashors1 --- nemo_reinforcer/data/datasets.py | 10 +-- tests/unit/data/test_datasets.py | 146 +++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+), 6 deletions(-) create mode 100644 tests/unit/data/test_datasets.py diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 6ef060164f..b8688b022a 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -21,6 +21,10 @@ TaskDataProcessFnCallable, DatumSpec, ) +from nemo_reinforcer.data.llm_message_utils import ( + add_dpo_loss_mask_to_message_log, + batched_message_log_to_flat_message, +) from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict @@ -183,12 +187,6 @@ def eval_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: return output -from nemo_reinforcer.data.llm_message_utils import ( - add_dpo_loss_mask_to_message_log, - batched_message_log_to_flat_message, -) - - def dpo_collate_fn(data_batch: List[DatumSpec], tokenizer) -> BatchedDataDict: """Collate function for DPO training. diff --git a/tests/unit/data/test_datasets.py b/tests/unit/data/test_datasets.py new file mode 100644 index 0000000000..aafbad9405 --- /dev/null +++ b/tests/unit/data/test_datasets.py @@ -0,0 +1,146 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from unittest.mock import MagicMock + +from nemo_reinforcer.data.datasets import dpo_collate_fn +from nemo_reinforcer.data.interfaces import DatumSpec +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + + +def test_dpo_collate_fn(): + """Test that dpo_collate_fn correctly processes DPO training data.""" + # Create mock tokenizer + mock_tokenizer = MagicMock() + mock_tokenizer.eos_token_id = 0 + + # Create test data with varying sequence lengths + data_batch = [ + DatumSpec( + message_log_chosen=[ + { + "role": "user", + "content": "Hello", + "token_ids": torch.tensor([1, 2, 3]), + }, + { + "role": "assistant", + "content": "Hi there", + "token_ids": torch.tensor([4, 5, 6, 7]), + }, + ], + message_log_rejected=[ + { + "role": "user", + "content": "Hello", + "token_ids": torch.tensor([1, 2, 3]), + }, + { + "role": "assistant", + "content": "Bye", + "token_ids": torch.tensor([8, 9]), + }, + ], + length_chosen=7, + length_rejected=5, + loss_multiplier=1.0, + idx=0, + task_name="test_task", + ), + DatumSpec( + message_log_chosen=[ + { + "role": "user", + "content": "How are you?", + "token_ids": torch.tensor([10, 11, 12]), + }, + { + "role": "assistant", + "content": "I'm good", + "token_ids": torch.tensor([13, 14, 15]), + }, + ], + message_log_rejected=[ + { + "role": "user", + "content": "How are you?", + "token_ids": torch.tensor([10, 11, 12]), + }, + { + "role": "assistant", + "content": "Not great", + "token_ids": torch.tensor([16, 17, 18, 19]), + }, + ], + length_chosen=6, + length_rejected=7, + loss_multiplier=0, + idx=1, + task_name="test_task", + ), + ] + + # Call dpo_collate_fn + train_data = dpo_collate_fn(data_batch, mock_tokenizer) + + # Verify the output structure + assert isinstance(train_data, BatchedDataDict) + assert "input_ids" in train_data + assert "input_lengths" in train_data + assert "token_mask" in train_data + assert "sample_mask" in train_data + + # Verify batch size is doubled (chosen + rejected for each example) + assert train_data["input_ids"].shape[0] == 4 # 2 examples * 2 (chosen + rejected) + + # Verify input_ids shape and padding + max_length = 7 # max of all sequence lengths + assert train_data["input_ids"].shape == (4, max_length) + + # Verify input_lengths + expected_lengths = [7, 5, 6, 7] # chosen1, rejected1, chosen2, rejected2 + assert torch.equal(train_data["input_lengths"], torch.tensor(expected_lengths)) + + # Verify token_mask + assert train_data["token_mask"].shape == (4, max_length) + # First example chosen (length 7) + assert torch.all(train_data["token_mask"][0][0:3] == 0) + assert torch.all(train_data["token_mask"][0][3:7] == 1) + # First example rejected (length 5) + assert torch.all(train_data["token_mask"][1][0:3] == 0) + assert torch.all(train_data["token_mask"][1][3:5] == 1) + assert torch.all(train_data["token_mask"][1][5:] == 0) + + # Verify sample_mask + expected_sample_mask = [ + 1.0, + 1.0, + 0.0, + 0.0, + ] # loss_multiplier repeated for chosen/rejected + assert torch.equal(train_data["sample_mask"], torch.tensor(expected_sample_mask)) + + # Verify message content is preserved + # First example chosen + assert torch.equal(train_data["input_ids"][0][0:3], torch.tensor([1, 2, 3])) # user + assert torch.equal( + train_data["input_ids"][0][3:7], torch.tensor([4, 5, 6, 7]) + ) # assistant + # First example rejected + assert torch.equal(train_data["input_ids"][1][0:3], torch.tensor([1, 2, 3])) # user + assert torch.equal( + train_data["input_ids"][1][3:5], torch.tensor([8, 9]) + ) # assistant From 56044abd5f17ff3bab012436663fb16060a5e746 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 13:14:19 -0700 Subject: [PATCH 20/57] add dataset unit tests Signed-off-by: ashors1 --- tests/unit/data/hf_datasets/test_dpo.py | 106 ++++++++++++++++++ tests/unit/data/hf_datasets/test_helpsteer.py | 72 ++++++++++++ 2 files changed, 178 insertions(+) create mode 100644 tests/unit/data/hf_datasets/test_dpo.py create mode 100644 tests/unit/data/hf_datasets/test_helpsteer.py diff --git a/tests/unit/data/hf_datasets/test_dpo.py b/tests/unit/data/hf_datasets/test_dpo.py new file mode 100644 index 0000000000..476b333d33 --- /dev/null +++ b/tests/unit/data/hf_datasets/test_dpo.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import json +import pytest +from unittest.mock import patch, MagicMock + +from nemo_reinforcer.data.hf_datasets.dpo import DPODataset + + +@pytest.fixture +def mock_dpo_data(): + """Create temporary DPO dataset files with sample data.""" + train_data = [ + { + "prompt": "What is 2+2?", + "chosen_response": "The answer is 4.", + "rejected_response": "I don't know.", + }, + { + "prompt": "What is the capital of France?", + "chosen_response": "The capital of France is Paris.", + "rejected_response": "The capital of France is London.", + }, + ] + + val_data = [ + { + "prompt": "What is 3*3?", + "chosen_response": "The answer is 9.", + "rejected_response": "The answer is 6.", + } + ] + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as train_file: + json.dump(train_data, train_file) + train_path = train_file.name + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as val_file: + json.dump(val_data, val_file) + val_path = val_file.name + + yield train_path, val_path + + # Cleanup + os.unlink(train_path) + os.unlink(val_path) + + +def test_dpo_dataset_initialization(mock_dpo_data): + """Test that DPODataset initializes correctly with valid data files.""" + train_path, val_path = mock_dpo_data + + dataset = DPODataset(train_data_path=train_path, val_data_path=val_path) + + # Verify dataset initialization + assert dataset.task_spec.task_name == "dpo" + assert dataset.task_spec.custom_template is not None + assert "messages" in dataset.task_spec.custom_template + + # Verify formatted_ds structure + assert "train" in dataset.formatted_ds + assert "validation" in dataset.formatted_ds + + assert len(dataset.formatted_ds["train"]) == 2 + assert len(dataset.formatted_ds["validation"]) == 1 + + +def test_dpo_dataset_invalid_files(): + """Test that DPODataset raises appropriate errors with invalid files.""" + with pytest.raises(FileNotFoundError): + DPODataset(train_data_path="nonexistent.json", val_data_path="nonexistent.json") + + +def test_dpo_dataset_data_format(mock_dpo_data): + """Test that DPODataset correctly formats the data.""" + train_path, val_path = mock_dpo_data + dataset = DPODataset(train_data_path=train_path, val_data_path=val_path) + + # Verify data format + train_sample = dataset.formatted_ds["train"][0] + assert "prompt" in train_sample + assert "chosen_response" in train_sample + assert "rejected_response" in train_sample + + # Verify data content + assert train_sample["prompt"] == "What is 2+2?" + assert train_sample["chosen_response"] == "The answer is 4." + assert train_sample["rejected_response"] == "I don't know." diff --git a/tests/unit/data/hf_datasets/test_helpsteer.py b/tests/unit/data/hf_datasets/test_helpsteer.py new file mode 100644 index 0000000000..b139671dad --- /dev/null +++ b/tests/unit/data/hf_datasets/test_helpsteer.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from nemo_reinforcer.data.hf_datasets.helpsteer3 import ( + HelpSteer3Dataset, + format_helpsteer3, +) + + +def test_format_helpsteer3(): + """Test the format_helpsteer3 function with different preference values.""" + # Test case 1: response1 is preferred (overall_preference < 0) + data1 = { + "context": "What is 2+2?", + "response1": "The answer is 4.", + "response2": "I don't know.", + "overall_preference": -1, + } + result1 = format_helpsteer3(data1) + assert result1["prompt"] == "What is 2+2?" + assert result1["chosen_response"] == "The answer is 4." + assert result1["rejected_response"] == "I don't know." + + # Test case 2: response2 is preferred (overall_preference > 0) + data2 = { + "context": "What is the capital of France?", + "response1": "The capital of France is London.", + "response2": "The capital of France is Paris.", + "overall_preference": 1, + } + result2 = format_helpsteer3(data2) + assert result2["prompt"] == "What is the capital of France?" + assert result2["chosen_response"] == "The capital of France is Paris." + assert result2["rejected_response"] == "The capital of France is London." + + +def test_helpsteer3_dataset_initialization(): + """Test that HelpSteer3Dataset initializes correctly.""" + + dataset = HelpSteer3Dataset() + + # Verify dataset initialization + assert dataset.task_spec.task_name == "HelpSteer3" + assert dataset.task_spec.custom_template is None # Should use tokenizer's template + + +def test_helpsteer3_dataset_data_format(): + """Test that HelpSteer3Dataset correctly formats the data.""" + + dataset = HelpSteer3Dataset() + + assert isinstance(dataset.formatted_ds, dict) + assert "train" in dataset.formatted_ds + assert "validation" in dataset.formatted_ds + + # Verify data format + sample = dataset.formatted_ds["train"][0] + assert "prompt" in sample + assert "chosen_response" in sample + assert "rejected_response" in sample From 0645b48ebb1690d653ff1759cbdfcf3e99663ec5 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 15:34:36 -0700 Subject: [PATCH 21/57] add DPO documentation Signed-off-by: ashors1 --- README.md | 75 ++++++++++++++++++++++++++++- docs/guides/dpo.md | 117 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+), 2 deletions(-) create mode 100644 docs/guides/dpo.md diff --git a/README.md b/README.md index 044c9cd954..2fdcf049ef 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,12 @@ - [SFT](#sft) - [Single Node](#single-node) - [Multi-node](#multi-node) - - [GRPO](#grpo) + - [DPO](#dpo) - [Single Node](#single-node-1) - [Multi-node](#multi-node-1) + - [GRPO](#grpo) + - [Single Node](#single-node-2) + - [Multi-node](#multi-node-2) - [Cluster Start](#cluster-start) **Nemo-Reinforcer** is a scalable and efficient post-training library designed for models ranging from 1 GPU to thousands, and from tiny to over 100 billion parameters. @@ -33,10 +36,10 @@ _āœ… Available now | šŸ”œ Coming in v0.2_ - āœ… **Environment Support** - Support for multi-environment training. - āœ… **Learning Algorithms** - GRPO (Group Relative Policy Optimization) and SFT (Supervised Fine-Tuning) - āœ… **Worker Isolation** - Process isolation between RL Actors (no worries about global state) +- āœ… **DPO Algorithm** - Direct Preference Optimization for alignment - šŸ”œ **Larger Model Support** - Native PyTorch support for models up to 70B parameters - šŸ”œ **Advanced Parallelism** - FSDP2, TP, SP, and sequence packing for efficient training - šŸ”œ **Environment Isolation** - Dependency isolation between components -- šŸ”œ **DPO Algorithm** - Direct Preference Optimization for alignment ## Installation @@ -118,6 +121,74 @@ sbatch \ ray.sub ``` +### DPO + +We provide a sample DPO experiment that uses the [HelpSteer3 dataset](https://huggingface.co/datasets/nvidia/HelpSteer3) for preference-based training. + +#### Single Node + +The default DPO experiment is configured to run on a single GPU. To launch the experiment: + +```sh +uv run python examples/run_dpo.py +``` + +This trains `Llama3.2-1B-Instruct` on one GPU. + +If you have access to more GPUs, you can update the experiment accordingly. To run on 8 GPUs, we update the cluster configuration and switch to an 8B Llama base model: + +```sh +uv run python examples/run_dpo.py \ + policy.model_name="meta-llama/Meta-Llama-3-8B-Instruct" \ + policy.train_global_batch_size=128 \ + dpo.val_global_batch_size=128 \ + cluster.gpus_per_node=8 +``` + +Any of the DPO parameters can be customized from the command line. For example: + +```sh +uv run python examples/run_dpo.py \ + dpo.sft_loss_weight=0.1 \ + dpo.preference_average_log_probs=True \ + checkpointing.checkpoint_dir="results/llama_dpo_sft" \ + logger.wandb_enabled=True \ + logger.wandb.name="llama-dpo-sft" +``` + +Refer to [dpo.yaml](examples/configs/dpo.yaml) for a full list of parameters that can be overridden. For an in-depth explanation of how to add your own DPO dataset, refer to the [DPO documentation](docs/guides/dpo.md) + +#### Multi-node + +For distributed DPO training across multiple nodes: + +Set `UV_CACHE_DIR` to a directory that can be read from all workers before running any uv run command. +```sh +export UV_CACHE_DIR=/path/that/all/workers/can/access/uv_cache +``` + +```sh +# Run from the root of NeMo-Reinforcer repo +NUM_ACTOR_NODES=2 +# Add a timestamp to make each job name unique +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +# DPO experiment uses Llama-3.1-8B model +COMMAND="uv pip install -e .; uv run ./examples/run_dpo.py --config examples/configs/dpo.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 checkpointing.checkpoint_dir='results/dpo_llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='dpo-llama8b'" \ +RAY_DEDUP_LOGS=0 \ +UV_CACHE_DIR=YOUR_UV_CACHE_DIR \ +CONTAINER=YOUR_CONTAINER \ +MOUNTS="$PWD:$PWD" \ +sbatch \ + --nodes=${NUM_ACTOR_NODES} \ + --account=YOUR_ACCOUNT \ + --job-name=YOUR_JOBNAME \ + --partition=YOUR_PARTITION \ + --time=4:0:0 \ + --gres=gpu:8 \ + ray.sub +``` + ### GRPO We have a reference GRPO experiment config set up trained for math benchmarks using the [OpenInstructMath2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2) dataset. diff --git a/docs/guides/dpo.md b/docs/guides/dpo.md new file mode 100644 index 0000000000..dc053ad53c --- /dev/null +++ b/docs/guides/dpo.md @@ -0,0 +1,117 @@ +# Direct Preference Optimization in Reinforcer + +## Launch a DPO Run + +The script [examples/run_dpo.py](../../examples/run_dpo.py) can be used to launch a DPO experiment. This script can either be launched locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the [cluster documentation](../cluster.md). + +Be sure to launch the job using `uv`. The command to launch a DPO job is as follows: +```bash +uv run examples/run_dpo.py --config +``` +If not specified, `config` will default to [examples/configs/dpo.yaml](../../examples/configs/dpo.yaml). + +## Configuration + +Reinforcer allows users to configure DPO experiments using `yaml` config files. An example DPO configuration file can be found [here](../../examples/configs/dpo.yaml). + +To override a value in the config, either update the value in the `yaml` file directly, or pass the override via the command line. For example: + +```bash +uv run examples/run_dpo.py \ + cluster.gpus_per_node=8 \ + dpo.sft_loss_weight=0.1 \ + dpo.preference_average_log_probs=True \ + logger.wandb.name="dpo-dev-8-gpu" +``` + +**Reminder**: Don't forget to set your HF_HOME and WANDB_API_KEY (if needed). You'll need to do a `huggingface-cli login` as well for Llama models. + +## Datasets + +Each class representing a Reinforcer DPO dataset is expected to have the following attributes: +1. `formatted_ds`: The dictionary of formatted datasets. This dictionary should contain `train` and `validation` splits, and each split should conform to the format described below. +2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset. + +DPO datasets are expected to follow a specific format with three key fields: +- `prompt`: The input prompt/context +- `chosen_response`: The preferred/winning response +- `rejected_response`: The non-preferred/losing response + +[data/hf_datasets/helpsteer3.py](../../nemo_reinforcer/data/hf_datasets/helpsteer3.py) provides an example of how to format data for DPO: + +```python +def format_helpsteer3(data): + response_1 = data["response1"] + response_2 = data["response2"] + overall_preference = data["overall_preference"] + + return { + "prompt": data["context"], + "chosen_response": response_1 if overall_preference < 0 else response_2, + "rejected_response": response_2 if overall_preference < 0 else response_1, + } +``` + +We also provide a [DPODataset](../../nemo_reinforcer/data/hf_datasets/dpo.py) class that is compatible with jsonl-formatted preference datsets. This class assumes train and validation datasets have been split and processed into the expected format offline. The jsonl files should consist of examples with `prompt`, `chosen_response`, and `rejected_response` keys. + +## Adding Custom DPO Datasets + +Adding a new DPO dataset is straightforward. Your custom dataset class should: +1. Implement the required format conversion in the constructor +2. Set up the appropriate `task_spec` + +Here's a minimal example which simply re-keys an existing jsonl dataset: + +```python +from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset +from nemo_reinforcer.data.interfaces import TaskDataSpec + +class CustomDPODataset: + def preprocess_dataset( + self, + data, + prompt_key: str = "context", + chosen_key: str = "chosen", + rejected_key: str = "rejected" + ): + + return { + "prompt": data[prompt_key], + "chosen_response": data[chosen_key], + "rejected_response": data[rejected_key], + } + + + def __init__( + self, + train_data_path: str, + val_data_path: str, + prompt_key: str, + chosen_key: str, + rejected_key: str, + ): + # Load and format your dataset + formatted_ds = { + "train": load_dataset("json", data_files=train_data_path, split="train"), + "validation": load_dataset("json", data_files=val_data_path, split="train"), + } + + self.formatted_ds = formatted_ds.map() + + # Initialize task spec with dataset name + self.task_spec = TaskDataSpec( + dataset_name="custom_dpo", + ) +``` + +## DPO-Specific Parameters + +The DPO implementation in Reinforcer supports several key parameters that can be adjusted: + +- `dpo.reference_policy_kl_penalty`: Controls the strength of the KL penalty term +- `dpo.preference_loss_weight`: Weight for the preference loss +- `dpo.sft_loss_weight`: Weight for the auxiliary SFT loss +- `dpo.preference_average_log_probs`: Whether to average log probabilities over tokens in the preference loss term +- `dpo.sft_average_log_probs`: Whether to average log probabilities over tokens in the SFT loss term + +These parameters can be adjusted in the config file or via command-line overrides to optimize training for your specific use case. From bbcba50fe083c3bda30e4dfbea60c7eefd3d29ed Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 15:40:25 -0700 Subject: [PATCH 22/57] fix example Signed-off-by: ashors1 --- docs/guides/dpo.md | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/docs/guides/dpo.md b/docs/guides/dpo.md index dc053ad53c..b1fac0b5aa 100644 --- a/docs/guides/dpo.md +++ b/docs/guides/dpo.md @@ -63,6 +63,7 @@ Adding a new DPO dataset is straightforward. Your custom dataset class should: Here's a minimal example which simply re-keys an existing jsonl dataset: ```python +from datasets import load_dataset from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset from nemo_reinforcer.data.interfaces import TaskDataSpec @@ -90,13 +91,23 @@ class CustomDPODataset: chosen_key: str, rejected_key: str, ): + # Load and format your dataset + fn_kwargs={ + "prompt_key": prompt_key, + "chosen_key": chosen_key, + "rejected_key": rejected_key + } formatted_ds = { - "train": load_dataset("json", data_files=train_data_path, split="train"), - "validation": load_dataset("json", data_files=val_data_path, split="train"), + "train": load_dataset("json", data_files=train_data_path, split="train").map( + self.preprocess_dataset, + fn_kwargs=fn_kwargs, + ), + "validation": load_dataset("json", data_files=val_data_path, split="train").map( + self.preprocess_dataset, + fn_kwargs=fn_kwargs, + ), } - - self.formatted_ds = formatted_ds.map() # Initialize task spec with dataset name self.task_spec = TaskDataSpec( From 5d58614da216f630e542c187b62ac30a398a5ce3 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 16:20:40 -0700 Subject: [PATCH 23/57] cleanup and add docstrings Signed-off-by: ashors1 --- examples/run_dpo.py | 3 +- nemo_reinforcer/algorithms/loss_functions.py | 59 ++++++++++++++++++- nemo_reinforcer/data/datasets.py | 8 +-- nemo_reinforcer/data/hf_datasets/dpo.py | 27 +++++---- .../data/hf_datasets/helpsteer3.py | 2 + nemo_reinforcer/data/llm_message_utils.py | 8 ++- nemo_reinforcer/models/policy/hf_policy.py | 13 ++-- tests/functional/dpo.sh | 2 +- 8 files changed, 92 insertions(+), 30 deletions(-) diff --git a/examples/run_dpo.py b/examples/run_dpo.py index 018430b015..d7070ec5e8 100644 --- a/examples/run_dpo.py +++ b/examples/run_dpo.py @@ -20,6 +20,7 @@ from omegaconf import OmegaConf from nemo_reinforcer.algorithms.dpo import MasterConfig, dpo_train, setup +from nemo_reinforcer.algorithms.utils import get_tokenizer from nemo_reinforcer.distributed.virtual_cluster import init_ray from nemo_reinforcer.utils.config import load_config from nemo_reinforcer.utils.logger import get_next_experiment_dir @@ -139,7 +140,7 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig): dpo_task_spec = data.task_spec - tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) + tokenizer = get_tokenizer(policy_config["model_name"]) train_dataset = AllTaskProcessedDataset( train_dataset, diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index c683a76703..7d83b909e3 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -148,6 +148,8 @@ def __call__( class NLLLoss(LossFunction): + """Negative Log Likelihood Loss function.""" + def __call__( self, next_token_logits: torch.Tensor, @@ -212,6 +214,61 @@ class DPOLossDataDict(TypedDict): class DPOLossFn(LossFunction): + """Direct Preference Optimization (DPO) loss function. + + This loss function implements the DPO algorithm as described in: + "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" + (https://arxiv.org/abs/2305.18290) + + The loss combines two main components: + 1. Preference Loss: Optimizes the model to prefer chosen responses over rejected ones + 2. SFT Loss (optional): Auxiliary supervised fine-tuning loss on chosen responses + + The total loss is computed as: + L(Īø) = w_p * L_pref(Īø) + w_s * L_sft(Īø) + + where: + - w_p is the preference_loss_weight + - w_s is the sft_loss_weight + - L_pref(Īø) is the preference loss term + - L_sft(Īø) is the supervised fine-tuning loss term + + The preference loss term is computed as: + L_pref(Īø) = -E[log(σ(β * (r_chosen - r_rejected)))] + + where: + - σ is the sigmoid function + - β is the reference_policy_kl_penalty + - r_chosen and r_rejected are the rewards for chosen and rejected responses + - The rewards are computed as the sum of log probability differences between + the current policy and reference policy + + If preference_average_log_probs is True, the rewards are averaged over tokens: + r = (1/n) * Ī£_t (log Ļ€_Īø(a_t|s_t) - log Ļ€_ref(a_t|s_t)) + + Otherwise, the rewards are summed over tokens. + + The SFT loss term is a standard negative log likelihood loss on the chosen responses. + If sft_average_log_probs is True, the loss is averaged over tokens. + + Args: + cfg (DPOLossConfig): Configuration dictionary containing: + - reference_policy_kl_penalty (float): Strength of the KL penalty term (β) + - preference_loss_weight (float): Weight for the preference loss term (w_p) + - sft_loss_weight (float): Weight for the SFT loss term (w_s) + - preference_average_log_probs (bool): Whether to average log probs across tokens in preference loss + - sft_average_log_probs (bool): Whether to average log probs across tokens in SFT loss + + Returns: + Tuple[torch.Tensor, dict]: A tuple containing: + - The total loss value + - A dictionary with metrics including: + - loss: Total loss value + - sft_loss: SFT loss component + - preference_loss: Preference loss component + - accuracy: Fraction of examples where chosen response has higher reward + """ + def __init__(self, cfg: DPOLossConfig): self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.preference_loss_weight = cfg["preference_loss_weight"] @@ -226,7 +283,7 @@ def split_output_tensor(self, tensor: torch.Tensor): def preference_loss( self, next_token_logits: torch.Tensor, data: BatchedDataDict[DPOLossDataDict] ) -> torch.Tensor: - ## TODO: there's some duplicate code here with the NLLLoss function. We should refactor + ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index b8688b022a..282d1b5a64 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -191,7 +191,7 @@ def dpo_collate_fn(data_batch: List[DatumSpec], tokenizer) -> BatchedDataDict: """Collate function for DPO training. This function separates the chosen and rejected responses to create - two examples per prompt. The chosen and rejected examples are concatenated + two examples per prompt. The chosen and rejected examples are interleaved along the batch dimension, resulting in a batch size of 2 * len(data_batch). """ message_log = [] @@ -211,9 +211,6 @@ def dpo_collate_fn(data_batch: List[DatumSpec], tokenizer) -> BatchedDataDict: length = torch.tensor(length) loss_multiplier = torch.tensor(loss_multiplier) - ## TODO - # extra_env_info = [datum_spec["extra_env_info"] for datum_spec in data_batch] - batch_max_length = torch.ones_like(length) * length.max() batch = BatchedDataDict( @@ -233,8 +230,7 @@ def dpo_collate_fn(data_batch: List[DatumSpec], tokenizer) -> BatchedDataDict: cat_and_padded, input_lengths = batched_message_log_to_flat_message( batch["message_log"], - ## TODO: update pad value - pad_value_dict={"token_ids": tokenizer.eos_token_id}, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, ) train_data: BatchedDataDict = BatchedDataDict( diff --git a/nemo_reinforcer/data/hf_datasets/dpo.py b/nemo_reinforcer/data/hf_datasets/dpo.py index bbf0af3e65..03635374af 100644 --- a/nemo_reinforcer/data/hf_datasets/dpo.py +++ b/nemo_reinforcer/data/hf_datasets/dpo.py @@ -16,19 +16,24 @@ from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset -## assumptions about DPO dataset: -## json files should have the following keys: -## "prompt" -## "chosen_response" -## "rejected_response" class DPODataset(HfDataset): - def __init__(self, train_data_path: str, val_data_path: str): - ## TODO: assuming for now that data has been split into train and val - ## as an offline preprocessing step + """Dataset class for Direct Preference Optimization (DPO) training. + + This class handles loading of preference data for DPO training. + The input JSON files should contain examples with the following structure: + { + "prompt": str, # The input prompt/context + "chosen_response": str, # The preferred/winning response + "rejected_response": str # The non-preferred/losing response + } + + Args: + train_data_path (str): Path to the JSON file containing training data + val_data_path (str): Path to the JSON file containing validation data - ## TODO: update the keys to match with what's expected from apply_chat_template - ## we need to do this outisde of the data class because we want to keep - ## chosen and rejected responses for a given prompt together when shuffling + """ + + def __init__(self, train_data_path: str, val_data_path: str): self.formatted_ds = { "train": load_dataset("json", data_files=train_data_path, split="train"), "validation": load_dataset("json", data_files=val_data_path, split="train"), diff --git a/nemo_reinforcer/data/hf_datasets/helpsteer3.py b/nemo_reinforcer/data/hf_datasets/helpsteer3.py index fc6b9fc4c4..7c877e2301 100644 --- a/nemo_reinforcer/data/hf_datasets/helpsteer3.py +++ b/nemo_reinforcer/data/hf_datasets/helpsteer3.py @@ -28,6 +28,8 @@ def format_helpsteer3(data): class HelpSteer3Dataset(HfDataset): + """HelpSteer3 preference dataset for DPO training.""" + def __init__(self): ds = load_dataset("nvidia/HelpSteer3", "preference") self.formatted_ds = ds.map(format_helpsteer3) diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py index f4faf6d14b..9303e77875 100644 --- a/nemo_reinforcer/data/llm_message_utils.py +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -130,7 +130,13 @@ def add_loss_mask_to_message_log( def add_dpo_loss_mask_to_message_log( message_log: LLMMessageLogType, ) -> None: - """Only unmask the final assistant message in the log.""" + """Add token-level loss masks to each message in a message log. + + This function differs from add_loss_mask_to_message_log in that it only unmasks the final assistant message in the log. + + Args: + message_log (LLMMessageLogType): List of message dictionaries containing token IDs and metadata + """ for message in message_log: for i, sentence in enumerate(message): if i == len(message) - 1: diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 039d433389..26e3d376cf 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -320,13 +320,7 @@ def train( # Backward pass - # Loss is accumulated across microbatches, so we need to scale by the number of microbatches - # loss = loss / num_microbatches - - ## TODO: improve this - ## loss = 0 indicates that there are no valid examples in the microbatch - ## we should probably use a reserved value here - #if loss != 0: + ## TODO(@ashors): improve this if not eval_mode: loss.backward() mb_losses.append(loss.item()) @@ -334,7 +328,9 @@ def train( # Clip gradients if not eval_mode: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=1.0 + ) # Update parameters self.optimizer.step() @@ -488,7 +484,6 @@ def get_reference_policy_logprobs( We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor. """ - ## TODO: investigate this. This is super slow with self.use_reference_model(): reference_logprobs = self.get_logprobs(data, micro_batch_size) diff --git a/tests/functional/dpo.sh b/tests/functional/dpo.sh index 9c6dca6763..4ca9ac944b 100755 --- a/tests/functional/dpo.sh +++ b/tests/functional/dpo.sh @@ -20,7 +20,7 @@ mkdir -p $LOG_DIR cd $PROJECT_ROOT python -u $PROJECT_ROOT/examples/run_dpo.py \ cluster.gpus_per_node=2 \ - dpo.max_num_steps=10 \ + dpo.max_num_steps=5 \ dpo.val_batches=1 \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ From d8b767a3384b44019c7fe9871539aaf711ede5d6 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 17:17:11 -0700 Subject: [PATCH 24/57] add another test, clean up Signed-off-by: ashors1 --- examples/configs/dpo.yaml | 20 +++-- examples/run_dpo.py | 1 - nemo_reinforcer/algorithms/dpo.py | 5 +- tests/unit/algorithms/test_loss_functions.py | 81 +++++++++++++++++++- 4 files changed, 91 insertions(+), 16 deletions(-) mode change 100644 => 100755 examples/configs/dpo.yaml diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml old mode 100644 new mode 100755 index 7a3d636031..2afe3debb3 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -2,7 +2,7 @@ dpo: max_num_epochs: 1 max_num_steps: 500 - val_period: 5 + val_period: 10 val_batches: 8 val_global_batch_size: 8 val_micro_batch_size: 1 @@ -25,13 +25,13 @@ checkpointing: metric_name: "val_loss" higher_is_better: false keep_top_k: 3 - save_period: 10000 + save_period: 50 policy: - model_name: "meta-llama/Meta-Llama-3.1-8B" - tokenizer_name: "meta-llama/Meta-Llama-3.1-8B" - train_global_batch_size: 256 - train_micro_batch_size: 1 + model_name: "meta-llama/Llama-3.2-1B-Instruct" + tokenizer_name: "meta-llama/Llama-3.2-1B-Instruct" + train_global_batch_size: 16 + train_micro_batch_size: 2 logprob_batch_size: ${policy.train_micro_batch_size} max_total_sequence_length: 1024 precision: "float32" @@ -41,16 +41,14 @@ policy: optimizer: name: "torch.optim.AdamW" kwargs: - lr: 5.0e-6 + lr: 5.0e-7 weight_decay: 0.1 betas: [0.9, 0.98] eps: 1e-5 data: + dataset_name: "HelpSteer3" max_input_seq_length: ${policy.max_total_sequence_length} - train_data_path: "/path/to/train.jsonl" - val_data_path: "/path/to/val.jsonl" - logger: log_dir: "logs" # Base directory for all logs wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running @@ -66,5 +64,5 @@ logger: flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) cluster: - gpus_per_node: 8 + gpus_per_node: 1 num_nodes: 1 diff --git a/examples/run_dpo.py b/examples/run_dpo.py index d7070ec5e8..03a23281a8 100644 --- a/examples/run_dpo.py +++ b/examples/run_dpo.py @@ -141,7 +141,6 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig): dpo_task_spec = data.task_spec tokenizer = get_tokenizer(policy_config["model_name"]) - train_dataset = AllTaskProcessedDataset( train_dataset, tokenizer, diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index bdd592c97f..a790d543ec 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -17,12 +17,11 @@ import numpy as np import torch -from transformers import AutoTokenizer from torchdata.stateful_dataloader import StatefulDataLoader from nemo_reinforcer.algorithms.loss_functions import ( DPOLossFn, ) -from nemo_reinforcer.algorithms.utils import set_seed +from nemo_reinforcer.algorithms.utils import set_seed, get_tokenizer from nemo_reinforcer.data import DataConfig from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, dpo_collate_fn from nemo_reinforcer.data.interfaces import TaskDataSpec @@ -145,7 +144,7 @@ def setup( # Data # ========================== ## TODO: clean up - tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) + tokenizer = get_tokenizer(policy_config["model_name"]) train_dataloader = StatefulDataLoader( train_dataset, batch_size=policy_config["train_global_batch_size"], diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index c8e989cd88..5705a27260 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -150,7 +150,86 @@ def test_dpo_loss(): 0.5 * expected_sft_loss + expected_preference_loss, ) - ## TODO: test with a batch of varying sequence lengths + +def test_dpo_loss_varying_sequence_lengths(): + """Test DPO loss with varying sequence lengths and preference_average_log_probs=True.""" + # Create DPO loss function with preference_average_log_probs=True + dpo_loss_fn_no_avg = DPOLossFn( + { + "reference_policy_kl_penalty": 0.1, + "preference_loss_weight": 1.0, + "sft_loss_weight": 0.5, + "preference_average_log_probs": False, + "sft_average_log_probs": False, + } + ) + dpo_loss_fn_avg = DPOLossFn( + { + "reference_policy_kl_penalty": 0.1, + "preference_loss_weight": 1.0, + "sft_loss_weight": 0.5, + "preference_average_log_probs": True, + "sft_average_log_probs": True, + } + ) + + # Create test data with varying sequence lengths + # Batch size 4 (2 pairs of chosen/rejected) + # Sequence lengths: [3, 5, 4, 6] + batch_size = 4 + max_seq_len = 6 + vocab_size = 10 + + # Create input_ids with varying lengths + input_ids = torch.zeros((batch_size, max_seq_len), dtype=torch.long).to("cuda") + input_ids[0, :3] = torch.arange(3) # length 3 + input_ids[1, :5] = torch.arange(5) # length 5 + input_ids[2, :4] = torch.arange(4) # length 4 + input_ids[3, :6] = torch.arange(6) # length 6 + + # Create token masks based on sequence lengths + token_mask = torch.zeros((batch_size, max_seq_len)).to("cuda") + token_mask[0, :3] = 1.0 + token_mask[1, :5] = 1.0 + token_mask[2, :4] = 1.0 + token_mask[3, :6] = 1.0 + + # Create sample mask (all valid) + sample_mask = torch.ones(batch_size).to("cuda") + + # Create reference policy logprobs + # Make chosen responses have slightly higher logprobs than rejected + reference_policy_logprobs = torch.zeros((batch_size, max_seq_len)).to("cuda") + # Create next token logits + next_token_logits = torch.zeros((batch_size, max_seq_len, vocab_size)).to("cuda") + + # Create batched data dictionary + data = BatchedDataDict( + { + "input_ids": input_ids, + "reference_policy_logprobs": reference_policy_logprobs, + "token_mask": token_mask, + "sample_mask": sample_mask, + } + ) + + # Compute loss + loss, metrics = dpo_loss_fn_no_avg(next_token_logits, data) + loss_avg, metrics_avg = dpo_loss_fn_avg(next_token_logits, data) + + num_unmasked_tokens = token_mask[:, 1:][::2].sum().item() + logprobs = torch.nn.functional.log_softmax(next_token_logits[:, 1:], dim=-1) + token_logprobs = logprobs.gather( + dim=-1, index=input_ids[:, 1:].unsqueeze(-1) + ).squeeze(-1) + expected_per_token_sft_loss = -(token_logprobs[::2] * token_mask[:, 1:][::2]) + ## sum across tokens in an example, average across examples + expected_sft_loss_no_avg = expected_per_token_sft_loss.sum(-1).mean() + ## average across tokens in an example, then average across examples + expected_sft_loss_avg = expected_per_token_sft_loss.sum() / num_unmasked_tokens + + assert torch.isclose(torch.tensor(metrics["sft_loss"]), expected_sft_loss_no_avg) + assert torch.isclose(torch.tensor(metrics_avg["sft_loss"]), expected_sft_loss_avg) def _setup_clipped_pg_test_data(batch_size=1, seq_len=4, vocab_size=8, device="cuda"): From 7d0312701be43442b778dc937b7bb6a264eed8bd Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 18:23:14 -0700 Subject: [PATCH 25/57] update config Signed-off-by: ashors1 --- examples/configs/dpo.yaml | 15 ++++++++++++++- nemo_reinforcer/algorithms/dpo.py | 8 ++++---- nemo_reinforcer/algorithms/loss_functions.py | 20 ++++++++++++++++---- nemo_reinforcer/data/datasets.py | 1 - nemo_reinforcer/models/policy/hf_policy.py | 3 --- tests/functional/dpo.sh | 5 ++--- 6 files changed, 36 insertions(+), 16 deletions(-) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 2afe3debb3..9cfa1f865a 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -32,7 +32,8 @@ policy: tokenizer_name: "meta-llama/Llama-3.2-1B-Instruct" train_global_batch_size: 16 train_micro_batch_size: 2 - logprob_batch_size: ${policy.train_micro_batch_size} + ## TODO(@ashors) support + #logprob_batch_size: ${policy.train_micro_batch_size} max_total_sequence_length: 1024 precision: "float32" fsdp_offload_enabled: false @@ -46,6 +47,18 @@ policy: betas: [0.9, 0.98] eps: 1e-5 + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.01 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [50] + data: dataset_name: "HelpSteer3" max_input_seq_length: ${policy.max_total_sequence_length} diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index a790d543ec..455f9ba117 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -188,10 +188,10 @@ def setup( policy = HfPolicy( cluster=cluster, config=policy_config, - weights_path=Path(last_checkpoint_path) / "policy.pt" + weights_path=Path(last_checkpoint_path) / "policy" / "weights" if last_checkpoint_path else None, - optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt" + optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer" if last_checkpoint_path else None, init_optimizer=True, @@ -226,9 +226,10 @@ def augment_dataloader(dataloader, policy, master_config): ## append ref policy logprobs to batch logprobs = policy.get_reference_policy_logprobs( batch, - ## TODO: make more robust micro_batch_size=master_config["policy"]["train_micro_batch_size"] * 2, )["reference_logprobs"].to("cpu") + ## want logprobs for batch to correspond to the log probabilities of the next tokens + ## so we roll the logprobs to the left by one batch["reference_policy_logprobs"] = torch.roll(logprobs, -1, dims=-1) yield batch @@ -277,7 +278,6 @@ def validate( mbs=val_mbs * 2, ) - ## TODO: this should already be averaged across microbatches.. why isn't it? val_metrics = { "loss": val_results["loss"].numpy(), } diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 7d83b909e3..d52b3e7b38 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -308,9 +308,14 @@ def preference_loss( rewards_chosen, rewards_rejected = self.split_output_tensor(rewards) rewards_delta = rewards_chosen - rewards_rejected - return -torch.nn.functional.logsigmoid( - self.reference_policy_kl_penalty * rewards_delta - ).mean(0), (rewards_chosen > rewards_rejected).float().mean(0) + return ( + -torch.nn.functional.logsigmoid( + self.reference_policy_kl_penalty * rewards_delta + ).mean(0), + (rewards_chosen > rewards_rejected).float().mean(0), + rewards_chosen.mean(), + rewards_rejected.mean(), + ) def __call__( self, next_token_logits: torch.Tensor, data: BatchedDataDict[DPOLossDataDict] @@ -326,7 +331,12 @@ def __call__( sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) sft_loss_chosen = sft_loss_chosen.mean(0) - preference_loss, accuracy = self.preference_loss(next_token_logits, data) + ( + preference_loss, + accuracy, + rewards_chosen_mean, + rewards_rejected_mean, + ) = self.preference_loss(next_token_logits, data) dpo_loss = ( self.sft_loss_weight * sft_loss_chosen @@ -338,4 +348,6 @@ def __call__( "sft_loss": sft_loss_chosen.item(), "preference_loss": preference_loss.item(), "accuracy": accuracy.item(), + "rewards_chosen_mean": rewards_chosen_mean.item(), + "rewards_rejected_mean": rewards_rejected_mean.item(), } diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 282d1b5a64..a0335ad497 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -217,7 +217,6 @@ def dpo_collate_fn(data_batch: List[DatumSpec], tokenizer) -> BatchedDataDict: message_log=message_log, length=length, loss_multiplier=loss_multiplier, - # extra_env_info=extra_env_info, task_name=task_names, idx=idx, batch_max_length=batch_max_length, diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index e4a5382556..6d5fb13101 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -449,7 +449,6 @@ def use_reference_model(self): On exit: Restores original references and re-flips cuda/cpu """ - # yield try: # Save original references original_model = self.model @@ -828,14 +827,12 @@ def offload_after_refit(self): ) def move_to_cpu(self, model): - ## this is the slowest part for param in model.parameters(): param.data = param.data.to("cpu", non_blocking=True, copy=True) for buffer in model.buffers(): buffer.data = buffer.data.to("cpu", non_blocking=True, copy=True) - ## commenting this out improves perf by 3x if hasattr(model, "_fsdp_wrapped_module"): model._fsdp_wrapped_module.to("cpu") return model diff --git a/tests/functional/dpo.sh b/tests/functional/dpo.sh index 4ca9ac944b..7d03cf32d4 100755 --- a/tests/functional/dpo.sh +++ b/tests/functional/dpo.sh @@ -32,7 +32,6 @@ python -u $PROJECT_ROOT/examples/run_dpo.py \ cd $SCRIPT_DIR python json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS -## TODO: add loss check once config is finalized -#python check_metrics.py $JSON_METRICS \ -# 'data["train/loss"]["9"] < 1500' \ +python check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["4"] < 0.694' \ From 18faa44ebd1c2f267f2d948658ddebbbf93826eb Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 21:09:02 -0700 Subject: [PATCH 26/57] rename test Signed-off-by: ashors1 --- tests/unit/data/hf_datasets/{test_dpo.py => test_dpo_dataset.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/unit/data/hf_datasets/{test_dpo.py => test_dpo_dataset.py} (100%) diff --git a/tests/unit/data/hf_datasets/test_dpo.py b/tests/unit/data/hf_datasets/test_dpo_dataset.py similarity index 100% rename from tests/unit/data/hf_datasets/test_dpo.py rename to tests/unit/data/hf_datasets/test_dpo_dataset.py From 7d97f30a5412a259cdda324cb414d019d2df54d6 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 22:25:07 -0700 Subject: [PATCH 27/57] fix test Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/dpo.py | 1 + tests/unit/data/test_datasets.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index 455f9ba117..1248012dba 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from functools import partial from typing import Optional, Tuple, TypedDict from tqdm import tqdm diff --git a/tests/unit/data/test_datasets.py b/tests/unit/data/test_datasets.py index aafbad9405..ee80c5e934 100644 --- a/tests/unit/data/test_datasets.py +++ b/tests/unit/data/test_datasets.py @@ -25,7 +25,7 @@ def test_dpo_collate_fn(): """Test that dpo_collate_fn correctly processes DPO training data.""" # Create mock tokenizer mock_tokenizer = MagicMock() - mock_tokenizer.eos_token_id = 0 + mock_tokenizer.pad_token_id = 0 # Create test data with varying sequence lengths data_batch = [ From 7a1e1901fc90ed71638c7f6ccd17461f2484758d Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 15 Apr 2025 10:08:52 -0700 Subject: [PATCH 28/57] address comments and update config Signed-off-by: ashors1 --- docs/guides/dpo.md | 4 ++++ examples/configs/dpo.yaml | 16 ++++++++-------- examples/run_dpo.py | 4 ++++ nemo_reinforcer/algorithms/dpo.py | 12 ++++-------- nemo_reinforcer/data/hf_datasets/helpsteer3.py | 2 +- tests/functional/dpo.sh | 2 +- 6 files changed, 22 insertions(+), 18 deletions(-) diff --git a/docs/guides/dpo.md b/docs/guides/dpo.md index b1fac0b5aa..1cc9dc238d 100644 --- a/docs/guides/dpo.md +++ b/docs/guides/dpo.md @@ -1,5 +1,9 @@ # Direct Preference Optimization in Reinforcer +[Direct Preference Optimization (DPO)](https://arxiv.org/pdf/2305.18290) is an RL-free alignment algorithm that operates on preference data. Given a prompt and a pair of chosen and rejected responses, DPO aims +to increase the probability of the chosen response and decrease the probability of the rejected response relative to a frozen reference model. The actor is initialized using the reference model. For more details, refer to the +[DPO paper](https://arxiv.org/pdf/2305.18290). + ## Launch a DPO Run The script [examples/run_dpo.py](../../examples/run_dpo.py) can be used to launch a DPO experiment. This script can either be launched locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the [cluster documentation](../cluster.md). diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 9cfa1f865a..b9d3b68587 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -1,15 +1,15 @@ # DPO Algorithm Configuration dpo: max_num_epochs: 1 - max_num_steps: 500 - val_period: 10 + max_num_steps: 250 + val_period: 25 val_batches: 8 val_global_batch_size: 8 val_micro_batch_size: 1 val_at_start: true seed: 42 - reference_policy_kl_penalty: 0.1 + reference_policy_kl_penalty: 0.05 preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss @@ -30,7 +30,7 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" tokenizer_name: "meta-llama/Llama-3.2-1B-Instruct" - train_global_batch_size: 16 + train_global_batch_size: 128 train_micro_batch_size: 2 ## TODO(@ashors) support #logprob_batch_size: ${policy.train_micro_batch_size} @@ -42,7 +42,7 @@ policy: optimizer: name: "torch.optim.AdamW" kwargs: - lr: 5.0e-7 + lr: 5.0e-6 weight_decay: 0.1 betas: [0.9, 0.98] eps: 1e-5 @@ -50,14 +50,14 @@ policy: scheduler: - name: "torch.optim.lr_scheduler.LinearLR" kwargs: - start_factor: 0.01 + start_factor: 0.1 end_factor: 1.0 - total_iters: 50 + total_iters: 20 - name: "torch.optim.lr_scheduler.ConstantLR" kwargs: factor: 1.0 total_iters: 10000000000 - - milestones: [50] + - milestones: [20] data: dataset_name: "HelpSteer3" diff --git a/examples/run_dpo.py b/examples/run_dpo.py index 03a23281a8..1a71deb466 100644 --- a/examples/run_dpo.py +++ b/examples/run_dpo.py @@ -15,6 +15,7 @@ import argparse import os import pprint +import warnings from typing import Dict, Any from omegaconf import OmegaConf @@ -102,6 +103,9 @@ def dpo_preprocessor( loss_multiplier = 1.0 if max(length_chosen, length_rejected) > max_seq_length: + warnings.warn( + f"Sequence length {max(length_chosen, length_rejected)} exceeds max_seq_length {max_seq_length}. Ignoring example." + ) # make smaller and mask out for message in message_log_chosen: message["token_ids"] = message["token_ids"][ diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index 1248012dba..f9206782ee 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -67,6 +67,7 @@ class DPOConfig(TypedDict): preference_average_log_probs: bool sft_average_log_probs: bool ## TODO(@ashors) support other loss functions + ## https://github.com/NVIDIA/reinforcer/issues/193 # preference_loss: str # gt_reward_scale: float preference_loss_weight: float @@ -87,7 +88,7 @@ class MasterConfig(TypedDict): # ======================================================= def setup( master_config: MasterConfig, - train_dataset: AllTaskProcessedDataset, ## TODO: figure out dataset stuff for DPO + train_dataset: AllTaskProcessedDataset, val_dataset: AllTaskProcessedDataset, ) -> Tuple[ HfPolicy, @@ -249,7 +250,6 @@ def validate( loss_fn, step: int, master_config: MasterConfig, - dpo_task_spec: TaskDataSpec, val_batches: int, val_batch_size: int, val_mbs: int, @@ -264,9 +264,6 @@ def validate( with timer.time("total_validation_time"): print(f"ā–¶ Starting validation at step {step}...") - # Show a progress indicator for validation - # val_total = len(val_dataloader) - for batch_idx, val_batch in enumerate( augment_dataloader(val_dataloader, policy, master_config) ): @@ -318,7 +315,6 @@ def dpo_train( loss_fn, master_config, logger, - dpo_task_spec, ## TODO: do we need? checkpointer, dpo_save_state, ): @@ -351,7 +347,6 @@ def dpo_train( loss_fn, step=0, master_config=master_config, - dpo_task_spec=dpo_task_spec, val_batches=dpo_config["val_batches"], val_batch_size=dpo_config["val_global_batch_size"], val_mbs=dpo_config["val_micro_batch_size"], @@ -376,6 +371,8 @@ def dpo_train( batch, loss_fn, eval_mode=False, + ## NOTE: we double the batch size here because each preference example corresponds to a pair of + ## examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch. gbs=master_config["policy"]["train_global_batch_size"] * 2, mbs=master_config["policy"]["train_micro_batch_size"] * 2, ) @@ -389,7 +386,6 @@ def dpo_train( loss_fn, step=total_steps + 1, master_config=master_config, - dpo_task_spec=dpo_task_spec, val_batches=dpo_config["val_batches"], val_batch_size=dpo_config["val_global_batch_size"], val_mbs=dpo_config["val_micro_batch_size"], diff --git a/nemo_reinforcer/data/hf_datasets/helpsteer3.py b/nemo_reinforcer/data/hf_datasets/helpsteer3.py index 7c877e2301..141132bbb9 100644 --- a/nemo_reinforcer/data/hf_datasets/helpsteer3.py +++ b/nemo_reinforcer/data/hf_datasets/helpsteer3.py @@ -22,7 +22,7 @@ def format_helpsteer3(data): return { "prompt": data["context"], - "chosen_response": response_1 if overall_preference < 0 else response_2, + "chosen_response": response_1 if overall_preference <= 0 else response_2, "rejected_response": response_2 if overall_preference < 0 else response_1, } diff --git a/tests/functional/dpo.sh b/tests/functional/dpo.sh index 7d03cf32d4..3cb78516ea 100755 --- a/tests/functional/dpo.sh +++ b/tests/functional/dpo.sh @@ -20,7 +20,7 @@ mkdir -p $LOG_DIR cd $PROJECT_ROOT python -u $PROJECT_ROOT/examples/run_dpo.py \ cluster.gpus_per_node=2 \ - dpo.max_num_steps=5 \ + dpo.max_num_steps=3 \ dpo.val_batches=1 \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ From 0f70438a54f760e4fcdcc378cc2e72193d6ee681 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 15 Apr 2025 10:19:20 -0700 Subject: [PATCH 29/57] add one more unit test Signed-off-by: ashors1 --- tests/unit/data/hf_datasets/test_helpsteer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/unit/data/hf_datasets/test_helpsteer.py b/tests/unit/data/hf_datasets/test_helpsteer.py index b139671dad..2b3a571881 100644 --- a/tests/unit/data/hf_datasets/test_helpsteer.py +++ b/tests/unit/data/hf_datasets/test_helpsteer.py @@ -45,6 +45,20 @@ def test_format_helpsteer3(): assert result2["chosen_response"] == "The capital of France is Paris." assert result2["rejected_response"] == "The capital of France is London." + # Test case 3: no preference (overall_preference = 0) + data3 = { + "context": "What is the weather like?", + "response1": "It's sunny today.", + "response2": "The weather is sunny.", + "overall_preference": 0, + } + result3 = format_helpsteer3(data3) + assert result3["prompt"] == "What is the weather like?" + # When preference is 0, neither response is preferred, so + # response 1 is used for both chosen and rejected + assert result3["chosen_response"] == "It's sunny today." + assert result3["rejected_response"] == "It's sunny today." + def test_helpsteer3_dataset_initialization(): """Test that HelpSteer3Dataset initializes correctly.""" From 21fd4383a5014da8e2c14acc2d8ab63fbf853b81 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 15 Apr 2025 10:23:13 -0700 Subject: [PATCH 30/57] minor readme update Signed-off-by: ashors1 --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index f9ba6e1db4..202bf82fc5 100644 --- a/README.md +++ b/README.md @@ -83,8 +83,7 @@ If you have access to more GPUs, you can update the experiment accordingly. To r ```sh uv run python examples/run_sft.py \ policy.model_name="meta-llama/Meta-Llama-3-8B" \ - policy.train_global_batch_size=128 \ - sft.val_global_batch_size=128 \ + policy.train_global_batch_size=256 \ cluster.gpus_per_node=8 ``` From 398b17b0b0adfb22cdaa6906278c88447186912d Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 15 Apr 2025 10:28:34 -0700 Subject: [PATCH 31/57] add note on gbs to config Signed-off-by: ashors1 --- examples/configs/dpo.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index b9d3b68587..d6b07c9f6b 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -30,8 +30,13 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" tokenizer_name: "meta-llama/Llama-3.2-1B-Instruct" + + # number of preference samples per batch + # each preference sample corresponds to a pair of chosen and rejected responses + # so the actual batch size processed by the model is train_global_batch_size * 2 train_global_batch_size: 128 train_micro_batch_size: 2 + ## TODO(@ashors) support #logprob_batch_size: ${policy.train_micro_batch_size} max_total_sequence_length: 1024 From 3f5276a0598ee9b6db45ee59b12072ce42542ea1 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 15 Apr 2025 10:51:57 -0700 Subject: [PATCH 32/57] fix functional test Signed-off-by: ashors1 --- tests/functional/dpo.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/dpo.sh b/tests/functional/dpo.sh index 3cb78516ea..1431f17e61 100755 --- a/tests/functional/dpo.sh +++ b/tests/functional/dpo.sh @@ -33,5 +33,5 @@ cd $SCRIPT_DIR python json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS python check_metrics.py $JSON_METRICS \ - 'data["train/loss"]["4"] < 0.694' \ + 'data["train/loss"]["2"] < 0.694' \ From 6e870a196922016f3ae8eaf4e2bb3e28132b776d Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 15 Apr 2025 12:32:29 -0700 Subject: [PATCH 33/57] small fixes Signed-off-by: ashors1 --- examples/run_dpo.py | 1 - nemo_reinforcer/algorithms/dpo.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_dpo.py b/examples/run_dpo.py index 1a71deb466..8a1312d1b6 100644 --- a/examples/run_dpo.py +++ b/examples/run_dpo.py @@ -217,7 +217,6 @@ def main(): loss_fn, master_config, logger, - dpo_task_spec, checkpointer, dpo_save_state, ) diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index f9206782ee..573e55aac3 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -13,6 +13,7 @@ # limitations under the License. import os from functools import partial +from pathlib import Path from typing import Optional, Tuple, TypedDict from tqdm import tqdm From 38e52642b902a3ecbd638f0d192da5e913a238c5 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 08:40:33 -0700 Subject: [PATCH 34/57] decrease num steps Signed-off-by: ashors1 --- examples/configs/dpo.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index d6b07c9f6b..cf71d2c779 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -1,7 +1,7 @@ # DPO Algorithm Configuration dpo: max_num_epochs: 1 - max_num_steps: 250 + max_num_steps: 200 val_period: 25 val_batches: 8 val_global_batch_size: 8 From e7deda81a0c5dd924e6f3e7164f7e6bd58dcff23 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 14:21:44 -0700 Subject: [PATCH 35/57] fix DPO validation and correctly handle samples that are longer than max_seq_length Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/dpo.py | 48 ++++++++++++++------ nemo_reinforcer/algorithms/loss_functions.py | 33 ++++++++++---- nemo_reinforcer/algorithms/sft.py | 2 +- nemo_reinforcer/models/policy/hf_policy.py | 12 ++--- 4 files changed, 66 insertions(+), 29 deletions(-) diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index 573e55aac3..41ec26cd6d 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings +from collections import defaultdict from functools import partial from pathlib import Path from typing import Optional, Tuple, TypedDict @@ -265,6 +267,8 @@ def validate( with timer.time("total_validation_time"): print(f"ā–¶ Starting validation at step {step}...") + val_metrics = defaultdict(lambda: 0.0) + num_valid_batches = 0 for batch_idx, val_batch in enumerate( augment_dataloader(val_dataloader, policy, master_config) ): @@ -277,14 +281,23 @@ def validate( mbs=val_mbs * 2, ) - val_metrics = { - "loss": val_results["loss"].numpy(), - } - val_metrics.update(val_results["all_mb_metrics"]) - val_metrics = {k: np.mean(v).item() for k, v in val_metrics.items()} - if val_batches > 0 and batch_idx >= val_batches: + if len(val_results["all_mb_metrics"]) == 0: + warnings.warn( + "No validation metrics were collected for this batch." + " This is likely because there were no valid samples." + ) + + else: + for k, v in val_results["all_mb_metrics"].items(): + val_metrics[k] += np.mean(v).item() + num_valid_batches += 1 + + if val_batches > 0 and batch_idx >= val_batches - 1: break + for k, v in val_metrics.items(): + val_metrics[k] /= num_valid_batches + # Calculate validation metrics policy.prepare_for_training() @@ -292,15 +305,22 @@ def validate( timing_metrics = timer.get_timing_metrics(reduction_op="sum") validation_time = timing_metrics.get("total_validation_time", 0) - # Print summary of validation results - print("\nšŸ“Š Validation Results:") - print(f" • Validation loss: {float(val_metrics['loss']):.4f}") - print(f" • Validation accuracy: {float(val_metrics['accuracy']):.4f}") + if len(val_metrics) == 0: + warnings.warn( + "No validation metrics were collected." + " This is likely because there were no valid samples in the validation set." + ) - # Print timing information - print("\n ā±ļø Validation Timing:") - validation_time = timing_metrics.get("total_validation_time", 0) - print(f" • Total validation time: {validation_time:.2f}s") + else: + # Print summary of validation results + print("\nšŸ“Š Validation Results:") + print(f" • Validation loss: {float(val_metrics['loss']):.4f}") + print(f" • Validation accuracy: {float(val_metrics['accuracy']):.4f}") + + # Print timing information + print("\n ā±ļø Validation Timing:") + validation_time = timing_metrics.get("total_validation_time", 0) + print(f" • Total validation time: {validation_time:.2f}s") # Make sure to reset the timer after validation timer.reset() diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index d52b3e7b38..c8abbd291c 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -143,6 +143,7 @@ def __call__( "probs_ratio_clamped": probs_ratio_clamped, "kl_penalty": kl.item() / self.reference_policy_kl_penalty if kl else 0, "token_mult_prob_error": mult_prob_error, + "num_valid_samples": sample_mask.sum().item(), }, ) @@ -175,8 +176,9 @@ def __call__( if dpo_loss: ## shape: [batch_size] num_unmasked_tokens = torch.sum(mask, -1) - loss = -torch.sum(token_logprobs * mask, dim=-1) + loss = -torch.sum(token_logprobs * mask, dim=-1) * sample_mask if dpo_average_log_probs: + ## multiple by sample_mask to zero out invalid samples loss = loss / num_unmasked_tokens.clamp(min=1) else: ## single scalar loss @@ -193,6 +195,7 @@ def __call__( "loss": loss.item() if loss.ndim == 0 else loss, "num_unmasked_tokens": num_unmasked_tokens, "total_tokens": mask.numel(), + "num_valid_samples": sample_mask.sum().item(), } @@ -213,6 +216,10 @@ class DPOLossDataDict(TypedDict): sample_mask: torch.Tensor +def average_valid_samples(tensor: torch.Tensor, sample_mask: torch.Tensor): + return tensor.sum(-1) / sample_mask.sum(-1).clamp(min=1) + + class DPOLossFn(LossFunction): """Direct Preference Optimization (DPO) loss function. @@ -286,7 +293,6 @@ def preference_loss( ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] - mask = token_mask * sample_mask.unsqueeze(-1) next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) @@ -303,18 +309,23 @@ def preference_loss( rewards = diff.sum(-1) if self.preference_average_log_probs: - rewards = rewards / mask.sum(-1).clamp(min=1) + rewards = rewards / token_mask.sum(-1).clamp(min=1) rewards_chosen, rewards_rejected = self.split_output_tensor(rewards) rewards_delta = rewards_chosen - rewards_rejected - return ( + per_sample_loss = ( -torch.nn.functional.logsigmoid( self.reference_policy_kl_penalty * rewards_delta - ).mean(0), + ) + * sample_mask[::2] + ) ## zero out invalid samples + + return ( + per_sample_loss.mean(0), (rewards_chosen > rewards_rejected).float().mean(0), - rewards_chosen.mean(), - rewards_rejected.mean(), + average_valid_samples(rewards_chosen, sample_mask[::2]), + average_valid_samples(rewards_rejected, sample_mask[1::2]), ) def __call__( @@ -329,7 +340,9 @@ def __call__( dpo_average_log_probs=self.sft_average_log_probs, ) sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) - sft_loss_chosen = sft_loss_chosen.mean(0) + sft_loss_chosen = average_valid_samples( + sft_loss_chosen, data["sample_mask"] + ) ( preference_loss, @@ -343,6 +356,9 @@ def __call__( + self.preference_loss_weight * preference_loss ) + ## divide by 2 because we're summing over (chosen, rejected) pairs + num_valid_samples = data["sample_mask"].sum() / 2 + return dpo_loss, { "loss": dpo_loss.item(), "sft_loss": sft_loss_chosen.item(), @@ -350,4 +366,5 @@ def __call__( "accuracy": accuracy.item(), "rewards_chosen_mean": rewards_chosen_mean.item(), "rewards_rejected_mean": rewards_rejected_mean.item(), + "num_valid_samples": num_valid_samples.item(), } diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index b5bb41aec5..57918a540e 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -265,7 +265,7 @@ def validate( ) val_metrics["val_loss"] += float(val_results["loss"]) - if val_batches > 0 and batch_idx >= val_batches: + if val_batches > 0 and batch_idx >= val_batches - 1: break val_metrics["val_loss"] /= val_batches diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 6d5fb13101..a1bfc7d00b 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -316,15 +316,15 @@ def train( logits = outputs.logits loss, loss_metrics = loss_fn(logits, mb) + num_valid_samples = loss_metrics["num_valid_samples"] loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] # Backward pass - - ## TODO(@ashors): improve this - if not eval_mode: - loss.backward() - mb_losses.append(loss.item()) - all_mb_metrics.append(loss_metrics) + if num_valid_samples > 0: + if not eval_mode: + loss.backward() + mb_losses.append(loss.item()) + all_mb_metrics.append(loss_metrics) # Clip gradients if not eval_mode: From 7fbf8476afdb99f97a1484557dffc6c2702354a9 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 14:24:09 -0700 Subject: [PATCH 36/57] fix comment Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/loss_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index c8abbd291c..afc50e5eee 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -176,9 +176,9 @@ def __call__( if dpo_loss: ## shape: [batch_size] num_unmasked_tokens = torch.sum(mask, -1) + ## multiply by sample_mask to zero out invalid samples loss = -torch.sum(token_logprobs * mask, dim=-1) * sample_mask if dpo_average_log_probs: - ## multiple by sample_mask to zero out invalid samples loss = loss / num_unmasked_tokens.clamp(min=1) else: ## single scalar loss From fb76ff42da199a47ade739c9030b753601785e94 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 14:25:25 -0700 Subject: [PATCH 37/57] fix reduction over valid samples Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/loss_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index afc50e5eee..1a006b5bd9 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -217,7 +217,7 @@ class DPOLossDataDict(TypedDict): def average_valid_samples(tensor: torch.Tensor, sample_mask: torch.Tensor): - return tensor.sum(-1) / sample_mask.sum(-1).clamp(min=1) + return tensor.sum() / sample_mask.sum().clamp(min=1) class DPOLossFn(LossFunction): From 7472f6b19c4c0bf2a9e8f29ab7a03ecd93bca0e0 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 17:28:34 -0700 Subject: [PATCH 38/57] small bug fixes Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/loss_functions.py | 8 ++++---- nemo_reinforcer/models/policy/hf_policy.py | 9 ++++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 1a006b5bd9..87bb3851c0 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -177,7 +177,7 @@ def __call__( ## shape: [batch_size] num_unmasked_tokens = torch.sum(mask, -1) ## multiply by sample_mask to zero out invalid samples - loss = -torch.sum(token_logprobs * mask, dim=-1) * sample_mask + loss = -torch.sum(token_logprobs * mask, dim=-1) if dpo_average_log_probs: loss = loss / num_unmasked_tokens.clamp(min=1) else: @@ -322,7 +322,7 @@ def preference_loss( ) ## zero out invalid samples return ( - per_sample_loss.mean(0), + average_valid_samples(per_sample_loss, sample_mask[::2]), (rewards_chosen > rewards_rejected).float().mean(0), average_valid_samples(rewards_chosen, sample_mask[::2]), average_valid_samples(rewards_rejected, sample_mask[1::2]), @@ -341,7 +341,7 @@ def __call__( ) sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) sft_loss_chosen = average_valid_samples( - sft_loss_chosen, data["sample_mask"] + sft_loss_chosen, data["sample_mask"][::2] ) ( @@ -366,5 +366,5 @@ def __call__( "accuracy": accuracy.item(), "rewards_chosen_mean": rewards_chosen_mean.item(), "rewards_rejected_mean": rewards_rejected_mean.item(), - "num_valid_samples": num_valid_samples.item(), + "num_valid_samples_per_mb": num_valid_samples.item(), } diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index a1bfc7d00b..d2fd312ddb 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -316,13 +316,16 @@ def train( logits = outputs.logits loss, loss_metrics = loss_fn(logits, mb) - num_valid_samples = loss_metrics["num_valid_samples"] + num_valid_samples = loss_metrics["num_valid_samples_per_mb"] loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] # Backward pass + if not eval_mode: + ## NOTE: invalid samples should be multiplied + ## by zero in the loss function to prevent them + ## from affecting the gradient + loss.backward() if num_valid_samples > 0: - if not eval_mode: - loss.backward() mb_losses.append(loss.item()) all_mb_metrics.append(loss_metrics) From ace13fa399e86c2e8ee1437aad0ebf99ae921b54 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 17:37:23 -0700 Subject: [PATCH 39/57] log sum of valid samples rather than average Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/dpo.py | 8 +++++++- nemo_reinforcer/algorithms/loss_functions.py | 2 +- nemo_reinforcer/models/policy/hf_policy.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index 41ec26cd6d..eb546dedc2 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -296,6 +296,8 @@ def validate( break for k, v in val_metrics.items(): + if k == "num_valid_samples": + continue val_metrics[k] /= num_valid_batches # Calculate validation metrics @@ -465,7 +467,11 @@ def dpo_train( "loss": train_results["loss"].numpy(), } metrics.update(train_results["all_mb_metrics"]) - metrics = {k: np.mean(v).item() for k, v in metrics.items()} + for k, v in metrics.items(): + if k == "num_valid_samples": + metrics[k] = np.sum(v).item() + else: + metrics[k] = np.mean(v).item() timing_metrics = timer.get_timing_metrics(reduction_op="sum") print("\nšŸ“Š Training Results:") diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 87bb3851c0..a015850d66 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -366,5 +366,5 @@ def __call__( "accuracy": accuracy.item(), "rewards_chosen_mean": rewards_chosen_mean.item(), "rewards_rejected_mean": rewards_rejected_mean.item(), - "num_valid_samples_per_mb": num_valid_samples.item(), + "num_valid_samples": num_valid_samples.item(), } diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index d2fd312ddb..9ef9f38bf4 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -316,7 +316,7 @@ def train( logits = outputs.logits loss, loss_metrics = loss_fn(logits, mb) - num_valid_samples = loss_metrics["num_valid_samples_per_mb"] + num_valid_samples = loss_metrics["num_valid_samples"] loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] # Backward pass From b5c66ba3a036fe3c346c179455de278a9dd6d823 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 18:52:00 -0700 Subject: [PATCH 40/57] address some comments Signed-off-by: ashors1 --- README.md | 18 +++++------- docs/guides/dpo.md | 38 ++++++++++++++++++++---- docs/helpers.py | 27 +++++++++++++++++ docs/index.md | 1 + examples/run_dpo.py | 48 ++++++++++++++++++++++++++++++- nemo_reinforcer/algorithms/dpo.py | 15 ++++++---- 6 files changed, 124 insertions(+), 23 deletions(-) create mode 100755 docs/helpers.py diff --git a/README.md b/README.md index 202bf82fc5..9a4667058f 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,8 @@ If you have access to more GPUs, you can update the experiment accordingly. To r ```sh uv run python examples/run_sft.py \ policy.model_name="meta-llama/Meta-Llama-3-8B" \ - policy.train_global_batch_size=256 \ + policy.train_global_batch_size=128 \ + sft.val_global_batch_size=128 \ cluster.gpus_per_node=8 ``` @@ -135,13 +136,12 @@ uv run python examples/run_dpo.py This trains `Llama3.2-1B-Instruct` on one GPU. -If you have access to more GPUs, you can update the experiment accordingly. To run on 8 GPUs, we update the cluster configuration and switch to an 8B Llama base model: +If you have access to more GPUs, you can update the experiment accordingly. To run on 8 GPUs, we update the cluster configuration and switch to an 8B Llama3.2 Instruct model: ```sh uv run python examples/run_dpo.py \ policy.model_name="meta-llama/Meta-Llama-3-8B-Instruct" \ - policy.train_global_batch_size=128 \ - dpo.val_global_batch_size=128 \ + policy.train_global_batch_size=256 \ cluster.gpus_per_node=8 ``` @@ -160,21 +160,17 @@ Refer to [dpo.yaml](examples/configs/dpo.yaml) for a full list of parameters tha #### Multi-node -For distributed DPO training across multiple nodes: - -Set `UV_CACHE_DIR` to a directory that can be read from all workers before running any uv run command. -```sh -export UV_CACHE_DIR=/path/that/all/workers/can/access/uv_cache -``` +For distributed DPO training across multiple nodes, modify the following script for your use case: ```sh # Run from the root of NeMo-Reinforcer repo +## number of nodes to use for your job NUM_ACTOR_NODES=2 # Add a timestamp to make each job name unique TIMESTAMP=$(date +%Y%m%d_%H%M%S) # DPO experiment uses Llama-3.1-8B model -COMMAND="uv pip install -e .; uv run ./examples/run_dpo.py --config examples/configs/dpo.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 checkpointing.checkpoint_dir='results/dpo_llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='dpo-llama8b'" \ +COMMAND="uv run ./examples/run_dpo.py --config examples/configs/dpo.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 checkpointing.checkpoint_dir='results/dpo_llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='dpo-llama8b'" \ RAY_DEDUP_LOGS=0 \ UV_CACHE_DIR=YOUR_UV_CACHE_DIR \ CONTAINER=YOUR_CONTAINER \ diff --git a/docs/guides/dpo.md b/docs/guides/dpo.md index 1cc9dc238d..3001daf569 100644 --- a/docs/guides/dpo.md +++ b/docs/guides/dpo.md @@ -66,10 +66,11 @@ Adding a new DPO dataset is straightforward. Your custom dataset class should: Here's a minimal example which simply re-keys an existing jsonl dataset: -```python +```{testcode} from datasets import load_dataset from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset from nemo_reinforcer.data.interfaces import TaskDataSpec +from docs.helpers import make_dpo_dataset class CustomDPODataset: def preprocess_dataset( @@ -79,13 +80,11 @@ class CustomDPODataset: chosen_key: str = "chosen", rejected_key: str = "rejected" ): - return { "prompt": data[prompt_key], "chosen_response": data[chosen_key], "rejected_response": data[rejected_key], } - def __init__( self, @@ -95,7 +94,6 @@ class CustomDPODataset: chosen_key: str, rejected_key: str, ): - # Load and format your dataset fn_kwargs={ "prompt_key": prompt_key, @@ -115,8 +113,38 @@ class CustomDPODataset: # Initialize task spec with dataset name self.task_spec = TaskDataSpec( - dataset_name="custom_dpo", + task_name="custom_dpo", ) + self.formatted_ds = formatted_ds + +# Create temporary files using helper function +train_file, val_file = make_dpo_dataset() + +# Initialize dataset +dataset = CustomDPODataset( + train_data_path=train_file.name, + val_data_path=val_file.name, + prompt_key="context", + chosen_key="chosen", + rejected_key="rejected" +) + +# Test dataset properties +print(f"Task name: {dataset.task_spec.task_name}") +print(f"Train examples: {len(dataset.formatted_ds['train'])}") +print(f"Validation examples: {len(dataset.formatted_ds['validation'])}") +print(f"First train example prompt: {dataset.formatted_ds['train'][0]['prompt']}") +print(f"First train example chosen response: {dataset.formatted_ds['train'][0]['chosen_response']}") +print(f"First train example rejected response: {dataset.formatted_ds['train'][0]['rejected_response']}") +``` + +```{testoutput} +Task name: custom_dpo +Train examples: 2 +Validation examples: 2 +First train example prompt: What is 2+2? +First train example chosen response: 4 +First train example rejected response: 5 ``` ## DPO-Specific Parameters diff --git a/docs/helpers.py b/docs/helpers.py new file mode 100755 index 0000000000..9bd28bc96a --- /dev/null +++ b/docs/helpers.py @@ -0,0 +1,27 @@ +import tempfile +import json + + +def make_dpo_dataset(): + train_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) + val_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) + + # Write train data + train_data = [ + {"context": "What is 2+2?", "chosen": "4", "rejected": "5"}, + {"context": "What is 3*3?", "chosen": "9", "rejected": "6"}, + ] + for item in train_data: + lines = train_file.write(json.dumps(item) + "\n") + train_file.flush() + + # Write validation data + val_data = [ + {"context": "What is 4+4?", "chosen": "8", "rejected": "7"}, + {"context": "What is 5*5?", "chosen": "25", "rejected": "20"}, + ] + for item in val_data: + lines = val_file.write(json.dumps(item) + "\n") + val_file.flush() + + return train_file, val_file diff --git a/docs/index.md b/docs/index.md index 553778ff98..ed951e9648 100644 --- a/docs/index.md +++ b/docs/index.md @@ -17,6 +17,7 @@ cluster.md adding_new_models.md guides/sft.md +guides/dpo.md guides/grpo.md guides/eval.md ``` diff --git a/examples/run_dpo.py b/examples/run_dpo.py index 8a1312d1b6..0e98b0e1bd 100644 --- a/examples/run_dpo.py +++ b/examples/run_dpo.py @@ -59,7 +59,53 @@ def dpo_preprocessor( max_seq_length: int, idx: int, ) -> DatumSpec: - """Process a datum dictionary for DPO training.""" + """Process a datum dictionary for DPO training. + + Examples: + ```{doctest} + >>> from transformers import AutoTokenizer + >>> from nemo_reinforcer.data.interfaces import TaskDataSpec + >>> + >>> # Initialize tokenizer and task spec + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") + >>> ## set a passthrough chat template for simplicity + >>> tokenizer.chat_template = "{% for message in messages %}{{ message['content'] }}{% endfor %}" + >>> task_spec = TaskDataSpec(task_name="test_dpo") + >>> + >>> datum = { + ... "prompt": "What is 2+2?", + ... "chosen_response": "4", + ... "rejected_response": "5" + ... } + >>> + >>> processed = dpo_preprocessor(datum, task_spec, tokenizer, max_seq_length=128, idx=0) + >>> len(processed["message_log_chosen"]) + 2 + >>> processed["message_log_chosen"][0]["content"] + '<|begin_of_text|>What is 2+2?' + >>> processed["message_log_chosen"][-1]["content"] + '4<|eot_id|>' + >>> processed["message_log_rejected"][-1]["content"] + '5<|eot_id|>' + >>> + >>> # prompt can also be a list with multiple messages + >>> datum = { + ... "prompt": [{"role": "user", "content": "I have a question."}, {"role": "assistant", "content": "Sure!"}, {"role": "user", "content": "What is 2+2?"}], + ... "chosen_response": "4", + ... "rejected_response": "5" + ... } + >>> processed = dpo_preprocessor(datum, task_spec, tokenizer, max_seq_length=128, idx=0) + >>> len(processed["message_log_chosen"]) + 4 + >>> processed["message_log_chosen"][1]["content"] + 'Sure!' + >>> processed["message_log_chosen"][-1]["content"] + '4<|eot_id|>' + >>> processed["message_log_rejected"][-1]["content"] + '5<|eot_id|>' + + ``` + """ if isinstance(datum_dict["prompt"], list): messages_chosen = datum_dict["prompt"].copy() messages_rejected = datum_dict["prompt"].copy() diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index eb546dedc2..87ad595ea7 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -148,7 +148,7 @@ def setup( # ========================== # Data # ========================== - ## TODO: clean up + ## TODO(@ashors) reduce boilerplate and move reused code into utils tokenizer = get_tokenizer(policy_config["model_name"]) train_dataloader = StatefulDataLoader( train_dataset, @@ -222,7 +222,7 @@ def setup( ) -def augment_dataloader(dataloader, policy, master_config): +def add_ref_logprobs_to_data(dataloader, policy, master_config): dataloader_iter = iter(dataloader) while True: try: @@ -232,7 +232,7 @@ def augment_dataloader(dataloader, policy, master_config): logprobs = policy.get_reference_policy_logprobs( batch, micro_batch_size=master_config["policy"]["train_micro_batch_size"] * 2, - )["reference_logprobs"].to("cpu") + )["reference_logprobs"] ## want logprobs for batch to correspond to the log probabilities of the next tokens ## so we roll the logprobs to the left by one batch["reference_policy_logprobs"] = torch.roll(logprobs, -1, dims=-1) @@ -270,7 +270,7 @@ def validate( val_metrics = defaultdict(lambda: 0.0) num_valid_batches = 0 for batch_idx, val_batch in enumerate( - augment_dataloader(val_dataloader, policy, master_config) + add_ref_logprobs_to_data(val_dataloader, policy, master_config) ): ## just run model fwd val_results = policy.train( @@ -380,10 +380,13 @@ def dpo_train( policy.prepare_for_training() - while current_epoch < max_num_epochs: + while ( + current_epoch < max_num_epochs + and total_steps < master_config["dpo"]["max_num_steps"] + ): print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") - for batch in augment_dataloader(train_dataloader, policy, master_config): + for batch in add_ref_logprobs_to_data(train_dataloader, policy, master_config): print( f"\n{'=' * 25} Step {current_step + 1}/{min(len(train_dataloader), master_config['dpo']['max_num_steps'])} {'=' * 25}" ) From 6124416b8a4bcd1da8aa082c8c3c05cdded85718 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 19:00:04 -0700 Subject: [PATCH 41/57] small gbs and mbs fix, add copyright Signed-off-by: ashors1 --- docs/helpers.py | 14 ++++++++++++++ nemo_reinforcer/models/policy/hf_policy.py | 10 +++++----- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/docs/helpers.py b/docs/helpers.py index 9bd28bc96a..805d5877d1 100755 --- a/docs/helpers.py +++ b/docs/helpers.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import tempfile import json diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 83ee8371c9..20578ae6cd 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -1036,11 +1036,11 @@ def train( mbs: Optional[int] = None, ): """Train the policy on a batch of data with a given loss function.""" + batch_size = gbs or self.cfg["train_global_batch_size"] + micro_batch_size = mbs or self.cfg["train_micro_batch_size"] # Shard and replicate the batch shards = self.dp_size - sharded_data = data.shard_by_batch_size( - shards, batch_size=gbs or self.cfg["train_global_batch_size"] - ) + sharded_data = data.shard_by_batch_size(shards, batch_size=batch_size) # Train each shard in parallel futures = self.worker_group.run_all_workers_multiple_data( @@ -1049,8 +1049,8 @@ def train( common_kwargs={ "loss_fn": loss_fn, "eval_mode": eval_mode, - "gbs": gbs, - "mbs": mbs, + "batch_size": batch_size, + "micro_batch_size": micro_batch_size, }, ) results = self.worker_group.get_all_worker_results(futures) From 6962df1777f490724f4b1d7079e5809d3f2da01d Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 20:35:43 -0700 Subject: [PATCH 42/57] small fixes following rebase Signed-off-by: ashors1 --- examples/configs/dpo.yaml | 3 ++- examples/run_dpo.py | 2 +- nemo_reinforcer/algorithms/dpo.py | 3 ++- nemo_reinforcer/models/policy/hf_policy.py | 5 +++-- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index cf71d2c779..b6ac194942 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -29,7 +29,8 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" - tokenizer_name: "meta-llama/Llama-3.2-1B-Instruct" + tokenizer: + name: "meta-llama/Llama-3.2-1B-Instruct" # number of preference samples per batch # each preference sample corresponds to a pair of chosen and rejected responses diff --git a/examples/run_dpo.py b/examples/run_dpo.py index 0e98b0e1bd..566a7dfca3 100644 --- a/examples/run_dpo.py +++ b/examples/run_dpo.py @@ -190,7 +190,7 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig): dpo_task_spec = data.task_spec - tokenizer = get_tokenizer(policy_config["model_name"]) + tokenizer = get_tokenizer(policy_config["tokenizer"]) train_dataset = AllTaskProcessedDataset( train_dataset, tokenizer, diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index 87ad595ea7..c2f82a2756 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -149,7 +149,7 @@ def setup( # Data # ========================== ## TODO(@ashors) reduce boilerplate and move reused code into utils - tokenizer = get_tokenizer(policy_config["model_name"]) + tokenizer = get_tokenizer(policy_config["tokenizer"]) train_dataloader = StatefulDataLoader( train_dataset, batch_size=policy_config["train_global_batch_size"], @@ -193,6 +193,7 @@ def setup( policy = HfPolicy( cluster=cluster, config=policy_config, + tokenizer=tokenizer, weights_path=Path(last_checkpoint_path) / "policy" / "weights" if last_checkpoint_path else None, diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 20578ae6cd..ecce3470b6 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -338,6 +338,7 @@ def train( num_valid_samples = loss_metrics["num_valid_samples"] loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] + loss = loss / num_microbatches # Backward pass if not eval_mode: ## NOTE: invalid samples should be multiplied @@ -1049,8 +1050,8 @@ def train( common_kwargs={ "loss_fn": loss_fn, "eval_mode": eval_mode, - "batch_size": batch_size, - "micro_batch_size": micro_batch_size, + "gbs": batch_size, + "mbs": micro_batch_size, }, ) results = self.worker_group.get_all_worker_results(futures) From d38891af11694deb9d2af90da427e688d2fb61b5 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 22:32:29 -0700 Subject: [PATCH 43/57] address comments, fix tests after rebase Signed-off-by: ashors1 --- examples/run_dpo.py | 4 +-- nemo_reinforcer/algorithms/dpo.py | 3 +- nemo_reinforcer/algorithms/loss_functions.py | 14 +++----- nemo_reinforcer/algorithms/sft.py | 35 ++++++++++++++----- .../unit/data/hf_datasets/test_dpo_dataset.py | 10 +++--- tests/unit/data/hf_datasets/test_helpsteer.py | 1 - 6 files changed, 38 insertions(+), 29 deletions(-) diff --git a/examples/run_dpo.py b/examples/run_dpo.py index 566a7dfca3..0f4922fc5d 100644 --- a/examples/run_dpo.py +++ b/examples/run_dpo.py @@ -241,7 +241,7 @@ def main(): init_ray() # setup data - dataset, val_dataset, tokenizer, dpo_task_spec = setup_data( + train_dataset, val_dataset, tokenizer, dpo_task_spec = setup_data( config["data"], config["policy"] ) ( @@ -254,7 +254,7 @@ def main(): checkpointer, dpo_save_state, master_config, - ) = setup(config, dataset, val_dataset) + ) = setup(config, tokenizer, train_dataset, val_dataset) dpo_train( policy, train_dataloader, diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index c2f82a2756..27dfd5e507 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -16,6 +16,7 @@ from collections import defaultdict from functools import partial from pathlib import Path +from transformers import AutoTokenizer from typing import Optional, Tuple, TypedDict from tqdm import tqdm @@ -91,6 +92,7 @@ class MasterConfig(TypedDict): # ======================================================= def setup( master_config: MasterConfig, + tokenizer: AutoTokenizer, train_dataset: AllTaskProcessedDataset, val_dataset: AllTaskProcessedDataset, ) -> Tuple[ @@ -149,7 +151,6 @@ def setup( # Data # ========================== ## TODO(@ashors) reduce boilerplate and move reused code into utils - tokenizer = get_tokenizer(policy_config["tokenizer"]) train_dataloader = StatefulDataLoader( train_dataset, batch_size=policy_config["train_global_batch_size"], diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index b1b9068c16..8ded00b284 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -211,10 +211,6 @@ class DPOLossDataDict(TypedDict): sample_mask: torch.Tensor -def average_valid_samples(tensor: torch.Tensor, sample_mask: torch.Tensor): - return tensor.sum() / sample_mask.sum().clamp(min=1) - - class DPOLossFn(LossFunction): """Direct Preference Optimization (DPO) loss function. @@ -317,10 +313,10 @@ def preference_loss( ) ## zero out invalid samples return ( - average_valid_samples(per_sample_loss, sample_mask[::2]), + masked_mean(per_sample_loss, sample_mask[::2]), (rewards_chosen > rewards_rejected).float().mean(0), - average_valid_samples(rewards_chosen, sample_mask[::2]), - average_valid_samples(rewards_rejected, sample_mask[1::2]), + masked_mean(rewards_chosen, sample_mask[::2]), + masked_mean(rewards_rejected, sample_mask[1::2]), ) def __call__( @@ -335,9 +331,7 @@ def __call__( dpo_average_log_probs=self.sft_average_log_probs, ) sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) - sft_loss_chosen = average_valid_samples( - sft_loss_chosen, data["sample_mask"][::2] - ) + sft_loss_chosen = masked_mean(sft_loss_chosen, data["sample_mask"][::2]) ( preference_loss, diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 3f5036d0fb..376b50db41 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings from transformers import AutoTokenizer from pathlib import Path from typing import Optional, Tuple, TypedDict @@ -236,6 +237,7 @@ def validate( # val_total = len(val_dataloader) val_metrics = {"val_loss": 0.0} + num_valid_batches = 0 for batch_idx, val_batch in enumerate(val_dataloader): ## add loss mask based on role to every message @@ -266,12 +268,26 @@ def validate( gbs=val_batch_size, mbs=val_mbs, ) - val_metrics["val_loss"] += float(val_results["loss"]) + + if len(val_results["all_mb_metrics"]) == 0: + warnings.warn( + "No validation metrics were collected for this batch." + " This is likely because there were no valid samples." + ) + else: + val_metrics["val_loss"] += float(val_results["loss"]) + num_valid_batches += 1 if val_batches > 0 and batch_idx >= val_batches - 1: break - val_metrics["val_loss"] /= val_batches + if num_valid_batches > 0: + val_metrics["val_loss"] /= num_valid_batches + else: + warnings.warn( + "No validation metrics were collected." + " This is likely because there were no valid samples in the validation set." + ) # Calculate validation metrics policy.prepare_for_training() @@ -280,14 +296,15 @@ def validate( timing_metrics = timer.get_timing_metrics(reduction_op="sum") validation_time = timing_metrics.get("total_validation_time", 0) - # Print summary of validation results - print("\nšŸ“Š Validation Results:") - print(f" • Validation loss: {val_metrics['val_loss']:.4f}") + if num_valid_batches > 0: + # Print summary of validation results + print("\nšŸ“Š Validation Results:") + print(f" • Validation loss: {val_metrics['val_loss']:.4f}") - # Print timing information - print("\n ā±ļø Validation Timing:") - validation_time = timing_metrics.get("total_validation_time", 0) - print(f" • Total validation time: {validation_time:.2f}s") + # Print timing information + print("\n ā±ļø Validation Timing:") + validation_time = timing_metrics.get("total_validation_time", 0) + print(f" • Total validation time: {validation_time:.2f}s") # Make sure to reset the timer after validation timer.reset() diff --git a/tests/unit/data/hf_datasets/test_dpo_dataset.py b/tests/unit/data/hf_datasets/test_dpo_dataset.py index 476b333d33..19d9d45ef6 100644 --- a/tests/unit/data/hf_datasets/test_dpo_dataset.py +++ b/tests/unit/data/hf_datasets/test_dpo_dataset.py @@ -45,20 +45,20 @@ def mock_dpo_data(): } ] + train_ctx = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) + val_ctx = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) + with tempfile.NamedTemporaryFile( mode="w", suffix=".json", delete=False ) as train_file: json.dump(train_data, train_file) train_path = train_file.name - with tempfile.NamedTemporaryFile( mode="w", suffix=".json", delete=False ) as val_file: json.dump(val_data, val_file) val_path = val_file.name - yield train_path, val_path - # Cleanup os.unlink(train_path) os.unlink(val_path) @@ -71,9 +71,7 @@ def test_dpo_dataset_initialization(mock_dpo_data): dataset = DPODataset(train_data_path=train_path, val_data_path=val_path) # Verify dataset initialization - assert dataset.task_spec.task_name == "dpo" - assert dataset.task_spec.custom_template is not None - assert "messages" in dataset.task_spec.custom_template + assert dataset.task_spec.task_name == "DPO" # Verify formatted_ds structure assert "train" in dataset.formatted_ds diff --git a/tests/unit/data/hf_datasets/test_helpsteer.py b/tests/unit/data/hf_datasets/test_helpsteer.py index 2b3a571881..304fd5d2ad 100644 --- a/tests/unit/data/hf_datasets/test_helpsteer.py +++ b/tests/unit/data/hf_datasets/test_helpsteer.py @@ -67,7 +67,6 @@ def test_helpsteer3_dataset_initialization(): # Verify dataset initialization assert dataset.task_spec.task_name == "HelpSteer3" - assert dataset.task_spec.custom_template is None # Should use tokenizer's template def test_helpsteer3_dataset_data_format(): From a7b2decf6d335ac71d8722c5326a38aaa213b2f1 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 22:35:05 -0700 Subject: [PATCH 44/57] add hydra-style overrides Signed-off-by: ashors1 --- examples/run_dpo.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/run_dpo.py b/examples/run_dpo.py index 0f4922fc5d..f780933310 100644 --- a/examples/run_dpo.py +++ b/examples/run_dpo.py @@ -23,7 +23,7 @@ from nemo_reinforcer.algorithms.dpo import MasterConfig, dpo_train, setup from nemo_reinforcer.algorithms.utils import get_tokenizer from nemo_reinforcer.distributed.virtual_cluster import init_ray -from nemo_reinforcer.utils.config import load_config +from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides from nemo_reinforcer.utils.logger import get_next_experiment_dir from nemo_reinforcer.data import DataConfig, hf_datasets from nemo_reinforcer.data.datasets import AllTaskProcessedDataset @@ -41,10 +41,7 @@ def parse_args(): ) # Parse known args for the script - args, remaining = parser.parse_known_args() - - # Convert remaining args to OmegaConf format - overrides = OmegaConf.from_dotlist(remaining) + args, overrides = parser.parse_known_args() return args, overrides @@ -222,7 +219,7 @@ def main(): if overrides: print(f"Overrides: {overrides}") - config = OmegaConf.merge(config, overrides) + config = parse_hydra_overrides(config, overrides) config: MasterConfig = OmegaConf.to_container(config, resolve=True) print("Applied CLI overrides") From 4ac99e3760e19728dc681638e9fc7a0dfe60cc98 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 22:48:24 -0700 Subject: [PATCH 45/57] sum valid samples across batch Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/grpo.py | 6 +++++- nemo_reinforcer/algorithms/sft.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index eb0947011f..cd7333dce0 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -648,7 +648,11 @@ def grpo_train( "reward": rewards.numpy(), } metrics.update(train_results["all_mb_metrics"]) - metrics = {k: np.mean(v).item() for k, v in metrics.items()} + for k, v in metrics.items(): + if k == "num_valid_samples": + metrics[k] = np.sum(v).item() + else: + metrics[k] = np.mean(v).item() metrics.update(gen_metrics) timing_metrics = timer.get_timing_metrics(reduction_op="sum") diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 376b50db41..3e2595872a 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -452,7 +452,11 @@ def sft_train( "loss": train_results["loss"].numpy(), } metrics.update(train_results["all_mb_metrics"]) - metrics = {k: np.mean(v).item() for k, v in metrics.items()} + for k, v in metrics.items(): + if k == "num_valid_samples": + metrics[k] = np.sum(v).item() + else: + metrics[k] = np.mean(v).item() timing_metrics = timer.get_timing_metrics(reduction_op="sum") print("\nšŸ“Š Training Results:") From 82e8e98d0e7e40ce9c341b3a401abd76ce9f4ec5 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 17 Apr 2025 10:38:23 -0700 Subject: [PATCH 46/57] decrease max steps and fix test Signed-off-by: ashors1 --- examples/configs/dpo.yaml | 2 +- tests/unit/algorithms/test_dpo.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index b6ac194942..0ad6bf36f6 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -1,7 +1,7 @@ # DPO Algorithm Configuration dpo: max_num_epochs: 1 - max_num_steps: 200 + max_num_steps: 150 val_period: 25 val_batches: 8 val_global_batch_size: 8 diff --git a/tests/unit/algorithms/test_dpo.py b/tests/unit/algorithms/test_dpo.py index 8c39e7e4d2..fa924a745d 100644 --- a/tests/unit/algorithms/test_dpo.py +++ b/tests/unit/algorithms/test_dpo.py @@ -16,7 +16,7 @@ import torch from unittest.mock import MagicMock, patch -from nemo_reinforcer.algorithms.dpo import augment_dataloader +from nemo_reinforcer.algorithms.dpo import add_ref_logprobs_to_data class MockPolicy: @@ -27,8 +27,8 @@ def get_reference_policy_logprobs(self, batch, micro_batch_size): return {"reference_logprobs": self.logprobs} -def test_augment_dataloader(): - """Test that augment_dataloader correctly adds reference policy logprobs to batches.""" +def test_add_logprobs_to_batch(): + """Test that add_ref_logprobs_to_data correctly adds reference policy logprobs to batches.""" # Create mock data batch_size = 2 seq_len = 4 @@ -55,7 +55,7 @@ def test_augment_dataloader(): # Get the augmented batches augmented_batches = list( - augment_dataloader(mock_dataloader, mock_policy, mock_master_config) + add_ref_logprobs_to_data(mock_dataloader, mock_policy, mock_master_config) ) # Verify we got exactly one batch From 233a9abd8c7b4fb55484d7046c94fafa3cfe3baf Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 17 Apr 2025 12:30:04 -0700 Subject: [PATCH 47/57] address remaining comments Signed-off-by: ashors1 --- nemo_reinforcer/data/datasets.py | 5 +-- .../data/hf_datasets/helpsteer3.py | 20 +++++++++-- nemo_reinforcer/data/llm_message_utils.py | 36 ++++++++----------- tests/unit/data/test_llm_message_utils.py | 16 +++++++++ 4 files changed, 52 insertions(+), 25 deletions(-) diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index a0335ad497..3dcef5e70a 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -22,7 +22,7 @@ DatumSpec, ) from nemo_reinforcer.data.llm_message_utils import ( - add_dpo_loss_mask_to_message_log, + add_loss_mask_to_message_log, batched_message_log_to_flat_message, ) from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict @@ -223,8 +223,9 @@ def dpo_collate_fn(data_batch: List[DatumSpec], tokenizer) -> BatchedDataDict: ) ## add loss mask based on role to every message - add_dpo_loss_mask_to_message_log( + add_loss_mask_to_message_log( batch["message_log"], + only_unmask_final=True, ) cat_and_padded, input_lengths = batched_message_log_to_flat_message( diff --git a/nemo_reinforcer/data/hf_datasets/helpsteer3.py b/nemo_reinforcer/data/hf_datasets/helpsteer3.py index 7ade565d42..0ad0263c30 100644 --- a/nemo_reinforcer/data/hf_datasets/helpsteer3.py +++ b/nemo_reinforcer/data/hf_datasets/helpsteer3.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from datasets import load_dataset +from absl import logging from nemo_reinforcer.data.interfaces import TaskDataSpec @@ -21,10 +22,25 @@ def format_helpsteer3(data): response_2 = data["response2"] overall_preference = data["overall_preference"] + if overall_preference < 0: + chosen = response_1 + rejected = response_2 + elif overall_preference == 0: + logging.log_every_n( + logging.WARNING, + "Preference is 0 for some examples! Setting chosen and rejected to response 1 since we don't know which response is better", + 1000, + ) + chosen = response_1 + rejected = response_1 + else: + chosen = response_2 + rejected = response_1 + return { "prompt": data["context"], - "chosen_response": response_1 if overall_preference <= 0 else response_2, - "rejected_response": response_2 if overall_preference < 0 else response_1, + "chosen_response": chosen, + "rejected_response": rejected, } diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py index e425ce9ed4..cc2e1f8ed1 100644 --- a/nemo_reinforcer/data/llm_message_utils.py +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -110,40 +110,34 @@ def get_keys_from_message_log( def add_loss_mask_to_message_log( message_log: LLMMessageLogType, roles_to_train_on: List[str] = ["assistant"], + only_unmask_final: bool = False, ) -> None: """Add token-level loss masks to each message in a message log. Args: message_log (LLMMessageLogType): List of message dictionaries containing token IDs and metadata roles_to_train_on (List[str]): List of strings indicating which speakers to unmask. Default: ["assistant"] + only_unmask_final (bool): If True, only unmask the final message in the log. Default: False """ for i, role in enumerate(roles_to_train_on): roles_to_train_on[i] = role.lower() - for message in message_log: - for sentence in message: - if sentence["role"] in roles_to_train_on: - sentence["token_loss_mask"] = torch.ones_like(sentence["token_ids"]) - else: - sentence["token_loss_mask"] = torch.zeros_like(sentence["token_ids"]) - - -def add_dpo_loss_mask_to_message_log( - message_log: LLMMessageLogType, -) -> None: - """Add token-level loss masks to each message in a message log. - - This function differs from add_loss_mask_to_message_log in that it only unmasks the final assistant message in the log. - - Args: - message_log (LLMMessageLogType): List of message dictionaries containing token IDs and metadata - """ for message in message_log: for i, sentence in enumerate(message): - if i == len(message) - 1: - sentence["token_loss_mask"] = torch.ones_like(sentence["token_ids"]) + if only_unmask_final: + if i == len(message) - 1: + sentence["token_loss_mask"] = torch.ones_like(sentence["token_ids"]) + else: + sentence["token_loss_mask"] = torch.zeros_like( + sentence["token_ids"] + ) else: - sentence["token_loss_mask"] = torch.zeros_like(sentence["token_ids"]) + if sentence["role"] in roles_to_train_on: + sentence["token_loss_mask"] = torch.ones_like(sentence["token_ids"]) + else: + sentence["token_loss_mask"] = torch.zeros_like( + sentence["token_ids"] + ) def _pad_tensor( diff --git a/tests/unit/data/test_llm_message_utils.py b/tests/unit/data/test_llm_message_utils.py index 87bba58015..2528b900c6 100644 --- a/tests/unit/data/test_llm_message_utils.py +++ b/tests/unit/data/test_llm_message_utils.py @@ -422,6 +422,22 @@ def test_add_loss_mask_to_chat_message_log( tokenized_chat_message_log[0][2]["token_loss_mask"], torch.tensor([1, 1]) ) + ## test only unmasking final message + add_loss_mask_to_message_log( + tokenized_chat_message_log, + only_unmask_final=True, + ) + assert torch.equal( + tokenized_chat_message_log[0][0]["token_loss_mask"], + torch.tensor([0, 0, 0, 0, 0, 0]), + ) + assert torch.equal( + tokenized_chat_message_log[0][1]["token_loss_mask"], torch.tensor([0, 0, 0]) + ) + assert torch.equal( + tokenized_chat_message_log[0][2]["token_loss_mask"], torch.tensor([1, 1]) + ) + def test_get_first_index_that_differs(): assert get_first_index_that_differs("hello", "hello") == 5 From c22bc60ef01a65ef4b7da7ec2400b6988127836c Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 17 Apr 2025 13:14:49 -0700 Subject: [PATCH 48/57] support dtensor with dpo Signed-off-by: ashors1 --- examples/configs/dpo.yaml | 16 +++++++ nemo_reinforcer/algorithms/dpo.py | 16 ++++++- nemo_reinforcer/algorithms/loss_functions.py | 44 +++++++++++++------- nemo_reinforcer/data/datasets.py | 5 ++- 4 files changed, 62 insertions(+), 19 deletions(-) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 0ad6bf36f6..c0e935b952 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -45,6 +45,18 @@ policy: fsdp_offload_enabled: false activation_checkpointing_enabled: false + dtensor_cfg: + enabled: false + cpu_offload: False + sequence_parallel: false + activation_checkpointing: false + tensor_parallel_size: 1 + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + optimizer: name: "torch.optim.AdamW" kwargs: @@ -52,6 +64,10 @@ policy: weight_decay: 0.1 betas: [0.9, 0.98] eps: 1e-5 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False scheduler: - name: "torch.optim.lr_scheduler.LinearLR" diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index 27dfd5e507..64d14164d5 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -155,7 +155,13 @@ def setup( train_dataset, batch_size=policy_config["train_global_batch_size"], shuffle=True, - collate_fn=partial(dpo_collate_fn, tokenizer=tokenizer), + collate_fn=partial( + dpo_collate_fn, + tokenizer=tokenizer, + make_sequence_length_divisible_by=policy_config[ + "make_sequence_length_divisible_by" + ], + ), drop_last=True, ) @@ -169,7 +175,13 @@ def setup( val_dataset, batch_size=dpo_config["val_global_batch_size"], shuffle=False, - collate_fn=partial(dpo_collate_fn, tokenizer=tokenizer), + collate_fn=partial( + dpo_collate_fn, + tokenizer=tokenizer, + make_sequence_length_divisible_by=policy_config[ + "make_sequence_length_divisible_by" + ], + ), drop_last=True, ) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 299408d775..abdb19b05e 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -175,14 +175,20 @@ def __call__( sample_mask = data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) - next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits - - # Gather the logprobs for the actual next tokens - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) + next_token_logits = next_token_logits.to(torch.float32) + if isinstance(next_token_logits, torch.distributed.tensor.DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, data["input_ids"] + ) + else: + next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits, dim=-1 + ) + logprobs = next_token_logprobs[:, :-1] # Remove last position's logits + token_logprobs = logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) if dpo_loss: ## shape: [batch_size] @@ -301,14 +307,20 @@ def preference_loss( token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] - next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits - - # Gather the logprobs for the actual next tokens - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) + next_token_logits = next_token_logits.to(torch.float32) + if isinstance(next_token_logits, torch.distributed.tensor.DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, data["input_ids"] + ) + else: + next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits, dim=-1 + ) + logprobs = next_token_logprobs[:, :-1] # Remove last position's logits + token_logprobs = logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) ref_logprobs = data["reference_policy_logprobs"][:, :-1] diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 3dcef5e70a..8d8ca78371 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -187,7 +187,9 @@ def eval_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: return output -def dpo_collate_fn(data_batch: List[DatumSpec], tokenizer) -> BatchedDataDict: +def dpo_collate_fn( + data_batch: List[DatumSpec], tokenizer, make_sequence_length_divisible_by: int +) -> BatchedDataDict: """Collate function for DPO training. This function separates the chosen and rejected responses to create @@ -231,6 +233,7 @@ def dpo_collate_fn(data_batch: List[DatumSpec], tokenizer) -> BatchedDataDict: cat_and_padded, input_lengths = batched_message_log_to_flat_message( batch["message_log"], pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=make_sequence_length_divisible_by, ) train_data: BatchedDataDict = BatchedDataDict( From 96c079c94a933c225a297d9fc97166bd35ec4b51 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 17 Apr 2025 15:34:20 -0700 Subject: [PATCH 49/57] minor bug fixes Signed-off-by: ashors1 --- nemo_reinforcer/models/policy/dtensor_policy_worker.py | 8 +++++--- nemo_reinforcer/models/policy/fsdp1_policy_worker.py | 4 ++-- tests/unit/data/test_datasets.py | 6 ++++-- tests/unit/test_utils.py | 5 ++++- 4 files changed, 15 insertions(+), 8 deletions(-) mode change 100644 => 100755 tests/unit/data/test_datasets.py diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index 4cfd86398d..30d67e5805 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -15,7 +15,7 @@ import gc from collections import defaultdict -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Dict, Optional import ray @@ -259,7 +259,7 @@ def train( ctx = torch.no_grad() self.model.eval() else: - ctx = contextlib.nullcontext() + ctx = nullcontext() # Ensure model is in training mode self.model.train() @@ -372,11 +372,13 @@ def train( metrics = { "global_loss": global_loss.cpu(), "local_loss": local_loss.cpu(), - "grad_norm": grad_norm, "rank": torch.distributed.get_rank(), "all_mb_metrics": dict(mb_metrics), } + if not eval_mode: + metrics["grad_norm"] = grad_norm + return metrics def get_logprobs( diff --git a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py index 3e757d8ac3..bed5d1a7f3 100644 --- a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py +++ b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py @@ -15,7 +15,7 @@ import gc import warnings from collections import defaultdict -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Dict, Optional import ray @@ -228,7 +228,7 @@ def train( ctx = torch.no_grad() self.model.eval() else: - ctx = contextlib.nullcontext() + ctx = nullcontext() # Ensure model is in training mode self.model.train() diff --git a/tests/unit/data/test_datasets.py b/tests/unit/data/test_datasets.py old mode 100644 new mode 100755 index ee80c5e934..7486e025c4 --- a/tests/unit/data/test_datasets.py +++ b/tests/unit/data/test_datasets.py @@ -94,7 +94,9 @@ def test_dpo_collate_fn(): ] # Call dpo_collate_fn - train_data = dpo_collate_fn(data_batch, mock_tokenizer) + train_data = dpo_collate_fn( + data_batch, mock_tokenizer, make_sequence_length_divisible_by=16 + ) # Verify the output structure assert isinstance(train_data, BatchedDataDict) @@ -107,7 +109,7 @@ def test_dpo_collate_fn(): assert train_data["input_ids"].shape[0] == 4 # 2 examples * 2 (chosen + rejected) # Verify input_ids shape and padding - max_length = 7 # max of all sequence lengths + max_length = 16 # max of all sequence lengths, padded to be divisible by 16 assert train_data["input_ids"].shape == (4, max_length) # Verify input_lengths diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 2773fd20f2..c68f52634d 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -22,7 +22,10 @@ def simple_loss( ) -> Tuple[torch.Tensor, Dict[str, Any]]: # Just return mean of logprobs as the loss for testing loss = next_token_logits.mean() - metrics = {"test_metric": loss.item() * 0.5} + metrics = { + "test_metric": loss.item() * 0.5, + "num_valid_samples": 1, + } return loss, metrics From 5a81c637072c9ffcba7aa2fda7b6cc70f8c4f6ac Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 17 Apr 2025 17:12:09 -0700 Subject: [PATCH 50/57] fix test loss fn Signed-off-by: ashors1 --- tests/unit/test_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index c68f52634d..9972c1a1b6 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -49,4 +49,7 @@ def nll_loss( token_loss_mask = data.get("token_loss_mask")[:, 1:].cuda() loss = -torch.sum(token_logprobs * token_loss_mask) - return loss, {"loss": loss.item()} + return loss, { + "loss": loss.item(), + "num_valid_samples": 1, + } From a8b3efc206909851de9ed98b295ddfe8f6f413ee Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 17 Apr 2025 17:57:11 -0700 Subject: [PATCH 51/57] update dpo docs Signed-off-by: ashors1 --- docs/guides/dpo.md | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/guides/dpo.md b/docs/guides/dpo.md index 3001daf569..249ca58491 100644 --- a/docs/guides/dpo.md +++ b/docs/guides/dpo.md @@ -49,10 +49,20 @@ def format_helpsteer3(data): response_2 = data["response2"] overall_preference = data["overall_preference"] + if overall_preference < 0: + chosen = response_1 + rejected = response_2 + elif overall_preference == 0: + chosen = response_1 + rejected = response_1 + else: + chosen = response_2 + rejected = response_1 + return { "prompt": data["context"], - "chosen_response": response_1 if overall_preference < 0 else response_2, - "rejected_response": response_2 if overall_preference < 0 else response_1, + "chosen_response": chosen, + "rejected_response": rejected, } ``` @@ -68,7 +78,6 @@ Here's a minimal example which simply re-keys an existing jsonl dataset: ```{testcode} from datasets import load_dataset -from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset from nemo_reinforcer.data.interfaces import TaskDataSpec from docs.helpers import make_dpo_dataset From 4c90c44d1af764ad625d163f869cb2ad0f1c92f0 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 18 Apr 2025 09:57:17 -0700 Subject: [PATCH 52/57] fix indentation Signed-off-by: ashors1 --- .../models/policy/dtensor_policy_worker.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index 30d67e5805..02e9f55330 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -356,30 +356,30 @@ def train( losses.append(torch.tensor(mb_losses).sum().item()) - # Compute global loss across all ranks - with torch.no_grad(): - local_loss = torch.tensor(losses, device="cuda") - global_loss = torch.zeros_like(local_loss) - torch.distributed.all_reduce(local_loss) - global_loss = local_loss / self.dp_size - - # Aggregate metrics across all microbatches - mb_metrics = defaultdict(list) - for m in all_mb_metrics: - for k, v in m.items(): - mb_metrics[k].append(v) - - metrics = { - "global_loss": global_loss.cpu(), - "local_loss": local_loss.cpu(), - "rank": torch.distributed.get_rank(), - "all_mb_metrics": dict(mb_metrics), - } - - if not eval_mode: - metrics["grad_norm"] = grad_norm - - return metrics + # Compute global loss across all ranks + with torch.no_grad(): + local_loss = torch.tensor(losses, device="cuda") + global_loss = torch.zeros_like(local_loss) + torch.distributed.all_reduce(local_loss) + global_loss = local_loss / self.dp_size + + # Aggregate metrics across all microbatches + mb_metrics = defaultdict(list) + for m in all_mb_metrics: + for k, v in m.items(): + mb_metrics[k].append(v) + + metrics = { + "global_loss": global_loss.cpu(), + "local_loss": local_loss.cpu(), + "rank": torch.distributed.get_rank(), + "all_mb_metrics": dict(mb_metrics), + } + + if not eval_mode: + metrics["grad_norm"] = grad_norm + + return metrics def get_logprobs( self, data: BatchedDataDict, micro_batch_size: int = None From 680dfbc12003b71b19ae5a8b19bea98ad4faac4a Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 18 Apr 2025 11:07:17 -0700 Subject: [PATCH 53/57] address remaining comments Signed-off-by: ashors1 --- README.md | 124 +++++++++---------- docs/guides/dpo.md | 2 +- examples/configs/dpo.yaml | 2 + tests/unit/algorithms/test_loss_functions.py | 61 +++++++++ 4 files changed, 122 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 2800335e2c..b13081e55f 100644 --- a/README.md +++ b/README.md @@ -5,15 +5,15 @@ - [Features](#features) - [Installation](#installation) - [Quick start](#quick-start) + - [GRPO](#grpo) + - [Single Node](#single-node-2) + - [Multi-node](#multi-node-2) - [SFT](#sft) - [Single Node](#single-node) - [Multi-node](#multi-node) - [DPO](#dpo) - [Single Node](#single-node-1) - [Multi-node](#multi-node-1) - - [GRPO](#grpo) - - [Single Node](#single-node-2) - - [Multi-node](#multi-node-2) - [Cluster Start](#cluster-start) **Nemo-Reinforcer** is a scalable and efficient post-training library designed for models ranging from 1 GPU to thousands, and from tiny to over 100 billion parameters. @@ -64,6 +64,61 @@ uv pip install -e '.[dev,test]' **Reminder**: Don't forget to set your HF_HOME and WANDB_API_KEY (if needed). You'll need to do a `huggingface-cli login` as well for Llama models. +### GRPO + +We have a reference GRPO experiment config set up trained for math benchmarks using the [OpenInstructMath2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2) dataset. + +#### Single Node + +To run GRPO on a single GPU for `Llama-3.2-1B-Instruct`: + +```sh +# Run the GRPO math example using a 1B parameter model +uv run python examples/run_grpo_math.py +``` + +By default, this uses the configuration in `examples/configs/grpo_math_1B.yaml`. You can customize parameters with command-line overrides. For example, to run on 8 gpus, + +```sh +# Run the GRPO math example using a 1B parameter model using 8 GPUs +uv run python examples/run_grpo_math.py \ + cluster.gpus_per_node=8 +``` + +You can override any of the parameters listed in the yaml configuration file. For example, + +```sh +uv run python examples/run_grpo_math.py \ + policy.model_name="Qwen/Qwen2-1.5B" \ + checkpointing.checkpoint_dir="results/qwen1_5b_math" \ + logger.wandb_enabled=True \ + logger.wandb.name="grpo-qwen1_5b_math" \ + logger.num_val_samples_to_print=10 \ +``` + +#### Multi-node + +```sh +# Run from the root of NeMo-Reinforcer repo +NUM_ACTOR_NODES=2 +# Add a timestamp to make each job name unique +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +# grpo_math_8b uses Llama-3.1-8B-Instruct model +COMMAND="uv pip install -e .; uv run ./examples/run_grpo_math.py --config examples/configs/grpo_math_8B.yaml cluster.num_nodes=2 checkpointing.checkpoint_dir='results/llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='grpo-llama8b_math'" \ +UV_CACHE_DIR=YOUR_UV_CACHE_DIR \ +CONTAINER=YOUR_CONTAINER \ +MOUNTS="$PWD:$PWD" \ +sbatch \ + --nodes=${NUM_ACTOR_NODES} \ + --account=YOUR_ACCOUNT \ + --job-name=YOUR_JOBNAME \ + --partition=YOUR_PARTITION \ + --time=4:0:0 \ + --gres=gpu:8 \ + ray.sub +``` + ### SFT We provide a sample SFT experiment that uses the [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/). @@ -92,14 +147,6 @@ Refer to `examples/configs/sft.yaml` for a full list of parameters that can be o #### Multi-node -For distributed training across multiple nodes: - -Set `UV_CACHE_DIR` to a directory that can be read from all workers before running any uv run command. - -```sh -export UV_CACHE_DIR=/path/that/all/workers/can/access/uv_cache -``` - ```sh # Run from the root of NeMo-Reinforcer repo NUM_ACTOR_NODES=2 @@ -184,61 +231,6 @@ sbatch \ ray.sub ``` -### GRPO - -We have a reference GRPO experiment config set up trained for math benchmarks using the [OpenInstructMath2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2) dataset. - -#### Single Node - -To run GRPO on a single GPU for `Llama-3.2-1B-Instruct`: - -```sh -# Run the GRPO math example using a 1B parameter model -uv run python examples/run_grpo_math.py -``` - -By default, this uses the configuration in `examples/configs/grpo_math_1B.yaml`. You can customize parameters with command-line overrides. For example, to run on 8 gpus, - -```sh -# Run the GRPO math example using a 1B parameter model using 8 GPUs -uv run python examples/run_grpo_math.py \ - cluster.gpus_per_node=8 -``` - -You can override any of the parameters listed in the yaml configuration file. For example, - -```sh -uv run python examples/run_grpo_math.py \ - policy.model_name="Qwen/Qwen2-1.5B" \ - checkpointing.checkpoint_dir="results/qwen1_5b_math" \ - logger.wandb_enabled=True \ - logger.wandb.name="grpo-qwen1_5b_math" \ - logger.num_val_samples_to_print=10 \ -``` - -#### Multi-node - -```sh -# Run from the root of NeMo-Reinforcer repo -NUM_ACTOR_NODES=2 -# Add a timestamp to make each job name unique -TIMESTAMP=$(date +%Y%m%d_%H%M%S) - -# grpo_math_8b uses Llama-3.1-8B-Instruct model -COMMAND="uv pip install -e .; uv run ./examples/run_grpo_math.py --config examples/configs/grpo_math_8B.yaml cluster.num_nodes=2 checkpointing.checkpoint_dir='results/llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='grpo-llama8b_math'" \ -UV_CACHE_DIR=YOUR_UV_CACHE_DIR \ -CONTAINER=YOUR_CONTAINER \ -MOUNTS="$PWD:$PWD" \ -sbatch \ - --nodes=${NUM_ACTOR_NODES} \ - --account=YOUR_ACCOUNT \ - --job-name=YOUR_JOBNAME \ - --partition=YOUR_PARTITION \ - --time=4:0:0 \ - --gres=gpu:8 \ - ray.sub -``` - ## Cluster Start Please visit [Cluster Start](docs/cluster.md) for how to get started on Slurm or Kubernetes. diff --git a/docs/guides/dpo.md b/docs/guides/dpo.md index 249ca58491..17e3bd303f 100644 --- a/docs/guides/dpo.md +++ b/docs/guides/dpo.md @@ -28,7 +28,7 @@ uv run examples/run_dpo.py \ logger.wandb.name="dpo-dev-8-gpu" ``` -**Reminder**: Don't forget to set your HF_HOME and WANDB_API_KEY (if needed). You'll need to do a `huggingface-cli login` as well for Llama models. +**Reminder**: Don't forget to set your `HF_HOME`, `WANDB_API_KEY`, and `HF_DATASETS_CACHE` (if needed). You'll need to do a `huggingface-cli login` as well for Llama models. ## Datasets diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index c0e935b952..f4b4b41c27 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -55,6 +55,8 @@ policy: # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + ## NOTE: there is a known issue with gradient clipping when using Dtensor + ## if using dtensor, set max_grad_norm to NULL max_grad_norm: 1.0 optimizer: diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 866f78dc00..abefb2b251 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -100,6 +100,9 @@ def test_nll_loss(): def test_dpo_loss(): + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + vocab_size = 16 batch_size = 1 num_unmasked_tokens = 2 @@ -153,6 +156,9 @@ def test_dpo_loss(): def test_dpo_loss_varying_sequence_lengths(): """Test DPO loss with varying sequence lengths and preference_average_log_probs=True.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + # Create DPO loss function with preference_average_log_probs=True dpo_loss_fn_no_avg = DPOLossFn( { @@ -232,6 +238,61 @@ def test_dpo_loss_varying_sequence_lengths(): assert torch.isclose(torch.tensor(metrics_avg["sft_loss"]), expected_sft_loss_avg) +def test_dpo_sft_matches_nll_loss(): + """Test that DPO SFT loss matches NLL loss when preference_loss_weight=0.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + # Setup test data + vocab_size = 8 + batch_size = 2 + dpo_data = { + "input_ids": torch.randint(0, vocab_size, (batch_size * 2, 5)) + .to(torch.int64) + .to("cuda"), + "token_mask": torch.tensor( + [[0, 0, 1, 1, 0], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0]] + ).to("cuda"), + "sample_mask": torch.tensor([1, 1, 1, 1]).to("cuda"), + "reference_policy_logprobs": torch.randn((4, 5)).to("cuda"), + } + + ## when computing the sft loss in DPO, we only use the chosen samples + sft_data = { + "input_ids": dpo_data["input_ids"][::2], + "token_mask": dpo_data["token_mask"][::2], + "sample_mask": dpo_data["sample_mask"][::2], + } + + # Create next token logits that will give non-zero loss + ## * 2 for chosen/rejected + next_token_logits = torch.randn((batch_size * 2, 5, vocab_size)).to("cuda") + + # Compute NLL loss + nll_loss_fn = NLLLoss() + nll_loss, nll_metrics = nll_loss_fn(next_token_logits[::2], sft_data) + + # Compute DPO loss with preference_loss_weight=0 + dpo_loss_fn = DPOLossFn( + cfg={ + "reference_policy_kl_penalty": 0.0, + "preference_loss_weight": 0.0, # Disable preference loss + "sft_loss_weight": 1.0, # Only use SFT loss + "preference_average_log_probs": False, + "sft_average_log_probs": False, + } + ) + dpo_loss, dpo_metrics = dpo_loss_fn(next_token_logits, dpo_data) + + # Verify losses match + ## since DPO SFT loss just sums across tokens in a batch and then averages over the batch, + ## we need to re-normalize by multiplying by the batch size and dividing by the total number + ## of unmasked chosen tokens + torch.testing.assert_close( + dpo_loss / torch.sum(dpo_data["token_mask"][::2]) * batch_size, nll_loss + ) + + def _setup_clipped_pg_test_data(batch_size=1, seq_len=4, vocab_size=8, device="cuda"): """Sets up basic mock data structure. Tests should fill values.""" input_ids = torch.randint( # Input IDs only needed if original loss fn used From b54b6c39099f326fcdda2aad233d395dc6d73849 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 18 Apr 2025 11:10:33 -0700 Subject: [PATCH 54/57] fix hyperlinks Signed-off-by: ashors1 --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index b13081e55f..f7cc3979aa 100644 --- a/README.md +++ b/README.md @@ -6,14 +6,14 @@ - [Installation](#installation) - [Quick start](#quick-start) - [GRPO](#grpo) - - [Single Node](#single-node-2) - - [Multi-node](#multi-node-2) - - [SFT](#sft) - [Single Node](#single-node) - [Multi-node](#multi-node) - - [DPO](#dpo) + - [SFT](#sft) - [Single Node](#single-node-1) - [Multi-node](#multi-node-1) + - [DPO](#dpo) + - [Single Node](#single-node-2) + - [Multi-node](#multi-node-2) - [Cluster Start](#cluster-start) **Nemo-Reinforcer** is a scalable and efficient post-training library designed for models ranging from 1 GPU to thousands, and from tiny to over 100 billion parameters. @@ -202,7 +202,7 @@ uv run python examples/run_dpo.py \ logger.wandb.name="llama-dpo-sft" ``` -Refer to [dpo.yaml](examples/configs/dpo.yaml) for a full list of parameters that can be overridden. For an in-depth explanation of how to add your own DPO dataset, refer to the [DPO documentation](docs/guides/dpo.md) +Refer to [dpo.yaml](examples/configs/dpo.yaml) for a full list of parameters that can be overridden. For an in-depth explanation of how to add your own DPO dataset, refer to the [DPO documentation](docs/guides/dpo.md). #### Multi-node From 4612a7a0a79a5545f2f8bb59638771570b0f94f3 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 18 Apr 2025 11:19:21 -0700 Subject: [PATCH 55/57] small readme fixes Signed-off-by: ashors1 Signed-off-by: ashors1 --- README.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f7cc3979aa..4a747738db 100644 --- a/README.md +++ b/README.md @@ -153,8 +153,7 @@ NUM_ACTOR_NODES=2 # Add a timestamp to make each job name unique TIMESTAMP=$(date +%Y%m%d_%H%M%S) -# SFT experiment uses Llama-3.1-8B model -COMMAND="uv pip install -e .; uv run ./examples/run_sft.py --config examples/configs/sft.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 checkpointing.checkpoint_dir='results/sft_llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='sft-llama8b'" \ +COMMAND="uv pip install -e .; uv run ./examples/run_sft.py --config examples/configs/sft.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 checkpointing.checkpoint_dir='results/sft_llama1b_2nodes' logger.wandb_enabled=True logger.wandb.name='sft-llama1b'" \ UV_CACHE_DIR=YOUR_UV_CACHE_DIR \ CONTAINER=YOUR_CONTAINER \ MOUNTS="$PWD:$PWD" \ @@ -182,11 +181,11 @@ uv run python examples/run_dpo.py This trains `Llama3.2-1B-Instruct` on one GPU. -If you have access to more GPUs, you can update the experiment accordingly. To run on 8 GPUs, we update the cluster configuration and switch to an 8B Llama3.2 Instruct model: +If you have access to more GPUs, you can update the experiment accordingly. To run on 8 GPUs, we update the cluster configuration and switch to an 8B Llama3.1 Instruct model: ```sh uv run python examples/run_dpo.py \ - policy.model_name="meta-llama/Meta-Llama-3-8B-Instruct" \ + policy.model_name="meta-llama/Llama-3.1-8B-Instruct" \ policy.train_global_batch_size=256 \ cluster.gpus_per_node=8 ``` @@ -215,8 +214,7 @@ NUM_ACTOR_NODES=2 # Add a timestamp to make each job name unique TIMESTAMP=$(date +%Y%m%d_%H%M%S) -# DPO experiment uses Llama-3.1-8B model -COMMAND="uv run ./examples/run_dpo.py --config examples/configs/dpo.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 checkpointing.checkpoint_dir='results/dpo_llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='dpo-llama8b'" \ +COMMAND="uv run ./examples/run_dpo.py --config examples/configs/dpo.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 dpo.val_global_batch_size=32 checkpointing.checkpoint_dir='results/dpo_llama81_2nodes' logger.wandb_enabled=True logger.wandb.name='dpo-llama1b'" \ RAY_DEDUP_LOGS=0 \ UV_CACHE_DIR=YOUR_UV_CACHE_DIR \ CONTAINER=YOUR_CONTAINER \ From fbc1f6dbfdc192897e3c70b63e406b6486d85433 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 21 Apr 2025 17:45:14 -0700 Subject: [PATCH 56/57] fix issues with merge Signed-off-by: ashors1 --- README.md | 1 - .../models/policy/dtensor_policy_worker.py | 87 +++++++++---------- 2 files changed, 41 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 03aaa9127c..eef9fc1b39 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,6 @@ NUM_ACTOR_NODES=2 # Add a timestamp to make each job name unique TIMESTAMP=$(date +%Y%m%d_%H%M%S) -# SFT experiment uses Llama-3.1-8B model COMMAND="uv run ./examples/run_sft.py --config examples/configs/sft.yaml cluster.num_nodes=2 cluster.gpus_per_node=8 checkpointing.checkpoint_dir='results/sft_llama8b_2nodes' logger.wandb_enabled=True logger.wandb.name='sft-llama8b'" \ CONTAINER=YOUR_CONTAINER \ MOUNTS="$PWD:$PWD" \ diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index 579efe824e..526f0f093a 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -334,22 +334,10 @@ def train( mb_losses.append(loss.item()) all_mb_metrics.append(loss_metrics) + grad_norm = None if not eval_mode: - loss.backward() - mb_losses.append(loss.item()) - all_mb_metrics.append(loss_metrics) - - grad_norm = None - if not eval_mode: - with torch.no_grad(): - grad_norm = get_grad_norm( - self.model.parameters(), - dp_group=self.dp_mesh.get_group(), - tp_group=self.tp_mesh.get_group(), - dtype=torch.float32, - ) - if self.max_grad_norm is not None: - clip_grad_by_total_norm_( + with torch.no_grad(): + grad_norm = get_grad_norm( self.model.parameters(), dp_group=self.dp_mesh.get_group(), tp_group=self.tp_mesh.get_group(), @@ -358,39 +346,46 @@ def train( if self.max_grad_norm is not None: clip_grad_by_total_norm_( self.model.parameters(), - max_grad_norm=self.max_grad_norm, - total_norm=grad_norm, + dp_group=self.dp_mesh.get_group(), + tp_group=self.tp_mesh.get_group(), dtype=torch.float32, ) - - # Update parameters - self.optimizer.step() - self.scheduler.step() - - losses.append(torch.tensor(mb_losses).sum().item()) - - # Compute global loss across all ranks - with torch.no_grad(): - local_loss = torch.tensor(losses, device="cuda") - global_loss = torch.zeros_like(local_loss) - torch.distributed.all_reduce(local_loss, group=self.dp_mesh.get_group()) - global_loss = local_loss / self.dp_size - - # Aggregate metrics across all microbatches - mb_metrics = defaultdict(list) - for m in all_mb_metrics: - for k, v in m.items(): - mb_metrics[k].append(v) - - metrics = { - "global_loss": global_loss.cpu(), - "local_loss": local_loss.cpu(), - "grad_norm": grad_norm, - "rank": torch.distributed.get_rank(), - "all_mb_metrics": dict(mb_metrics), - } - - return metrics + if self.max_grad_norm is not None: + clip_grad_by_total_norm_( + self.model.parameters(), + max_grad_norm=self.max_grad_norm, + total_norm=grad_norm, + dtype=torch.float32, + ) + + # Update parameters + self.optimizer.step() + self.scheduler.step() + + losses.append(torch.tensor(mb_losses).sum().item()) + + # Compute global loss across all ranks + with torch.no_grad(): + local_loss = torch.tensor(losses, device="cuda") + global_loss = torch.zeros_like(local_loss) + torch.distributed.all_reduce(local_loss, group=self.dp_mesh.get_group()) + global_loss = local_loss / self.dp_size + + # Aggregate metrics across all microbatches + mb_metrics = defaultdict(list) + for m in all_mb_metrics: + for k, v in m.items(): + mb_metrics[k].append(v) + + metrics = { + "global_loss": global_loss.cpu(), + "local_loss": local_loss.cpu(), + "grad_norm": grad_norm, + "rank": torch.distributed.get_rank(), + "all_mb_metrics": dict(mb_metrics), + } + + return metrics def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: """Get the logprobs of the model for a batch of data. From 01c4f9012139acf1fb117de75ce1ac86e6c3cb16 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 21 Apr 2025 19:05:24 -0700 Subject: [PATCH 57/57] fix issue with rebase, add functional dpo test to ci Signed-off-by: ashors1 --- .github/workflows/cicd-main.yml | 1 + .../models/policy/dtensor_policy_worker.py | 15 +++++---------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 7a80e5dcd7..54d81f46ee 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -170,6 +170,7 @@ jobs: if [[ "${{ needs.pre-flight.outputs.test_level }}" =~ ^(L1|L2)$ ]]; then uv run --no-sync bash ./tests/functional/sft.sh uv run --no-sync bash ./tests/functional/grpo.sh + uv run --no-sync bash ./tests/functional/dpo.sh else echo Skipping functional tests for level ${{ needs.pre-flight.outputs.test_level }} fi diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index 526f0f093a..2c4bd78efd 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -346,17 +346,10 @@ def train( if self.max_grad_norm is not None: clip_grad_by_total_norm_( self.model.parameters(), - dp_group=self.dp_mesh.get_group(), - tp_group=self.tp_mesh.get_group(), + max_grad_norm=self.max_grad_norm, + total_norm=grad_norm, dtype=torch.float32, ) - if self.max_grad_norm is not None: - clip_grad_by_total_norm_( - self.model.parameters(), - max_grad_norm=self.max_grad_norm, - total_norm=grad_norm, - dtype=torch.float32, - ) # Update parameters self.optimizer.step() @@ -387,7 +380,9 @@ def train( return metrics - def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: + def get_logprobs( + self, data: BatchedDataDict, micro_batch_size: int = None + ) -> BatchedDataDict: """Get the logprobs of the model for a batch of data. Uses the configured logprob_batch_size to do microbatching.