Skip to content

[PyTorch] Support pickling Float8Tensor#529

Merged
timmoon10 merged 13 commits intoNVIDIA:mainfrom
timmoon10:float8tensor-pickle
Dec 7, 2023
Merged

[PyTorch] Support pickling Float8Tensor#529
timmoon10 merged 13 commits intoNVIDIA:mainfrom
timmoon10:float8tensor-pickle

Conversation

@timmoon10
Copy link
Collaborator

We've experienced some problems when trying to checkpoint FP8 models (NVIDIA-NeMo/NeMo#7909 (comment)). The root cause is because we cast FP8 params to higher precision when checkpointing TE modules:

state[key] = val.from_float8()

This messes with some of the bookkeeping for checkpointing in Megatron-core, e.g. to figure out corresponding tensors in the model and optimizer state_dict s. I've modified the behavior so that pickling Float8Tensor s will save the FP8 data, dtypes, and scale-inv (but not fp8_meta). This fixes the error for me when I run some quick tests.

This is built on top of #524. Closes #524.

Avoid FP8 casts when copying between Float8Tensors. Make make_like a class function.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

After discussion with @sudhakarsingh27, I think the cleanest approach to handle fp8_meta when loading a checkpoint was to modify Float8Tensor.copy_ so that it just copies _data and _scale_inv when copying from another Float8Tensor. This also has the benefit of making these copies faster (avoiding an extra write and read in higher precision) and avoiding adding any rounding error.

I'll go more into detail. We generally expect users to use torch.nn.Module.state_dict/torch.nn.Module.load_state_dict when checkpointing models. When loading a checkpoint, you typically initialize a model with junk weights, unpickle a file to get a state dict, copy the weight values from the state dict into the model, and then discard the state dict. Since we're going to throw away the unpickled Float8Tensors, it's fine if it doesn't have an fp8_meta as long as the corresponding model weights do. However, we have the problem that the parameters are copied from the state dict before any extra state like fp8_meta (see the implementation of torch.nn.Module.load_state_dict. The existing implementation of Float8Tensor.copy_ tried to use the FP8 scale from fp8_meta (requiring casting to high precision and back to FP8), but we have no reason to expect the initial scale is any good and there could be numerical/convergence problems. This PR's change to copy_ works around this since the _scale_inv for the loaded weight is presumably good. By the time we start training we can expect that the Float8Tensor's fp8_meta has been properly configured and that will take precedence for any future FP8 casts (e.g. in the optimizer step). I'm not completely at ease since there could be other use-cases where we want an unpickled Float8Tensor to have access to fp8_meta (CPU offloading?), but I'm not aware of them and it's probably best not to overengineer a solution.

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10 timmoon10 marked this pull request as draft December 4, 2023 07:40
timmoon10 and others added 2 commits December 5, 2023 17:31
Debugged pickling and copy functions.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 marked this pull request as ready for review December 5, 2023 17:33
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Collaborator

@sudhakarsingh27 sudhakarsingh27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm!

@timmoon10 timmoon10 deleted the float8tensor-pickle branch February 2, 2024 01:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments