Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam#1078
Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam#1078timmoon10 merged 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: kunlunl <kunlunl@nvidia.com>
|
@timmoon10 Hello, I noticed no one has commented on this MR for a long time, could you please take a look, or could you help find someone to review it? |
There was a problem hiding this comment.
Overall this looks good. It would be more general if we disentangled the state dtypes and state scaling (e.g. why not have scaled FP32 states or unscaled BF16 states?), but this does cover the specific cases in the MS-AMP paper.
For future reference, this PR adapts logic from NVIDIA/apex#1771. This is a proof-of-concept with several opporunities for future improvement:
- TE kernel for computing absmax and scale
- Fusing scale/unscale within Adam kernel
- Reduce memory usage in optimizer step, perhaps by processing params in chunks
- Reduce memory usage in checkpointing, perhaps by storing checkpoint buffers in CPU
There was a problem hiding this comment.
Yes, I know this problem. I talked with @Wong4j offline and invited him to review this PR.
His MR in MCore (fuse dtype casting) has not been merged yet, so I put the "fusing dtype casting" function into a new MR in MCore, together with this precision-aware optimizer.
|
/te-ci pytorch |
|
/te-ci pytorch |
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Kunlun Li <94586211+kunlunl@users.noreply.github.com>
|
/te-ci pytorch |
|
LGTM. |
|
Hello. @kunlunl @timmoon10 Memory: |
Description
Add options to set the dtypes of master weights, exp_avg and exp_avg_sq of FusedAdam.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: