Skip to content

[PyTorch] Experimental FP8 tensor class#452

Merged
ksivaman merged 29 commits intoNVIDIA:mainfrom
timmoon10:float8tensor_experiments
Oct 31, 2023
Merged

[PyTorch] Experimental FP8 tensor class#452
ksivaman merged 29 commits intoNVIDIA:mainfrom
timmoon10:float8tensor_experiments

Conversation

@timmoon10
Copy link
Collaborator

This FP8 tensor class is based on the implementation at https://github.com/facebookexperimental/protoquant/tree/fp8_poc and is primarily oriented toward enabling efficient FP8 support in Apex's DistributedFusedAdam. See NVIDIA-NeMo/NeMo#7469 and NVIDIA-NeMo/NeMo#7565.

CC @sudhakarsingh27 @ksivaman

@timmoon10 timmoon10 added the enhancement New feature or request label Sep 29, 2023
@ptrendx
Copy link
Member

ptrendx commented Sep 30, 2023

/te-ci

1 similar comment
@timmoon10
Copy link
Collaborator Author

/te-ci

@ptrendx ptrendx added the 1.0.0 label Oct 16, 2023
@sudhakarsingh27
Copy link
Collaborator

/te-ci pytorch

@ksivaman ksivaman force-pushed the float8tensor_experiments branch from de20156 to 4315115 Compare October 16, 2023 21:08
Co-authored-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman force-pushed the float8tensor_experiments branch from 67f7cd3 to b6bfddb Compare October 19, 2023 23:45
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman marked this pull request as ready for review October 20, 2023 04:47
@ksivaman
Copy link
Member

/te-ci pytorch

@ksivaman
Copy link
Member

/te-ci pytorch

1 similar comment
@ksivaman
Copy link
Member

/te-ci pytorch

handled outside this class. If a tensor is initialized with an FP8
metadata object, it extracts the information it needs so it isn't
affected by later changes in the FP8 metadata (although its design
does cause us to leak some subtle side-effects into FP8 metadata).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doc is not really correct since we are holding a view to the meta, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ops using the tensor class's __torch_dispatch__ are insensitive to external changes in the meta since we cache scale_inv. However, all bets are off when we extract _data and pass it to external ops like tex.fp8_gemm.

handled outside this class. If a tensor is initialized with an FP8
metadata object, it extracts the information it needs so it isn't
affected by later changes in the FP8 metadata (although its design
does cause us to leak some subtle side-effects into FP8 metadata).
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ops using the tensor class's __torch_dispatch__ are insensitive to external changes in the meta since we cache scale_inv. However, all bets are off when we extract _data and pass it to external ops like tex.fp8_gemm.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the float8tensor_experiments branch from 8ff9e05 to dfcbcf1 Compare October 24, 2023 21:27
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
timmoon10 and others added 8 commits October 25, 2023 14:25
Handle case where transpose cache is updated externally.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
timmoon10 and others added 2 commits October 26, 2023 17:34
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
ksivaman and others added 4 commits October 27, 2023 22:20
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Easier for multiple tensors to share, e.g. detached tensors.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 force-pushed the float8tensor_experiments branch from 718d284 to 94848da Compare October 31, 2023 00:34
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving as experimental. We will iterate upon this in the next release.

@ksivaman ksivaman merged commit b1820c4 into NVIDIA:main Oct 31, 2023
ksivaman added a commit that referenced this pull request Oct 31, 2023
* Experimental FP8 tensor

Co-authored-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add fp8 tensor to ci test

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments and tests

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Minor changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Default to FP8 usage

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix docs

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Naming changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* minor fix

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix transpose caching

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Debug transpose caching

Handle case where transpose cache is updated externally.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Rename FP8GlobalStateManager.with_fp8_parameters

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* remove Float8Tensor from import API

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Avoid caching FP8 transposes if not required

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix import error in FP8 tensor tests

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix tranpose caching and checkpointing bug

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Improve caching and fix distopt case

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/float8_tensor.py

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>

* Remove recursive logic

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cache reset bug

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Store FP8 attributes in dict

Easier for multiple tensors to share, e.g. detached tensors.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Make sure scale_inv is 1D tensor

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Make sure scale_inv is 1D tensor

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fixes and detach recipe

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Set default fp8 data type

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com>
Comment on lines +507 to +508
* full model training using optimizer with master weights, where the high
precision copies of weights are already present in the optimizer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this look in practice? If the model will be initialized directly with fp8 weights, how does the optimizer get high-precision copies?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

1.0.0 enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants