Skip to content

logic to select tf32 API as per Pytorch version#42428

Merged
ArthurZucker merged 7 commits intohuggingface:mainfrom
khushali9:tf32-api-deprecation
Dec 1, 2025
Merged

logic to select tf32 API as per Pytorch version#42428
ArthurZucker merged 7 commits intohuggingface:mainfrom
khushali9:tf32-api-deprecation

Conversation

@khushali9
Copy link
Copy Markdown
Contributor

@khushali9 khushali9 commented Nov 26, 2025

What does this PR do?
The ask is to use fp32_precision instead of allow_tf32 for
Pytorch version >= 2.9.0 for CUDA
as pointed out in this doc mentioned in the #42371

I have also added test cases

Fixes #42371 (#42371)

Can you review @Rocketknight1 @ArthurZucker

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@khushali9
Copy link
Copy Markdown
Contributor Author

@Rocketknight1 here is the fresh clean PR, thanks for helping.

@wasertech
Copy link
Copy Markdown
Contributor

wasertech commented Nov 26, 2025

As mentioned in #42371 (comment), you probably want to add a CI test. There should already be a simple test in the .github/workflow but that is using the previous version of pytorch, it should be easy to duplicate it and update it to make the new test, then we shall see. nvm I couldn't see the tests result from the app the but did it alright 😄👍 conftest is the way to go! Thanks for being so quick.

@khushali9
Copy link
Copy Markdown
Contributor Author

As mentioned in #42371 (comment), you probably want to add a CI test. There should already be a simple test in the .github/workflow but that is using the previous version of pytorch, it should be easy to duplicate it and update it to make the new test, then we shall see. nvm I couldn't see the tests result from the app the but did it alright 😄👍 conftest is the way to go! Thanks for being so quick.

Thanks for checking, yes I already had tests but they failed on CI cause of pytorch version, but I did test for both version so we are good to go with this change.

Copy link
Copy Markdown
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this looks good to me now! cc core maintainers @Cyrilvallez @ArthurZucker just in case, though, since this is touching a lot of core files

tl;dr for reviewers the API in PyTorch for enabling TensorFloat32 changed in 2.9, and using the old API now triggers a UserWarning. This patch makes a helper function enable_tf32() that calls either the old or new API depending on the Torch version, so it works on all versions without warnings.

PyTorch docs here: https://docs.pytorch.org/docs/2.9/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices

Comment thread src/transformers/utils/import_utils.py Outdated
pytorch_version = version.parse(get_torch_version())
if pytorch_version >= version.parse("2.9.0"):
precision_mode = "tf32" if enable else "ieee"
torch.backends.cuda.matmul.fp32_precision = precision_mode
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I just noticed in the docs - can we also set torch.backends.fp32_precision = precision_mode? The docs are a bit unclear but I think that's the "global" setting.

Copy link
Copy Markdown
Contributor Author

@khushali9 khushali9 Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rocketknight1 can you point me how I can see my PR changes in the document link. If you are talking about pytorch doc then I saw it.

For torch.backends.fp32_precision=precision_mode

I can make change but I do not see anywhere in code torch.backends.allow_32 being used. But I did check pytorch doc and yes torch.backends.fp32_precision is global setting. I can update PR with that.

One more question this issue is for CUDA , but in our code I do see musa , like this
torch.backends.mudnn.allow_tf32 , which my code does not touch. Do you want me to update that API as well ? I was not sure about that as in example they only talked about cuda and not musa. But I feel I should make that changes too. let me know what you think.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @khushali9, I think what's happening is that torch.backends.fp32_precision is the "master" setting, which sets everything else by default. The lower-level settings like torch.backends.cuda.matmul.fp32_precision are only necessary if we want to override the master setting.

Therefore, I think torch.backends.fp32_precision is what we should do in future, and hopefully we won't need specific code for musa after that. However, the PyTorch documentation is a bit unclear about this. I agree that it's a bit different from our old API, but I think this is what we want the function to do!

Copy link
Copy Markdown
Contributor Author

@khushali9 khushali9 Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree @Rocketknight1 , thats why as soon as we find out it is > 2.9.0 I set global setting, but to take care of our code lines I also added those musa and cuda related changes in the block.

Shall I remove and just keep global setting in >2.9.0 block ? or is this PR good to go ?

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@khushali9
Copy link
Copy Markdown
Contributor Author

@Rocketknight1 Can you review it again, addressed changes related to setting global setting and also handling MUSA devices.

@Rocketknight1
Copy link
Copy Markdown
Member

@khushali9 I think we can just remove everything else and keep the global torch.backends.fp32_precision line in the torch >= 2.9.0 path. However, it'd be cool if you or someone else could test that that actually behaves the way we want it to!

@khushali9
Copy link
Copy Markdown
Contributor Author

@khushali9 I think we can just remove everything else and keep the global torch.backends.fp32_precision line in the torch >= 2.9.0 path. However, it'd be cool if you or someone else could test that that actually behaves the way we want it to!

Sure @Rocketknight1 I will remove those things, and test again.

@khushali9
Copy link
Copy Markdown
Contributor Author

@Rocketknight1 It is ready with new changes and I tested with 2.10.0, 2.9.0, 2.8.0 with enable equals True and False options.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks 🤗

@ArthurZucker ArthurZucker merged commit 7f5c209 into huggingface:main Dec 1, 2025
21 checks passed
@khushali9 khushali9 deleted the tf32-api-deprecation branch December 1, 2025 15:14
sarathc-cerebras pushed a commit to sarathc-cerebras/transformers that referenced this pull request Dec 7, 2025
* logic to select tf32 API as per Pytorch version

* new method added into __all__

* make style and quality ran

* added global setting for tf32

* added support for MUSA as well

* make style and quality run

* cleared >= 2.9.0 torch version logic
Comment thread src/transformers/utils/import_utils.py
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* logic to select tf32 API as per Pytorch version

* new method added into __all__

* make style and quality ran

* added global setting for tf32

* added support for MUSA as well

* make style and quality run

* cleared >= 2.9.0 torch version logic
Comment thread conftest.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Please use the new API settings to control TF32 behavior, ...

6 participants