draft: feat: fused loss and logit to logprob conversion#994
draft: feat: fused loss and logit to logprob conversion#994
Conversation
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
|
thanks! what are the expected gains we can expect? also is this related to #496? |
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
|
Unrelated to #496 but we can expect the memory spikes seen at the loss functions to go away. |
|
@guyueh1 can you please review this ? |
guyueh1
left a comment
There was a problem hiding this comment.
Overall like the idea of doing torch compile full graph on the logprob function, but since the softmax output is still kept, the memory will still be huge, we may need another path to completely remove the overhead; is this just saving the memory for mask tensor, log_softmax tensor, can you provide data how much it is saving?
also resolve the comments please
| from torch.distributed.tensor import DTensor, distribute_tensor | ||
|
|
||
|
|
||
| @torch.no_grad() |
There was a problem hiding this comment.
is it necessary to remove the @torch.no_grad()?
| masked_target_chunk = target - vocab_start_index | ||
| masked_target_chunk[target_mask_chunk] = 0 | ||
|
|
||
| distributed_logprob_forward( |
There was a problem hiding this comment.
why is this not returning anything? and is the consequent line 196-217 necessary or should be removed?
| logits = logits.to(dtype=torch.float32) | ||
|
|
||
| softmax_output = _compute_distributed_log_softmax( | ||
| softmax_output = _distributed_logprob_forward( |
There was a problem hiding this comment.
this name is not found, is it distributed_logprob_forward?
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information