Skip to content

remove disabled autocast context and device type calc, as freqs are already force upcast to float precision (and fix https://github.com/pytorch/pytorch/issues/128394)#31959

Closed
lessw2020 wants to merge 2 commits intohuggingface:mainfrom
lessw2020:remove_autocast
Closed

Conversation

@lessw2020
Copy link
Copy Markdown

@lessw2020 lessw2020 commented Jul 14, 2024

What does this PR do?

This PR:
1 - Removes the redundant autocast and device_type calcs regarding the Rope frequency embeddings, both for code efficiency and more importantly to re-enable torch export tracing of llama models.

In order to ensure that the frequencies are cast to float32 (to avoid loss of precision over long context) a disabled autocontext was added along with a .float() upcast to the relevant variables (frequency and position).
However, this autocast is redundant, adds unneeded context overhead, and also creates a problem for torch.export for the llama models as autocast does not support meta device.

Since the upcasts are already hardcoded to the float dtype, a disabled autocast is not needed. Similarly the device type checks are then also not needed.

Thus, this PR simply removes the autocast context which should resolve the outstanding torch.export issue here:
pytorch/pytorch#128394

Fixes # (issue)
pytorch/pytorch#128394

Before submitting

Who can review?

@gante @kwen2501 @angelayi @danielhanchen

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@danielhanchen
Copy link
Copy Markdown
Contributor

The context manager was added to disable torch auto cast because the upcasting was ignored by amp auto cast in torch 2.2 and 2.3. I'm unsure if later versions of torch fixed it.

During fine-tuning and inference, RoPE must upcast to float32 to preserve precision, since mixed precision training will downcast everything that is not explicitly marked.

One way to test is this is adding a print statement to print the dtype, and if it is float32, then the auto cast context manager is unnecessary.

@danielhanchen
Copy link
Copy Markdown
Contributor

For inference yes the autocast context manager is redundant, but when enabling mixed precision, the autocast disabler had to be placed - I'm unsure if newer versions of torch can bypass this issue

@gante
Copy link
Copy Markdown
Contributor

gante commented Jul 16, 2024

👋

What I'm reading then is that we SHOULDN'T merge this PR until we confirm that pytorch doesn't change the type of an explicit .float() cast when autocast is active. Otherwise, we will get a regression (thank you @danielhanchen for confirming 🤗 )


Running make fix-copies on the transformers root folder and then pushing will make our CI green. In a nutshell, running that command will push the same change to other llama-like models.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmm AFAIK this will break what we tried to fix with : #29285 no?

@lessw2020
Copy link
Copy Markdown
Author

Hi all - we have a much better solution now by making some changes to how PP does the model tracing via PT export and with that, it now handles the autocast issue directly so no changes needed here in the transformer code.
Thus, I'm going to close this PR as the proper fix is within export/PP directly.
The PR for reference is here in PyTorch:
pytorch/pytorch#130998

@lessw2020 lessw2020 closed this Jul 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants