[PyTorch] Use dummy amax for Float8Tensor cast#693
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
There was a problem hiding this comment.
Right now, any change to the values in Float8Tensor will update the amax_history in fp8_meta. This is necessary to automatically update the amax after in-place operations (e.g. the optimizer step) and it makes it easier to reason about fp8_meta (Float8Tensor treats the contents of fp8_meta as ground-truth). This happens in Float8Tensor.to_float8:
It also happens in
Float8Tensor.__torch_dispatch__ with in-place operations:This PR changes the behavior so to_float8 ignores the amax_history in fp8_meta. I see some possibilities:
- We treat
fp8_metadifferently betwento_float8and__torch_dispatch__. This is confusing. - We change
__torch_dispatch__so it also ignores theamax_historyinfp8_meta. This means we no longer fuse the amax with the FP8 cast, but have to externally call an amax kernel after each in-place operation. - If there is some localized bug when calling
to_float8, we could pass in a dummy amax there instead of modifyingFloat8Tensor.
We should do the third one if possible. As far as I'm aware, the only place we call to_float8 is at:
If we have to change Float8Tensor, I'd prefer if we reverted this branch and deleted these lines instead:
TransformerEngine/transformer_engine/pytorch/float8_tensor.py
Lines 104 to 105 in b8eea8a
This makes it more obvious that
Float8Tensor is ignoring the amax_history in fp8_meta. We also need to figure out and document our decision regarding __torch_dispatch__.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
timmoon10
left a comment
There was a problem hiding this comment.
This is much nicer than the previous approach. I'm not quite comfortable that the initial weight amax will never be included in the amax history. But I suppose it's a subtle point that only affects the first step and I'll approve if it makes #575 cleaner.
|
For reference, this also resolves the inconsistency in amax histories between weights and activations/grads, where the initial amax is included only in the weight history. |
|
/te-ci pytorch |
No description provided.