Update Torch version check for flex attention#45445
Update Torch version check for flex attention#45445ZSLsherly wants to merge 1 commit intohuggingface:mainfrom
Conversation
This commit corrects the PyTorch version check for importing `AuxRequest` from `torch.nn.attention.flex_attention`. The `AuxRequest` class was actually introduced in PyTorch 2.9.1, not 2.9.0. The current code attempts to import it for any version >= 2.9.0, which causes an `ImportError` in PyTorch 2.9.0 environments.
There was a problem hiding this comment.
Pull request overview
Adjusts the flex-attention integration’s PyTorch feature-gating to avoid importing AuxRequest on unsupported PyTorch versions, preventing ImportError in PyTorch 2.9.0 environments.
Changes:
- Update
_TORCH_FLEX_USE_AUXthreshold fromtorch>=2.9.0totorch>=2.9.1forAuxRequestimport gating.
|
|
||
|
|
||
| _TORCH_FLEX_USE_AUX = is_torch_greater_or_equal("2.9.0") | ||
| _TORCH_FLEX_USE_AUX = is_torch_greater_or_equal("2.9.1") |
There was a problem hiding this comment.
After updating the AuxRequest availability check to 2.9.1, the docstring in get_flex_attention_lse_kwargs still says the behavior changes in torch 2.9 and refers to the wrong argument names (mentions aux_request / “python version”, but the code uses return_aux and the decision is based on the torch version). Please update that docstring to match the new 2.9.1 threshold and the actual kwargs.
Rocketknight1
left a comment
There was a problem hiding this comment.
Yep, sounds good to me!
|
No, not correct |
|
Wait, is the issue purely hallucinated? |
|
I've checked locally runs fine without errors; the docs just start from 2.9.1 (I guess the latest minor for that version) |
|
I suspect it was purely based on the docs but wasn't run at all by the contributor |
|
My fault, I was also trusting the docs link! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Sorry, this one's on me. Retested with a clean environment and you're right—torch 2.9.0 works fine. Looks like my old env was just borked. Thanks for the spot |
|
No worries, glad it got resolved |
This commit corrects the PyTorch version check for importing
AuxRequestfromtorch.nn.attention.flex_attention(line51). TheAuxRequestclass was actually introduced in PyTorch 2.9.1, not 2.9.0. The current code attempts to import it for any version >= 2.9.0, which causes anImportErrorin PyTorch 2.9.0 environments.You can view the introduced version of AuxRequest at https://docs.pytorch.org/docs/2.9/nn.attention.flex_attention.html.
Fixes #45446