-
Notifications
You must be signed in to change notification settings - Fork 355
draft: feat: fused loss and logit to logprob conversion #994
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,6 @@ | |
| from torch.distributed.tensor import DTensor, distribute_tensor | ||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def _compute_distributed_log_softmax( | ||
| vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup | ||
| ) -> torch.Tensor: | ||
|
|
@@ -36,25 +35,67 @@ def _compute_distributed_log_softmax( | |
| log probabilities normalized across the full vocabulary dimension. | ||
| """ | ||
| logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True) | ||
| torch.distributed.all_reduce( | ||
| logits_max, | ||
| op=torch.distributed.ReduceOp.MAX, | ||
| group=group, | ||
| ) | ||
|
|
||
| if group is not None: | ||
| torch.distributed.all_reduce( | ||
| logits_max, | ||
| op=torch.distributed.ReduceOp.MAX, | ||
| group=group, | ||
| ) | ||
|
|
||
| # Subtract the maximum value. | ||
| vocab_parallel_logits = vocab_parallel_logits - logits_max | ||
|
|
||
| sum_exp_logits = vocab_parallel_logits.exp().sum(-1, keepdim=True).float() | ||
|
|
||
| torch.distributed.all_reduce( | ||
| sum_exp_logits, | ||
| op=torch.distributed.ReduceOp.SUM, | ||
| group=group, | ||
| ) | ||
| if group is not None: | ||
| torch.distributed.all_reduce( | ||
| sum_exp_logits, | ||
| op=torch.distributed.ReduceOp.SUM, | ||
| group=group, | ||
| ) | ||
|
|
||
| return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype) | ||
|
|
||
| @torch.compile(fullgraph=True) | ||
| def distributed_logprob_forward(vocab_parallel_logits, target, vocab_start_index, vocab_end_index, tp_group): | ||
| # Create a mask of valid vocab ids (1 means it needs to be masked). | ||
| target_mask = (target < vocab_start_index) | (target >= vocab_end_index) | ||
| masked_target = target - vocab_start_index | ||
| masked_target[target_mask] = 0 | ||
|
|
||
| vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32) | ||
|
|
||
| log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=tp_group) | ||
| softmax_output = log_probs.exp() | ||
|
|
||
| log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) | ||
| log_probs[target_mask] = 0.0 | ||
|
|
||
| if tp_group is not None: | ||
| torch.distributed.all_reduce( | ||
| log_probs, | ||
| op=torch.distributed.ReduceOp.SUM, | ||
| group=tp_group, | ||
| ) | ||
| return log_probs, softmax_output | ||
|
|
||
| @torch.compile(fullgraph=True) | ||
| def distributed_logprob_backward(grad_output, target, softmax, vocab_start_index, vocab_end_index): | ||
| V = softmax.size(-1) | ||
|
|
||
| target_mask = (target < vocab_start_index) | (target >= vocab_end_index) | ||
| masked_target = target - vocab_start_index | ||
| masked_target[target_mask] = 0 | ||
|
|
||
| is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( | ||
| masked_target, num_classes=V | ||
| ) | ||
|
|
||
| grad_input = is_chosen.float().sub_(softmax) | ||
| grad_input.mul_(grad_output.unsqueeze(-1)) | ||
|
|
||
| return grad_input | ||
|
|
||
| class DistributedLogprob(torch.autograd.Function): | ||
| """Custom autograd function for computing log probabilities in a distributed setting. | ||
|
|
@@ -72,28 +113,16 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func | |
| group: torch.distributed.ProcessGroup, | ||
| inference_only: bool = False, | ||
| ) -> torch.Tensor: | ||
| # Create a mask of valid vocab ids (1 means it needs to be masked). | ||
| target_mask = (target < vocab_start_index) | (target >= vocab_end_index) | ||
| masked_target = target - vocab_start_index | ||
| masked_target[target_mask] = 0 | ||
|
|
||
| vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32) | ||
|
|
||
| log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=group) | ||
| softmax_output = log_probs.exp() | ||
|
|
||
| log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) | ||
| log_probs[target_mask] = 0.0 | ||
|
|
||
| torch.distributed.all_reduce( | ||
| log_probs, | ||
| op=torch.distributed.ReduceOp.SUM, | ||
| group=group, | ||
|
|
||
| log_probs, softmax = distributed_logprob_forward( | ||
| vocab_parallel_logits, target, vocab_start_index, vocab_end_index, group | ||
| ) | ||
|
|
||
| if not inference_only: | ||
| # only save for backward when we have inference only=False | ||
| ctx.save_for_backward(softmax_output, target_mask, masked_target) | ||
| ctx.save_for_backward(softmax, target) | ||
| ctx.vocab_start_index = vocab_start_index | ||
| ctx.vocab_end_index = vocab_end_index | ||
|
|
||
|
|
||
| return log_probs | ||
|
|
||
|
|
@@ -103,40 +132,18 @@ def backward( | |
| *grad_outputs: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, None, None, None, None, None, None]: | ||
| grad_output = grad_outputs[0] | ||
| softmax, target_mask, masked_target = ctx.saved_tensors | ||
| softmax, target = ctx.saved_tensors | ||
| vocab_start_index = ctx.vocab_start_index | ||
| vocab_end_index = ctx.vocab_end_index | ||
|
|
||
| if softmax.ndim == 3: | ||
| B, S, V = softmax.shape | ||
|
|
||
| # skip `torch.nn.functional.one_hot` | ||
| row = ( | ||
| torch.arange(B, device=softmax.device) | ||
| .view(-1, 1) | ||
| .expand(-1, S) | ||
| .reshape(-1) | ||
| ) | ||
| col = torch.arange(S, device=softmax.device).expand(B, -1).reshape(-1) | ||
| flat_idx = (row * S + col) * V | ||
|
|
||
| flat_chosen = flat_idx.masked_select( | ||
| ~target_mask.reshape(-1) | ||
| ) + masked_target.masked_select(~target_mask) | ||
|
|
||
| # `neg` is zero-copy | ||
| grad_input = softmax.neg() | ||
| grad_input = grad_input.mul_(grad_output.unsqueeze(-1)) | ||
|
|
||
| grad_output_selected = grad_output.masked_select(~target_mask) | ||
| grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) | ||
| else: | ||
| V = softmax.size(-1) | ||
| is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( | ||
| masked_target, num_classes=V | ||
| ) | ||
| grad_input = is_chosen.float().sub_(softmax) | ||
| grad_input.mul_(grad_output.unsqueeze(-1)) | ||
| grad_input = distributed_logprob_backward( | ||
| grad_output, | ||
| target, | ||
| softmax, | ||
| vocab_start_index, | ||
| vocab_end_index, | ||
| ) | ||
|
|
||
| # if you add an argument to the forward method, then you must add a corresponding None here | ||
| return grad_input, None, None, None, None, None, None | ||
|
|
||
|
|
||
|
|
@@ -162,11 +169,6 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func | |
| tp_group: torch.distributed.ProcessGroup, | ||
| inference_only: bool = False, | ||
| ) -> torch.Tensor: | ||
| # Create a mask of valid vocab ids (1 means it needs to be masked). | ||
| target_mask = (target < vocab_start_index) | (target >= vocab_end_index) | ||
| masked_target = target - vocab_start_index | ||
| masked_target[target_mask] = 0 | ||
|
|
||
| seq_size = int(vocab_parallel_logits.shape[1]) | ||
| num_chunks = (seq_size + chunk_size - 1) // chunk_size | ||
| all_log_probs = [] | ||
|
|
@@ -175,6 +177,22 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func | |
| chunk_start = chunk_idx * chunk_size | ||
| chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) | ||
|
|
||
| vocab_parallel_logits_chunk = vocab_parallel_logits[:, chunk_start:chunk_end, :] | ||
| target_chunk = target[:, chunk_start:chunk_end] | ||
|
|
||
| # Create a mask of valid vocab ids (1 means it needs to be masked). | ||
| target_mask_chunk = (target_chunk < vocab_start_index) | (target_chunk >= vocab_end_index) | ||
| masked_target_chunk = target - vocab_start_index | ||
| masked_target_chunk[target_mask_chunk] = 0 | ||
|
|
||
| distributed_logprob_forward( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this not returning anything? and is the consequent line 196-217 necessary or should be removed? |
||
| vocab_parallel_logits_chunk, | ||
| target_chunk, | ||
| vocab_start_index, | ||
| vocab_end_index, | ||
| tp_group, | ||
| ) | ||
|
|
||
| logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] | ||
| logits = logits.to(dtype=torch.float32) | ||
|
|
||
|
|
@@ -229,7 +247,7 @@ def backward( | |
| logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] | ||
| logits = logits.to(dtype=torch.float32) | ||
|
|
||
| softmax_output = _compute_distributed_log_softmax( | ||
| softmax_output = _distributed_logprob_forward( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this name is not found, is it |
||
| logits, | ||
| group=tp_group, | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it necessary to remove the @torch.no_grad()?