Skip to content

draft: feat: fused loss and logit to logprob conversion#994

Open
jiemingz wants to merge 2 commits intomainfrom
jiemingz/loss_funcs
Open

draft: feat: fused loss and logit to logprob conversion#994
jiemingz wants to merge 2 commits intomainfrom
jiemingz/loss_funcs

Conversation

@jiemingz
Copy link
Copy Markdown
Contributor

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
@terrykong
Copy link
Copy Markdown
Collaborator

thanks! what are the expected gains we can expect? also is this related to #496?

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
@jiemingz
Copy link
Copy Markdown
Contributor Author

Unrelated to #496 but we can expect the memory spikes seen at the loss functions to go away.

@euronymous-aithal
Copy link
Copy Markdown
Contributor

@guyueh1 can you please review this ?

Copy link
Copy Markdown
Contributor

@guyueh1 guyueh1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall like the idea of doing torch compile full graph on the logprob function, but since the softmax output is still kept, the memory will still be huge, we may need another path to completely remove the overhead; is this just saving the memory for mask tensor, log_softmax tensor, can you provide data how much it is saving?

also resolve the comments please

from torch.distributed.tensor import DTensor, distribute_tensor


@torch.no_grad()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it necessary to remove the @torch.no_grad()?

masked_target_chunk = target - vocab_start_index
masked_target_chunk[target_mask_chunk] = 0

distributed_logprob_forward(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this not returning anything? and is the consequent line 196-217 necessary or should be removed?

logits = logits.to(dtype=torch.float32)

softmax_output = _compute_distributed_log_softmax(
softmax_output = _distributed_logprob_forward(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this name is not found, is it distributed_logprob_forward?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants