Skip to content

Incorrect PyTorch version check for AuxRequest import in flex_attention #45446

@ZSLsherly

Description

@ZSLsherly

System Info

In src/transformers/integrations/flex_attention.py, the code currently checks for PyTorch version >= 2.9.0 to import AuxRequest from torch.nn.attention.flex_attention. However, AuxRequest was actually introduced in PyTorch 2.9.1.
According to the official PyTorch documentation, AuxRequest is available starting from version 2.9.1:https://docs.pytorch.org/docs/2.9/nn.attention.flex_attention.html.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch

#This will fail in PyTorch 2.9.0
from torch.nn.attention.flex_attention import AuxRequest

Expected behavior

The version check should be updated to >= 2.9.1 to ensure compatibility with PyTorch 2.9.0.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions