Skip to content

[PyTorch] Support dtype casting in fused adam#977

Merged
timmoon10 merged 12 commits intoNVIDIA:mainfrom
Wong4j:jaywan/add_fused_adam
Aug 16, 2024
Merged

[PyTorch] Support dtype casting in fused adam#977
timmoon10 merged 12 commits intoNVIDIA:mainfrom
Wong4j:jaywan/add_fused_adam

Conversation

@Wong4j
Copy link
Contributor

@Wong4j Wong4j commented Jul 1, 2024

Description

FusedAdam updates the params in-place currently.
This PR adds dtype casting in FusedAdam kernel, in addition to updating the master params in-place, but also can update extra model params. The extra params can be of bf16, fp16, fp8 type.

Update:
I have validated the convergence using GPT training in Megatron-LM. The losses before and after enabling this feature are identical in bits.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Wong4j Wong4j changed the title Support dtype casting in fused adam [PyTorch] Support dtype casting in fused adam Jul 1, 2024
@Wong4j Wong4j force-pushed the jaywan/add_fused_adam branch from aa11601 to 4277bd1 Compare July 1, 2024 08:47
@Wong4j Wong4j changed the title [PyTorch] Support dtype casting in fused adam [WIP] [PyTorch] Support dtype casting in fused adam Jul 1, 2024
@Wong4j Wong4j force-pushed the jaywan/add_fused_adam branch from fd68cdd to f65a320 Compare July 12, 2024 03:01
@Wong4j Wong4j changed the title [WIP] [PyTorch] Support dtype casting in fused adam [PyTorch] Support dtype casting in fused adam Jul 12, 2024
@Wong4j Wong4j force-pushed the jaywan/add_fused_adam branch from f6f7c49 to b4c90a8 Compare July 12, 2024 15:29
@Wong4j
Copy link
Contributor Author

Wong4j commented Jul 12, 2024

@timmoon10 Could you please take a look?
The corresponding changes to Megatron-LM are in our internal gitlab MR#1736.

@timmoon10
Copy link
Collaborator

/te-ci pytorch

@timmoon10 timmoon10 self-requested a review July 12, 2024 20:26
@Wong4j
Copy link
Contributor Author

Wong4j commented Jul 15, 2024

Hi @timmoon10 , I encountered an issue when trying to update scale_inv inside the Adam kernel using *scale_inv_ptr = 1.0f / scale. This resulted in loss not being bit-wise aligned. The reason is that TE/PyTorch compilation uses --use_fast_math, which compiles the reciprocal calculation into a single MUFU.RCP instruction, producing an approximate result rather than an accurate one.
To achieve bit-wise alignment of the loss, I had to update scale_inv outside the Adam kernel. This also leads to suboptimal performance. Do you have any suggestions to address this?

@zlsh80826
Copy link
Collaborator

/te-ci pytorch

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice now that this file uses unittest, while the CI infrastructure uses pytest:

pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py

It may be better to fix that in a separate PR.

@Wong4j Wong4j force-pushed the jaywan/add_fused_adam branch from 4c2d42e to 47a448b Compare July 16, 2024 15:33
@timmoon10 timmoon10 self-requested a review July 16, 2024 18:27
@Wong4j Wong4j force-pushed the jaywan/add_fused_adam branch 2 times, most recently from c47909c to 5ae1573 Compare July 19, 2024 06:35
@timmoon10
Copy link
Collaborator

/te-ci pytorch

@timmoon10
Copy link
Collaborator

timmoon10 commented Jul 24, 2024

Based on a discussion with @ptrendx, I think we should give more thought to the API. While this is primarily targeting Megatron-LM, it's important that other TE users can use it easily without relying on Mcore infrastructure.

@ptrendx's preferred API is for the optimizer to hold the model weights (including Float8Tensors) and to treat the master weights as optimizer state (similar to exp_avg and exp_avg_sq). This is similar to Option 1 in #977 (comment). The workflow should look like:

model = MyModel()  # Mix of fp32, bf16, fp8 params
optim = FusedAdam(model.parameters(), dtype=torch.float32)  # Create FP32 master weights for each non-fp32 param
optim.step()
# optim.state[bf16_param]["exp_avg"] is fp32 tensor
# optim.state[bf16_param]["exp_avg_sq"] is fp32 tensor
# optim.state[bf16_param]["master_param"] is fp32 tensor
# optim.state[fp32_param]["master_param"] is None

This API is more natural for standard PyTorch workflows and it doesn't require maintaining separate model weights/master weights like in Megatron-LM. That said, I can see value in keeping master_weights as an optional kwarg since Megatron-LM already allocates them:

model = MyModel()  # Mix of fp32, bf16, fp8 params
master_weights = [param.float() for param in model.parameters()]
optim = FusedAdam(model.parameters(), dtype=torch.float32, master_weights=master_weights)
# optim.state[param]["master_param"] is from my_master_weights

@Wong4j Wong4j force-pushed the jaywan/add_fused_adam branch 2 times, most recently from 541e6e7 to ae375cd Compare August 6, 2024 14:21
@Wong4j
Copy link
Contributor Author

Wong4j commented Aug 6, 2024

Hi @timmoon10 , I have made modifications to the FusedAdam API based on your suggestions. I already tested my changes in Megatron-LM, and the training loss matches the previous results exactly.
However, there are still some issues that need to be discussed:

  1. I have restricted that master_weights must be provided by the user, and the user-provided master_weights must be a list of tensors. If the user does not provide master_weights (i.e., master_weights=None), only the model weights will be updated. Is this approach reasonable?

  2. In Megatron-LM, master_weights are created in the __init__ method of dist opt, while FusedAdam is created earlier. Therefore, I had to initially set master_weights to None, and then modify optimizer.master_weights in the __init__ method of dist opt with the following code:

# create optimizer
optimizer = FusedAdam(param_groups, ... , master_weights=None)
optimizer = DistributedOptimizer(optimizer, *other_args)

# inside __init__ of dist opt
master_weights = list(itertools.chain(*self.shard_fp32_from_float16_groups))
self.optimizer.master_weights = master_weights  # self.optimizer is FusedAdam

This usage is somewhat uncomfortable, but not entirely unusual. Any suggestions?

  1. Kunlun is currently implementing MX-FP16. After some discussion, we believe that it seems more reasonable to place the creation of master_weights inside FusedAdam. This is because exp_avg, exp_avg_sq and master_weight are optimizer states, and since "exp_avg" and "exp_avg_sq" are created and updated within FusedAdam, master_weight should be handled in the same way. However, this change would also conflict with the design logic of Megatron.

@Wong4j Wong4j force-pushed the jaywan/add_fused_adam branch 5 times, most recently from ccd54cd to 267c90a Compare August 13, 2024 02:10
@Wong4j
Copy link
Contributor Author

Wong4j commented Aug 13, 2024

@timmoon10 Could you please take a look?

Wong4j added 9 commits August 16, 2024 15:30
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
@Wong4j Wong4j force-pushed the jaywan/add_fused_adam branch from 267c90a to 44dca61 Compare August 16, 2024 07:30
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for implementing all the API changes, this is much cleaner and easier to reason about. I think there are still some things that could be improved (options to construct master weights internally, cleaning up how to specify master weights, mixed FP16/BF16, fixing the tests), but those are internal changes that can be worked on later.

@timmoon10
Copy link
Collaborator

/te-ci pytorch

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments