Skip to content

[PyTorch] Use dummy amax for Float8Tensor cast#693

Merged
ksivaman merged 4 commits intoNVIDIA:mainfrom
ksivaman:float8_tensor_amax_fix
Mar 1, 2024
Merged

[PyTorch] Use dummy amax for Float8Tensor cast#693
ksivaman merged 4 commits intoNVIDIA:mainfrom
ksivaman:float8_tensor_amax_fix

Conversation

@ksivaman
Copy link
Member

No description provided.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from timmoon10 February 29, 2024 22:51
@ksivaman ksivaman self-assigned this Feb 29, 2024
@ksivaman ksivaman marked this pull request as draft February 29, 2024 22:51
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, any change to the values in Float8Tensor will update the amax_history in fp8_meta. This is necessary to automatically update the amax after in-place operations (e.g. the optimizer step) and it makes it easier to reason about fp8_meta (Float8Tensor treats the contents of fp8_meta as ground-truth). This happens in Float8Tensor.to_float8:

amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]

It also happens in Float8Tensor.__torch_dispatch__ with in-place operations:
amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]

This PR changes the behavior so to_float8 ignores the amax_history in fp8_meta. I see some possibilities:

  • We treat fp8_meta differently betwen to_float8 and __torch_dispatch__. This is confusing.
  • We change __torch_dispatch__ so it also ignores the amax_history in fp8_meta. This means we no longer fuse the amax with the FP8 cast, but have to externally call an amax kernel after each in-place operation.
  • If there is some localized bug when calling to_float8, we could pass in a dummy amax there instead of modifying Float8Tensor.

We should do the third one if possible. As far as I'm aware, the only place we call to_float8 is at:

param = Float8Tensor.to_float8(

If we have to change Float8Tensor, I'd prefer if we reverted this branch and deleted these lines instead:

if amax is None:
amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]

This makes it more obvious that Float8Tensor is ignoring the amax_history in fp8_meta. We also need to figure out and document our decision regarding __torch_dispatch__.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman marked this pull request as ready for review March 1, 2024 01:05
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much nicer than the previous approach. I'm not quite comfortable that the initial weight amax will never be included in the amax history. But I suppose it's a subtle point that only affects the first step and I'll approve if it makes #575 cleaner.

@ksivaman
Copy link
Member Author

ksivaman commented Mar 1, 2024

For reference, this also resolves the inconsistency in amax histories between weights and activations/grads, where the initial amax is included only in the weight history.

@ksivaman
Copy link
Member Author

ksivaman commented Mar 1, 2024

/te-ci pytorch

@ksivaman ksivaman merged commit 4e2ce51 into NVIDIA:main Mar 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants