Skip to content

Fix XGLMModelLanguageGenerationTest.test_batched_nan_fp16#19473

Merged
ydshieh merged 3 commits intohuggingface:mainfrom
ydshieh:fix_pt_xglm_test
Oct 11, 2022
Merged

Fix XGLMModelLanguageGenerationTest.test_batched_nan_fp16#19473
ydshieh merged 3 commits intohuggingface:mainfrom
ydshieh:fix_pt_xglm_test

Conversation

@ydshieh
Copy link
Copy Markdown
Collaborator

@ydshieh ydshieh commented Oct 10, 2022

What does this PR do?

#18057 added this test to test running with fp16.

However, from_pretrained(model_name, torch_dtype=torch.float16 seems not able to change the dtype for weights registered below:

self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
if hasattr(self, "weights"):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.register_buffer("weights", emb_weights)

and hidden_states becomes again float32 (because position is) at
hidden_states = inputs_embeds + positions

and finally failed at hidden_states = self.self_attn_layer_norm(hidden_states) with

RuntimeError: expected scalar type Float but found Half

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Oct 10, 2022

The documentation is not available anymore as the PR was closed or merged.

# embed positions
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
# embed positions, cast from float32 to `inputs_embeds.dtype`
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length).to(inputs_embeds.dtype)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

position (and embed_positions's weight) are in float32 even if we load the model in float16. Need this cast to make later LayerNorm layer work.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the right fix would be to make sure the weights have the correct dtype, as the embedding layer is the biggest one, so the memory save is very important.

@ydshieh ydshieh requested a review from sgugger October 10, 2022 19:04
Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this. While this fixes the issue, I'm not sure if it's the right fix.

# embed positions
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
# embed positions, cast from float32 to `inputs_embeds.dtype`
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length).to(inputs_embeds.dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the right fix would be to make sure the weights have the correct dtype, as the embedding layer is the biggest one, so the memory save is very important.

@ydshieh
Copy link
Copy Markdown
Collaborator Author

ydshieh commented Oct 11, 2022

@sgugger OK, let me check if I can do something for (not a real) weights defined by

self.register_buffer("weights", emb_weights)

emb[padding_idx, :] = 0

return emb
return emb.to(torch.get_default_dtype())
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The involved test in this PR uses

from_pretrained(model_name, torch_dtype=torch.float16, ...)

but at init. time, it uses float to create some tensors that is then registered to the buffer (and keep float32)

emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this feels like a much better fix!

@ydshieh ydshieh merged commit c664661 into huggingface:main Oct 11, 2022
@ydshieh ydshieh deleted the fix_pt_xglm_test branch October 11, 2022 14:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants