Skip to content

[refactor] set attention implementation#38974

Merged
zucchini-nlp merged 20 commits intohuggingface:mainfrom
zucchini-nlp:refactor-set-attn-impl
Jul 15, 2025
Merged

[refactor] set attention implementation#38974
zucchini-nlp merged 20 commits intohuggingface:mainfrom
zucchini-nlp:refactor-set-attn-impl

Conversation

@zucchini-nlp
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp commented Jun 23, 2025

What does this PR do?

As per title, refactors attention implementation setting and makes it a public API. We should encourage users to model.set_attn_implementation() whenever they want to change it after loading the model, instead of setting config's private attr model.config._attn_implementation="sdpa"

After the clean-up, we will be calling attention implementation setter only once per pretrained model class, when init the module. Since from_pretrained/from_config at the end call init, we don't need to keep it as a classmethod. Also setting attention after init allows us to know which backbones support attn or do not, and might be useful of we want to early raise errors in the future versions.

Also, removed redundant flags for FA2/FA3. Realized that we can use one flag for both versions :)

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread src/transformers/models/dpt/modeling_dpt.py Outdated
@zucchini-nlp
Copy link
Copy Markdown
Member Author

zucchini-nlp commented Jul 7, 2025

Failing tests are unrelated, ready for review

@zucchini-nlp zucchini-nlp requested a review from Cyrilvallez July 7, 2025 10:04
@zucchini-nlp
Copy link
Copy Markdown
Member Author

run-slow: bark, blip_2, instructblipvideo, modernbert, qwen2_5_vl, qwen2_vl, zamba

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jul 7, 2025

This comment contains run-slow, running the specified jobs:

models: ['models/bark', 'models/blip_2', 'models/instructblipvideo', 'models/modernbert', 'models/qwen2_5_vl', 'models/qwen2_vl', 'models/zamba']
quantizations: [] ...

@zucchini-nlp
Copy link
Copy Markdown
Member Author

  • removed redundant flags for FA2/FA3. Realized that we can use one flag for both versions :)

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extremely nice and welcome PR! Super glad to see this become a public API, and to simplify how it's done overall! 🤗🚀

Comment thread src/transformers/modeling_utils.py Outdated
Comment thread src/transformers/modeling_utils.py Outdated
Comment on lines +2444 to +2467
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logi
if importlib.util.find_spec("flash_attn") is None and not is_torch_npu_available():
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
else:
# Check FA2 installed version compatibility
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
if torch.version.cuda:
if flash_attention_version < version.parse("2.1.0"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
)
elif not torch.cuda.is_available():
raise ValueError(
f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device."
)
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
elif torch.version.hip:
if flash_attention_version < version.parse("2.0.4"):
raise ImportError(
f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Detected version {flash_attention_version}. {install_message}"
)
else:
raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have a simple is_fa2_installed somewhere for all this? (same comment for fa3)

Copy link
Copy Markdown
Member Author

@zucchini-nlp zucchini-nlp Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, indeed, ig we are checking the exact issue here why not is_flash_attn_2_available() to raise a proper and informative error. The import checker utils like is_flash_attn_2_available() usually return a simple boolean, and raise no errors

We try to raise informative errors in modeling code only, so I don't think we need to move this helper in import_utilis. WDYT?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, we can keep as-is, at least for now! This is already a big PR!

Comment on lines +1230 to +1232
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
# Invert if floating, some attention interfaces pass already a boolean 4D mask
if attention_mask_tensor.is_floating_point():
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nasty indeed, but unrelated right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is related unfortunately. The attention tensor at some point with SDPA came in as boolean 4D mask, while in eager mode it is a floating point mask. AFAIK we support both types of masks from users

This part tries to revert ops back and get a 2D boolean mask, if a 4D mask is found. The 2D mask is later used by Qwen's special 3D position ids constructor. Actually it is same as #39333, just realized

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see but what I meant is it's still an issue independently of this refactor! Fine to fix it here though no worries!

Comment on lines -2010 to +2018
model = model_class(config).to(torch_device).to(dtype).eval()
model = model_class(copy.deepcopy(config)).to(torch_device).to(dtype).eval()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify why we need to start deepcopy all the configs? 🤗

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, it was needed because when we do Model(config) the config changes its attention implementation in-place. It was like that always tbh, and I added deepcopy at some point when trying to remove config._attn_implementation_autoset attribute

