diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index ed778dc392..d0a7b05be7 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -325,6 +325,7 @@ def __call__( token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) + seq_index = data.get("seq_index", None) next_token_logits = next_token_logits.to(torch.float32) @@ -346,7 +347,7 @@ def __call__( token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"] + next_token_logits, data["input_ids"], seq_index=seq_index ) else: next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token @@ -580,6 +581,7 @@ def _dpo_loss( ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] + seq_index = data.get("seq_index", None) next_token_logits = next_token_logits.to(torch.float32) if vocab_parallel_group is not None: @@ -599,7 +601,7 @@ def _dpo_loss( token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"] + next_token_logits, data["input_ids"], seq_index=seq_index ) else: next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token