[PyTorch] Experimental FP8 tensor class#452
Conversation
|
/te-ci |
1 similar comment
|
/te-ci |
|
/te-ci pytorch |
de20156 to
4315115
Compare
Co-authored-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Sudhakar Singh <sudhakars@nvidia.com> Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
67f7cd3 to
b6bfddb
Compare
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
|
/te-ci pytorch |
1 similar comment
|
/te-ci pytorch |
| handled outside this class. If a tensor is initialized with an FP8 | ||
| metadata object, it extracts the information it needs so it isn't | ||
| affected by later changes in the FP8 metadata (although its design | ||
| does cause us to leak some subtle side-effects into FP8 metadata). |
There was a problem hiding this comment.
This doc is not really correct since we are holding a view to the meta, right?
There was a problem hiding this comment.
Ops using the tensor class's __torch_dispatch__ are insensitive to external changes in the meta since we cache scale_inv. However, all bets are off when we extract _data and pass it to external ops like tex.fp8_gemm.
| handled outside this class. If a tensor is initialized with an FP8 | ||
| metadata object, it extracts the information it needs so it isn't | ||
| affected by later changes in the FP8 metadata (although its design | ||
| does cause us to leak some subtle side-effects into FP8 metadata). |
There was a problem hiding this comment.
Ops using the tensor class's __torch_dispatch__ are insensitive to external changes in the meta since we cache scale_inv. However, all bets are off when we extract _data and pass it to external ops like tex.fp8_gemm.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
8ff9e05 to
dfcbcf1
Compare
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Handle case where transpose cache is updated externally. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Easier for multiple tensors to share, e.g. detached tensors. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
718d284 to
94848da
Compare
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…10/TransformerEngine into float8tensor_experiments
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
ptrendx
left a comment
There was a problem hiding this comment.
Approving as experimental. We will iterate upon this in the next release.
* Experimental FP8 tensor Co-authored-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Sudhakar Singh <sudhakars@nvidia.com> Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Add fp8 tensor to ci test Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * review comments and tests Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Minor changes Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Default to FP8 usage Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix docs Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Naming changes Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * minor fix Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix transpose caching Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Debug transpose caching Handle case where transpose cache is updated externally. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Rename FP8GlobalStateManager.with_fp8_parameters Signed-off-by: Tim Moon <tmoon@nvidia.com> * remove Float8Tensor from import API Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Avoid caching FP8 transposes if not required Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix import error in FP8 tensor tests Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix tranpose caching and checkpointing bug Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Improve caching and fix distopt case Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Update transformer_engine/pytorch/float8_tensor.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Remove recursive logic Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix cache reset bug Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Store FP8 attributes in dict Easier for multiple tensors to share, e.g. detached tensors. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon <tmoon@nvidia.com> * Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fixes and detach recipe Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Set default fp8 data type Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Sudhakar Singh <sudhakars@nvidia.com> Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com>
| * full model training using optimizer with master weights, where the high | ||
| precision copies of weights are already present in the optimizer. |
There was a problem hiding this comment.
How does this look in practice? If the model will be initialized directly with fp8 weights, how does the optimizer get high-precision copies?
This FP8 tensor class is based on the implementation at https://github.com/facebookexperimental/protoquant/tree/fp8_poc and is primarily oriented toward enabling efficient FP8 support in Apex's
DistributedFusedAdam. See NVIDIA-NeMo/NeMo#7469 and NVIDIA-NeMo/NeMo#7565.CC @sudhakarsingh27 @ksivaman