Right now I reverted it back, it caused more problems than I thought. We could revert deepcopy as well, but I think it's more robust to keep it and ensure no in-place changes in original config. Another option is to use only Model.from_config() in tests, it deepcopies config internally

@zucchini-nlp
Copy link
Copy Markdown
Member Author

We also need to make proper documentation about attentions, I am stuck at finding the right place. Existing docs are a bit dispersed and don't explain much how the API works internally :(

Let's leave it for next PR, might need to change doc structure slightly

Comment thread src/transformers/modeling_utils.py Outdated
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: aimv2, arcee, aria, audio_spectrogram_transformer, aya_vision, bamba, bark, bart, biogpt, bitnet, blenderbot, blenderbot_small, blip_2, chameleon, clip, cohere

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, let's merge it!! 🔥🔥

@Cyrilvallez
Copy link
Copy Markdown
Member

(But indeed nice documentation about this seem super important for users, let's keep it in mind somewhere!)

@zucchini-nlp
Copy link
Copy Markdown
Member Author

Yep, doc coming in the next PR

@zucchini-nlp zucchini-nlp merged commit 8d6259b into huggingface:main Jul 15, 2025
25 checks passed
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

rjgleaton pushed a commit to rjgleaton/transformers that referenced this pull request Jul 17, 2025
* update

* fix some tests

* init from config, changes it in-place, add deepcopy in tests

* fix modernbert

* don't delete thsi config attr

* update

* style and copies

* skip tests in generation

* fix style

* accidentally removed flash-attn-3, revert

* docs

* forgot about flags set to False

* fix copies

* address a few comments

* fix copies

* custom code BC
@rangehow
Copy link
Copy Markdown
Contributor

This PR does not seem to work as expected. It breaks the default attention mechanism of ModernBERT (it should be fa2 when the user does not specify it, but now it is sdpa). Moreover, this function is not actually called when the model is constructed through class initialization.

@zucchini-nlp @ArthurZucker

zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* update

* fix some tests

* init from config, changes it in-place, add deepcopy in tests

* fix modernbert

* don't delete thsi config attr

* update

* style and copies

* skip tests in generation

* fix style

* accidentally removed flash-attn-3, revert

* docs

* forgot about flags set to False

* fix copies

* address a few comments

* fix copies

* custom code BC
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* update

* fix some tests

* init from config, changes it in-place, add deepcopy in tests

* fix modernbert

* don't delete thsi config attr

* update

* style and copies

* skip tests in generation

* fix style

* accidentally removed flash-attn-3, revert

* docs

* forgot about flags set to False

* fix copies

* address a few comments

* fix copies

* custom code BC
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* update

* fix some tests

* init from config, changes it in-place, add deepcopy in tests

* fix modernbert

* don't delete thsi config attr

* update

* style and copies

* skip tests in generation

* fix style

* accidentally removed flash-attn-3, revert

* docs

* forgot about flags set to False

* fix copies

* address a few comments

* fix copies

* custom code BC
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* update

* fix some tests

* init from config, changes it in-place, add deepcopy in tests

* fix modernbert

* don't delete thsi config attr

* update

* style and copies

* skip tests in generation

* fix style

* accidentally removed flash-attn-3, revert

* docs

* forgot about flags set to False

* fix copies

* address a few comments

* fix copies

* custom code BC
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* update

* fix some tests

* init from config, changes it in-place, add deepcopy in tests

* fix modernbert

* don't delete thsi config attr

* update

* style and copies

* skip tests in generation

* fix style

* accidentally removed flash-attn-3, revert

* docs

* forgot about flags set to False

* fix copies

* address a few comments

* fix copies

* custom code BC
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* update

* fix some tests

* init from config, changes it in-place, add deepcopy in tests

* fix modernbert

* don't delete thsi config attr

* update

* style and copies

* skip tests in generation

* fix style

* accidentally removed flash-attn-3, revert

* docs

* forgot about flags set to False

* fix copies

* address a few comments

* fix copies

* custom code BC
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* update

* fix some tests

* init from config, changes it in-place, add deepcopy in tests

* fix modernbert

* don't delete thsi config attr

* update

* style and copies

* skip tests in generation

* fix style

* accidentally removed flash-attn-3, revert

* docs

* forgot about flags set to False

* fix copies

* address a few comments

* fix copies

* custom code BC
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants