[refactor] set attention implementation#38974
[refactor] set attention implementation#38974zucchini-nlp merged 20 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Failing tests are unrelated, ready for review |
|
run-slow: bark, blip_2, instructblipvideo, modernbert, qwen2_5_vl, qwen2_vl, zamba |
|
This comment contains run-slow, running the specified jobs: models: ['models/bark', 'models/blip_2', 'models/instructblipvideo', 'models/modernbert', 'models/qwen2_5_vl', 'models/qwen2_vl', 'models/zamba'] |
|
Cyrilvallez
left a comment
There was a problem hiding this comment.
Extremely nice and welcome PR! Super glad to see this become a public API, and to simplify how it's done overall! 🤗🚀
| # package `flash-attn` can not be installed on Ascend NPU, ignore related validation logi | ||
| if importlib.util.find_spec("flash_attn") is None and not is_torch_npu_available(): | ||
| raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") | ||
| else: | ||
| # Check FA2 installed version compatibility | ||
| flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) | ||
| if torch.version.cuda: | ||
| if flash_attention_version < version.parse("2.1.0"): | ||
| raise ImportError( | ||
| f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}" | ||
| ) | ||
| elif not torch.cuda.is_available(): | ||
| raise ValueError( | ||
| f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device." | ||
| ) | ||
| else: | ||
| raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") | ||
| elif torch.version.hip: | ||
| if flash_attention_version < version.parse("2.0.4"): | ||
| raise ImportError( | ||
| f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Detected version {flash_attention_version}. {install_message}" | ||
| ) | ||
| else: | ||
| raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") |
There was a problem hiding this comment.
Don't we have a simple is_fa2_installed somewhere for all this? (same comment for fa3)
There was a problem hiding this comment.
Huh, indeed, ig we are checking the exact issue here why not is_flash_attn_2_available() to raise a proper and informative error. The import checker utils like is_flash_attn_2_available() usually return a simple boolean, and raise no errors
We try to raise informative errors in modeling code only, so I don't think we need to move this helper in import_utilis. WDYT?
There was a problem hiding this comment.
Alright, we can keep as-is, at least for now! This is already a big PR!
| attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min | ||
| attention_mask_tensor = (1.0 - attention_mask_tensor).int() | ||
| # Invert if floating, some attention interfaces pass already a boolean 4D mask | ||
| if attention_mask_tensor.is_floating_point(): | ||
| attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min | ||
| attention_mask_tensor = (1.0 - attention_mask_tensor).int() |
There was a problem hiding this comment.
Looks nasty indeed, but unrelated right?
There was a problem hiding this comment.
It is related unfortunately. The attention tensor at some point with SDPA came in as boolean 4D mask, while in eager mode it is a floating point mask. AFAIK we support both types of masks from users
This part tries to revert ops back and get a 2D boolean mask, if a 4D mask is found. The 2D mask is later used by Qwen's special 3D position ids constructor. Actually it is same as #39333, just realized
There was a problem hiding this comment.
I see but what I meant is it's still an issue independently of this refactor! Fine to fix it here though no worries!
| model = model_class(config).to(torch_device).to(dtype).eval() | ||
| model = model_class(copy.deepcopy(config)).to(torch_device).to(dtype).eval() |
There was a problem hiding this comment.
Can you clarify why we need to start deepcopy all the configs? 🤗
There was a problem hiding this comment.
Oh yeah, it was needed because when we do Model(config) the config changes its attention implementation in-place. It was like that always tbh, and I added deepcopy at some point when trying to remove config._attn_implementation_autoset attribute
Right now I reverted it back, it caused more problems than I thought. We could revert deepcopy as well, but I think it's more robust to keep it and ensure no in-place changes in original config. Another option is to use only Model.from_config() in tests, it deepcopies config internally
|
We also need to make proper documentation about attentions, I am stuck at finding the right place. Existing docs are a bit dispersed and don't explain much how the API works internally :( Let's leave it for next PR, might need to change doc structure slightly |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: aimv2, arcee, aria, audio_spectrogram_transformer, aya_vision, bamba, bark, bart, biogpt, bitnet, blenderbot, blenderbot_small, blip_2, chameleon, clip, cohere |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Alright, let's merge it!! 🔥🔥
|
(But indeed nice documentation about this seem super important for users, let's keep it in mind somewhere!) |
|
Yep, doc coming in the next PR |
* update * fix some tests * init from config, changes it in-place, add deepcopy in tests * fix modernbert * don't delete thsi config attr * update * style and copies * skip tests in generation * fix style * accidentally removed flash-attn-3, revert * docs * forgot about flags set to False * fix copies * address a few comments * fix copies * custom code BC
|
This PR does not seem to work as expected. It breaks the default attention mechanism of ModernBERT (it should be fa2 when the user does not specify it, but now it is sdpa). Moreover, this function is not actually called when the model is constructed through class initialization. |
* update * fix some tests * init from config, changes it in-place, add deepcopy in tests * fix modernbert * don't delete thsi config attr * update * style and copies * skip tests in generation * fix style * accidentally removed flash-attn-3, revert * docs * forgot about flags set to False * fix copies * address a few comments * fix copies * custom code BC
* update * fix some tests * init from config, changes it in-place, add deepcopy in tests * fix modernbert * don't delete thsi config attr * update * style and copies * skip tests in generation * fix style * accidentally removed flash-attn-3, revert * docs * forgot about flags set to False * fix copies * address a few comments * fix copies * custom code BC
* update * fix some tests * init from config, changes it in-place, add deepcopy in tests * fix modernbert * don't delete thsi config attr * update * style and copies * skip tests in generation * fix style * accidentally removed flash-attn-3, revert * docs * forgot about flags set to False * fix copies * address a few comments * fix copies * custom code BC
* update * fix some tests * init from config, changes it in-place, add deepcopy in tests * fix modernbert * don't delete thsi config attr * update * style and copies * skip tests in generation * fix style * accidentally removed flash-attn-3, revert * docs * forgot about flags set to False * fix copies * address a few comments * fix copies * custom code BC
* update * fix some tests * init from config, changes it in-place, add deepcopy in tests * fix modernbert * don't delete thsi config attr * update * style and copies * skip tests in generation * fix style * accidentally removed flash-attn-3, revert * docs * forgot about flags set to False * fix copies * address a few comments * fix copies * custom code BC
* update * fix some tests * init from config, changes it in-place, add deepcopy in tests * fix modernbert * don't delete thsi config attr * update * style and copies * skip tests in generation * fix style * accidentally removed flash-attn-3, revert * docs * forgot about flags set to False * fix copies * address a few comments * fix copies * custom code BC
* update * fix some tests * init from config, changes it in-place, add deepcopy in tests * fix modernbert * don't delete thsi config attr * update * style and copies * skip tests in generation * fix style * accidentally removed flash-attn-3, revert * docs * forgot about flags set to False * fix copies * address a few comments * fix copies * custom code BC
What does this PR do?
As per title, refactors attention implementation setting and makes it a public API. We should encourage users to
model.set_attn_implementation()whenever they want to change it after loading the model, instead of setting config's private attrmodel.config._attn_implementation="sdpa"After the clean-up, we will be calling attention implementation setter only once per pretrained model class, when init the module. Since
from_pretrained/from_configat the end callinit, we don't need to keep it as aclassmethod. Also setting attention after init allows us to know which backbones support attn or do not, and might be useful of we want to early raise errors in the future versions.Also, removed redundant flags for FA2/FA3. Realized that we can use one flag for both versions :)