Skip to content

Distributed optimizer infrastructure for FP8 parameters#1723

Merged
crcrpar merged 2 commits intoNVIDIA:masterfrom
timmoon10:distopt-fp8
Sep 29, 2023
Merged

Distributed optimizer infrastructure for FP8 parameters#1723
crcrpar merged 2 commits intoNVIDIA:masterfrom
timmoon10:distopt-fp8

Conversation

@timmoon10
Copy link
Contributor

This PR does some refactoring that will enable distributed optimizer support for FP8 parameters in NeMo. It adds the option to do parameter all-gathers in integer dtypes and adds two member functions - _check_params_shard_dtypes and _param_copy_fragments - to handle casting into and out of the all-gather buffer. For now these functions will either do a direct cast for floating-point dtypes or copy the most significant bytes for other dtypes. I plan to override these functions in the NeMo derived class so that it casts to FP8, performs the all-gather in UINT8, and unpacks into a custom FP8 tensor class.

This PR depends on #1719 and #1721.

Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar merged commit 2386a91 into NVIDIA:master Sep 29, 2023
minitu pushed a commit to minitu/apex that referenced this pull request Sep 29, 2023
* Add distopt support for param syncs with non-floating-point dtypes

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update apex/contrib/optimizers/distributed_fused_adam.py

Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>
crcrpar added a commit that referenced this pull request Sep 30, 2023
* Add update_scale_hysteresis

* Fix compile errors

* Massively reduce LayerNorm/RMSNorm GPU memory usage in modern networks by tricking torch autograd (#1715)

* input grad checks out

* adding clamp gamma

* Both old and proposed implementation checks out

* 2 tests not yet passed due to numerical issues

* mem_eff works

* fast-layer-norm done

* Moving mem-eff to templates

* Relax tolerance for memory efficient backward

* Fix backward api of python

* Distributed optimizer infrastructure for FP8 parameters (#1723)

* Add distopt support for param syncs with non-floating-point dtypes

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update apex/contrib/optimizers/distributed_fused_adam.py

Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>

* Add unit test

* Fix comment in unit test

* Remove unnecessary bits

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Jaemin Choi <jaeminc@nvidia.com>
Co-authored-by: Rui Wang <rui@helixon.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments