Fixing RotaryEmbedding.forward to return float16 values in float16 precision mode.#24262
Fixing RotaryEmbedding.forward to return float16 values in float16 precision mode.#24262kikutakou wants to merge 2 commits intohuggingface:mainfrom
Conversation
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for fixing!
Just a comment on the creation of the embeddings.
There was a problem hiding this comment.
I think we might also want to control the type when creating the weights e.g. like here for Llama
cc @younesbelkada who know's more about this
There was a problem hiding this comment.
nit: you can set dtype and device in one to call
| cos = self.cos_cached[:seq_len, ...].to(x.device).to(x.dtype) | |
| sin = self.sin_cached[:seq_len, ...].to(x.device).to(x.dtype) | |
| cos = self.cos_cached[:seq_len, ...].to(x.device, dtype=x.dtype) | |
| sin = self.sin_cached[:seq_len, ...].to(x.device, dtype=x.dtype) |
There was a problem hiding this comment.
Thanks for the comment! It's reflected to the patch!
ArthurZucker
left a comment
There was a problem hiding this comment.
As mentioned by @amyeroberts, I believe that the issue is rather with the initialization since calling model.half() will probably do the operation in float16. This means that the initialisation with torch.float16 as an argument of from_pretrained is not really doing it's job. I would be more in favor of fixing the init rather than changing the forward!
|
I will investigate whether or not this is the source of instabilities in Llama2! If so, will adresse it |
|
No time to deep dive into this at the moment! If someone wants to check this feel free to do so! 😉 |
25a6a94 to
aced0ab
Compare
aced0ab to
3b4944b
Compare
| self.dim = dim | ||
| self.max_position_embeddings = max_position_embeddings | ||
| self.base = base | ||
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) |
There was a problem hiding this comment.
float() always turns a tensor to float32. This is why initialisation with dtype=float16 did'n work.
| def _set_cos_sin_cache(self, seq_len, device): | ||
| self.max_seq_len_cached = seq_len | ||
| t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) | ||
| t = torch.arange(self.max_seq_len_cached, device=device).float() |
There was a problem hiding this comment.
Since emb.cos() and emb.sin() at line 314 can only be calculated float32 on CPU, this variable must be float32.
If this t is float16 and emb.cos() is calculated on CPU, the following error will be raised:
RuntimeError: "cos_vml_cpu" not implemented for 'Half'
|
Thanks for the comment.
I've investigated and changed the patch to fix this issue.
|
|
This was actually fixed by #25830 ! |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
RotaryEmbedding.forward() returns values with float32 precision even in float16 precision mode.
This affects to the subsequent calculation and takes extra GPU memory usage.
This PR fixes that problem.
Fixes #24261
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.