remove disabled autocast context and device type calc, as freqs are already force upcast to float precision (and fix https://github.com/pytorch/pytorch/issues/128394)#31959
remove disabled autocast context and device type calc, as freqs are already force upcast to float precision (and fix https://github.com/pytorch/pytorch/issues/128394)#31959lessw2020 wants to merge 2 commits intohuggingface:mainfrom lessw2020:remove_autocast
Conversation
|
The context manager was added to disable torch auto cast because the upcasting was ignored by amp auto cast in torch 2.2 and 2.3. I'm unsure if later versions of torch fixed it. During fine-tuning and inference, RoPE must upcast to float32 to preserve precision, since mixed precision training will downcast everything that is not explicitly marked. One way to test is this is adding a print statement to print the dtype, and if it is float32, then the auto cast context manager is unnecessary. |
|
For inference yes the autocast context manager is redundant, but when enabling mixed precision, the autocast disabler had to be placed - I'm unsure if newer versions of torch can bypass this issue |
|
👋 What I'm reading then is that we SHOULDN'T merge this PR until we confirm that pytorch doesn't change the type of an explicit Running |
ArthurZucker
left a comment
There was a problem hiding this comment.
Mmmm AFAIK this will break what we tried to fix with : #29285 no?
|
Hi all - we have a much better solution now by making some changes to how PP does the model tracing via PT export and with that, it now handles the autocast issue directly so no changes needed here in the transformer code. |
What does this PR do?
This PR:
1 - Removes the redundant autocast and device_type calcs regarding the Rope frequency embeddings, both for code efficiency and more importantly to re-enable torch export tracing of llama models.
In order to ensure that the frequencies are cast to float32 (to avoid loss of precision over long context) a disabled autocontext was added along with a .float() upcast to the relevant variables (frequency and position).
However, this autocast is redundant, adds unneeded context overhead, and also creates a problem for torch.export for the llama models as autocast does not support meta device.
Since the upcasts are already hardcoded to the float dtype, a disabled autocast is not needed. Similarly the device type checks are then also not needed.
Thus, this PR simply removes the autocast context which should resolve the outstanding torch.export issue here:
pytorch/pytorch#128394
Fixes # (issue)
pytorch/pytorch#128394
Before submitting
This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
Did you read the contributor guideline,
Pull Request section?
[X ] Was this discussed/approved via a Github issue or the forum? Please add a link
to it if that's the case.
RoPE loses precision for Llama / Gemma + Gemma logits.float() #29285 (comment)
Did you make sure to update the documentation with your changes? Here are the
documentation guidelines, and
here are tips on formatting docstrings.
Did you write any new necessary tests?
Who can review?
@gante @kwen2501 @angelayi @danielhanchen
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.