diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 29cc5eb6b7..491a2ea36e 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -18,7 +18,6 @@ from torch.distributed.tensor import DTensor, distribute_tensor -@torch.no_grad() def _compute_distributed_log_softmax( vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup ) -> torch.Tensor: @@ -36,25 +35,67 @@ def _compute_distributed_log_softmax( log probabilities normalized across the full vocabulary dimension. """ logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True) - torch.distributed.all_reduce( - logits_max, - op=torch.distributed.ReduceOp.MAX, - group=group, - ) + + if group is not None: + torch.distributed.all_reduce( + logits_max, + op=torch.distributed.ReduceOp.MAX, + group=group, + ) # Subtract the maximum value. vocab_parallel_logits = vocab_parallel_logits - logits_max sum_exp_logits = vocab_parallel_logits.exp().sum(-1, keepdim=True).float() - torch.distributed.all_reduce( - sum_exp_logits, - op=torch.distributed.ReduceOp.SUM, - group=group, - ) + if group is not None: + torch.distributed.all_reduce( + sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=group, + ) return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype) +@torch.compile(fullgraph=True) +def distributed_logprob_forward(vocab_parallel_logits, target, vocab_start_index, vocab_end_index, tp_group): + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + + vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32) + + log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=tp_group) + softmax_output = log_probs.exp() + + log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) + log_probs[target_mask] = 0.0 + + if tp_group is not None: + torch.distributed.all_reduce( + log_probs, + op=torch.distributed.ReduceOp.SUM, + group=tp_group, + ) + return log_probs, softmax_output + +@torch.compile(fullgraph=True) +def distributed_logprob_backward(grad_output, target, softmax, vocab_start_index, vocab_end_index): + V = softmax.size(-1) + + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + + is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( + masked_target, num_classes=V + ) + + grad_input = is_chosen.float().sub_(softmax) + grad_input.mul_(grad_output.unsqueeze(-1)) + + return grad_input class DistributedLogprob(torch.autograd.Function): """Custom autograd function for computing log probabilities in a distributed setting. @@ -72,28 +113,16 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func group: torch.distributed.ProcessGroup, inference_only: bool = False, ) -> torch.Tensor: - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - masked_target = target - vocab_start_index - masked_target[target_mask] = 0 - - vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32) - - log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=group) - softmax_output = log_probs.exp() - - log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) - log_probs[target_mask] = 0.0 - - torch.distributed.all_reduce( - log_probs, - op=torch.distributed.ReduceOp.SUM, - group=group, + + log_probs, softmax = distributed_logprob_forward( + vocab_parallel_logits, target, vocab_start_index, vocab_end_index, group ) - if not inference_only: # only save for backward when we have inference only=False - ctx.save_for_backward(softmax_output, target_mask, masked_target) + ctx.save_for_backward(softmax, target) + ctx.vocab_start_index = vocab_start_index + ctx.vocab_end_index = vocab_end_index + return log_probs @@ -103,40 +132,18 @@ def backward( *grad_outputs: torch.Tensor, ) -> tuple[torch.Tensor, None, None, None, None, None, None]: grad_output = grad_outputs[0] - softmax, target_mask, masked_target = ctx.saved_tensors + softmax, target = ctx.saved_tensors + vocab_start_index = ctx.vocab_start_index + vocab_end_index = ctx.vocab_end_index - if softmax.ndim == 3: - B, S, V = softmax.shape - - # skip `torch.nn.functional.one_hot` - row = ( - torch.arange(B, device=softmax.device) - .view(-1, 1) - .expand(-1, S) - .reshape(-1) - ) - col = torch.arange(S, device=softmax.device).expand(B, -1).reshape(-1) - flat_idx = (row * S + col) * V - - flat_chosen = flat_idx.masked_select( - ~target_mask.reshape(-1) - ) + masked_target.masked_select(~target_mask) - - # `neg` is zero-copy - grad_input = softmax.neg() - grad_input = grad_input.mul_(grad_output.unsqueeze(-1)) - - grad_output_selected = grad_output.masked_select(~target_mask) - grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) - else: - V = softmax.size(-1) - is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( - masked_target, num_classes=V - ) - grad_input = is_chosen.float().sub_(softmax) - grad_input.mul_(grad_output.unsqueeze(-1)) + grad_input = distributed_logprob_backward( + grad_output, + target, + softmax, + vocab_start_index, + vocab_end_index, + ) - # if you add an argument to the forward method, then you must add a corresponding None here return grad_input, None, None, None, None, None, None @@ -162,11 +169,6 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func tp_group: torch.distributed.ProcessGroup, inference_only: bool = False, ) -> torch.Tensor: - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - masked_target = target - vocab_start_index - masked_target[target_mask] = 0 - seq_size = int(vocab_parallel_logits.shape[1]) num_chunks = (seq_size + chunk_size - 1) // chunk_size all_log_probs = [] @@ -175,6 +177,22 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func chunk_start = chunk_idx * chunk_size chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + vocab_parallel_logits_chunk = vocab_parallel_logits[:, chunk_start:chunk_end, :] + target_chunk = target[:, chunk_start:chunk_end] + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask_chunk = (target_chunk < vocab_start_index) | (target_chunk >= vocab_end_index) + masked_target_chunk = target - vocab_start_index + masked_target_chunk[target_mask_chunk] = 0 + + distributed_logprob_forward( + vocab_parallel_logits_chunk, + target_chunk, + vocab_start_index, + vocab_end_index, + tp_group, + ) + logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] logits = logits.to(dtype=torch.float32) @@ -229,7 +247,7 @@ def backward( logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] logits = logits.to(dtype=torch.float32) - softmax_output = _compute_distributed_log_softmax( + softmax_output = _distributed_logprob_forward( logits, group=tp_group, )