Skip to content

Loading existing weights trained without TransformerEngine #199

@tylerweitzman

Description

@tylerweitzman

Hi,
I get the following error:

RuntimeError: Error(s) in loading state_dict for Sequential:
	Missing key(s) in state_dict: "0._extra_state". 

I am trying to get a model that has already been trained on an A100 GPU to run faster on an H100 GPU by leveraging the TransformerEngine. However, since it was trained using vanilla nn.Linear there is no _extra_state.

What exactly does _extra_state represent? Is it the amax values for FP8 recipes? Something different? Is there any way to create it retro-actively? IF yes, can it be possible to do with a forward pass or does it have to be a backward pass? The model already performs well and I do not want to modify its weights, only to make it faster.

I am using nvcr.io/nvidia/pytorch:23.04-py3

The error is easy to replicate like this:

import torch
import torch.nn as nn

# Create a simple model with one linear layer
model = nn.Sequential(nn.Linear(10, 5))
print(model)

# Initialize the weights randomly
for param in model.parameters():
    nn.init.normal_(param, mean=0, std=1)

# Save the model's weights
torch.save(model.state_dict(), 'model_weights.pth')

# Load the same model using TransformerEngine.Linear instead of nn.Linear
del model
import transformer_engine.pytorch as te
model = nn.Sequential(te.Linear(in_features=10, out_features=5, bias=True))
model.load_state_dict(torch.load('model_weights.pth'))

Thanks so much

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions