logic to select tf32 API as per Pytorch version#42428
logic to select tf32 API as per Pytorch version#42428ArthurZucker merged 7 commits intohuggingface:mainfrom
Conversation
|
@Rocketknight1 here is the fresh clean PR, thanks for helping. |
|
|
Thanks for checking, yes I already had tests but they failed on CI cause of pytorch version, but I did test for both version so we are good to go with this change. |
There was a problem hiding this comment.
Yes, this looks good to me now! cc core maintainers @Cyrilvallez @ArthurZucker just in case, though, since this is touching a lot of core files
tl;dr for reviewers the API in PyTorch for enabling TensorFloat32 changed in 2.9, and using the old API now triggers a UserWarning. This patch makes a helper function enable_tf32() that calls either the old or new API depending on the Torch version, so it works on all versions without warnings.
PyTorch docs here: https://docs.pytorch.org/docs/2.9/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices
| pytorch_version = version.parse(get_torch_version()) | ||
| if pytorch_version >= version.parse("2.9.0"): | ||
| precision_mode = "tf32" if enable else "ieee" | ||
| torch.backends.cuda.matmul.fp32_precision = precision_mode |
There was a problem hiding this comment.
One thing I just noticed in the docs - can we also set torch.backends.fp32_precision = precision_mode? The docs are a bit unclear but I think that's the "global" setting.
There was a problem hiding this comment.
@Rocketknight1 can you point me how I can see my PR changes in the document link. If you are talking about pytorch doc then I saw it.
For torch.backends.fp32_precision=precision_mode
I can make change but I do not see anywhere in code torch.backends.allow_32 being used. But I did check pytorch doc and yes torch.backends.fp32_precision is global setting. I can update PR with that.
One more question this issue is for CUDA , but in our code I do see musa , like this
torch.backends.mudnn.allow_tf32 , which my code does not touch. Do you want me to update that API as well ? I was not sure about that as in example they only talked about cuda and not musa. But I feel I should make that changes too. let me know what you think.
There was a problem hiding this comment.
Hi @khushali9, I think what's happening is that torch.backends.fp32_precision is the "master" setting, which sets everything else by default. The lower-level settings like torch.backends.cuda.matmul.fp32_precision are only necessary if we want to override the master setting.
Therefore, I think torch.backends.fp32_precision is what we should do in future, and hopefully we won't need specific code for musa after that. However, the PyTorch documentation is a bit unclear about this. I agree that it's a bit different from our old API, but I think this is what we want the function to do!
There was a problem hiding this comment.
Yeah I agree @Rocketknight1 , thats why as soon as we find out it is > 2.9.0 I set global setting, but to take care of our code lines I also added those musa and cuda related changes in the block.
Shall I remove and just keep global setting in >2.9.0 block ? or is this PR good to go ?
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@Rocketknight1 Can you review it again, addressed changes related to setting global setting and also handling MUSA devices. |
|
@khushali9 I think we can just remove everything else and keep the global |
Sure @Rocketknight1 I will remove those things, and test again. |
|
@Rocketknight1 It is ready with new changes and I tested with 2.10.0, 2.9.0, 2.8.0 with enable equals True and False options. |
* logic to select tf32 API as per Pytorch version * new method added into __all__ * make style and quality ran * added global setting for tf32 * added support for MUSA as well * make style and quality run * cleared >= 2.9.0 torch version logic
* logic to select tf32 API as per Pytorch version * new method added into __all__ * make style and quality ran * added global setting for tf32 * added support for MUSA as well * make style and quality run * cleared >= 2.9.0 torch version logic
What does this PR do?
The ask is to use fp32_precision instead of allow_tf32 for
Pytorch version >= 2.9.0 for CUDA
as pointed out in this doc mentioned in the #42371
I have also added test cases
Fixes #42371 (#42371)
Can you review @Rocketknight1 @ArthurZucker
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.