-
Notifications
You must be signed in to change notification settings - Fork 641
Closed
Description
Hi,
I get the following error:
RuntimeError: Error(s) in loading state_dict for Sequential:
Missing key(s) in state_dict: "0._extra_state". I am trying to get a model that has already been trained on an A100 GPU to run faster on an H100 GPU by leveraging the TransformerEngine. However, since it was trained using vanilla nn.Linear there is no _extra_state.
What exactly does _extra_state represent? Is it the amax values for FP8 recipes? Something different? Is there any way to create it retro-actively? IF yes, can it be possible to do with a forward pass or does it have to be a backward pass? The model already performs well and I do not want to modify its weights, only to make it faster.
I am using nvcr.io/nvidia/pytorch:23.04-py3
The error is easy to replicate like this:
import torch
import torch.nn as nn
# Create a simple model with one linear layer
model = nn.Sequential(nn.Linear(10, 5))
print(model)
# Initialize the weights randomly
for param in model.parameters():
nn.init.normal_(param, mean=0, std=1)
# Save the model's weights
torch.save(model.state_dict(), 'model_weights.pth')
# Load the same model using TransformerEngine.Linear instead of nn.Linear
del model
import transformer_engine.pytorch as te
model = nn.Sequential(te.Linear(in_features=10, out_features=5, bias=True))
model.load_state_dict(torch.load('model_weights.pth'))Thanks so much
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels