[PyTorch] Support pickling Float8Tensor#529
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Avoid FP8 casts when copying between Float8Tensors. Make make_like a class function. Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
After discussion with @sudhakarsingh27, I think the cleanest approach to handle I'll go more into detail. We generally expect users to use |
|
/te-ci pytorch |
Debugged pickling and copy functions. Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch |
We've experienced some problems when trying to checkpoint FP8 models (NVIDIA-NeMo/NeMo#7909 (comment)). The root cause is because we cast FP8 params to higher precision when checkpointing TE modules:
TransformerEngine/transformer_engine/pytorch/module/base.py
Line 831 in 8864983
This messes with some of the bookkeeping for checkpointing in Megatron-core, e.g. to figure out corresponding tensors in the model and optimizer
state_dicts. I've modified the behavior so that picklingFloat8Tensors will save the FP8 data, dtypes, and scale-inv (but notfp8_meta). This fixes the error for me when I run some quick tests.This is built on top of #524. Closes #524.