[PyTorch] Remove special handling for FP8 params in FP8 recipe infrastructure#1326
Merged
ksivaman merged 4 commits intoNVIDIA:mainfrom Nov 14, 2024
Merged
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Collaborator
Author
|
/te-ci pytorch L1 L3 |
Member
|
/te-ci pytorch |
13 tasks
Collaborator
Author
|
The convergence tests in pipeline 20334396 timed out, but all the tests that did run passed. |
ksivaman
approved these changes
Nov 14, 2024
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
#1142 exposed a very subtle bug that caused non-deterministic test failures in
test_fusible_ops_with_userbuffers.py.Bug description
test_fusible_ops_with_userbuffers.pyruns multiple test cases at a time because launching a parallel job is expensive, so it constructs and destroys multiple TE models with FP8 parameters. Python IDs may be reused after an object is deallocated, so the Python ID for FP8 tensors is sometimes reused. However,Float8Tensor.post_optimizer_step_fwd_amax_reductionuses Python IDs to check whether to perform amax reductions and FP8 scale updates. I observed that this was causing FP8 scale updates at weird times, which corrupted UB buffers, which caused hangs.🫠
In short, the problem is from this weird callback in
Float8Tensor:TransformerEngine/transformer_engine/pytorch/tensor/float8_tensor.py
Line 77 in 2643ba1
This hack was added in #575 so that we would properly update FP8 scales for FP8 params after the optimizer step. However, we've made improvements since then:
Thus, there's no need to do an FP8 scale update for the weights immediately after the optimizer step. We just need to do it sometime before the next optimizer step and there should be no change in numerics. In fact, these FP8 scales are already participating in the forward pass amax reduction and scale update, so avoiding those operations reduces runtime overheads. Also, this just makes
Float8Tensormore sane and less tightly coupled with the FP8 recipe infrastructure.Type of change
Changes
Checklist: