[PyTorch] Support dtype casting in fused adam#977
Conversation
aa11601 to
4277bd1
Compare
fd68cdd to
f65a320
Compare
f6f7c49 to
b4c90a8
Compare
|
@timmoon10 Could you please take a look? |
|
/te-ci pytorch |
|
Hi @timmoon10 , I encountered an issue when trying to update scale_inv inside the Adam kernel using |
|
/te-ci pytorch |
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
I notice now that this file uses unittest, while the CI infrastructure uses pytest:
It may be better to fix that in a separate PR.
4c2d42e to
47a448b
Compare
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
Outdated
Show resolved
Hide resolved
c47909c to
5ae1573
Compare
|
/te-ci pytorch |
|
Based on a discussion with @ptrendx, I think we should give more thought to the API. While this is primarily targeting Megatron-LM, it's important that other TE users can use it easily without relying on Mcore infrastructure. @ptrendx's preferred API is for the optimizer to hold the model weights (including model = MyModel() # Mix of fp32, bf16, fp8 params
optim = FusedAdam(model.parameters(), dtype=torch.float32) # Create FP32 master weights for each non-fp32 param
optim.step()
# optim.state[bf16_param]["exp_avg"] is fp32 tensor
# optim.state[bf16_param]["exp_avg_sq"] is fp32 tensor
# optim.state[bf16_param]["master_param"] is fp32 tensor
# optim.state[fp32_param]["master_param"] is NoneThis API is more natural for standard PyTorch workflows and it doesn't require maintaining separate model weights/master weights like in Megatron-LM. That said, I can see value in keeping model = MyModel() # Mix of fp32, bf16, fp8 params
master_weights = [param.float() for param in model.parameters()]
optim = FusedAdam(model.parameters(), dtype=torch.float32, master_weights=master_weights)
# optim.state[param]["master_param"] is from my_master_weights |
541e6e7 to
ae375cd
Compare
|
Hi @timmoon10 , I have made modifications to the FusedAdam API based on your suggestions. I already tested my changes in Megatron-LM, and the training loss matches the previous results exactly.
# create optimizer
optimizer = FusedAdam(param_groups, ... , master_weights=None)
optimizer = DistributedOptimizer(optimizer, *other_args)
# inside __init__ of dist opt
master_weights = list(itertools.chain(*self.shard_fp32_from_float16_groups))
self.optimizer.master_weights = master_weights # self.optimizer is FusedAdamThis usage is somewhat uncomfortable, but not entirely unusual. Any suggestions?
|
ccd54cd to
267c90a
Compare
|
@timmoon10 Could you please take a look? |
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
267c90a to
44dca61
Compare
There was a problem hiding this comment.
LGTM. Thanks for implementing all the API changes, this is much cleaner and easier to reason about. I think there are still some things that could be improved (options to construct master weights internally, cleaning up how to specify master weights, mixed FP16/BF16, fixing the tests), but those are internal changes that can be worked on later.
|
/te-ci pytorch |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch |
Description
FusedAdam updates the params in-place currently.
This PR adds dtype casting in FusedAdam kernel, in addition to updating the master params in-place, but also can update extra model params. The extra params can be of bf16, fp16, fp8 type.
Update:
I have validated the convergence using GPT training in Megatron-LM. The losses before and after enabling this feature are identical in bits.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: