Skip to content

Fixing RotaryEmbedding.forward to return float16 values in float16 precision mode.#24262

Closed
kikutakou wants to merge 2 commits intohuggingface:mainfrom
kikutakou:ko_gptneox_fp16_fix
Closed

Fixing RotaryEmbedding.forward to return float16 values in float16 precision mode.#24262
kikutakou wants to merge 2 commits intohuggingface:mainfrom
kikutakou:ko_gptneox_fp16_fix

Conversation

@kikutakou
Copy link
Copy Markdown

What does this PR do?

RotaryEmbedding.forward() returns values with float32 precision even in float16 precision mode.
This affects to the subsequent calculation and takes extra GPU memory usage.
This PR fixes that problem.

Fixes #24261

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Copy link
Copy Markdown
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Just a comment on the creation of the embeddings.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might also want to control the type when creating the weights e.g. like here for Llama

cc @younesbelkada who know's more about this

Comment on lines 278 to 279
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can set dtype and device in one to call

Suggested change
cos = self.cos_cached[:seq_len, ...].to(x.device).to(x.dtype)
sin = self.sin_cached[:seq_len, ...].to(x.device).to(x.dtype)
cos = self.cos_cached[:seq_len, ...].to(x.device, dtype=x.dtype)
sin = self.sin_cached[:seq_len, ...].to(x.device, dtype=x.dtype)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment! It's reflected to the patch!

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned by @amyeroberts, I believe that the issue is rather with the initialization since calling model.half() will probably do the operation in float16. This means that the initialisation with torch.float16 as an argument of from_pretrained is not really doing it's job. I would be more in favor of fixing the init rather than changing the forward!

@huggingface huggingface deleted a comment from github-actions Bot Jul 20, 2023
@ArthurZucker
Copy link
Copy Markdown
Collaborator

I will investigate whether or not this is the source of instabilities in Llama2! If so, will adresse it

@huggingface huggingface deleted a comment from github-actions Bot Aug 16, 2023
@ArthurZucker
Copy link
Copy Markdown
Collaborator

No time to deep dive into this at the moment! If someone wants to check this feel free to do so! 😉

@kikutakou kikutakou force-pushed the ko_gptneox_fp16_fix branch from 25a6a94 to aced0ab Compare August 18, 2023 04:03
@kikutakou kikutakou force-pushed the ko_gptneox_fp16_fix branch from aced0ab to 3b4944b Compare August 23, 2023 02:34
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float() always turns a tensor to float32. This is why initialisation with dtype=float16 did'n work.

def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device).float()
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since emb.cos() and emb.sin() at line 314 can only be calculated float32 on CPU, this variable must be float32.

If this t is float16 and emb.cos() is calculated on CPU, the following error will be raised:

RuntimeError: "cos_vml_cpu" not implemented for 'Half'

@kikutakou
Copy link
Copy Markdown
Author

@ArthurZucker

Thanks for the comment.

The initialisation with torch.float16 as an argument of from_pretrained is not really doing it's job.

I've investigated and changed the patch to fix this issue.
Could you have a look at this patch?

from_pretrained changes torch default_dtype to the specified dtype, then initialize all weights.
GPTNeoXRotaryEmbedding.__init__() calls float() which always returns float32 even when default dtype is float16.
This was the reason.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

This was actually fixed by #25830 !

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Nov 9, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions Bot closed this Nov 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPTNeoXAttention takes extra GPU memory footprint in torch.float16 precision mode.

3 participants