From cbc1c938fd98b21757dafcacd9150d4508c1cfeb Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Tue, 26 Aug 2025 19:48:23 -0700 Subject: [PATCH 1/2] fused loss Signed-off-by: Jimmy Zhang --- nemo_rl/distributed/model_utils.py | 147 ++++++++++++++++++----------- 1 file changed, 90 insertions(+), 57 deletions(-) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 29cc5eb6b7..1ae43da4c2 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -55,6 +55,61 @@ def _compute_distributed_log_softmax( return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype) +@torch.compile +def distributed_logprob_forward(vocab_parallel_logits, target, vocab_start_index, vocab_end_index, tp_group): + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + + vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32) + + log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=tp_group) + softmax_output = log_probs.exp() + + log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) + log_probs[target_mask] = 0.0 + + torch.distributed.all_reduce( + log_probs, + op=torch.distributed.ReduceOp.SUM, + group=tp_group, + ) + return log_probs, softmax_output + + +@torch.compile +def distributed_logprob_backward(softmax, target, grad_output, vocab_start_index, vocab_end_index): + B, S, V = softmax.shape + + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + + # skip `torch.nn.functional.one_hot` + row = ( + torch.arange(B, device=softmax.device) + .view(-1, 1) + .expand(-1, S) + .reshape(-1) + ) + col = torch.arange(S, device=softmax.device).expand(B, -1).reshape(-1) + flat_idx = (row * S + col) * V + + flat_chosen = flat_idx.masked_select( + ~target_mask.reshape(-1) + ) + masked_target.masked_select(~target_mask) + + # `neg` is zero-copy + grad_input = softmax.neg() + grad_input = grad_input.mul_(grad_output.unsqueeze(-1)) + + grad_output_selected = grad_output.masked_select(~target_mask) + grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) + + return grad_input + + class DistributedLogprob(torch.autograd.Function): """Custom autograd function for computing log probabilities in a distributed setting. @@ -72,28 +127,16 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func group: torch.distributed.ProcessGroup, inference_only: bool = False, ) -> torch.Tensor: - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - masked_target = target - vocab_start_index - masked_target[target_mask] = 0 - - vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32) - - log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=group) - softmax_output = log_probs.exp() - - log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) - log_probs[target_mask] = 0.0 - - torch.distributed.all_reduce( - log_probs, - op=torch.distributed.ReduceOp.SUM, - group=group, + + log_probs, softmax_output = distributed_logprob_forward( + vocab_parallel_logits, target, vocab_start_index, vocab_end_index, group ) - if not inference_only: # only save for backward when we have inference only=False - ctx.save_for_backward(softmax_output, target_mask, masked_target) + ctx.save_for_backward(softmax_output, target) + ctx.vocab_start_index = vocab_start_index + ctx.vocab_end_index = vocab_end_index + return log_probs @@ -103,39 +146,18 @@ def backward( *grad_outputs: torch.Tensor, ) -> tuple[torch.Tensor, None, None, None, None, None, None]: grad_output = grad_outputs[0] - softmax, target_mask, masked_target = ctx.saved_tensors - - if softmax.ndim == 3: - B, S, V = softmax.shape - - # skip `torch.nn.functional.one_hot` - row = ( - torch.arange(B, device=softmax.device) - .view(-1, 1) - .expand(-1, S) - .reshape(-1) - ) - col = torch.arange(S, device=softmax.device).expand(B, -1).reshape(-1) - flat_idx = (row * S + col) * V - - flat_chosen = flat_idx.masked_select( - ~target_mask.reshape(-1) - ) + masked_target.masked_select(~target_mask) - - # `neg` is zero-copy - grad_input = softmax.neg() - grad_input = grad_input.mul_(grad_output.unsqueeze(-1)) - - grad_output_selected = grad_output.masked_select(~target_mask) - grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) - else: - V = softmax.size(-1) - is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( - masked_target, num_classes=V - ) - grad_input = is_chosen.float().sub_(softmax) - grad_input.mul_(grad_output.unsqueeze(-1)) + softmax, target = ctx.saved_tensors + vocab_start_index = ctx.vocab_start_index + vocab_end_index = ctx.vocab_end_index + assert softmax.ndim == 3 + grad_input = distributed_logprob_backward( + softmax, + target, + grad_output, + vocab_start_index, + vocab_end_index, + ) # if you add an argument to the forward method, then you must add a corresponding None here return grad_input, None, None, None, None, None, None @@ -162,11 +184,6 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func tp_group: torch.distributed.ProcessGroup, inference_only: bool = False, ) -> torch.Tensor: - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - masked_target = target - vocab_start_index - masked_target[target_mask] = 0 - seq_size = int(vocab_parallel_logits.shape[1]) num_chunks = (seq_size + chunk_size - 1) // chunk_size all_log_probs = [] @@ -175,6 +192,22 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func chunk_start = chunk_idx * chunk_size chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + vocab_parallel_logits_chunk = vocab_parallel_logits[:, chunk_start:chunk_end, :] + target_chunk = target[:, chunk_start:chunk_end] + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask_chunk = (target_chunk < vocab_start_index) | (target_chunk >= vocab_end_index) + masked_target_chunk = target - vocab_start_index + masked_target_chunk[target_mask_chunk] = 0 + + distributed_logprob_forward( + vocab_parallel_logits_chunk, + target_chunk, + vocab_start_index, + vocab_end_index, + tp_group, + ) + logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] logits = logits.to(dtype=torch.float32) @@ -229,7 +262,7 @@ def backward( logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] logits = logits.to(dtype=torch.float32) - softmax_output = _compute_distributed_log_softmax( + softmax_output = _distributed_logprob_forward( logits, group=tp_group, ) From d47e43cf8434d1f2a44d1903b801e472a414fb36 Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Wed, 27 Aug 2025 13:13:56 -0700 Subject: [PATCH 2/2] fully fused logprob Signed-off-by: Jimmy Zhang --- nemo_rl/distributed/model_utils.py | 81 ++++++++++++------------------ 1 file changed, 33 insertions(+), 48 deletions(-) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 1ae43da4c2..491a2ea36e 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -18,7 +18,6 @@ from torch.distributed.tensor import DTensor, distribute_tensor -@torch.no_grad() def _compute_distributed_log_softmax( vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup ) -> torch.Tensor: @@ -36,26 +35,29 @@ def _compute_distributed_log_softmax( log probabilities normalized across the full vocabulary dimension. """ logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True) - torch.distributed.all_reduce( - logits_max, - op=torch.distributed.ReduceOp.MAX, - group=group, - ) + + if group is not None: + torch.distributed.all_reduce( + logits_max, + op=torch.distributed.ReduceOp.MAX, + group=group, + ) # Subtract the maximum value. vocab_parallel_logits = vocab_parallel_logits - logits_max sum_exp_logits = vocab_parallel_logits.exp().sum(-1, keepdim=True).float() - torch.distributed.all_reduce( - sum_exp_logits, - op=torch.distributed.ReduceOp.SUM, - group=group, - ) + if group is not None: + torch.distributed.all_reduce( + sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=group, + ) return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype) -@torch.compile +@torch.compile(fullgraph=True) def distributed_logprob_forward(vocab_parallel_logits, target, vocab_start_index, vocab_end_index, tp_group): # Create a mask of valid vocab ids (1 means it needs to be masked). target_mask = (target < vocab_start_index) | (target >= vocab_end_index) @@ -70,46 +72,30 @@ def distributed_logprob_forward(vocab_parallel_logits, target, vocab_start_index log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) log_probs[target_mask] = 0.0 - torch.distributed.all_reduce( - log_probs, - op=torch.distributed.ReduceOp.SUM, - group=tp_group, - ) + if tp_group is not None: + torch.distributed.all_reduce( + log_probs, + op=torch.distributed.ReduceOp.SUM, + group=tp_group, + ) return log_probs, softmax_output - -@torch.compile -def distributed_logprob_backward(softmax, target, grad_output, vocab_start_index, vocab_end_index): - B, S, V = softmax.shape +@torch.compile(fullgraph=True) +def distributed_logprob_backward(grad_output, target, softmax, vocab_start_index, vocab_end_index): + V = softmax.size(-1) target_mask = (target < vocab_start_index) | (target >= vocab_end_index) masked_target = target - vocab_start_index masked_target[target_mask] = 0 - - # skip `torch.nn.functional.one_hot` - row = ( - torch.arange(B, device=softmax.device) - .view(-1, 1) - .expand(-1, S) - .reshape(-1) + + is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( + masked_target, num_classes=V ) - col = torch.arange(S, device=softmax.device).expand(B, -1).reshape(-1) - flat_idx = (row * S + col) * V - - flat_chosen = flat_idx.masked_select( - ~target_mask.reshape(-1) - ) + masked_target.masked_select(~target_mask) - - # `neg` is zero-copy - grad_input = softmax.neg() - grad_input = grad_input.mul_(grad_output.unsqueeze(-1)) - grad_output_selected = grad_output.masked_select(~target_mask) - grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) + grad_input = is_chosen.float().sub_(softmax) + grad_input.mul_(grad_output.unsqueeze(-1)) return grad_input - - class DistributedLogprob(torch.autograd.Function): """Custom autograd function for computing log probabilities in a distributed setting. @@ -128,12 +114,12 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func inference_only: bool = False, ) -> torch.Tensor: - log_probs, softmax_output = distributed_logprob_forward( + log_probs, softmax = distributed_logprob_forward( vocab_parallel_logits, target, vocab_start_index, vocab_end_index, group ) if not inference_only: # only save for backward when we have inference only=False - ctx.save_for_backward(softmax_output, target) + ctx.save_for_backward(softmax, target) ctx.vocab_start_index = vocab_start_index ctx.vocab_end_index = vocab_end_index @@ -150,15 +136,14 @@ def backward( vocab_start_index = ctx.vocab_start_index vocab_end_index = ctx.vocab_end_index - assert softmax.ndim == 3 grad_input = distributed_logprob_backward( - softmax, + grad_output, target, - grad_output, + softmax, vocab_start_index, vocab_end_index, ) - # if you add an argument to the forward method, then you must add a corresponding None here + return grad_input, None, None, None, None, None, None