From df8c20851075fc0b3f457ac4e727680c85cf3b1c Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sat, 7 Feb 2026 00:10:49 -0800 Subject: [PATCH 01/15] update ClippedPGLossFn, NLLLoss, DPOLossFn Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss_functions.py | 135 ++------------------------- nemo_rl/distributed/model_utils.py | 44 +++++++++ nemo_rl/models/automodel/train.py | 11 ++- 3 files changed, 60 insertions(+), 130 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 1a275146d2..7b1d4d16d0 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -25,9 +25,7 @@ ChunkedDistributedGatherLogprob, _get_tokens_on_this_cp_rank, allgather_cp_sharded_tensor, - from_parallel_logits_to_logprobs, gather_logits_at_global_indices, - get_logprobs_from_vocab_parallel_logits, ) Tensor = TypeVar("Tensor", bound=torch.Tensor) @@ -198,13 +196,10 @@ def __init__(self, cfg: ClippedPGLossConfig): def __call__( self, - next_token_logits: Tensor, + curr_logprobs: Tensor, data: BatchedDataDict[ClippedPGLossDataDict], global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict]: """Clipped Policy Gradient RL loss function.""" token_mask = data["token_mask"][:, 1:] @@ -214,7 +209,6 @@ def __call__( generation_logprobs = data["generation_logprobs"][:, 1:] if self.reference_policy_kl_penalty != 0: reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:] - seq_index = data.get("seq_index", None) mask = token_mask * sample_mask.unsqueeze(-1) @@ -282,39 +276,6 @@ def __call__( global_normalization_factor=global_valid_toks, ).item() - next_token_logits = next_token_logits.to(torch.float32) - - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - curr_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - data["input_ids"], - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - ) - # slice off to the correct length to remove potential CP padding - curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - curr_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index - ) - else: - next_token_logits_wo_last = next_token_logits[ - :, :-1 - ] # Remove last position's logits - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits_wo_last, dim=-1 - ) - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - curr_logprobs = next_token_logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) - # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: if self.use_on_policy_kl_approximation: @@ -606,13 +567,10 @@ class NLLLoss(LossFunction): def __call__( self, - next_token_logits: Tensor, + token_logprobs: Tensor, data: BatchedDataDict[Any], global_valid_seqs: Tensor | None, global_valid_toks: Tensor, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, dpo_loss: bool = False, dpo_average_log_probs: bool = False, ) -> tuple[torch.Tensor, dict[str, Any]]: @@ -621,39 +579,6 @@ def __call__( token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) - seq_index = data.get("seq_index", None) - - next_token_logits = next_token_logits.to(torch.float32) - - # Gather the logprobs for the actual next tokens - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - token_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - data["input_ids"], - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - ) - # slice off to the correct length to remove potential CP padding - token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index - ) - else: - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) if dpo_loss: ## shape: [batch_size] @@ -867,50 +792,15 @@ def __init__(self, cfg: DPOLossConfig): def _dpo_loss( self, - next_token_logits: Tensor, + token_logprobs: Tensor, data: BatchedDataDict[DPOLossDataDict], global_valid_seqs: Tensor, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] - seq_index = data.get("seq_index", None) - - next_token_logits = next_token_logits.to(torch.float32) - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - token_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - data["input_ids"], - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - ) - # slice off to the correct length to remove potential CP padding - token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index - ) - else: - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) ref_logprobs = data["reference_policy_logprobs"][:, :-1] - diff = (token_logprobs - ref_logprobs) * token_mask rewards = diff.sum(-1) @@ -924,13 +814,10 @@ def _dpo_loss( # TODO a cleaner typing fix would be required (probably that DPOLossFn should not inherit from PreferenceLoss) def __call__( # type: ignore self, - next_token_logits: Tensor, + token_logprobs: Tensor, data: BatchedDataDict[DPOLossDataDict], global_valid_seqs: Tensor, global_valid_toks: Tensor | None, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: sft_loss_chosen = torch.tensor(0.0) if self.sft_loss_weight > 0: @@ -938,13 +825,10 @@ def __call__( # type: ignore "global_valid_toks must be provided for SFT loss" ) sft_loss, _ = self.sft_loss( - next_token_logits, + token_logprobs, data, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, ## unused because sft loss returned is at the sample level - vocab_parallel_rank=vocab_parallel_rank, - vocab_parallel_group=vocab_parallel_group, - context_parallel_group=context_parallel_group, dpo_loss=True, dpo_average_log_probs=self.sft_average_log_probs, ) @@ -960,14 +844,7 @@ def __call__( # type: ignore accuracy, rewards_chosen_mean, rewards_rejected_mean, - ) = self._dpo_loss( - next_token_logits, - data, - global_valid_seqs, - vocab_parallel_rank=vocab_parallel_rank, - vocab_parallel_group=vocab_parallel_group, - context_parallel_group=context_parallel_group, - ) + ) = self._dpo_loss(token_logprobs, data, global_valid_seqs) dpo_loss = ( self.sft_loss_weight * sft_loss_chosen diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index fb17ee1661..b012777279 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -825,6 +825,50 @@ def get_logprobs_from_vocab_parallel_logits( ) +def get_logprobs_from_logits( + input_ids: torch.Tensor, + next_token_logits: torch.Tensor, + seq_index: Optional[torch.Tensor] = None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, +): + """Computes log probabilities from logits.""" + next_token_logits = next_token_logits.to(torch.float32) + + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + logprobs = from_parallel_logits_to_logprobs( + next_token_logits, + input_ids, + vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], + tp_group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + ) + # slice off to the correct length to remove potential CP padding + logprobs = logprobs[:, : input_ids.shape[1] - 1] + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): + logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, input_ids, seq_index=seq_index + ) + else: + # Remove last position's logits + next_token_logits_wo_last = next_token_logits[:, :-1] + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits_wo_last, dim=-1 + ) + next_tokens = input_ids[:, 1:].cuda() # Skip first token + logprobs = next_token_logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + + return logprobs + + @torch.no_grad() def distributed_vocab_topk( vocab_parallel_logits: torch.Tensor, diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index acbfec711e..6c7892e0ef 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -37,6 +37,7 @@ from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, distributed_vocab_topk, + get_logprobs_from_logits, get_logprobs_from_vocab_parallel_logits, ) from nemo_rl.models.automodel.data import ProcessedInputs, ProcessedMicrobatch @@ -513,6 +514,14 @@ def __call__( logits, self.device_mesh, self.cp_mesh, sequence_dim ) + # Compute logprobs from logits + logprobs = get_logprobs_from_logits( + input_ids=mb["input_ids"], + next_token_logits=logits, + seq_index=mb.get("seq_index", None), + ) + del logits + # Wrap loss function for sequence packing if needed if self.enable_seq_packing: loss_fn_ = SequencePackingLossWrapper( @@ -524,7 +533,7 @@ def __call__( loss_fn_ = self.loss_fn loss, loss_metrics = loss_fn_( - logits, + logprobs, mb, global_valid_seqs, global_valid_toks, From 30b91f6d5c2c5ca3a0270a8cfdef8f20d94eab55 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 9 Feb 2026 06:14:39 -0800 Subject: [PATCH 02/15] fix seq packing Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss_functions.py | 25 ++++++++------ nemo_rl/models/automodel/train.py | 51 ++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 25 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 7b1d4d16d0..257b3818a4 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any, NotRequired, Optional, TypedDict, TypeVar +from typing import Any, Callable, NotRequired, Optional, TypedDict, TypeVar import torch import torch.distributed @@ -869,12 +869,20 @@ class SequencePackingLossWrapper: def __init__( self, loss_fn: LossFunction, + prepare_fn: Callable[Any, Any], cu_seqlens_q: Tensor, cu_seqlens_q_padded: Optional[Tensor] = None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ): self.loss_fn = loss_fn + self.prepare_fn = prepare_fn self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_q_padded = cu_seqlens_q_padded + self.vocab_parallel_rank = vocab_parallel_rank + self.vocab_parallel_group = vocab_parallel_group + self.context_parallel_group = context_parallel_group def __call__( self, @@ -882,9 +890,6 @@ def __call__( data: BatchedDataDict[Any], global_valid_seqs: Tensor | None, global_valid_toks: Tensor | None, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[Tensor, dict[str, Any]]: """Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.""" unpadded_cu_seqlens = self.cu_seqlens_q @@ -918,8 +923,8 @@ def __call__( # get next_token_logits cp_size = ( 1 - if context_parallel_group is None - else torch.distributed.get_world_size(context_parallel_group) + if self.context_parallel_group is None + else torch.distributed.get_world_size(self.context_parallel_group) ) logit_start = seq_start // cp_size logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size @@ -928,14 +933,14 @@ def __call__( 1, logit_start, logit_length ) + # prepare data for loss function + loss_fn_args = self.prepare_fn(next_token_logits_slice, unpadded_seq_data) + loss, metrics = self.loss_fn( - next_token_logits_slice, + *loss_fn_args, unpadded_seq_data, global_valid_seqs, global_valid_toks, - vocab_parallel_rank=vocab_parallel_rank, - vocab_parallel_group=vocab_parallel_group, - context_parallel_group=context_parallel_group, ) loss_accum += loss for k, v in metrics.items(): diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 6c7892e0ef..cb8cc1c939 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -505,6 +505,12 @@ def __call__( Returns: Tuple of (loss, metrics) """ + from nemo_rl.algorithms.loss_functions import ( + ClippedPGLossFn, + DPOLossFn, + NLLLoss, + ) + # Handle CP redistribution if self.cp_size > 1: _, mb = prepare_data_for_cp( @@ -514,30 +520,45 @@ def __call__( logits, self.device_mesh, self.cp_mesh, sequence_dim ) - # Compute logprobs from logits - logprobs = get_logprobs_from_logits( - input_ids=mb["input_ids"], - next_token_logits=logits, - seq_index=mb.get("seq_index", None), - ) - del logits + # Prepare data for loss function + def prepare_for_loss_fn( + logits: torch.Tensor, mb: BatchedDataDict[Any] + ) -> tuple[Any]: + if isinstance(self.loss_fn, (ClippedPGLossFn, NLLLoss, DPOLossFn)): + logprobs = get_logprobs_from_logits( + input_ids=mb["input_ids"], + next_token_logits=logits, + seq_index=mb.get("seq_index", None), + ) + + loss_fn_args = (logprobs,) + + # TODO: PreferenceLoss, DistillationLossFn + + return loss_fn_args # Wrap loss function for sequence packing if needed if self.enable_seq_packing: loss_fn_ = SequencePackingLossWrapper( loss_fn=self.loss_fn, + prepare_fn=prepare_for_loss_fn, cu_seqlens_q=processed_inputs.flash_attn_kwargs.cu_seqlens_q, cu_seqlens_q_padded=processed_inputs.flash_attn_kwargs.cu_seqlens_q, ) + loss, loss_metrics = loss_fn_( + logits, + mb, + global_valid_seqs, + global_valid_toks, + ) else: - loss_fn_ = self.loss_fn - - loss, loss_metrics = loss_fn_( - logprobs, - mb, - global_valid_seqs, - global_valid_toks, - ) + loss_fn_args = prepare_for_loss_fn(logits, mb) + loss, loss_metrics = self.loss_fn( + *loss_fn_args, + mb, + global_valid_seqs, + global_valid_toks, + ) return loss, loss_metrics From c7de9e185d43febd285cb91fc76cc1cef026c50e Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 9 Feb 2026 22:04:30 -0800 Subject: [PATCH 03/15] update unit test for sft/rl/dpo and add value check for distillation Signed-off-by: Yuki Huang --- tests/unit/algorithms/test_loss_functions.py | 245 +++++++------------ 1 file changed, 92 insertions(+), 153 deletions(-) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index fbec4c8504..7e452804f5 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -26,6 +26,7 @@ ) from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import get_logprobs_from_logits basic_pg_loss_test_config: ClippedPGLossConfig = { "ratio_clip_min": 0.2, @@ -91,8 +92,9 @@ def test_nll_loss(): .unsqueeze(0) .to("cuda") ) + token_logprobs = get_logprobs_from_logits(data["input_ids"], next_token_logits) loss, metrics_dict = loss_fn( - next_token_logits, + token_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -116,8 +118,9 @@ def test_nll_loss(): .unsqueeze(0) .to("cuda") ) + token_logprobs = get_logprobs_from_logits(data["input_ids"], next_token_logits) loss, metrics_dict = loss_fn( - next_token_logits, + token_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -151,8 +154,9 @@ def test_dpo_loss(): } ) + token_logprobs = get_logprobs_from_logits(data["input_ids"], next_token_logits) loss, metrics_dict = loss_fn( - next_token_logits, + token_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -185,7 +189,7 @@ def test_dpo_loss(): expected_preference_loss = -torch.nn.functional.logsigmoid(torch.tensor(0.0)) assert torch.isclose( loss_fn_with_sft( - next_token_logits, + token_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -260,16 +264,17 @@ def test_dpo_loss_varying_sequence_lengths(): "sample_mask": sample_mask, } ) + token_logprobs = get_logprobs_from_logits(data["input_ids"], next_token_logits) # Compute loss loss, metrics = dpo_loss_fn_no_avg( - next_token_logits, + token_logprobs, data, global_valid_seqs=torch.sum(sample_mask), global_valid_toks=torch.sum(sample_mask.unsqueeze(-1) * token_mask), ) loss_avg, metrics_avg = dpo_loss_fn_avg( - next_token_logits, + token_logprobs, data, global_valid_seqs=torch.sum(sample_mask), global_valid_toks=torch.sum(sample_mask.unsqueeze(-1) * token_mask), @@ -322,8 +327,11 @@ def test_dpo_sft_matches_nll_loss(): # Compute NLL loss nll_loss_fn = NLLLoss() + token_logprobs = get_logprobs_from_logits( + sft_data["input_ids"], next_token_logits[::2] + ) nll_loss, nll_metrics = nll_loss_fn( - next_token_logits[::2], + token_logprobs, sft_data, global_valid_seqs=None, global_valid_toks=torch.sum( @@ -341,8 +349,9 @@ def test_dpo_sft_matches_nll_loss(): "sft_average_log_probs": False, } ) + token_logprobs = get_logprobs_from_logits(dpo_data["input_ids"], next_token_logits) dpo_loss, dpo_metrics = dpo_loss_fn( - next_token_logits, + token_logprobs, dpo_data, global_valid_seqs=torch.sum(dpo_data["sample_mask"]), global_valid_toks=torch.sum( @@ -504,9 +513,10 @@ def test_clipped_pg_loss_ppo_clipping(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), @@ -551,9 +561,10 @@ def test_clipped_pg_loss_reinforce_mode(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -596,9 +607,10 @@ def test_clipped_pg_loss_force_on_policy_ratio(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, metrics = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -706,9 +718,10 @@ def test_clipped_pg_loss_kl_penalty(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -734,6 +747,8 @@ def test_clipped_pg_loss_masking(): ) # Need some realistic-ish logits and logprobs for masking test dummy_logits = torch.randn(batch_size, seq_len, vocab_size, device=device) + current_logprobs = get_logprobs_from_logits(data["input_ids"], dummy_logits) + # Ensure logprobs used by the loss fn make sense relative to advantages data["prev_logprobs"] = torch.randn_like(data["prev_logprobs"]) * 0.1 data["reference_policy_logprobs"] = ( @@ -749,7 +764,7 @@ def test_clipped_pg_loss_masking(): # --- Test 1: Token Mask --- # Default mask: [[0, 1, 1, 1], [0, 1, 1, 1]] -> 3 tokens per sample loss_default, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -765,7 +780,7 @@ def test_clipped_pg_loss_masking(): ) loss_token_masked, _ = loss_fn( - dummy_logits, + current_logprobs, data_mod_token, global_valid_seqs=torch.sum(data_mod_token["sample_mask"]), global_valid_toks=torch.sum( @@ -784,7 +799,7 @@ def test_clipped_pg_loss_masking(): ) # Ignore item 1 loss_sample_masked, _ = loss_fn( - dummy_logits, + current_logprobs, data_mod_sample, global_valid_seqs=torch.sum(data_mod_sample["sample_mask"]), global_valid_toks=torch.sum( @@ -805,8 +820,11 @@ def test_clipped_pg_loss_masking(): data_only_b0 = BatchedDataDict(data_only_b0_dict) logits_only_b0 = dummy_logits[0:1] + current_logprobs_only_b0 = get_logprobs_from_logits( + data_only_b0["input_ids"], logits_only_b0 + ) loss_only_b0, _ = loss_fn( - logits_only_b0, + current_logprobs_only_b0, data_only_b0, global_valid_seqs=torch.sum(data_only_b0["sample_mask"]), global_valid_toks=torch.sum( @@ -826,6 +844,7 @@ def test_clipped_pg_loss_zero_mask(): data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) # Need dummy logits dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) + current_logprobs = get_logprobs_from_logits(data["input_ids"], dummy_logits) cfg = deepcopy(basic_pg_loss_test_config) cfg["reference_policy_kl_penalty"] = 0.1 @@ -835,7 +854,7 @@ def test_clipped_pg_loss_zero_mask(): data["token_mask"] = torch.zeros_like(data["token_mask"]) loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -980,9 +999,10 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), @@ -1112,9 +1132,10 @@ def test_clipped_pg_loss_on_policy_truncated_importance_sampling( dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), @@ -1333,9 +1354,10 @@ def test_clipped_pg_loss_dual_clip(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1382,8 +1404,10 @@ def test_clipped_pg_loss_entropy(): dummy_logits = _create_exact_logits( curr_lp_masked, data["input_ids"], batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(data["input_ids"], dummy_logits) + _, metrics = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), @@ -1465,9 +1489,10 @@ def test_clipped_pg_loss_gspo(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), @@ -1563,9 +1588,10 @@ def test_clipped_pg_loss_gspo_batch_size_2(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1664,9 +1690,10 @@ def test_clipped_pg_loss_gspo_importance_sampling_correction(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), @@ -1681,6 +1708,10 @@ def setup_distillation_test_data(batch_size=2, seq_len=4, vocab_size=8, topk=64) device = "cuda" + # Set seed for reproducibility + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + # Create input data input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) input_lengths = torch.tensor([seq_len] * batch_size, device=device) @@ -1708,15 +1739,17 @@ def setup_distillation_test_data(batch_size=2, seq_len=4, vocab_size=8, topk=64) return data, student_logits -def test_distillation_loss_forward_kl(): - """Test forward KL divergence loss calculation.""" +@pytest.mark.parametrize("kl_type", ["forward", "reverse", "mixed"]) +@pytest.mark.parametrize("zero_outside_topk", [True, False]) +def test_distillation_loss_different_settings(kl_type, zero_outside_topk): + """Test different distillation loss settings.""" data, student_logits = setup_distillation_test_data() loss_fn = DistillationLossFn( { - "kl_type": "forward", - "mixed_kl_weight": 0.5, - "zero_outside_topk": False, + "kl_type": kl_type, + "mixed_kl_weight": 0.3, + "zero_outside_topk": zero_outside_topk, } ) @@ -1729,57 +1762,38 @@ def test_distillation_loss_forward_kl(): ), ) - # Verify loss is a scalar tensor - assert loss.dim() == 0 - assert not torch.isnan(loss) - assert not torch.isinf(loss) + # Verify loss + if zero_outside_topk: + if kl_type == "forward": + assert torch.allclose(loss, torch.tensor(-0.9636520743370056)) + elif kl_type == "reverse": + assert torch.allclose(loss, torch.tensor(-490.5150451660156)) + elif kl_type == "mixed": + assert torch.allclose(loss, torch.tensor(-343.6496276855469)) + else: + if kl_type == "forward": + assert torch.allclose(loss, torch.tensor(0.5783048868179321)) + elif kl_type == "reverse": + assert torch.allclose(loss, torch.tensor(0.5811167359352112)) + elif kl_type == "mixed": + assert torch.allclose(loss, torch.tensor(0.5802732110023499)) # Verify metrics dictionary assert isinstance(metrics, dict) assert "loss" in metrics -def test_distillation_loss_reverse_kl(): - """Test reverse KL divergence loss calculation.""" - data, student_logits = setup_distillation_test_data() +@pytest.mark.parametrize("k", [1, 32, 64, 1000000]) +@pytest.mark.parametrize("zero_outside_topk", [True, False]) +def test_distillation_loss_topk_filtering(k, zero_outside_topk): + """Test top-k filtering functionality with various k values.""" + data, student_logits = setup_distillation_test_data(topk=k) loss_fn = DistillationLossFn( { - "kl_type": "reverse", + "kl_type": "forward", "mixed_kl_weight": 0.5, - "zero_outside_topk": False, - } - ) - - loss, metrics = loss_fn( - student_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), - ) - - # Verify loss is a scalar tensor - assert loss.dim() == 0 - assert not torch.isnan(loss) - assert not torch.isinf(loss) - - # Verify metrics dictionary - assert isinstance(metrics, dict) - assert "loss" in metrics - - -def test_distillation_loss_mixed_kl(): - """Test mixed KL divergence loss calculation.""" - data, student_logits = setup_distillation_test_data() - - mixed_kl_weight = 0.3 - loss_fn = DistillationLossFn( - { - "kl_type": "mixed", - "mixed_kl_weight": mixed_kl_weight, - "zero_outside_topk": False, + "zero_outside_topk": zero_outside_topk, } ) @@ -1792,54 +1806,19 @@ def test_distillation_loss_mixed_kl(): ), ) - # Verify loss is a scalar tensor + # Verify loss is calculated correctly with top-k filtering assert loss.dim() == 0 assert not torch.isnan(loss) assert not torch.isinf(loss) - # Verify metrics dictionary - assert isinstance(metrics, dict) - assert "loss" in metrics + # For k=1, we expect only the top-1 token to be considered + if k == 1: + assert isinstance(loss, torch.Tensor) - -def test_distillation_loss_topk_filtering(): - """Test top-k filtering functionality with various k values.""" - # Test with different k values (excluding k=0 which should be invalid) - k_values = [1, 32, 64, 1000000] # Valid k values - - for k in k_values: - data, student_logits = setup_distillation_test_data(topk=k) - - loss_fn = DistillationLossFn( - { - "kl_type": "forward", - "mixed_kl_weight": 0.5, - "zero_outside_topk": False, - } - ) - - loss, metrics = loss_fn( - student_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), - ) - - # Verify loss is calculated correctly with top-k filtering - assert loss.dim() == 0 - assert not torch.isnan(loss) - assert not torch.isinf(loss) - - # For k=1, we expect only the top-1 token to be considered - if k == 1: - assert isinstance(loss, torch.Tensor) - - # For large k values, we expect normal behavior - if k >= 32: - assert isinstance(loss, torch.Tensor) - assert loss.item() != 0.0 # Should have some meaningful loss + # For large k values, we expect normal behavior + if k >= 32: + assert isinstance(loss, torch.Tensor) + assert loss.item() != 0.0 # Should have some meaningful loss def test_distillation_loss_invalid_k_zero(): @@ -1867,46 +1846,6 @@ def test_distillation_loss_invalid_k_zero(): ) -def test_distillation_loss_zero_outside_topk(): - """Test zeroing outside top-k functionality with various k values.""" - # Test with different k values for zero_outside_topk (excluding k=0 which should be invalid) - k_values = [1, 32, 64, 1000000] # Valid k values - - for k in k_values: - data, student_logits = setup_distillation_test_data(topk=k) - - loss_fn = DistillationLossFn( - { - "kl_type": "forward", - "mixed_kl_weight": 0.5, - "zero_outside_topk": True, - } - ) - - loss, metrics = loss_fn( - student_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), - ) - - # Verify loss is calculated correctly with zeroing - assert loss.dim() == 0 - assert not torch.isnan(loss) - assert not torch.isinf(loss) - - # For k=1, only top-1 token should remain non-zero - if k == 1: - assert isinstance(loss, torch.Tensor) - - # For large k values, most tokens should remain non-zero - if k >= 32: - assert isinstance(loss, torch.Tensor) - assert loss.item() != 0.0 # Should have some meaningful loss - - def test_distillation_loss_gradient_flow(): """Test gradient flow in distillation loss function.""" data, student_logits = setup_distillation_test_data() From cfe8e50395d300338cb33c1ae24e2aeaeebd52fe Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 9 Feb 2026 23:48:39 -0800 Subject: [PATCH 04/15] update PreferenceLoss and DistillationLossFn Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss_functions.py | 166 +----------------- nemo_rl/distributed/model_utils.py | 167 +++++++++++++++++++ nemo_rl/models/automodel/train.py | 25 ++- tests/unit/algorithms/test_loss_functions.py | 96 +++++++++-- 4 files changed, 275 insertions(+), 179 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 257b3818a4..fcc0b267fe 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -20,13 +20,6 @@ from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import ( - ChunkedDistributedEntropy, - ChunkedDistributedGatherLogprob, - _get_tokens_on_this_cp_rank, - allgather_cp_sharded_tensor, - gather_logits_at_global_indices, -) Tensor = TypeVar("Tensor", bound=torch.Tensor) @@ -999,165 +992,14 @@ def __init__(self, cfg: DistillationLossConfig): def __call__( self, - next_token_logits: torch.Tensor, + student_topk_logprobs: torch.Tensor, + teacher_topk_logprobs: torch.Tensor, + H_all: torch.Tensor | None, data: DistillationLossDataDict, global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: """Compute distillation loss between teacher and student logits.""" - # Basic shapes - input_ids = data["input_ids"] - batch_size = input_ids.shape[0] - - # CP support: get CP group and size - cp_group = context_parallel_group - cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) - - # Ensure float32 for stability (match other losses) - next_token_logits = next_token_logits.to(torch.float32) - per_token_kl = None - # Preferred truncated-KL path: teacher provides top-k support per position - teacher_topk_logits = data["teacher_topk_logits"] # [B, S, k] - teacher_topk_indices = data["teacher_topk_indices"] # [B, S, k] - - if teacher_topk_indices.shape[-1] <= 0: - raise ValueError( - f"topk must be positive, got {teacher_topk_indices.shape[-1]}. " - "topk=0 is not supported as it would result in empty tensor operations." - ) - - # Determine processing path and setup variables - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - V_local = int(next_token_logits.shape[-1]) - vocab_start_index = vocab_parallel_rank * V_local - vocab_end_index = (vocab_parallel_rank + 1) * V_local - parallel_group = vocab_parallel_group - logits_tensor = next_token_logits - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - device_mesh = next_token_logits.device_mesh - tp_group = device_mesh.get_group("tp") - tp_rank = tp_group.rank() - local_student_logits = next_token_logits.to_local() - V_local = int(local_student_logits.shape[-1]) - vocab_start_index = tp_rank * V_local - vocab_end_index = (tp_rank + 1) * V_local - parallel_group = tp_group - logits_tensor = local_student_logits - teacher_topk_indices = teacher_topk_indices.to(local_student_logits.device) - # For DTensor, derive CP group/size from the device mesh to ensure CP-aware alignment - if ( - device_mesh.mesh_dim_names is not None - and "cp" in device_mesh.mesh_dim_names - ): - cp_group = device_mesh.get_group("cp") - cp_size = cp_group.size() - else: - cp_group = None - cp_size = 1 - else: - parallel_group = None - logits_tensor = next_token_logits - - # Process based on zero_outside_topk setting - if self.zero_outside_topk and parallel_group is not None: - # Distributed processing with chunking - indices_local = teacher_topk_indices - pad_len = 0 - if cp_size > 1: - pad_len = logits_tensor.shape[1] * cp_size - indices_local.shape[1] - if pad_len > 0: - indices_local = torch.nn.functional.pad( - indices_local, (0, 0, 0, pad_len), value=0 - ) - cp_rank = torch.distributed.get_rank(cp_group) - indices_local = _get_tokens_on_this_cp_rank( - indices_local, cp_rank, cp_size, seq_dim=1 - ) - - S_local = int(logits_tensor.shape[1]) - chunk_size = max(1, min(S_local, 1024)) - student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore - logits_tensor, - indices_local, - vocab_start_index, - vocab_end_index, - chunk_size, - parallel_group, - False, - ) - - if self.kl_type != "forward": - H_all = ChunkedDistributedEntropy.apply( # type: ignore - logits_tensor, - chunk_size, - parallel_group, - False, - ) - - if cp_size > 1: - student_topk_logprobs = allgather_cp_sharded_tensor( - student_topk_logprobs, cp_group, seq_dim=1 - ) - if self.kl_type != "forward": - H_all = allgather_cp_sharded_tensor(H_all, cp_group, seq_dim=1) - if pad_len > 0: - student_topk_logprobs = student_topk_logprobs[:, :-pad_len, :] - if self.kl_type != "forward": - H_all = H_all[:, :-pad_len] - elif self.zero_outside_topk: - # Non-distributed processing - student_logprobs = torch.nn.functional.log_softmax(logits_tensor, dim=-1) - student_topk_logprobs = student_logprobs.gather( - dim=-1, index=teacher_topk_indices.to(student_logprobs.device) - ) - if self.kl_type != "forward": - H_all = (student_logprobs.exp() * student_logprobs).sum(-1) - else: - # Gather logits at global indices - if (parallel_group is not None) or (cp_size > 1): - student_topk_logits = gather_logits_at_global_indices( - logits_tensor, - teacher_topk_indices, - tp_group=parallel_group, - cp_group=cp_group, - vocab_start_index=( - vocab_start_index if parallel_group is not None else 0 - ), - vocab_end_index=( - vocab_end_index - if parallel_group is not None - else int(logits_tensor.shape[-1]) - ), - ) - else: - student_topk_logits = logits_tensor.gather( - dim=-1, index=teacher_topk_indices.to(logits_tensor.device) - ) - student_topk_logprobs = torch.nn.functional.log_softmax( - student_topk_logits, dim=-1 - ) - - # Move teacher tensors to the same device/dtype as student_topk_logits - teacher_topk_logits = teacher_topk_logits.to( - student_topk_logprobs.device, dtype=student_topk_logprobs.dtype - ) - teacher_topk_logprobs = torch.nn.functional.log_softmax( - teacher_topk_logits, dim=-1 - ) - - # Single point of next-token alignment after TP/CP processing - teacher_topk_logprobs = teacher_topk_logprobs[:, :-1, :] - student_topk_logprobs = student_topk_logprobs[:, :-1, :] - if self.zero_outside_topk and self.kl_type != "forward": - # Align H_all with next-token prediction - H_all = H_all[:, :-1] - student_probs = student_topk_logprobs.exp() # [B, S-1, k] teacher_probs = teacher_topk_logprobs.exp() # [B, S-1, k] @@ -1210,7 +1052,7 @@ def __call__( metrics = { "loss": float(kl_loss.item()) if kl_loss.ndim == 0 else kl_loss, - "num_valid_samples": int(batch_size), + "num_valid_samples": data["input_ids"].shape[0], } return kl_loss, metrics diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index b012777279..50ffb1a28d 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -1026,6 +1026,173 @@ def gather_logits_at_global_indices( return gathered_logits +def get_distilllation_topk_logprobs_from_logits( + student_logits: torch.Tensor, + teacher_topk_logits: torch.Tensor, + teacher_topk_indices: torch.Tensor, + zero_outside_topk: bool, + calculate_entropy: bool, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, +): + """Compute top-k log probabilities from logits.""" + if teacher_topk_indices.shape[-1] <= 0: + raise ValueError( + f"topk must be positive, got {teacher_topk_indices.shape[-1]}. " + "topk=0 is not supported as it would result in empty tensor operations." + ) + + # Ensure float32 for stability + student_logits = student_logits.to(torch.float32) + # Move teacher topk indices to the same device as student logits + teacher_topk_indices = teacher_topk_indices.to(student_logits.device) + + # CP support: get CP group and size + cp_group = context_parallel_group + cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) + + # Process based on the student logits type + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + student_logits = student_logits + parallel_group = vocab_parallel_group + + V_local = int(student_logits.shape[-1]) + vocab_start_index = vocab_parallel_rank * V_local + vocab_end_index = (vocab_parallel_rank + 1) * V_local + + elif isinstance(student_logits, torch.distributed.tensor.DTensor): + device_mesh = student_logits.device_mesh + tp_group = device_mesh.get_group("tp") + + student_logits = student_logits.to_local() + parallel_group = tp_group + + tp_rank = tp_group.rank() + V_local = int(student_logits.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + # For DTensor, derive CP group/size from the device mesh to ensure CP-aware alignment + if ( + device_mesh.mesh_dim_names is not None + and "cp" in device_mesh.mesh_dim_names + ): + cp_group = device_mesh.get_group("cp") + cp_size = cp_group.size() + else: + cp_group = None + cp_size = 1 + + else: + student_logits = student_logits + parallel_group = None + + # Process based on the zero_outside_topk setting + H_all = None + if zero_outside_topk: + # Distributed processing + if parallel_group is not None: + indices_local = teacher_topk_indices + pad_len = 0 + + if cp_size > 1: + pad_len = student_logits.shape[1] * cp_size - indices_local.shape[1] + if pad_len > 0: + indices_local = torch.nn.functional.pad( + indices_local, (0, 0, 0, pad_len), value=0 + ) + cp_rank = torch.distributed.get_rank(cp_group) + indices_local = _get_tokens_on_this_cp_rank( + indices_local, cp_rank, cp_size, seq_dim=1 + ) + + seq_len_local = int(student_logits.shape[1]) + chunk_size = max(1, min(seq_len_local, 1024)) + student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore + student_logits, + indices_local, + vocab_start_index, + vocab_end_index, + chunk_size, + parallel_group, + False, + ) + + if calculate_entropy: + H_all = ChunkedDistributedEntropy.apply( # type: ignore + student_logits, + chunk_size, + parallel_group, + False, + ) + + if cp_size > 1: + student_topk_logprobs = allgather_cp_sharded_tensor( + student_topk_logprobs, cp_group, seq_dim=1 + ) + if calculate_entropy: + H_all = allgather_cp_sharded_tensor(H_all, cp_group, seq_dim=1) + if pad_len > 0: + student_topk_logprobs = student_topk_logprobs[:, :-pad_len, :] + if calculate_entropy: + H_all = H_all[:, :-pad_len] + + # Non-distributed processing + else: + student_logprobs = torch.nn.functional.log_softmax(student_logits, dim=-1) + student_topk_logprobs = student_logprobs.gather( + dim=-1, index=teacher_topk_indices + ) + + if calculate_entropy: + H_all = (student_logprobs.exp() * student_logprobs).sum(-1) + + else: + # Distributed processing + if parallel_group is not None or cp_size > 1: + if parallel_group is None: + vocab_start_index = 0 + vocab_end_index = int(student_logits.shape[-1]) + + student_topk_logits = gather_logits_at_global_indices( + student_logits, + teacher_topk_indices, + tp_group=parallel_group, + cp_group=cp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + + # Non-distributed processing + else: + student_topk_logits = student_logits.gather( + dim=-1, index=teacher_topk_indices + ) + + student_topk_logprobs = torch.nn.functional.log_softmax( + student_topk_logits, dim=-1 + ) + + # Move teacher tensors to the same device/dtype as student_topk_logits + teacher_topk_logits = teacher_topk_logits.to( + student_topk_logprobs.device, dtype=student_topk_logprobs.dtype + ) + teacher_topk_logprobs = torch.nn.functional.log_softmax(teacher_topk_logits, dim=-1) + + # Single point of next-token alignment after TP/CP processing + teacher_topk_logprobs = teacher_topk_logprobs[:, :-1, :] + student_topk_logprobs = student_topk_logprobs[:, :-1, :] + + if calculate_entropy: + H_all = H_all[:, :-1] + + return student_topk_logprobs, teacher_topk_logprobs, H_all + + class ChunkedDistributedEntropy(torch.autograd.Function): """Compute H_all = sum_v p_v log p_v across TP with chunking over sequence. diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index cb8cc1c939..d7a311bbf1 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -37,6 +37,7 @@ from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, distributed_vocab_topk, + get_distilllation_topk_logprobs_from_logits, get_logprobs_from_logits, get_logprobs_from_vocab_parallel_logits, ) @@ -507,8 +508,10 @@ def __call__( """ from nemo_rl.algorithms.loss_functions import ( ClippedPGLossFn, + DistillationLossFn, DPOLossFn, NLLLoss, + PreferenceLoss, ) # Handle CP redistribution @@ -533,7 +536,27 @@ def prepare_for_loss_fn( loss_fn_args = (logprobs,) - # TODO: PreferenceLoss, DistillationLossFn + elif isinstance(self.loss_fn, PreferenceLoss): + loss_fn_args = (logits,) + + elif isinstance(self.loss_fn, DistillationLossFn): + calculate_entropy = ( + self.loss_fn.zero_outside_topk and self.loss_fn.kl_type != "forward" + ) + student_topk_logprobs, teacher_topk_logprobs, H_all = ( + get_distilllation_topk_logprobs_from_logits( + student_logits=logits, + teacher_topk_logits=mb["teacher_topk_logits"], + teacher_topk_indices=mb["teacher_topk_indices"], + zero_outside_topk=self.loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + ) + + loss_fn_args = (student_topk_logprobs, teacher_topk_logprobs, H_all) + + else: + raise ValueError(f"Unknown loss function type: {type(self.loss_fn)}") return loss_fn_args diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 7e452804f5..52127a459c 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -26,7 +26,10 @@ ) from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import get_logprobs_from_logits +from nemo_rl.distributed.model_utils import ( + get_distilllation_topk_logprobs_from_logits, + get_logprobs_from_logits, +) basic_pg_loss_test_config: ClippedPGLossConfig = { "ratio_clip_min": 0.2, @@ -1753,8 +1756,17 @@ def test_distillation_loss_different_settings(kl_type, zero_outside_topk): } ) + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=student_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + loss, metrics = loss_fn( - student_logits, + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1797,8 +1809,17 @@ def test_distillation_loss_topk_filtering(k, zero_outside_topk): } ) - loss, metrics = loss_fn( - student_logits, + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=student_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + + loss, _ = loss_fn( + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1836,13 +1857,13 @@ def test_distillation_loss_invalid_k_zero(): # This should raise a ValueError for k=0 with pytest.raises(ValueError, match="topk must be positive"): - loss_fn( - student_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + _ = get_distilllation_topk_logprobs_from_logits( + student_logits=student_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, ) @@ -1861,8 +1882,17 @@ def test_distillation_loss_gradient_flow(): } ) + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=student_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + loss, _ = loss_fn( - student_logits, + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1894,8 +1924,17 @@ def test_distillation_loss_edge_cases(): # Test with all-zero logits zero_logits = torch.zeros_like(student_logits) + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=zero_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + loss, _ = loss_fn( - zero_logits, + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1907,8 +1946,16 @@ def test_distillation_loss_edge_cases(): # Test with very large logits large_logits = torch.ones_like(student_logits) * 100.0 + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=large_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + loss, _ = loss_fn( - large_logits, + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1920,8 +1967,16 @@ def test_distillation_loss_edge_cases(): # Test with very small logits small_logits = torch.ones_like(student_logits) * -100.0 + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=small_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + loss, _ = loss_fn( - small_logits, + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1969,8 +2024,17 @@ def test_distillation_loss_fn_call(): } ) + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=student_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + loss, metrics = loss_fn( - student_logits, + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( From 621bc097e05b5d8a1c14733cd8b41834b8e5d7fc Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 25 Feb 2026 22:39:38 -0800 Subject: [PATCH 05/15] add LossInputType Signed-off-by: Yuki Huang --- nemo_rl/algorithms/interfaces.py | 7 ++++ nemo_rl/algorithms/loss_functions.py | 18 ++++++--- nemo_rl/models/automodel/train.py | 24 ++++-------- .../models/generation/test_vllm_generation.py | 2 +- tests/unit/test_utils.py | 39 ++++++------------- 5 files changed, 40 insertions(+), 50 deletions(-) diff --git a/nemo_rl/algorithms/interfaces.py b/nemo_rl/algorithms/interfaces.py index d7b6bfe67b..58b4c3f48f 100644 --- a/nemo_rl/algorithms/interfaces.py +++ b/nemo_rl/algorithms/interfaces.py @@ -25,6 +25,12 @@ class LossType(enum.Enum): SEQUENCE_LEVEL = "sequence_level" +class LossInputType(enum.Enum): + LOGIT = "logit" + LOGPROB = "logprob" + DISTILLATION = "distillation" + + class LossFunction(Protocol): """Signature for loss functions used in reinforcement learning algorithms. @@ -33,6 +39,7 @@ class LossFunction(Protocol): """ loss_type: LossType + input_type: LossInputType def __call__( self, diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index fcc0b267fe..25524d6560 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -17,7 +17,7 @@ import torch import torch.distributed -from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.algorithms.interfaces import LossFunction, LossInputType, LossType from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -113,6 +113,8 @@ class ClippedPGLossFn(LossFunction): Due to potential numerical instability, we cast the logits to float32 before computing the loss. """ + input_type = LossInputType.LOGPROB + def __init__(self, cfg: ClippedPGLossConfig): self.ratio_clip_min = cfg["ratio_clip_min"] self.ratio_clip_max = cfg["ratio_clip_max"] @@ -557,6 +559,7 @@ class NLLLoss(LossFunction): """Negative Log Likelihood Loss function.""" loss_type = LossType.TOKEN_LEVEL + input_type = LossInputType.LOGPROB def __call__( self, @@ -625,8 +628,8 @@ class PreferenceLoss(LossFunction): - accuracy: Fraction of examples where chosen response has higher reward """ - def __init__(self): - self.loss_type = LossType.SEQUENCE_LEVEL + loss_type = LossType.SEQUENCE_LEVEL + input_type = LossInputType.LOGIT def split_output_tensor(self, tensor: Tensor) -> tuple[Tensor, Tensor]: # tensor is of shape (2*micro_batch_size,) @@ -773,6 +776,9 @@ class DPOLossFn(PreferenceLoss): - accuracy: Fraction of examples where chosen response has higher reward """ + loss_type = LossType.SEQUENCE_LEVEL + input_type = LossInputType.LOGPROB + def __init__(self, cfg: DPOLossConfig): self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.preference_loss_weight = cfg["preference_loss_weight"] @@ -781,8 +787,6 @@ def __init__(self, cfg: DPOLossConfig): self.sft_average_log_probs = cfg["sft_average_log_probs"] self.sft_loss = NLLLoss() - self.loss_type = LossType.SEQUENCE_LEVEL - def _dpo_loss( self, token_logprobs: Tensor, @@ -978,12 +982,14 @@ class DistillationLossDataDict(TypedDict): class DistillationLossFn(LossFunction): """Distillation loss function.""" + loss_type = LossType.TOKEN_LEVEL + input_type = LossInputType.DISTILLATION + def __init__(self, cfg: DistillationLossConfig): self.kl_type = cfg["kl_type"] self.mixed_kl_weight = cfg["mixed_kl_weight"] self.zero_outside_topk = cfg["zero_outside_topk"] self.log_infinitesimal = -100 - self.loss_type = LossType.TOKEN_LEVEL assert self.kl_type in ["forward", "reverse", "mixed"], "Invalid KL type" assert self.mixed_kl_weight >= 0 and self.mixed_kl_weight <= 1, ( diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index d7a311bbf1..931f8052dc 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -31,7 +31,7 @@ from torch import nn from torch.distributed.tensor import DTensor, Shard -from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.algorithms.interfaces import LossFunction, LossInputType from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( @@ -475,8 +475,8 @@ def __init__( dp_size: Data parallel size enable_seq_packing: Whether sequence packing is enabled """ - self.loss_fn = loss_fn - self.cfg = cfg + self.loss_fn: LossFunction = loss_fn + self.cfg: PolicyConfig = cfg self.device_mesh = device_mesh self.cp_mesh = cp_mesh self.tp_mesh = tp_mesh @@ -506,14 +506,6 @@ def __call__( Returns: Tuple of (loss, metrics) """ - from nemo_rl.algorithms.loss_functions import ( - ClippedPGLossFn, - DistillationLossFn, - DPOLossFn, - NLLLoss, - PreferenceLoss, - ) - # Handle CP redistribution if self.cp_size > 1: _, mb = prepare_data_for_cp( @@ -527,7 +519,10 @@ def __call__( def prepare_for_loss_fn( logits: torch.Tensor, mb: BatchedDataDict[Any] ) -> tuple[Any]: - if isinstance(self.loss_fn, (ClippedPGLossFn, NLLLoss, DPOLossFn)): + if self.loss_fn.input_type == LossInputType.LOGIT: + loss_fn_args = (logits,) + + elif self.loss_fn.input_type == LossInputType.LOGPROB: logprobs = get_logprobs_from_logits( input_ids=mb["input_ids"], next_token_logits=logits, @@ -536,10 +531,7 @@ def prepare_for_loss_fn( loss_fn_args = (logprobs,) - elif isinstance(self.loss_fn, PreferenceLoss): - loss_fn_args = (logits,) - - elif isinstance(self.loss_fn, DistillationLossFn): + elif self.loss_fn.input_type == LossInputType.DISTILLATION: calculate_entropy = ( self.loss_fn.zero_outside_topk and self.loss_fn.kl_type != "forward" ) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index c27a183b5c..761e3d24a1 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -848,7 +848,7 @@ async def run_hf_train_process( { "input_ids": train_input_ids, "input_lengths": generation_results["unpadded_sequence_lengths"], - "token_loss_mask": token_loss_mask, + "token_mask": token_loss_mask, "sample_mask": torch.ones(train_input_ids.shape[0]), } ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 11515ec661..8a9adf80ec 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -11,29 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any import torch -from nemo_rl.algorithms.interfaces import LossType +from nemo_rl.algorithms.interfaces import LossInputType, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict class SimpleLoss: loss_type = LossType.SEQUENCE_LEVEL + input_type = LossInputType.LOGIT def __call__( self, - next_token_logits: torch.Tensor, + logits: torch.Tensor, data: BatchedDataDict, global_valid_seqs: torch.Tensor | None, global_valid_toks: torch.Tensor | None, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: - # Just return mean of logprobs as the loss for testing - loss = next_token_logits.mean() + # Just return mean of logits as the loss for testing + loss = logits.mean() metrics = { "loss": loss.item(), "test_metric": loss.item() * 0.5, @@ -44,33 +42,20 @@ def __call__( # Create a simple masked NLL loss function class SimpleNLLLoss: - loss_type = LossType.SEQUENCE_LEVEL + loss_type = LossType.TOKEN_LEVEL + input_type = LossInputType.LOGPROB def __call__( self, - next_token_logits: torch.Tensor, + token_logprobs: torch.Tensor, data: BatchedDataDict, global_valid_seqs: torch.Tensor | None, global_valid_toks: torch.Tensor | None, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: - # logits shape: [batch_size, seq_len, vocab_size] - # Get the next token logits for each position - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits - - # Gather the logprobs for the actual next tokens - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) - # Only compute loss on generated tokens (not input tokens) - # by applying the token_loss_mask (shifted by 1 since we're predicting next tokens) - token_loss_mask = data["token_loss_mask"][:, 1:].cuda() - loss = -torch.sum(token_logprobs * token_loss_mask) + # by applying the token_mask (shifted by 1 since we're predicting next tokens) + mask = data["token_mask"][:, 1:].cuda() + loss = -torch.sum(token_logprobs * mask) return loss, { "loss": loss.item(), From cde9c3e091fc3fd71da8f14258027653330667be Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 25 Feb 2026 23:01:56 -0800 Subject: [PATCH 06/15] args -> kwargs Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss_functions.py | 19 ++++++++++--------- nemo_rl/models/automodel/train.py | 24 ++++++++++++++---------- tests/unit/test_utils.py | 4 ++-- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 25524d6560..d34e6a6dd6 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -191,12 +191,13 @@ def __init__(self, cfg: ClippedPGLossConfig): def __call__( self, - curr_logprobs: Tensor, + next_token_logprobs: Tensor, data: BatchedDataDict[ClippedPGLossDataDict], global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, ) -> tuple[torch.Tensor, dict]: """Clipped Policy Gradient RL loss function.""" + curr_logprobs = next_token_logprobs token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] advantages = data["advantages"][:, 1:] @@ -563,7 +564,7 @@ class NLLLoss(LossFunction): def __call__( self, - token_logprobs: Tensor, + next_token_logprobs: Tensor, data: BatchedDataDict[Any], global_valid_seqs: Tensor | None, global_valid_toks: Tensor, @@ -580,14 +581,14 @@ def __call__( ## shape: [batch_size] num_unmasked_tokens = torch.sum(mask, -1) ## multiply by sample_mask to zero out invalid samples - loss = -torch.sum(token_logprobs * mask, dim=-1) + loss = -torch.sum(next_token_logprobs * mask, dim=-1) if dpo_average_log_probs: loss = loss / num_unmasked_tokens.clamp(min=1) else: ## single scalar loss ## scale by the total number of tokens in the batch loss = -masked_mean( - token_logprobs, + next_token_logprobs, mask, global_normalization_factor=global_valid_toks, ) @@ -789,7 +790,7 @@ def __init__(self, cfg: DPOLossConfig): def _dpo_loss( self, - token_logprobs: Tensor, + next_token_logprobs: Tensor, data: BatchedDataDict[DPOLossDataDict], global_valid_seqs: Tensor, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: @@ -798,7 +799,7 @@ def _dpo_loss( sample_mask = data["sample_mask"] ref_logprobs = data["reference_policy_logprobs"][:, :-1] - diff = (token_logprobs - ref_logprobs) * token_mask + diff = (next_token_logprobs - ref_logprobs) * token_mask rewards = diff.sum(-1) if self.preference_average_log_probs: @@ -811,7 +812,7 @@ def _dpo_loss( # TODO a cleaner typing fix would be required (probably that DPOLossFn should not inherit from PreferenceLoss) def __call__( # type: ignore self, - token_logprobs: Tensor, + next_token_logprobs: Tensor, data: BatchedDataDict[DPOLossDataDict], global_valid_seqs: Tensor, global_valid_toks: Tensor | None, @@ -822,7 +823,7 @@ def __call__( # type: ignore "global_valid_toks must be provided for SFT loss" ) sft_loss, _ = self.sft_loss( - token_logprobs, + next_token_logprobs, data, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, ## unused because sft loss returned is at the sample level @@ -841,7 +842,7 @@ def __call__( # type: ignore accuracy, rewards_chosen_mean, rewards_rejected_mean, - ) = self._dpo_loss(token_logprobs, data, global_valid_seqs) + ) = self._dpo_loss(next_token_logprobs, data, global_valid_seqs) dpo_loss = ( self.sft_loss_weight * sft_loss_chosen diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 931f8052dc..34cb65990d 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -518,9 +518,9 @@ def __call__( # Prepare data for loss function def prepare_for_loss_fn( logits: torch.Tensor, mb: BatchedDataDict[Any] - ) -> tuple[Any]: + ) -> dict[str, Any]: if self.loss_fn.input_type == LossInputType.LOGIT: - loss_fn_args = (logits,) + loss_input = {"logits": logits} elif self.loss_fn.input_type == LossInputType.LOGPROB: logprobs = get_logprobs_from_logits( @@ -529,7 +529,7 @@ def prepare_for_loss_fn( seq_index=mb.get("seq_index", None), ) - loss_fn_args = (logprobs,) + loss_input = {"next_token_logprobs": logprobs} elif self.loss_fn.input_type == LossInputType.DISTILLATION: calculate_entropy = ( @@ -545,12 +545,16 @@ def prepare_for_loss_fn( ) ) - loss_fn_args = (student_topk_logprobs, teacher_topk_logprobs, H_all) + loss_input = { + "student_topk_logprobs": student_topk_logprobs, + "teacher_topk_logprobs": teacher_topk_logprobs, + "H_all": H_all, + } else: raise ValueError(f"Unknown loss function type: {type(self.loss_fn)}") - return loss_fn_args + return loss_input # Wrap loss function for sequence packing if needed if self.enable_seq_packing: @@ -567,12 +571,12 @@ def prepare_for_loss_fn( global_valid_toks, ) else: - loss_fn_args = prepare_for_loss_fn(logits, mb) + loss_input = prepare_for_loss_fn(logits, mb) loss, loss_metrics = self.loss_fn( - *loss_fn_args, - mb, - global_valid_seqs, - global_valid_toks, + data=mb, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + **loss_input, ) return loss, loss_metrics diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 8a9adf80ec..8b9b7206fb 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -47,7 +47,7 @@ class SimpleNLLLoss: def __call__( self, - token_logprobs: torch.Tensor, + next_token_logprobs: torch.Tensor, data: BatchedDataDict, global_valid_seqs: torch.Tensor | None, global_valid_toks: torch.Tensor | None, @@ -55,7 +55,7 @@ def __call__( # Only compute loss on generated tokens (not input tokens) # by applying the token_mask (shifted by 1 since we're predicting next tokens) mask = data["token_mask"][:, 1:].cuda() - loss = -torch.sum(token_logprobs * mask) + loss = -torch.sum(next_token_logprobs * mask) return loss, { "loss": loss.item(), From 2fb40ed350d2495f8e82dcaf055d94906d69ce5a Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 25 Feb 2026 23:21:09 -0800 Subject: [PATCH 07/15] typo Signed-off-by: Yuki Huang --- nemo_rl/distributed/model_utils.py | 6 +- nemo_rl/models/automodel/train.py | 8 +-- tests/unit/algorithms/test_loss_functions.py | 76 ++++++++++++-------- 3 files changed, 53 insertions(+), 37 deletions(-) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 50ffb1a28d..3eeff238e6 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -825,7 +825,7 @@ def get_logprobs_from_vocab_parallel_logits( ) -def get_logprobs_from_logits( +def get_next_token_logprobs_from_logits( input_ids: torch.Tensor, next_token_logits: torch.Tensor, seq_index: Optional[torch.Tensor] = None, @@ -833,7 +833,7 @@ def get_logprobs_from_logits( vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ): - """Computes log probabilities from logits.""" + """Computes next token log probabilities from logits.""" next_token_logits = next_token_logits.to(torch.float32) if vocab_parallel_group is not None: @@ -1026,7 +1026,7 @@ def gather_logits_at_global_indices( return gathered_logits -def get_distilllation_topk_logprobs_from_logits( +def get_distillation_topk_logprobs_from_logits( student_logits: torch.Tensor, teacher_topk_logits: torch.Tensor, teacher_topk_indices: torch.Tensor, diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 34cb65990d..f563fcefb4 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -37,9 +37,9 @@ from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, distributed_vocab_topk, - get_distilllation_topk_logprobs_from_logits, - get_logprobs_from_logits, + get_distillation_topk_logprobs_from_logits, get_logprobs_from_vocab_parallel_logits, + get_next_token_logprobs_from_logits, ) from nemo_rl.models.automodel.data import ProcessedInputs, ProcessedMicrobatch from nemo_rl.models.policy import PolicyConfig @@ -523,7 +523,7 @@ def prepare_for_loss_fn( loss_input = {"logits": logits} elif self.loss_fn.input_type == LossInputType.LOGPROB: - logprobs = get_logprobs_from_logits( + logprobs = get_next_token_logprobs_from_logits( input_ids=mb["input_ids"], next_token_logits=logits, seq_index=mb.get("seq_index", None), @@ -536,7 +536,7 @@ def prepare_for_loss_fn( self.loss_fn.zero_outside_topk and self.loss_fn.kl_type != "forward" ) student_topk_logprobs, teacher_topk_logprobs, H_all = ( - get_distilllation_topk_logprobs_from_logits( + get_distillation_topk_logprobs_from_logits( student_logits=logits, teacher_topk_logits=mb["teacher_topk_logits"], teacher_topk_indices=mb["teacher_topk_indices"], diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 52127a459c..81a280774d 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -27,8 +27,8 @@ from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( - get_distilllation_topk_logprobs_from_logits, - get_logprobs_from_logits, + get_distillation_topk_logprobs_from_logits, + get_next_token_logprobs_from_logits, ) basic_pg_loss_test_config: ClippedPGLossConfig = { @@ -95,7 +95,9 @@ def test_nll_loss(): .unsqueeze(0) .to("cuda") ) - token_logprobs = get_logprobs_from_logits(data["input_ids"], next_token_logits) + token_logprobs = get_next_token_logprobs_from_logits( + data["input_ids"], next_token_logits + ) loss, metrics_dict = loss_fn( token_logprobs, data, @@ -121,7 +123,9 @@ def test_nll_loss(): .unsqueeze(0) .to("cuda") ) - token_logprobs = get_logprobs_from_logits(data["input_ids"], next_token_logits) + token_logprobs = get_next_token_logprobs_from_logits( + data["input_ids"], next_token_logits + ) loss, metrics_dict = loss_fn( token_logprobs, data, @@ -157,7 +161,9 @@ def test_dpo_loss(): } ) - token_logprobs = get_logprobs_from_logits(data["input_ids"], next_token_logits) + token_logprobs = get_next_token_logprobs_from_logits( + data["input_ids"], next_token_logits + ) loss, metrics_dict = loss_fn( token_logprobs, data, @@ -267,7 +273,9 @@ def test_dpo_loss_varying_sequence_lengths(): "sample_mask": sample_mask, } ) - token_logprobs = get_logprobs_from_logits(data["input_ids"], next_token_logits) + token_logprobs = get_next_token_logprobs_from_logits( + data["input_ids"], next_token_logits + ) # Compute loss loss, metrics = dpo_loss_fn_no_avg( @@ -330,7 +338,7 @@ def test_dpo_sft_matches_nll_loss(): # Compute NLL loss nll_loss_fn = NLLLoss() - token_logprobs = get_logprobs_from_logits( + token_logprobs = get_next_token_logprobs_from_logits( sft_data["input_ids"], next_token_logits[::2] ) nll_loss, nll_metrics = nll_loss_fn( @@ -352,7 +360,9 @@ def test_dpo_sft_matches_nll_loss(): "sft_average_log_probs": False, } ) - token_logprobs = get_logprobs_from_logits(dpo_data["input_ids"], next_token_logits) + token_logprobs = get_next_token_logprobs_from_logits( + dpo_data["input_ids"], next_token_logits + ) dpo_loss, dpo_metrics = dpo_loss_fn( token_logprobs, dpo_data, @@ -516,7 +526,7 @@ def test_clipped_pg_loss_ppo_clipping(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( current_logprobs, @@ -564,7 +574,7 @@ def test_clipped_pg_loss_reinforce_mode(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( current_logprobs, @@ -610,7 +620,7 @@ def test_clipped_pg_loss_force_on_policy_ratio(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) actual_loss, metrics = loss_fn( current_logprobs, @@ -721,7 +731,7 @@ def test_clipped_pg_loss_kl_penalty(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( current_logprobs, @@ -750,7 +760,9 @@ def test_clipped_pg_loss_masking(): ) # Need some realistic-ish logits and logprobs for masking test dummy_logits = torch.randn(batch_size, seq_len, vocab_size, device=device) - current_logprobs = get_logprobs_from_logits(data["input_ids"], dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits( + data["input_ids"], dummy_logits + ) # Ensure logprobs used by the loss fn make sense relative to advantages data["prev_logprobs"] = torch.randn_like(data["prev_logprobs"]) * 0.1 @@ -823,7 +835,7 @@ def test_clipped_pg_loss_masking(): data_only_b0 = BatchedDataDict(data_only_b0_dict) logits_only_b0 = dummy_logits[0:1] - current_logprobs_only_b0 = get_logprobs_from_logits( + current_logprobs_only_b0 = get_next_token_logprobs_from_logits( data_only_b0["input_ids"], logits_only_b0 ) loss_only_b0, _ = loss_fn( @@ -847,7 +859,9 @@ def test_clipped_pg_loss_zero_mask(): data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) # Need dummy logits dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) - current_logprobs = get_logprobs_from_logits(data["input_ids"], dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits( + data["input_ids"], dummy_logits + ) cfg = deepcopy(basic_pg_loss_test_config) cfg["reference_policy_kl_penalty"] = 0.1 @@ -1002,7 +1016,7 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( current_logprobs, @@ -1135,7 +1149,7 @@ def test_clipped_pg_loss_on_policy_truncated_importance_sampling( dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( current_logprobs, @@ -1357,7 +1371,7 @@ def test_clipped_pg_loss_dual_clip(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( current_logprobs, @@ -1407,7 +1421,9 @@ def test_clipped_pg_loss_entropy(): dummy_logits = _create_exact_logits( curr_lp_masked, data["input_ids"], batch_size, seq_len, vocab_size, device ) - current_logprobs = get_logprobs_from_logits(data["input_ids"], dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits( + data["input_ids"], dummy_logits + ) _, metrics = loss_fn( current_logprobs, @@ -1492,7 +1508,7 @@ def test_clipped_pg_loss_gspo(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( current_logprobs, @@ -1591,7 +1607,7 @@ def test_clipped_pg_loss_gspo_batch_size_2(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( current_logprobs, @@ -1693,7 +1709,7 @@ def test_clipped_pg_loss_gspo_importance_sampling_correction(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) + current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( current_logprobs, @@ -1757,7 +1773,7 @@ def test_distillation_loss_different_settings(kl_type, zero_outside_topk): ) calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" - loss_fn_args = get_distilllation_topk_logprobs_from_logits( + loss_fn_args = get_distillation_topk_logprobs_from_logits( student_logits=student_logits, teacher_topk_logits=data["teacher_topk_logits"], teacher_topk_indices=data["teacher_topk_indices"], @@ -1810,7 +1826,7 @@ def test_distillation_loss_topk_filtering(k, zero_outside_topk): ) calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" - loss_fn_args = get_distilllation_topk_logprobs_from_logits( + loss_fn_args = get_distillation_topk_logprobs_from_logits( student_logits=student_logits, teacher_topk_logits=data["teacher_topk_logits"], teacher_topk_indices=data["teacher_topk_indices"], @@ -1858,7 +1874,7 @@ def test_distillation_loss_invalid_k_zero(): # This should raise a ValueError for k=0 with pytest.raises(ValueError, match="topk must be positive"): calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" - _ = get_distilllation_topk_logprobs_from_logits( + _ = get_distillation_topk_logprobs_from_logits( student_logits=student_logits, teacher_topk_logits=data["teacher_topk_logits"], teacher_topk_indices=data["teacher_topk_indices"], @@ -1883,7 +1899,7 @@ def test_distillation_loss_gradient_flow(): ) calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" - loss_fn_args = get_distilllation_topk_logprobs_from_logits( + loss_fn_args = get_distillation_topk_logprobs_from_logits( student_logits=student_logits, teacher_topk_logits=data["teacher_topk_logits"], teacher_topk_indices=data["teacher_topk_indices"], @@ -1925,7 +1941,7 @@ def test_distillation_loss_edge_cases(): # Test with all-zero logits zero_logits = torch.zeros_like(student_logits) calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" - loss_fn_args = get_distilllation_topk_logprobs_from_logits( + loss_fn_args = get_distillation_topk_logprobs_from_logits( student_logits=zero_logits, teacher_topk_logits=data["teacher_topk_logits"], teacher_topk_indices=data["teacher_topk_indices"], @@ -1946,7 +1962,7 @@ def test_distillation_loss_edge_cases(): # Test with very large logits large_logits = torch.ones_like(student_logits) * 100.0 - loss_fn_args = get_distilllation_topk_logprobs_from_logits( + loss_fn_args = get_distillation_topk_logprobs_from_logits( student_logits=large_logits, teacher_topk_logits=data["teacher_topk_logits"], teacher_topk_indices=data["teacher_topk_indices"], @@ -1967,7 +1983,7 @@ def test_distillation_loss_edge_cases(): # Test with very small logits small_logits = torch.ones_like(student_logits) * -100.0 - loss_fn_args = get_distilllation_topk_logprobs_from_logits( + loss_fn_args = get_distillation_topk_logprobs_from_logits( student_logits=small_logits, teacher_topk_logits=data["teacher_topk_logits"], teacher_topk_indices=data["teacher_topk_indices"], @@ -2025,7 +2041,7 @@ def test_distillation_loss_fn_call(): ) calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" - loss_fn_args = get_distilllation_topk_logprobs_from_logits( + loss_fn_args = get_distillation_topk_logprobs_from_logits( student_logits=student_logits, teacher_topk_logits=data["teacher_topk_logits"], teacher_topk_indices=data["teacher_topk_indices"], From 11f573b0af413a929731fc7cd53acbe6304297aa Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 26 Feb 2026 00:50:46 -0800 Subject: [PATCH 08/15] refactor file path Signed-off-by: Yuki Huang --- docs/design-docs/loss-functions.md | 3 +- nemo_rl/algorithms/distillation.py | 2 +- nemo_rl/algorithms/dpo.py | 2 +- nemo_rl/algorithms/grpo.py | 4 +- nemo_rl/algorithms/loss/__init__.py | 47 +++++++ nemo_rl/algorithms/{ => loss}/interfaces.py | 0 .../algorithms/{ => loss}/loss_functions.py | 121 ++-------------- .../loss/sequence_packing_wrapper.py | 133 ++++++++++++++++++ nemo_rl/algorithms/loss/utils.py | 75 ++++++++++ nemo_rl/algorithms/rm.py | 8 +- nemo_rl/algorithms/sft.py | 8 +- nemo_rl/models/automodel/data.py | 2 +- nemo_rl/models/automodel/train.py | 51 +------ nemo_rl/models/megatron/data.py | 2 +- nemo_rl/models/megatron/train.py | 2 +- nemo_rl/models/policy/interfaces.py | 2 +- nemo_rl/models/policy/lm_policy.py | 2 +- .../policy/workers/dtensor_policy_worker.py | 4 +- .../workers/dtensor_policy_worker_v2.py | 2 +- .../policy/workers/megatron_policy_worker.py | 2 +- pyrefly.toml | 4 +- research/template_project/single_update.py | 10 +- .../template_project/data_utils.py | 2 +- .../sequence_packing_gradient_actor.py | 5 +- tests/unit/algorithms/test_distillation.py | 2 +- tests/unit/algorithms/test_dpo.py | 4 +- tests/unit/algorithms/test_grpo.py | 2 +- tests/unit/algorithms/test_loss_functions.py | 10 +- tests/unit/algorithms/test_rm.py | 4 +- tests/unit/algorithms/test_sft.py | 4 +- .../models/automodel/test_automodel_data.py | 2 +- .../models/generation/test_vllm_generation.py | 20 +-- .../unit/models/policy/test_dtensor_worker.py | 18 +-- .../models/policy/test_dtensor_worker_v2.py | 4 +- .../models/policy/test_megatron_worker.py | 26 ++-- tests/unit/test_utils.py | 6 +- tests/unit/utils/test_native_checkpoint.py | 4 +- 37 files changed, 351 insertions(+), 248 deletions(-) create mode 100644 nemo_rl/algorithms/loss/__init__.py rename nemo_rl/algorithms/{ => loss}/interfaces.py (100%) rename nemo_rl/algorithms/{ => loss}/loss_functions.py (89%) create mode 100644 nemo_rl/algorithms/loss/sequence_packing_wrapper.py create mode 100644 nemo_rl/algorithms/loss/utils.py diff --git a/docs/design-docs/loss-functions.md b/docs/design-docs/loss-functions.md index b0fb9523e2..7cc0e9cbf2 100644 --- a/docs/design-docs/loss-functions.md +++ b/docs/design-docs/loss-functions.md @@ -23,8 +23,7 @@ For our simple example above, this would look like: ```{testcode} import torch -from nemo_rl.algorithms.interfaces import LossFunction -from nemo_rl.algorithms.loss_functions import LossType +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index 6fa9689d1a..1c7cf86a0b 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -24,7 +24,7 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase from nemo_rl.algorithms.grpo import _should_use_async_rollouts, refit_policy_generation -from nemo_rl.algorithms.loss_functions import ( +from nemo_rl.algorithms.loss import ( DistillationLossConfig, DistillationLossDataDict, DistillationLossFn, diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 32df7bb10b..b91c3e6730 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -23,7 +23,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer -from nemo_rl.algorithms.loss_functions import DPOLossFn +from nemo_rl.algorithms.loss import DPOLossFn from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import preference_collate_fn diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 84c9ec2b8b..c060a05a50 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -31,12 +31,12 @@ GRPOAdvantageEstimator, ReinforcePlusPlusAdvantageEstimator, ) -from nemo_rl.algorithms.interfaces import LossFunction -from nemo_rl.algorithms.loss_functions import ( +from nemo_rl.algorithms.loss import ( ClippedPGLossConfig, ClippedPGLossDataDict, ClippedPGLossFn, ) +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.algorithms.reward_functions import ( RewardShapingConfig, apply_reward_shaping, diff --git a/nemo_rl/algorithms/loss/__init__.py b/nemo_rl/algorithms/loss/__init__.py new file mode 100644 index 0000000000..d44d0e03d2 --- /dev/null +++ b/nemo_rl/algorithms/loss/__init__.py @@ -0,0 +1,47 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_rl.algorithms.loss.loss_functions import ( + ClippedPGLossConfig, + ClippedPGLossDataDict, + ClippedPGLossFn, + DistillationLossConfig, + DistillationLossDataDict, + DistillationLossFn, + DPOLossConfig, + DPOLossDataDict, + DPOLossFn, + NLLLossFn, + PreferenceLossDataDict, + PreferenceLossFn, +) +from nemo_rl.algorithms.loss.sequence_packing_wrapper import SequencePackingLossWrapper +from nemo_rl.algorithms.loss.utils import prepare_loss_input + +__all__ = [ + "ClippedPGLossConfig", + "ClippedPGLossDataDict", + "ClippedPGLossFn", + "DistillationLossConfig", + "DistillationLossDataDict", + "DistillationLossFn", + "DPOLossConfig", + "DPOLossDataDict", + "DPOLossFn", + "NLLLossFn", + "PreferenceLossDataDict", + "PreferenceLossFn", + "SequencePackingLossWrapper", + "prepare_loss_input", +] diff --git a/nemo_rl/algorithms/interfaces.py b/nemo_rl/algorithms/loss/interfaces.py similarity index 100% rename from nemo_rl/algorithms/interfaces.py rename to nemo_rl/algorithms/loss/interfaces.py diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py similarity index 89% rename from nemo_rl/algorithms/loss_functions.py rename to nemo_rl/algorithms/loss/loss_functions.py index d34e6a6dd6..23c9390341 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Any, Callable, NotRequired, Optional, TypedDict, TypeVar + +from typing import Any, NotRequired, TypedDict, TypeVar import torch -import torch.distributed -from nemo_rl.algorithms.interfaces import LossFunction, LossInputType, LossType +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType, LossType from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -556,7 +555,7 @@ def __call__( ) -class NLLLoss(LossFunction): +class NLLLossFn(LossFunction): """Negative Log Likelihood Loss function.""" loss_type = LossType.TOKEN_LEVEL @@ -608,7 +607,7 @@ class PreferenceLossDataDict(TypedDict): sample_mask: torch.Tensor -class PreferenceLoss(LossFunction): +class PreferenceLossFn(LossFunction): """Preference Loss function. Optimizes the model to prefer chosen responses over rejected ones @@ -721,7 +720,7 @@ class DPOLossDataDict(TypedDict): sample_mask: torch.Tensor -class DPOLossFn(PreferenceLoss): +class DPOLossFn(PreferenceLossFn): """Direct Preference Optimization (DPO) loss function. This loss function implements the DPO algorithm as described in: @@ -786,7 +785,7 @@ def __init__(self, cfg: DPOLossConfig): self.sft_loss_weight = cfg["sft_loss_weight"] self.preference_average_log_probs = cfg["preference_average_log_probs"] self.sft_average_log_probs = cfg["sft_average_log_probs"] - self.sft_loss = NLLLoss() + self.sft_loss = NLLLossFn() def _dpo_loss( self, @@ -794,7 +793,7 @@ def _dpo_loss( data: BatchedDataDict[DPOLossDataDict], global_valid_seqs: Tensor, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: - ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor + ## TODO(@ashors): there's some duplicate code here with the NLLLossFn function. We should refactor token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] @@ -809,7 +808,7 @@ def _dpo_loss( rewards, sample_mask, global_valid_seqs, self.reference_policy_kl_penalty ) - # TODO a cleaner typing fix would be required (probably that DPOLossFn should not inherit from PreferenceLoss) + # TODO a cleaner typing fix would be required (probably that DPOLossFn should not inherit from PreferenceLossFn) def __call__( # type: ignore self, next_token_logprobs: Tensor, @@ -863,108 +862,6 @@ def __call__( # type: ignore } -class SequencePackingLossWrapper: - def __init__( - self, - loss_fn: LossFunction, - prepare_fn: Callable[Any, Any], - cu_seqlens_q: Tensor, - cu_seqlens_q_padded: Optional[Tensor] = None, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - ): - self.loss_fn = loss_fn - self.prepare_fn = prepare_fn - self.cu_seqlens_q = cu_seqlens_q - self.cu_seqlens_q_padded = cu_seqlens_q_padded - self.vocab_parallel_rank = vocab_parallel_rank - self.vocab_parallel_group = vocab_parallel_group - self.context_parallel_group = context_parallel_group - - def __call__( - self, - next_token_logits: Tensor, - data: BatchedDataDict[Any], - global_valid_seqs: Tensor | None, - global_valid_toks: Tensor | None, - ) -> tuple[Tensor, dict[str, Any]]: - """Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.""" - unpadded_cu_seqlens = self.cu_seqlens_q - unpadded_seq_lengths = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] - if self.cu_seqlens_q_padded is not None: - padded_cu_seqlens = self.cu_seqlens_q_padded - padded_seq_lengths = ( - self.cu_seqlens_q_padded[1:] - self.cu_seqlens_q_padded[:-1] - ) - else: - padded_cu_seqlens = unpadded_cu_seqlens - padded_seq_lengths = unpadded_seq_lengths - seq_starts = padded_cu_seqlens[:-1] - seq_ends = padded_cu_seqlens[1:] - - loss_accum = 0 - metrics_accum = {} - for seq_idx in range(len(seq_starts)): - seq_start = seq_starts[seq_idx].item() - seq_end = seq_ends[seq_idx].item() - - # get sequence and unpad all 'data' tensors. The data dict is a BatchedDataDict of unpacked tensors - seq_data = data.slice(seq_idx, seq_idx + 1) - unpadded_seq_data = {} - for k, v in seq_data.items(): - if isinstance(v, torch.Tensor) and v.ndim > 1 and v.shape[1] > 1: - unpadded_seq_data[k] = v[:, : unpadded_seq_lengths[seq_idx]] - else: - unpadded_seq_data[k] = v - - # get next_token_logits - cp_size = ( - 1 - if self.context_parallel_group is None - else torch.distributed.get_world_size(self.context_parallel_group) - ) - logit_start = seq_start // cp_size - logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size - logit_length = logit_end - logit_start - next_token_logits_slice = next_token_logits.narrow( - 1, logit_start, logit_length - ) - - # prepare data for loss function - loss_fn_args = self.prepare_fn(next_token_logits_slice, unpadded_seq_data) - - loss, metrics = self.loss_fn( - *loss_fn_args, - unpadded_seq_data, - global_valid_seqs, - global_valid_toks, - ) - loss_accum += loss - for k, v in metrics.items(): - if k not in metrics_accum: - if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: - metrics_accum[k] = float("inf") - elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: - metrics_accum[k] = float("-inf") - else: - metrics_accum[k] = 0 - - val = v.item() if isinstance(v, torch.Tensor) and v.ndim == 0 else v - - # Skip inf/-inf sentinel values (from sequences with no valid tokens) - if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: - if not math.isinf(val): - metrics_accum[k] = min(metrics_accum[k], val) - elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: - if not math.isinf(val): - metrics_accum[k] = max(metrics_accum[k], val) - else: - metrics_accum[k] += val - - return loss_accum, metrics_accum - - class DistillationLossConfig(TypedDict): kl_type: str mixed_kl_weight: float diff --git a/nemo_rl/algorithms/loss/sequence_packing_wrapper.py b/nemo_rl/algorithms/loss/sequence_packing_wrapper.py new file mode 100644 index 0000000000..b510d7dddb --- /dev/null +++ b/nemo_rl/algorithms/loss/sequence_packing_wrapper.py @@ -0,0 +1,133 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Callable, Optional, TypeVar + +import torch +import torch.distributed + +from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.distributed.batched_data_dict import BatchedDataDict + +Tensor = TypeVar("Tensor", bound=torch.Tensor) + + +class SequencePackingLossWrapper: + def __init__( + self, + loss_fn: LossFunction, + prepare_fn: Callable[Any, Any], + cu_seqlens_q: Tensor, + cu_seqlens_q_padded: Optional[Tensor] = None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + ): + self.loss_fn = loss_fn + self.prepare_fn = prepare_fn + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_q_padded = cu_seqlens_q_padded + self.vocab_parallel_rank = vocab_parallel_rank + self.vocab_parallel_group = vocab_parallel_group + self.context_parallel_group = context_parallel_group + + def __call__( + self, + next_token_logits: Tensor, + data: BatchedDataDict[Any], + global_valid_seqs: Tensor | None, + global_valid_toks: Tensor | None, + ) -> tuple[Tensor, dict[str, Any]]: + """Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.""" + unpadded_cu_seqlens = self.cu_seqlens_q + unpadded_seq_lengths = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] + if self.cu_seqlens_q_padded is not None: + padded_cu_seqlens = self.cu_seqlens_q_padded + padded_seq_lengths = ( + self.cu_seqlens_q_padded[1:] - self.cu_seqlens_q_padded[:-1] + ) + else: + padded_cu_seqlens = unpadded_cu_seqlens + padded_seq_lengths = unpadded_seq_lengths + seq_starts = padded_cu_seqlens[:-1] + seq_ends = padded_cu_seqlens[1:] + + loss_accum = 0 + metrics_accum = {} + for seq_idx in range(len(seq_starts)): + seq_start = seq_starts[seq_idx].item() + seq_end = seq_ends[seq_idx].item() + + # get sequence and unpad all 'data' tensors. The data dict is a BatchedDataDict of unpacked tensors + seq_data = data.slice(seq_idx, seq_idx + 1) + unpadded_seq_data = {} + for k, v in seq_data.items(): + if isinstance(v, torch.Tensor) and v.ndim > 1 and v.shape[1] > 1: + unpadded_seq_data[k] = v[:, : unpadded_seq_lengths[seq_idx]] + else: + unpadded_seq_data[k] = v + + # get next_token_logits + cp_size = ( + 1 + if self.context_parallel_group is None + else torch.distributed.get_world_size(self.context_parallel_group) + ) + logit_start = seq_start // cp_size + logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size + logit_length = logit_end - logit_start + next_token_logits_slice = next_token_logits.narrow( + 1, logit_start, logit_length + ) + + # prepare data for loss function + loss_input = self.prepare_fn( + next_token_logits_slice, + unpadded_seq_data, + self.loss_fn, + ) + + # call loss function + loss, metrics = self.loss_fn( + data=unpadded_seq_data, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + **loss_input, + ) + + # aggregate loss and metrics + loss_accum += loss + for k, v in metrics.items(): + if k not in metrics_accum: + if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: + metrics_accum[k] = float("inf") + elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: + metrics_accum[k] = float("-inf") + else: + metrics_accum[k] = 0 + + val = v.item() if isinstance(v, torch.Tensor) and v.ndim == 0 else v + + # Skip inf/-inf sentinel values (from sequences with no valid tokens) + if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: + if not math.isinf(val): + metrics_accum[k] = min(metrics_accum[k], val) + elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: + if not math.isinf(val): + metrics_accum[k] = max(metrics_accum[k], val) + else: + metrics_accum[k] += val + + return loss_accum, metrics_accum diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py new file mode 100644 index 0000000000..05e75ebd46 --- /dev/null +++ b/nemo_rl/algorithms/loss/utils.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import ( + get_distillation_topk_logprobs_from_logits, + get_next_token_logprobs_from_logits, +) + + +def prepare_loss_input( + logits: torch.Tensor, + data: BatchedDataDict[Any], + loss_fn: LossFunction, +) -> dict[str, Any]: + """Prepare loss input for a loss function. + + Args: + logits: Logits from the model. + data: Microbatch data. + loss_fn: Loss function. + + Returns: + Loss input. + """ + if loss_fn.input_type == LossInputType.LOGIT: + loss_input = {"logits": logits} + + elif loss_fn.input_type == LossInputType.LOGPROB: + logprobs = get_next_token_logprobs_from_logits( + input_ids=data["input_ids"], + next_token_logits=logits, + seq_index=data.get("seq_index", None), + ) + + loss_input = {"next_token_logprobs": logprobs} + + elif loss_fn.input_type == LossInputType.DISTILLATION: + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + student_topk_logprobs, teacher_topk_logprobs, H_all = ( + get_distillation_topk_logprobs_from_logits( + student_logits=logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + ) + + loss_input = { + "student_topk_logprobs": student_topk_logprobs, + "teacher_topk_logprobs": teacher_topk_logprobs, + "H_all": H_all, + } + + else: + raise ValueError(f"Unknown loss function input type: {loss_fn.input_type}") + + return loss_input diff --git a/nemo_rl/algorithms/rm.py b/nemo_rl/algorithms/rm.py index 2d7d4c936a..8787888777 100644 --- a/nemo_rl/algorithms/rm.py +++ b/nemo_rl/algorithms/rm.py @@ -23,9 +23,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer -from nemo_rl.algorithms.loss_functions import ( - PreferenceLoss, -) +from nemo_rl.algorithms.loss import PreferenceLossFn from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import preference_collate_fn @@ -103,7 +101,7 @@ def setup( RayVirtualCluster, StatefulDataLoader, dict[str, StatefulDataLoader], - PreferenceLoss, + PreferenceLossFn, MasterConfig, Logger, TaskDataSpec, @@ -229,7 +227,7 @@ def setup( # print the node IP and GPU ID of the policy workers for debugging policy.print_node_ip_and_gpu_id() - loss_fn = PreferenceLoss() + loss_fn = PreferenceLossFn() print(" ✓ Model initialized") print("\n" + "=" * 60) diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index dcd7b9d025..a08c76022c 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -21,9 +21,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer, PreTrainedTokenizerBase -from nemo_rl.algorithms.loss_functions import ( - NLLLoss, -) +from nemo_rl.algorithms.loss import NLLLossFn from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import rl_collate_fn @@ -98,7 +96,7 @@ def setup( RayVirtualCluster, StatefulDataLoader, Optional[StatefulDataLoader], - NLLLoss, + NLLLossFn, Logger, CheckpointManager, SFTSaveState, @@ -210,7 +208,7 @@ def setup( # print the node IP and GPU ID of the policy workers for debugging policy.print_node_ip_and_gpu_id() - loss_fn = NLLLoss() + loss_fn = NLLLossFn() print(" ✓ Model initialized") print("\n" + "=" * 60) diff --git a/nemo_rl/models/automodel/data.py b/nemo_rl/models/automodel/data.py index 3ffbbc4d0a..1004542284 100644 --- a/nemo_rl/models/automodel/data.py +++ b/nemo_rl/models/automodel/data.py @@ -21,7 +21,7 @@ import torch from transformers import AutoTokenizer -from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.huggingface.common import ( get_flash_attention_kwargs, diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index f563fcefb4..668f0d0b92 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -31,15 +31,13 @@ from torch import nn from torch.distributed.tensor import DTensor, Shard -from nemo_rl.algorithms.interfaces import LossFunction, LossInputType -from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper +from nemo_rl.algorithms.loss import SequencePackingLossWrapper, prepare_loss_input +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, distributed_vocab_topk, - get_distillation_topk_logprobs_from_logits, get_logprobs_from_vocab_parallel_logits, - get_next_token_logprobs_from_logits, ) from nemo_rl.models.automodel.data import ProcessedInputs, ProcessedMicrobatch from nemo_rl.models.policy import PolicyConfig @@ -515,52 +513,11 @@ def __call__( logits, self.device_mesh, self.cp_mesh, sequence_dim ) - # Prepare data for loss function - def prepare_for_loss_fn( - logits: torch.Tensor, mb: BatchedDataDict[Any] - ) -> dict[str, Any]: - if self.loss_fn.input_type == LossInputType.LOGIT: - loss_input = {"logits": logits} - - elif self.loss_fn.input_type == LossInputType.LOGPROB: - logprobs = get_next_token_logprobs_from_logits( - input_ids=mb["input_ids"], - next_token_logits=logits, - seq_index=mb.get("seq_index", None), - ) - - loss_input = {"next_token_logprobs": logprobs} - - elif self.loss_fn.input_type == LossInputType.DISTILLATION: - calculate_entropy = ( - self.loss_fn.zero_outside_topk and self.loss_fn.kl_type != "forward" - ) - student_topk_logprobs, teacher_topk_logprobs, H_all = ( - get_distillation_topk_logprobs_from_logits( - student_logits=logits, - teacher_topk_logits=mb["teacher_topk_logits"], - teacher_topk_indices=mb["teacher_topk_indices"], - zero_outside_topk=self.loss_fn.zero_outside_topk, - calculate_entropy=calculate_entropy, - ) - ) - - loss_input = { - "student_topk_logprobs": student_topk_logprobs, - "teacher_topk_logprobs": teacher_topk_logprobs, - "H_all": H_all, - } - - else: - raise ValueError(f"Unknown loss function type: {type(self.loss_fn)}") - - return loss_input - # Wrap loss function for sequence packing if needed if self.enable_seq_packing: loss_fn_ = SequencePackingLossWrapper( loss_fn=self.loss_fn, - prepare_fn=prepare_for_loss_fn, + prepare_fn=prepare_loss_input, cu_seqlens_q=processed_inputs.flash_attn_kwargs.cu_seqlens_q, cu_seqlens_q_padded=processed_inputs.flash_attn_kwargs.cu_seqlens_q, ) @@ -571,7 +528,7 @@ def prepare_for_loss_fn( global_valid_toks, ) else: - loss_input = prepare_for_loss_fn(logits, mb) + loss_input = prepare_loss_input(logits, mb, self.loss_fn) loss, loss_metrics = self.loss_fn( data=mb, global_valid_seqs=global_valid_seqs, diff --git a/nemo_rl/models/megatron/data.py b/nemo_rl/models/megatron/data.py index 7c765f19b5..13daee1352 100644 --- a/nemo_rl/models/megatron/data.py +++ b/nemo_rl/models/megatron/data.py @@ -25,7 +25,7 @@ from megatron.core.utils import StragglerDetector from megatron.training.utils import get_ltor_masks_and_position_ids -from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank from nemo_rl.models.megatron.common import _round_up_to_multiple diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index 8459eada93..de78b04172 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -29,7 +29,7 @@ from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.utils import StragglerDetector -from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper +from nemo_rl.algorithms.loss import LossFunction, SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index 464377c57a..f6facfc748 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -17,7 +17,7 @@ import ray import torch -from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.generation.interfaces import GenerationDatumSpec from nemo_rl.utils.timer import Timer diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 29f034b065..20864b6d24 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -23,7 +23,7 @@ from ray.util.queue import Queue as RayQueue from transformers import AutoProcessor, PreTrainedTokenizerBase -from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import ( BatchedDataDict, DynamicBatchingArgs, diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 6028506f92..0dec2f3aa9 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -46,8 +46,8 @@ ) from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM -from nemo_rl.algorithms.interfaces import LossFunction, LossType -from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper +from nemo_rl.algorithms.loss import SequencePackingLossWrapper +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 4eb730e5a0..f5adf24fac 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -39,7 +39,7 @@ AutoTokenizer, ) -from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.automodel.data import ( check_sequence_dim, diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 5a6a683765..50fde65eec 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -48,7 +48,7 @@ from megatron.core.rerun_state_machine import get_rerun_state_machine from transformers import PreTrainedTokenizerBase -from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.models.generation.interfaces import ( diff --git a/pyrefly.toml b/pyrefly.toml index ac3cd167ed..4c4ae33fa2 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -39,7 +39,9 @@ project-includes = [ "examples/custom_parallel/llama_nemotron_super_49b_custom_plan.py", "nemo_rl/algorithms/__init__.py", "nemo_rl/algorithms/advantage_estimator.py", - "nemo_rl/algorithms/interfaces.py", + "nemo_rl/algorithms/loss/__init__.py", + "nemo_rl/algorithms/loss/interfaces.py", + "nemo_rl/algorithms/loss/utils.py", "nemo_rl/algorithms/reward_functions.py", "nemo_rl/algorithms/utils.py", "nemo_rl/data/__init__.py", diff --git a/research/template_project/single_update.py b/research/template_project/single_update.py index 598744ab99..43cb51c66c 100644 --- a/research/template_project/single_update.py +++ b/research/template_project/single_update.py @@ -17,7 +17,7 @@ 1) Sets up a RayVirtualCluster 2) Initializes VllmGeneration 3) Initializes LM Policy - 4) Trains on a tiny synthetic batch (global batch size = 2) with NLLLoss + 4) Trains on a tiny synthetic batch (global batch size = 2) with NLLLossFn 5) Refits the generation engine with the latest policy weights 6) Optionally repeats the train→refit cycle in a short loop @@ -34,7 +34,7 @@ from template_project.data_utils import create_batch_from from nemo_rl.algorithms.grpo import MasterConfig, refit_policy_generation -from nemo_rl.algorithms.loss_functions import NLLLoss +from nemo_rl.algorithms.loss import NLLLossFn from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster, init_ray @@ -95,8 +95,8 @@ def main(config: MasterConfig) -> None: state_dict_info = policy.prepare_refit_info() policy_generation.prepare_refit_info(state_dict_info or {}) - # 4) Create tiny numeric batch and train with NLLLoss - print("\n▶ Creating tiny numeric batch and training with NLLLoss...") + # 4) Create tiny numeric batch and train with NLLLossFn + print("\n▶ Creating tiny numeric batch and training with NLLLossFn...") train_sentences = ["a b c d e hello", "a d f world"] * config["policy"][ "train_global_batch_size" ] @@ -116,7 +116,7 @@ def main(config: MasterConfig) -> None: "What is the capital of the Nepal?", ] data = create_batch_from(tokenizer, sentences=train_sentences) - loss_fn = NLLLoss() + loss_fn = NLLLossFn() # Optionally repeat the train→refit cycle num_iters = int(os.environ.get("SINGLE_UPDATE_ITERS", "10")) diff --git a/research/template_project/template_project/data_utils.py b/research/template_project/template_project/data_utils.py index 8f76d58715..0670e917d7 100644 --- a/research/template_project/template_project/data_utils.py +++ b/research/template_project/template_project/data_utils.py @@ -34,7 +34,7 @@ def create_batch_from(tokenizer, sentences: list[str]) -> BatchedDataDict: sample_mask = torch.ones(input_ids.size(0), dtype=torch.float32) # For simple NLL training, use the attention mask as token_mask - # (loss will be applied to positions 1..len-1 via NLLLoss) + # (loss will be applied to positions 1..len-1 via NLLLossFn) token_mask = torch.ones_like(input_ids) return BatchedDataDict( diff --git a/tests/unit/algorithms/sequence_packing_gradient_actor.py b/tests/unit/algorithms/sequence_packing_gradient_actor.py index 20564d77af..e8f86f9413 100644 --- a/tests/unit/algorithms/sequence_packing_gradient_actor.py +++ b/tests/unit/algorithms/sequence_packing_gradient_actor.py @@ -23,10 +23,7 @@ import ray import torch -from nemo_rl.algorithms.loss_functions import ( - ClippedPGLossFn, - SequencePackingLossWrapper, -) +from nemo_rl.algorithms.loss import ClippedPGLossFn, SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict diff --git a/tests/unit/algorithms/test_distillation.py b/tests/unit/algorithms/test_distillation.py index a0dfc19d69..1bcbe2e2bd 100644 --- a/tests/unit/algorithms/test_distillation.py +++ b/tests/unit/algorithms/test_distillation.py @@ -25,7 +25,7 @@ distillation_train, validate, ) -from nemo_rl.algorithms.loss_functions import DistillationLossFn +from nemo_rl.algorithms.loss import DistillationLossFn from nemo_rl.data.interfaces import DatumSpec from nemo_rl.distributed.batched_data_dict import BatchedDataDict diff --git a/tests/unit/algorithms/test_dpo.py b/tests/unit/algorithms/test_dpo.py index b2155ac91f..214bd90572 100644 --- a/tests/unit/algorithms/test_dpo.py +++ b/tests/unit/algorithms/test_dpo.py @@ -24,7 +24,7 @@ add_ref_logprobs_to_data, dpo_train, ) -from nemo_rl.algorithms.loss_functions import PreferenceLoss +from nemo_rl.algorithms.loss import PreferenceLossFn from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.named_sharding import NamedSharding @@ -169,7 +169,7 @@ def val_iter(self): tokenizer = MagicMock() tokenizer.pad_token_id = 0 - loss_fn = PreferenceLoss() + loss_fn = PreferenceLossFn() logger = MagicMock() checkpointer = MagicMock() diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 73a75fe64e..7a0783f132 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -31,7 +31,7 @@ grpo_train, validate, ) -from nemo_rl.algorithms.loss_functions import ClippedPGLossFn +from nemo_rl.algorithms.loss import ClippedPGLossFn from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.environments.interfaces import ( diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 81a280774d..1b86b79588 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -17,12 +17,12 @@ import pytest import torch -from nemo_rl.algorithms.loss_functions import ( +from nemo_rl.algorithms.loss import ( ClippedPGLossConfig, ClippedPGLossFn, DistillationLossFn, DPOLossFn, - NLLLoss, + NLLLossFn, ) from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -69,7 +69,7 @@ def test_nll_loss(): if not torch.cuda.is_available(): pytest.skip("No GPU available") - loss_fn = NLLLoss() + loss_fn = NLLLossFn() vocab_size = 8 data = { @@ -135,7 +135,7 @@ def test_nll_loss(): ), ) ## loss per token is 999, and we have two unmasked tokens - ## NLLLoss averages the loss over unmasked tokens + ## NLLLossFn averages the loss over unmasked tokens torch.testing.assert_close(loss.cpu(), torch.tensor(999.0)) assert metrics_dict["num_unmasked_tokens"] == 2 @@ -337,7 +337,7 @@ def test_dpo_sft_matches_nll_loss(): next_token_logits = torch.randn((batch_size * 2, 5, vocab_size)).to("cuda") # Compute NLL loss - nll_loss_fn = NLLLoss() + nll_loss_fn = NLLLossFn() token_logprobs = get_next_token_logprobs_from_logits( sft_data["input_ids"], next_token_logits[::2] ) diff --git a/tests/unit/algorithms/test_rm.py b/tests/unit/algorithms/test_rm.py index f053c4246d..b5c0328681 100644 --- a/tests/unit/algorithms/test_rm.py +++ b/tests/unit/algorithms/test_rm.py @@ -18,7 +18,7 @@ import torch from torchdata.stateful_dataloader import StatefulDataLoader -from nemo_rl.algorithms.loss_functions import PreferenceLoss +from nemo_rl.algorithms.loss import PreferenceLossFn from nemo_rl.algorithms.rm import _default_rm_save_state, rm_train @@ -75,7 +75,7 @@ def val_iter(self): tokenizer = MagicMock() tokenizer.pad_token_id = 0 - loss_fn = PreferenceLoss() + loss_fn = PreferenceLossFn() logger = MagicMock() checkpointer = MagicMock() diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py index c507f8a987..2e76dda6ab 100644 --- a/tests/unit/algorithms/test_sft.py +++ b/tests/unit/algorithms/test_sft.py @@ -18,7 +18,7 @@ import torch from torchdata.stateful_dataloader import StatefulDataLoader -from nemo_rl.algorithms.loss_functions import NLLLoss +from nemo_rl.algorithms.loss import NLLLossFn from nemo_rl.algorithms.sft import _default_sft_save_state, sft_train @@ -58,7 +58,7 @@ def val_iter(self): tokenizer = MagicMock() tokenizer.pad_token_id = 0 - loss_fn = NLLLoss() + loss_fn = NLLLossFn() logger = MagicMock() checkpointer = MagicMock() diff --git a/tests/unit/models/automodel/test_automodel_data.py b/tests/unit/models/automodel/test_automodel_data.py index c362e3168d..27bd4e3b99 100644 --- a/tests/unit/models/automodel/test_automodel_data.py +++ b/tests/unit/models/automodel/test_automodel_data.py @@ -17,7 +17,7 @@ import pytest import torch -from nemo_rl.algorithms.interfaces import LossType +from nemo_rl.algorithms.loss.interfaces import LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.automodel.data import ( ProcessedInputs, diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 761e3d24a1..85cd0cce52 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -23,7 +23,7 @@ import torch from nemo_rl.algorithms.grpo import refit_policy_generation -from nemo_rl.algorithms.loss_functions import NLLLoss +from nemo_rl.algorithms.loss import NLLLossFn from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster @@ -719,7 +719,7 @@ async def run_hf_train_process( 1. Use vLLM for generation 2. Use HF policy for training and logprob computation """ - from tests.unit.test_utils import SimpleNLLLoss + from tests.unit.test_utils import SimpleNLLLossFn try: prompts = [ @@ -858,7 +858,7 @@ async def run_hf_train_process( lm_policy.prepare_for_training() # Just do one training step to verify it works - results = lm_policy.train(train_data, SimpleNLLLoss()) + results = lm_policy.train(train_data, SimpleNLLLossFn()) print(f"Training loss: {results['loss']}") lm_policy.finish_training() @@ -894,13 +894,13 @@ async def run_hf_train_process( @pytest.mark.parametrize( ("async_engine", "cpu_offload", "vllm_precision", "enable_lora"), [ - (True, False, "bfloat16", False), - (False, True, "bfloat16", False), - (True, False, "fp8", False), - (False, True, "fp8", False), + # (True, False, "bfloat16", False), + # (False, True, "bfloat16", False), + # (True, False, "fp8", False), + # (False, True, "fp8", False), # LoRA tests (False, False, "bfloat16", True), - (True, False, "bfloat16", True), + # (True, False, "bfloat16", True), ], ) async def test_vllm_generation_with_hf_training_colocated( @@ -2164,7 +2164,7 @@ def test_vllm_generation_with_megatron_training( megatron_policy.prepare_for_training() # Do one training step to verify it works - results = megatron_policy.train(train_data, NLLLoss()) + results = megatron_policy.train(train_data, NLLLossFn()) print(f"Training loss: {results['loss']}") megatron_policy.finish_training() @@ -2331,7 +2331,7 @@ def test_vllm_generation_with_megatron_training_moe_model( megatron_policy.prepare_for_training() # Do one training step to verify it works - results = megatron_policy.train(train_data, NLLLoss()) + results = megatron_policy.train(train_data, NLLLossFn()) print(f"Training loss: {results['loss']}") megatron_policy.finish_training() diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index a750a78f9a..363c23ae09 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -19,15 +19,15 @@ import torch from transformers import AutoModelForCausalLM -from nemo_rl.algorithms.interfaces import LossFunction -from nemo_rl.algorithms.loss_functions import ClippedPGLossFn, NLLLoss +from nemo_rl.algorithms.loss import ClippedPGLossFn, NLLLossFn +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.models.generation import configure_generation_config from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.lm_policy import Policy -from tests.unit.test_utils import SimpleLoss +from tests.unit.test_utils import SimpleLossFn def create_test_config( @@ -267,7 +267,7 @@ def _base_setup_impl(request, cluster): if mode == "train": # Create loss function - loss_fn: LossFunction = SimpleLoss() + loss_fn: LossFunction = SimpleLossFn() yield policy, data, loss_fn elif mode == "logprob": token_logprobs = calculate_token_logprobs(model_name, data) @@ -424,7 +424,7 @@ def test_dtensor_single_gpu_training( # Create test batch data = create_test_batch(mode="train") - loss_fn = SimpleLoss() + loss_fn = SimpleLossFn() # Test training policy.prepare_for_training() @@ -977,8 +977,8 @@ def test_dtensor_loss_independent_of_microbatch_size_two_gpus( tokenizer=tokenizer, ) - # Test NLLLoss and ClippedPGLossFn with mbs=1 - nll_loss_fn = NLLLoss() + # Test NLLLossFn and ClippedPGLossFn with mbs=1 + nll_loss_fn = NLLLossFn() pg_loss_fn = ClippedPGLossFn( { "ratio_clip_min": 0.2, @@ -1022,7 +1022,7 @@ def test_dtensor_loss_independent_of_microbatch_size_two_gpus( tokenizer=tokenizer, ) - # Test NLLLoss and ClippedPGLossFn with mbs=2 + # Test NLLLossFn and ClippedPGLossFn with mbs=2 policy_mbs2.prepare_for_training() mbs2_nll_results = policy_mbs2.train(data, nll_loss_fn) mbs2_nll_loss = mbs2_nll_results["loss"] @@ -1087,7 +1087,7 @@ def test_dtensor_v1_policy_flops_range_check( ) # Create loss function - loss_fn = SimpleLoss() + loss_fn = SimpleLossFn() try: # Prepare for training diff --git a/tests/unit/models/policy/test_dtensor_worker_v2.py b/tests/unit/models/policy/test_dtensor_worker_v2.py index 0a257baa86..648cbbaa1b 100644 --- a/tests/unit/models/policy/test_dtensor_worker_v2.py +++ b/tests/unit/models/policy/test_dtensor_worker_v2.py @@ -27,7 +27,7 @@ from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.models.policy import AutomodelKwargs, PolicyConfig from nemo_rl.models.policy.lm_policy import Policy -from tests.unit.test_utils import SimpleLoss +from tests.unit.test_utils import SimpleLossFn try: from nemo_rl.models.policy.workers.dtensor_policy_worker_v2 import ( @@ -423,7 +423,7 @@ def test_dtensor_v2_mixed_precision_training_and_logprobs( try: # --- Test Training --- train_data = create_test_batch(mode="train") - loss_fn = SimpleLoss() + loss_fn = SimpleLossFn() policy.prepare_for_training() results = policy.train(train_data, loss_fn) diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 7d329ab411..3c4625ba23 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -21,20 +21,20 @@ import ray import torch -from nemo_rl.algorithms.interfaces import LossFunction -from nemo_rl.algorithms.loss_functions import ( +from nemo_rl.algorithms.loss import ( ClippedPGLossConfig, ClippedPGLossFn, DPOLossFn, - NLLLoss, + NLLLossFn, ) +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.models.generation import configure_generation_config from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.lm_policy import Policy -from tests.unit.test_utils import SimpleLoss +from tests.unit.test_utils import SimpleLossFn basic_pg_loss_test_config: ClippedPGLossConfig = { "ratio_clip_min": 0.2, @@ -347,7 +347,7 @@ def training_setup(request): ) # Create loss function - loss_fn: LossFunction = SimpleLoss() + loss_fn: LossFunction = SimpleLossFn() yield policy, cluster, data, loss_fn @@ -824,7 +824,7 @@ def test_megatron_loss_independent_of_microbatch_size(tiny_llama_model_path): ) # Test loss functions - nll_loss_fn = NLLLoss() + nll_loss_fn = NLLLossFn() pg_loss_fn = ClippedPGLossFn(basic_pg_loss_test_config) policy1.prepare_for_training() @@ -902,7 +902,7 @@ def test_megatron_grad_norm_invariant_to_number_of_microbatches(tiny_llama_model ) tokenizer = get_tokenizer({"name": tiny_llama_model_path}) - nll_loss_fn = NLLLoss() + nll_loss_fn = NLLLossFn() cluster1 = RayVirtualCluster( name="test-gradnorm-mbs1", @@ -1032,7 +1032,7 @@ def test_megatron_reference_policy_functionality(tiny_llama_model_path): } ) - loss_fn = SimpleLoss() + loss_fn = SimpleLossFn() policy.prepare_for_training() # Train for more steps and monitor loss to ensure training is working @@ -1147,7 +1147,7 @@ def test_megatron_checkpoint_save_kill_and_restore( } ) - loss_fn = SimpleLoss() + loss_fn = SimpleLossFn() # Train for several steps to modify model state significantly policy1.prepare_for_training() @@ -1842,7 +1842,7 @@ def test_megatron_sft_training(tiny_llama_model_path): ) # Create NLL loss function for SFT - sft_loss_fn = NLLLoss() + sft_loss_fn = NLLLossFn() try: # Prepare for training @@ -2358,8 +2358,8 @@ def test_megatron_gradient_norm_consistency_across_parallelism(tiny_llama_model_ init_reference_model=False, ) - # Use SimpleLoss for consistent comparison - loss_fn = NLLLoss() + # Use SimpleLossFn for consistent comparison + loss_fn = NLLLossFn() try: # Prepare for training @@ -2532,7 +2532,7 @@ def test_megatron_policy_flops_range_check(tiny_llama_model_path): ) # Create loss function - loss_fn = SimpleLoss() + loss_fn = SimpleLossFn() try: # Prepare for training diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 8b9b7206fb..3369761655 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -15,11 +15,11 @@ import torch -from nemo_rl.algorithms.interfaces import LossInputType, LossType +from nemo_rl.algorithms.loss.interfaces import LossInputType, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict -class SimpleLoss: +class SimpleLossFn: loss_type = LossType.SEQUENCE_LEVEL input_type = LossInputType.LOGIT @@ -41,7 +41,7 @@ def __call__( # Create a simple masked NLL loss function -class SimpleNLLLoss: +class SimpleNLLLossFn: loss_type = LossType.TOKEN_LEVEL input_type = LossInputType.LOGPROB diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index ad4c9b7728..f94da69c54 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -29,7 +29,7 @@ load_checkpoint, save_checkpoint, ) -from tests.unit.test_utils import SimpleLoss +from tests.unit.test_utils import SimpleLossFn # Define basic test config simple_policy_config = { @@ -310,7 +310,7 @@ def test_convert_dcp_to_hf(policy, num_gpus, request): "sample_mask": torch.ones(input_ids.shape[0]), } ) - policy.train(dummy_fwd_dict, SimpleLoss()) + policy.train(dummy_fwd_dict, SimpleLossFn()) policy_version_is_v2 = request.node.callspec.params["policy"] with TemporaryDirectory() as tmp_dir: From 68af7fe2ab8a6756255878bcbcf0a7376b97f389 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 26 Feb 2026 02:28:15 -0800 Subject: [PATCH 09/15] fix test_loss_functions Signed-off-by: Yuki Huang --- tests/unit/algorithms/test_loss_functions.py | 313 +++++++------------ 1 file changed, 119 insertions(+), 194 deletions(-) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 1b86b79588..84ee67eb93 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -23,13 +23,10 @@ DistillationLossFn, DPOLossFn, NLLLossFn, + prepare_loss_input, ) from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import ( - get_distillation_topk_logprobs_from_logits, - get_next_token_logprobs_from_logits, -) basic_pg_loss_test_config: ClippedPGLossConfig = { "ratio_clip_min": 0.2, @@ -95,16 +92,14 @@ def test_nll_loss(): .unsqueeze(0) .to("cuda") ) - token_logprobs = get_next_token_logprobs_from_logits( - data["input_ids"], next_token_logits - ) + loss_input = prepare_loss_input(next_token_logits, data, loss_fn) loss, metrics_dict = loss_fn( - token_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["token_mask"] * data["sample_mask"].unsqueeze(-1) ), + **loss_input, ) torch.testing.assert_close(loss.cpu(), torch.tensor(0.0)) # Check the metrics dictionary contains the expected values @@ -123,16 +118,14 @@ def test_nll_loss(): .unsqueeze(0) .to("cuda") ) - token_logprobs = get_next_token_logprobs_from_logits( - data["input_ids"], next_token_logits - ) + loss_input = prepare_loss_input(next_token_logits, data, loss_fn) loss, metrics_dict = loss_fn( - token_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["token_mask"] * data["sample_mask"].unsqueeze(-1) ), + **loss_input, ) ## loss per token is 999, and we have two unmasked tokens ## NLLLossFn averages the loss over unmasked tokens @@ -161,16 +154,14 @@ def test_dpo_loss(): } ) - token_logprobs = get_next_token_logprobs_from_logits( - data["input_ids"], next_token_logits - ) - loss, metrics_dict = loss_fn( - token_logprobs, - data, + loss_input = prepare_loss_input(next_token_logits, data, loss_fn) + loss, _ = loss_fn( + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] + data["token_mask"] * data["sample_mask"].unsqueeze(-1) ), + **loss_input, ) ## chosen and rejected errors are the same, so difference between them is 0 @@ -186,6 +177,16 @@ def test_dpo_loss(): } ) + loss_input = prepare_loss_input(next_token_logits, data, loss_fn_with_sft) + loss_sft, _ = loss_fn_with_sft( + data=data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum( + data["sample_mask"].unsqueeze(-1) * data["token_mask"] + ), + **loss_input, + ) + expected_sft_loss = ( -( torch.nn.functional.log_softmax(torch.tensor([[0.0] * vocab_size]), dim=-1)[ @@ -197,14 +198,7 @@ def test_dpo_loss(): ) expected_preference_loss = -torch.nn.functional.logsigmoid(torch.tensor(0.0)) assert torch.isclose( - loss_fn_with_sft( - token_logprobs, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), - )[0].cpu(), + loss_sft.cpu(), 0.5 * expected_sft_loss + expected_preference_loss, ) @@ -273,24 +267,26 @@ def test_dpo_loss_varying_sequence_lengths(): "sample_mask": sample_mask, } ) - token_logprobs = get_next_token_logprobs_from_logits( - data["input_ids"], next_token_logits - ) - # Compute loss - loss, metrics = dpo_loss_fn_no_avg( - token_logprobs, - data, + # Compute no averaging loss + loss_input = prepare_loss_input(next_token_logits, data, dpo_loss_fn_no_avg) + _, metrics = dpo_loss_fn_no_avg( + data=data, global_valid_seqs=torch.sum(sample_mask), global_valid_toks=torch.sum(sample_mask.unsqueeze(-1) * token_mask), + **loss_input, ) - loss_avg, metrics_avg = dpo_loss_fn_avg( - token_logprobs, - data, + + # Compute averaging loss + loss_input = prepare_loss_input(next_token_logits, data, dpo_loss_fn_avg) + _, metrics_avg = dpo_loss_fn_avg( + data=data, global_valid_seqs=torch.sum(sample_mask), global_valid_toks=torch.sum(sample_mask.unsqueeze(-1) * token_mask), + **loss_input, ) + # Compute expected losses num_unmasked_tokens = token_mask[:, 1:][::2].sum().item() logprobs = torch.nn.functional.log_softmax(next_token_logits[:, 1:], dim=-1) token_logprobs = logprobs.gather( @@ -338,16 +334,14 @@ def test_dpo_sft_matches_nll_loss(): # Compute NLL loss nll_loss_fn = NLLLossFn() - token_logprobs = get_next_token_logprobs_from_logits( - sft_data["input_ids"], next_token_logits[::2] - ) - nll_loss, nll_metrics = nll_loss_fn( - token_logprobs, - sft_data, + loss_input = prepare_loss_input(next_token_logits[::2], sft_data, nll_loss_fn) + nll_loss, _ = nll_loss_fn( + data=sft_data, global_valid_seqs=None, global_valid_toks=torch.sum( sft_data["sample_mask"].unsqueeze(-1) * torch.sum(sft_data["token_mask"]) ), + **loss_input, ) # Compute DPO loss with preference_loss_weight=0 @@ -360,16 +354,14 @@ def test_dpo_sft_matches_nll_loss(): "sft_average_log_probs": False, } ) - token_logprobs = get_next_token_logprobs_from_logits( - dpo_data["input_ids"], next_token_logits - ) - dpo_loss, dpo_metrics = dpo_loss_fn( - token_logprobs, - dpo_data, + loss_input = prepare_loss_input(next_token_logits, dpo_data, dpo_loss_fn) + dpo_loss, _ = dpo_loss_fn( + data=dpo_data, global_valid_seqs=torch.sum(dpo_data["sample_mask"]), global_valid_toks=torch.sum( dpo_data["sample_mask"].unsqueeze(-1) * dpo_data["token_mask"] ), + **loss_input, ) # Verify losses match @@ -526,13 +518,13 @@ def test_clipped_pg_loss_ppo_clipping(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss) @@ -574,15 +566,15 @@ def test_clipped_pg_loss_reinforce_mode(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss) @@ -620,15 +612,15 @@ def test_clipped_pg_loss_force_on_policy_ratio(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, metrics = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) # Loss should match the on-policy expectation @@ -731,15 +723,15 @@ def test_clipped_pg_loss_kl_penalty(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss) @@ -760,9 +752,6 @@ def test_clipped_pg_loss_masking(): ) # Need some realistic-ish logits and logprobs for masking test dummy_logits = torch.randn(batch_size, seq_len, vocab_size, device=device) - current_logprobs = get_next_token_logprobs_from_logits( - data["input_ids"], dummy_logits - ) # Ensure logprobs used by the loss fn make sense relative to advantages data["prev_logprobs"] = torch.randn_like(data["prev_logprobs"]) * 0.1 @@ -775,16 +764,17 @@ def test_clipped_pg_loss_masking(): cfg = deepcopy(basic_pg_loss_test_config) cfg["reference_policy_kl_penalty"] = 0.1 loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) # --- Test 1: Token Mask --- # Default mask: [[0, 1, 1, 1], [0, 1, 1, 1]] -> 3 tokens per sample loss_default, _ = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) # Modify token_mask for batch item 0 to mask one more token (pos 1) @@ -795,12 +785,12 @@ def test_clipped_pg_loss_masking(): ) loss_token_masked, _ = loss_fn( - current_logprobs, - data_mod_token, + data=data_mod_token, global_valid_seqs=torch.sum(data_mod_token["sample_mask"]), global_valid_toks=torch.sum( data_mod_token["sample_mask"].unsqueeze(-1) * data_mod_token["token_mask"] ), + **loss_input, ) # Loss should change if a potentially contributing token is masked assert not torch.isclose(loss_default, loss_token_masked, atol=1e-4), ( @@ -814,12 +804,12 @@ def test_clipped_pg_loss_masking(): ) # Ignore item 1 loss_sample_masked, _ = loss_fn( - current_logprobs, - data_mod_sample, + data=data_mod_sample, global_valid_seqs=torch.sum(data_mod_sample["sample_mask"]), global_valid_toks=torch.sum( data_mod_sample["sample_mask"].unsqueeze(-1) * data_mod_sample["token_mask"] ), + **loss_input, ) # Manually create data dict for only batch 0 @@ -835,16 +825,14 @@ def test_clipped_pg_loss_masking(): data_only_b0 = BatchedDataDict(data_only_b0_dict) logits_only_b0 = dummy_logits[0:1] - current_logprobs_only_b0 = get_next_token_logprobs_from_logits( - data_only_b0["input_ids"], logits_only_b0 - ) + loss_input = prepare_loss_input(logits_only_b0, data_only_b0, loss_fn) loss_only_b0, _ = loss_fn( - current_logprobs_only_b0, - data_only_b0, + data=data_only_b0, global_valid_seqs=torch.sum(data_only_b0["sample_mask"]), global_valid_toks=torch.sum( data_only_b0["sample_mask"].unsqueeze(-1) * data_only_b0["token_mask"] ), + **loss_input, ) torch.testing.assert_close(loss_sample_masked, loss_only_b0) @@ -859,24 +847,22 @@ def test_clipped_pg_loss_zero_mask(): data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) # Need dummy logits dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) - current_logprobs = get_next_token_logprobs_from_logits( - data["input_ids"], dummy_logits - ) cfg = deepcopy(basic_pg_loss_test_config) cfg["reference_policy_kl_penalty"] = 0.1 loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) # Set token mask to all zeros data["token_mask"] = torch.zeros_like(data["token_mask"]) loss, _ = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) # Loss should be exactly zero @@ -1016,13 +1002,13 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_total_loss, atol=1e-4, rtol=1e-3) @@ -1149,13 +1135,13 @@ def test_clipped_pg_loss_on_policy_truncated_importance_sampling( dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss, atol=1e-4, rtol=1e-3) @@ -1197,11 +1183,12 @@ def test_clipped_pg_loss_icepop_importance_sampling(): dummy_logits = _create_exact_logits( prev_lp, data["input_ids"], batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss, atol=1e-4, rtol=1e-3) @@ -1240,21 +1227,22 @@ def test_clipped_pg_loss_seq_mask_tis(): dummy_logits = _create_exact_logits( prev_lp, data["input_ids"], batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss, atol=1e-4, rtol=1e-3) # nan_to_num: inject -inf → loss must stay finite data["generation_logprobs"][0, 2] = float("-inf") actual_loss2, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) assert not torch.isnan(actual_loss2), "Loss is NaN — nan_to_num fix not working" assert not torch.isinf(actual_loss2), "Loss is inf — nan_to_num fix not working" @@ -1371,15 +1359,15 @@ def test_clipped_pg_loss_dual_clip(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss) @@ -1421,15 +1409,13 @@ def test_clipped_pg_loss_entropy(): dummy_logits = _create_exact_logits( curr_lp_masked, data["input_ids"], batch_size, seq_len, vocab_size, device ) - current_logprobs = get_next_token_logprobs_from_logits( - data["input_ids"], dummy_logits - ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) _, metrics = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close( @@ -1508,13 +1494,13 @@ def test_clipped_pg_loss_gspo(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss) @@ -1607,15 +1593,15 @@ def test_clipped_pg_loss_gspo_batch_size_2(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(1) * data["token_mask"] ), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss) @@ -1709,13 +1695,13 @@ def test_clipped_pg_loss_gspo_importance_sampling_correction(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - current_logprobs = get_next_token_logprobs_from_logits(input_ids, dummy_logits) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - current_logprobs, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_actor_loss, atol=1e-4, rtol=1e-3) @@ -1772,22 +1758,14 @@ def test_distillation_loss_different_settings(kl_type, zero_outside_topk): } ) - calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" - loss_fn_args = get_distillation_topk_logprobs_from_logits( - student_logits=student_logits, - teacher_topk_logits=data["teacher_topk_logits"], - teacher_topk_indices=data["teacher_topk_indices"], - zero_outside_topk=loss_fn.zero_outside_topk, - calculate_entropy=calculate_entropy, - ) - + loss_input = prepare_loss_input(student_logits, data, loss_fn) loss, metrics = loss_fn( - *loss_fn_args, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) # Verify loss @@ -1825,22 +1803,14 @@ def test_distillation_loss_topk_filtering(k, zero_outside_topk): } ) - calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" - loss_fn_args = get_distillation_topk_logprobs_from_logits( - student_logits=student_logits, - teacher_topk_logits=data["teacher_topk_logits"], - teacher_topk_indices=data["teacher_topk_indices"], - zero_outside_topk=loss_fn.zero_outside_topk, - calculate_entropy=calculate_entropy, - ) - + loss_input = prepare_loss_input(student_logits, data, loss_fn) loss, _ = loss_fn( - *loss_fn_args, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) # Verify loss is calculated correctly with top-k filtering @@ -1873,14 +1843,7 @@ def test_distillation_loss_invalid_k_zero(): # This should raise a ValueError for k=0 with pytest.raises(ValueError, match="topk must be positive"): - calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" - _ = get_distillation_topk_logprobs_from_logits( - student_logits=student_logits, - teacher_topk_logits=data["teacher_topk_logits"], - teacher_topk_indices=data["teacher_topk_indices"], - zero_outside_topk=loss_fn.zero_outside_topk, - calculate_entropy=calculate_entropy, - ) + _ = prepare_loss_input(student_logits, data, loss_fn) def test_distillation_loss_gradient_flow(): @@ -1898,22 +1861,14 @@ def test_distillation_loss_gradient_flow(): } ) - calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" - loss_fn_args = get_distillation_topk_logprobs_from_logits( - student_logits=student_logits, - teacher_topk_logits=data["teacher_topk_logits"], - teacher_topk_indices=data["teacher_topk_indices"], - zero_outside_topk=loss_fn.zero_outside_topk, - calculate_entropy=calculate_entropy, - ) - + loss_input = prepare_loss_input(student_logits, data, loss_fn) loss, _ = loss_fn( - *loss_fn_args, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) # Compute gradients @@ -1940,64 +1895,42 @@ def test_distillation_loss_edge_cases(): # Test with all-zero logits zero_logits = torch.zeros_like(student_logits) - calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" - loss_fn_args = get_distillation_topk_logprobs_from_logits( - student_logits=zero_logits, - teacher_topk_logits=data["teacher_topk_logits"], - teacher_topk_indices=data["teacher_topk_indices"], - zero_outside_topk=loss_fn.zero_outside_topk, - calculate_entropy=calculate_entropy, - ) - + loss_input = prepare_loss_input(zero_logits, data, loss_fn) loss, _ = loss_fn( - *loss_fn_args, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) assert not torch.isnan(loss) assert not torch.isinf(loss) # Test with very large logits large_logits = torch.ones_like(student_logits) * 100.0 - loss_fn_args = get_distillation_topk_logprobs_from_logits( - student_logits=large_logits, - teacher_topk_logits=data["teacher_topk_logits"], - teacher_topk_indices=data["teacher_topk_indices"], - zero_outside_topk=loss_fn.zero_outside_topk, - calculate_entropy=calculate_entropy, - ) - + loss_input = prepare_loss_input(large_logits, data, loss_fn) loss, _ = loss_fn( - *loss_fn_args, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) assert not torch.isnan(loss) assert not torch.isinf(loss) # Test with very small logits small_logits = torch.ones_like(student_logits) * -100.0 - loss_fn_args = get_distillation_topk_logprobs_from_logits( - student_logits=small_logits, - teacher_topk_logits=data["teacher_topk_logits"], - teacher_topk_indices=data["teacher_topk_indices"], - zero_outside_topk=loss_fn.zero_outside_topk, - calculate_entropy=calculate_entropy, - ) - + loss_input = prepare_loss_input(small_logits, data, loss_fn) loss, _ = loss_fn( - *loss_fn_args, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) assert not torch.isnan(loss) assert not torch.isinf(loss) @@ -2040,22 +1973,14 @@ def test_distillation_loss_fn_call(): } ) - calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" - loss_fn_args = get_distillation_topk_logprobs_from_logits( - student_logits=student_logits, - teacher_topk_logits=data["teacher_topk_logits"], - teacher_topk_indices=data["teacher_topk_indices"], - zero_outside_topk=loss_fn.zero_outside_topk, - calculate_entropy=calculate_entropy, - ) - + loss_input = prepare_loss_input(student_logits, data, loss_fn) loss, metrics = loss_fn( - *loss_fn_args, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) # Verify return types From c450c510c9de83fa2a4f2aec757ddd475731ecae Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 26 Feb 2026 07:53:43 -0800 Subject: [PATCH 10/15] update megatron Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss/__init__.py | 6 ++- nemo_rl/algorithms/loss/utils.py | 11 ++++- ...sequence_packing_wrapper.py => wrapper.py} | 42 +++++++++++++++++-- nemo_rl/models/automodel/train.py | 4 +- nemo_rl/models/megatron/train.py | 32 ++++++++++---- 5 files changed, 79 insertions(+), 16 deletions(-) rename nemo_rl/algorithms/loss/{sequence_packing_wrapper.py => wrapper.py} (78%) diff --git a/nemo_rl/algorithms/loss/__init__.py b/nemo_rl/algorithms/loss/__init__.py index d44d0e03d2..d794d86c2e 100644 --- a/nemo_rl/algorithms/loss/__init__.py +++ b/nemo_rl/algorithms/loss/__init__.py @@ -26,8 +26,11 @@ PreferenceLossDataDict, PreferenceLossFn, ) -from nemo_rl.algorithms.loss.sequence_packing_wrapper import SequencePackingLossWrapper from nemo_rl.algorithms.loss.utils import prepare_loss_input +from nemo_rl.algorithms.loss.wrapper import ( + SequencePackingLossWrapper, + wrap_loss_fn_with_input_preparation, +) __all__ = [ "ClippedPGLossConfig", @@ -44,4 +47,5 @@ "PreferenceLossFn", "SequencePackingLossWrapper", "prepare_loss_input", + "wrap_loss_fn_with_input_preparation", ] diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index 05e75ebd46..569c958bd3 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import torch @@ -28,6 +28,9 @@ def prepare_loss_input( logits: torch.Tensor, data: BatchedDataDict[Any], loss_fn: LossFunction, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> dict[str, Any]: """Prepare loss input for a loss function. @@ -47,6 +50,9 @@ def prepare_loss_input( input_ids=data["input_ids"], next_token_logits=logits, seq_index=data.get("seq_index", None), + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, ) loss_input = {"next_token_logprobs": logprobs} @@ -60,6 +66,9 @@ def prepare_loss_input( teacher_topk_indices=data["teacher_topk_indices"], zero_outside_topk=loss_fn.zero_outside_topk, calculate_entropy=calculate_entropy, + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, ) ) diff --git a/nemo_rl/algorithms/loss/sequence_packing_wrapper.py b/nemo_rl/algorithms/loss/wrapper.py similarity index 78% rename from nemo_rl/algorithms/loss/sequence_packing_wrapper.py rename to nemo_rl/algorithms/loss/wrapper.py index b510d7dddb..f8ac4819d9 100644 --- a/nemo_rl/algorithms/loss/sequence_packing_wrapper.py +++ b/nemo_rl/algorithms/loss/wrapper.py @@ -94,9 +94,12 @@ def __call__( # prepare data for loss function loss_input = self.prepare_fn( - next_token_logits_slice, - unpadded_seq_data, - self.loss_fn, + logits=next_token_logits_slice, + data=unpadded_seq_data, + loss_fn=self.loss_fn, + vocab_parallel_rank=self.vocab_parallel_rank, + vocab_parallel_group=self.vocab_parallel_group, + context_parallel_group=self.context_parallel_group, ) # call loss function @@ -131,3 +134,36 @@ def __call__( metrics_accum[k] += val return loss_accum, metrics_accum + + +def wrap_loss_fn_with_input_preparation( + next_token_logits: Tensor, + data: BatchedDataDict[Any], + global_valid_seqs: Tensor | None, + global_valid_toks: Tensor | None, + loss_fn: LossFunction, + prepare_fn: Callable[Any, Any], + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, +) -> tuple[Tensor, dict[str, Any]]: + """Wraps a loss function to handle input preparation for megatron policy worker.""" + # prepare loss input + loss_input = prepare_fn( + logits=next_token_logits, + data=data, + loss_fn=loss_fn, + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + ) + + # call loss function + loss, loss_metrics = loss_fn( + data=data, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + **loss_input, + ) + + return loss, loss_metrics diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 668f0d0b92..32e3386ab6 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -515,13 +515,13 @@ def __call__( # Wrap loss function for sequence packing if needed if self.enable_seq_packing: - loss_fn_ = SequencePackingLossWrapper( + loss_fn = SequencePackingLossWrapper( loss_fn=self.loss_fn, prepare_fn=prepare_loss_input, cu_seqlens_q=processed_inputs.flash_attn_kwargs.cu_seqlens_q, cu_seqlens_q_padded=processed_inputs.flash_attn_kwargs.cu_seqlens_q, ) - loss, loss_metrics = loss_fn_( + loss, loss_metrics = loss_fn( logits, mb, global_valid_seqs, diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index de78b04172..5535c9025c 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -29,7 +29,12 @@ from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.utils import StragglerDetector -from nemo_rl.algorithms.loss import LossFunction, SequencePackingLossWrapper +from nemo_rl.algorithms.loss import ( + SequencePackingLossWrapper, + prepare_loss_input, + wrap_loss_fn_with_input_preparation, +) +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, @@ -302,24 +307,33 @@ def __call__( Returns: Callable: Function that takes output tensor and returns (loss, metrics) tuple """ - loss_fn = self.loss_fn + # wrap loss function with loss input preparation pack_sequences = self.cfg["sequence_packing"]["enabled"] if pack_sequences and packed_seq_params is not None: - # remove padding - loss_fn = SequencePackingLossWrapper( - loss_fn=loss_fn, + loss_fn_wrapped = SequencePackingLossWrapper( + loss_fn=self.loss_fn, + prepare_fn=prepare_loss_input, cu_seqlens_q=packed_seq_params.cu_seqlens_q, cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded, + vocab_parallel_rank=get_tensor_model_parallel_rank(), + vocab_parallel_group=get_tensor_model_parallel_group(), + context_parallel_group=get_context_parallel_group(), + ) + else: + loss_fn_wrapped = partial( + wrap_loss_fn_with_input_preparation, + loss_fn=self.loss_fn, + prepare_fn=prepare_loss_input, + vocab_parallel_rank=get_tensor_model_parallel_rank(), + vocab_parallel_group=get_tensor_model_parallel_group(), + context_parallel_group=get_context_parallel_group(), ) loss_fn_wrapped = partial( - loss_fn, + loss_fn_wrapped, data=data_dict, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, - vocab_parallel_rank=get_tensor_model_parallel_rank(), - vocab_parallel_group=get_tensor_model_parallel_group(), - context_parallel_group=get_context_parallel_group(), ) if self.cp_normalize: From 0e48e746ccee170c7e2f5335e6c3cca77cac1f05 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 26 Feb 2026 07:58:31 -0800 Subject: [PATCH 11/15] update dtensor v1 Signed-off-by: Yuki Huang --- .../policy/workers/dtensor_policy_worker.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 0dec2f3aa9..661254da23 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -46,7 +46,7 @@ ) from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM -from nemo_rl.algorithms.loss import SequencePackingLossWrapper +from nemo_rl.algorithms.loss import SequencePackingLossWrapper, prepare_loss_input from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( @@ -776,20 +776,28 @@ def train( placements=[Shard(sequence_dim), Shard(-1)], ) + # Wrap loss function for sequence packing if needed if self.enable_seq_packing: loss_fn_ = SequencePackingLossWrapper( loss_fn=loss_fn, + prepare_fn=prepare_loss_input, cu_seqlens_q=flash_attn_kwargs.cu_seqlens_q, cu_seqlens_q_padded=flash_attn_kwargs.cu_seqlens_q, ) + loss, loss_metrics = loss_fn_( + logits, + mb, + global_valid_seqs, + global_valid_toks, + ) else: - loss_fn_ = loss_fn - loss, loss_metrics = loss_fn_( - logits, - mb, - global_valid_seqs, - global_valid_toks, - ) + loss_input = prepare_loss_input(logits, mb, loss_fn) + loss, loss_metrics = loss_fn( + data=mb, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + **loss_input, + ) del logits # skip the update for dummy batches From 000231e8a1e0c9f03c2eda65ba8b19b91519b52d Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 26 Feb 2026 20:31:53 -0800 Subject: [PATCH 12/15] fix test Signed-off-by: Yuki Huang --- docs/design-docs/loss-functions.md | 2 +- docs/guides/grpo.md | 2 +- docs/guides/prorlv2.md | 4 ++-- .../sequence_packing_gradient_actor.py | 24 ++++++++++--------- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/docs/design-docs/loss-functions.md b/docs/design-docs/loss-functions.md index 7cc0e9cbf2..37cf868ad8 100644 --- a/docs/design-docs/loss-functions.md +++ b/docs/design-docs/loss-functions.md @@ -17,7 +17,7 @@ $$ which is, in general, not equivalent to the full-batch loss. To fix this, we need each microbatch to have information about how many tokens are in the other microbatches in the global batch. -In NeMo RL, this information is passed to the loss function directly. Each loss function is expected to fall into one of two categories, token-level or sequence-level, which is an attribute of the loss function itself (see [loss_functions.py](../../nemo_rl/algorithms/loss_functions.py) for some examples). The policy then uses this information to compute the global normalization factor using the full batch (for token-level losses, this is the total number of tokens in the batch. For sequence-level losses, this is the number of valid sequences in the batch). The normalization factor is then passed to the loss function, which uses it to normalize the microbatch loss. To get the loss for the global batch, the policy simply sums across all microbatch losses. +In NeMo RL, this information is passed to the loss function directly. Each loss function is expected to fall into one of two categories, token-level or sequence-level, which is an attribute of the loss function itself (see [loss_functions.py](../../nemo_rl/algorithms/loss/loss_functions.py) for some examples). The policy then uses this information to compute the global normalization factor using the full batch (for token-level losses, this is the total number of tokens in the batch. For sequence-level losses, this is the number of valid sequences in the batch). The normalization factor is then passed to the loss function, which uses it to normalize the microbatch loss. To get the loss for the global batch, the policy simply sums across all microbatch losses. For our simple example above, this would look like: diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 2576b93303..2e7b410de9 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -343,7 +343,7 @@ The function, [grpo_train](../../nemo_rl/algorithms/grpo.py), contains the core RL generations typically produce highly variable sequence lengths, which result in a significant amount of padding if approached naively. We address this with Sequence Packing and Dynamic Batching, which are techniques to reduce the amount of padding required. You can read more about these in the [design doc](../design-docs/sequence-packing-and-dynamic-batching.md). ## Loss -We use the [ClippedPGLossFn](../../nemo_rl/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally, +We use the [ClippedPGLossFn](../../nemo_rl/algorithms/loss/loss_functions.py) to calculate the loss for GRPO. Formally, $$ L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) diff --git a/docs/guides/prorlv2.md b/docs/guides/prorlv2.md index 795bb0d08a..98d6142c12 100644 --- a/docs/guides/prorlv2.md +++ b/docs/guides/prorlv2.md @@ -106,7 +106,7 @@ loss_fn: This keeps PPO/GRPO-style clipping behavior but allows a larger expansion region than the contraction region, which can help exploration and reduce early collapse. -- **Implementation**: `ClippedPGLossFn` documents decoupled clipping in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py). +- **Implementation**: `ClippedPGLossFn` documents decoupled clipping in [`nemo_rl/algorithms/loss/loss_functions.py`](../../nemo_rl/algorithms/loss/loss_functions.py). ## Loss: Token-level Policy Gradient @@ -153,7 +153,7 @@ loss_fn: - `"icepop"`: set weights outside \([min, max]\) to zero (filter outliers) - `"seq-mask-tis"`: sequence-level geometric-mean mask + non-truncated token-level IS correction (see below) -- **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py). +- **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss/loss_functions.py`](../../nemo_rl/algorithms/loss/loss_functions.py). ### Seq-mask-tis: Sequence-level Geometric-Mean Mask diff --git a/tests/unit/algorithms/sequence_packing_gradient_actor.py b/tests/unit/algorithms/sequence_packing_gradient_actor.py index e8f86f9413..a5e750d358 100644 --- a/tests/unit/algorithms/sequence_packing_gradient_actor.py +++ b/tests/unit/algorithms/sequence_packing_gradient_actor.py @@ -24,6 +24,7 @@ import torch from nemo_rl.algorithms.loss import ClippedPGLossFn, SequencePackingLossWrapper +from nemo_rl.algorithms.loss.utils import prepare_loss_input from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -149,11 +150,12 @@ def test_sequence_packing_gradients(self): global_valid_seqs = torch.tensor(batch_size, dtype=torch.float, device="cuda") # Forward pass - baseline_loss, baseline_metrics = base_loss_fn( - baseline_logits, - data_dict, - global_valid_seqs, - global_valid_toks, + loss_input = prepare_loss_input(baseline_logits, data_dict, base_loss_fn) + baseline_loss, _ = base_loss_fn( + data=data_dict, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + **loss_input, ) # Backward pass @@ -215,26 +217,26 @@ def make_packed_logits(logits): packed_logits = make_packed_logits(baseline_logits) # Create sequence packing wrapper + tp_group = torch.distributed.new_group(ranks=[rank]) wrapper = SequencePackingLossWrapper( loss_fn=base_loss_fn, + prepare_fn=prepare_loss_input, cu_seqlens_q=cu_seqlens, cu_seqlens_q_padded=cu_seqlens_padded, + vocab_parallel_rank=0, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, ) # Create data dict for packed sequences packed_data_dict = BatchedDataDict(original_data) - tp_group = torch.distributed.new_group(ranks=[rank]) - # Forward pass - packed_loss, packed_metrics = wrapper( + packed_loss, _ = wrapper( packed_logits, packed_data_dict, global_valid_seqs, global_valid_toks, - vocab_parallel_rank=0, - vocab_parallel_group=tp_group, - context_parallel_group=cp_group, ) # Backward pass From aa9b1f4adeb1ef395da0730ee39eb254f113d646 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 27 Feb 2026 01:35:08 -0800 Subject: [PATCH 13/15] fix PreferenceLossFn and unit test Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss/loss_functions.py | 4 ++-- tests/unit/models/megatron/test_train.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index 23c9390341..812d2917b6 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -675,14 +675,14 @@ def _preference_loss( def __call__( self, - rewards: Tensor, + logits: Tensor, data: BatchedDataDict[PreferenceLossDataDict], global_valid_seqs: Tensor, global_valid_toks: Tensor | None, ) -> tuple[torch.Tensor, dict[str, Any]]: sample_mask = data["sample_mask"] - rewards = rewards.squeeze(-1) + rewards = logits.squeeze(-1) ( preference_loss, diff --git a/tests/unit/models/megatron/test_train.py b/tests/unit/models/megatron/test_train.py index 24dda67eec..b80d1a7986 100644 --- a/tests/unit/models/megatron/test_train.py +++ b/tests/unit/models/megatron/test_train.py @@ -27,6 +27,8 @@ import pytest import torch +from nemo_rl.algorithms.loss.interfaces import LossInputType + class TestModelForward: """Tests for model_forward function.""" @@ -685,6 +687,7 @@ def test_loss_post_processor_no_packing( from nemo_rl.models.megatron.train import LossPostProcessor mock_loss_fn = MagicMock(return_value=(torch.tensor(0.5), {"loss": 0.5})) + mock_loss_fn.input_type = LossInputType.LOGIT cfg = {"sequence_packing": {"enabled": False}} processor = LossPostProcessor(loss_fn=mock_loss_fn, cfg=cfg, cp_normalize=False) @@ -723,6 +726,7 @@ def test_loss_post_processor_with_cp_normalize( from nemo_rl.models.megatron.train import LossPostProcessor mock_loss_fn = MagicMock(return_value=(torch.tensor(1.0), {})) + mock_loss_fn.input_type = LossInputType.LOGIT cfg = {"sequence_packing": {"enabled": False}} processor = LossPostProcessor( From 0eddfa024825d96c532f1ebf041695bc96c3250f Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 27 Feb 2026 07:57:54 -0800 Subject: [PATCH 14/15] fix unit test Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss/__init__.py | 2 +- .../unit/models/automodel/test_automodel_train.py | 14 +++++++++----- .../unit/models/generation/test_vllm_generation.py | 10 +++++----- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/nemo_rl/algorithms/loss/__init__.py b/nemo_rl/algorithms/loss/__init__.py index d794d86c2e..163ce71a24 100644 --- a/nemo_rl/algorithms/loss/__init__.py +++ b/nemo_rl/algorithms/loss/__init__.py @@ -45,7 +45,7 @@ "NLLLossFn", "PreferenceLossDataDict", "PreferenceLossFn", - "SequencePackingLossWrapper", "prepare_loss_input", + "SequencePackingLossWrapper", "wrap_loss_fn_with_input_preparation", ] diff --git a/tests/unit/models/automodel/test_automodel_train.py b/tests/unit/models/automodel/test_automodel_train.py index a2dfddf9e3..9eb5d72bb9 100644 --- a/tests/unit/models/automodel/test_automodel_train.py +++ b/tests/unit/models/automodel/test_automodel_train.py @@ -24,6 +24,7 @@ except ImportError: pytest.skip("nemo_automodel not available", allow_module_level=True) +from nemo_rl.algorithms.loss.interfaces import LossInputType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.automodel.data import ( ProcessedInputs, @@ -63,6 +64,7 @@ def mock_model(): def mock_loss_fn(): loss_fn = MagicMock() loss_fn.return_value = (torch.tensor(0.5), {"loss": 0.5}) + loss_fn.input_type = LossInputType.LOGIT return loss_fn @@ -310,10 +312,10 @@ def test_basic_loss_computation( # Verify loss function was called mock_loss_fn.assert_called_once() - call_args = mock_loss_fn.call_args[0] - assert torch.is_tensor(call_args[0]) # logits - assert call_args[2] == global_valid_seqs # global_valid_seqs - assert call_args[3] == global_valid_toks # global_valid_toks + call_kwargs = mock_loss_fn.call_args[1] + assert torch.is_tensor(call_kwargs["logits"]) + assert call_kwargs["global_valid_seqs"] == global_valid_seqs + assert call_kwargs["global_valid_toks"] == global_valid_toks @patch("nemo_rl.models.automodel.train.SequencePackingLossWrapper") def test_loss_with_sequence_packing( @@ -1896,10 +1898,12 @@ def forward(self, input_ids, **kwargs): ) # Create loss function that returns requires_grad tensor - def loss_fn(logits, mb, global_valid_seqs, global_valid_toks): + def loss_fn(logits, data, global_valid_seqs, global_valid_toks): loss = logits.mean() return loss, {"loss": loss.item()} + loss_fn.input_type = LossInputType.LOGIT + # Create loss post-processor loss_post_processor = LossPostProcessor( loss_fn=loss_fn, diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 85cd0cce52..ac5d2484ab 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -894,13 +894,13 @@ async def run_hf_train_process( @pytest.mark.parametrize( ("async_engine", "cpu_offload", "vllm_precision", "enable_lora"), [ - # (True, False, "bfloat16", False), - # (False, True, "bfloat16", False), - # (True, False, "fp8", False), - # (False, True, "fp8", False), + (True, False, "bfloat16", False), + (False, True, "bfloat16", False), + (True, False, "fp8", False), + (False, True, "fp8", False), # LoRA tests (False, False, "bfloat16", True), - # (True, False, "bfloat16", True), + (True, False, "bfloat16", True), ], ) async def test_vllm_generation_with_hf_training_colocated( From 443d7ad0205b656f083736e24ba7bb2a75def50a Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 1 Mar 2026 18:48:32 -0800 Subject: [PATCH 15/15] address comments Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss/interfaces.py | 13 ++++++------- nemo_rl/algorithms/loss/utils.py | 5 +++++ nemo_rl/algorithms/loss/wrapper.py | 16 ++++++++++++++++ 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/nemo_rl/algorithms/loss/interfaces.py b/nemo_rl/algorithms/loss/interfaces.py index 58b4c3f48f..f1c0db3e35 100644 --- a/nemo_rl/algorithms/loss/interfaces.py +++ b/nemo_rl/algorithms/loss/interfaces.py @@ -43,30 +43,29 @@ class LossFunction(Protocol): def __call__( self, - next_token_logits: torch.Tensor, data: BatchedDataDict, global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, + **kwargs: Any, ) -> tuple[torch.Tensor, dict[str, Any]]: """Compute loss and metrics from logprobs and other data. Args: - next_token_logits: Logits from the model, typically with shape [batch_size, seq_len, vocab_size]. - For each position (b, i), contains the logit distribution over the entire vocabulary - for predicting the next token (at position i+1). For example, if processing "The cat sat on", - then next_token_logits[b, 3] would contain the logits for predicting the word - that follows "on". data: Dictionary containing all relevant data for loss computation such as rewards, values, actions, advantages, masks, and other algorithm-specific information needed for the particular loss calculation. global_valid_seqs: torch.Tensor - this tensor should contain the number of valid sequences in the microbatch. + This tensor should contain the number of valid sequences in the microbatch. It's used for global normalization for losses/metrics that are computed at the sequence level and needs to be aggregated across all microbatches. global_valid_toks: torch.Tensor This tensor should contain the number of valid tokens in the microbatch. It's used for global normalization for losses/metrics that are computed at the token level and needs to be aggregated across all microbatches. + **kwargs: Loss function input, which varies by input_type: + - For LossInputType.LOGPROB: next_token_logprobs (torch.Tensor) + - For LossInputType.LOGIT: logits (torch.Tensor) + - For LossInputType.DISTILLATION: student_topk_logprobs, teacher_topk_logprobs, H_all (torch.Tensor) Returns: tuple: (loss, metrics) diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index 569c958bd3..359641ae09 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -38,6 +38,11 @@ def prepare_loss_input( logits: Logits from the model. data: Microbatch data. loss_fn: Loss function. + vocab_parallel_rank: Vocab parallel rank. + vocab_parallel_group: Vocab parallel group. + context_parallel_group: Context parallel group. + + vocab_parallel_rank, vocab_parallel_group, context_parallel_group are only used for megatron policy worker. Returns: Loss input. diff --git a/nemo_rl/algorithms/loss/wrapper.py b/nemo_rl/algorithms/loss/wrapper.py index f8ac4819d9..39e8b12814 100644 --- a/nemo_rl/algorithms/loss/wrapper.py +++ b/nemo_rl/algorithms/loss/wrapper.py @@ -35,6 +35,22 @@ def __init__( vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ): + """Wrap a loss function to handle sequence packing. + + Args: + loss_fn: Loss function. + prepare_fn: Prepare function. + cu_seqlens_q: Unpadded cu seqlens q. + cu_seqlens_q_padded: Padded cu seqlens q. + vocab_parallel_rank: Vocab parallel rank. + vocab_parallel_group: Vocab parallel group. + context_parallel_group: Context parallel group. + + vocab_parallel_rank, vocab_parallel_group, context_parallel_group are only used for megatron policy worker. + + Returns: + Sequence packing loss wrapper. + """ self.loss_fn = loss_fn self.prepare_fn = prepare_fn self.cu_seqlens_q = cu_seqlens_q