Allow kernel modules to declare their preferred mask function#44680
Allow kernel modules to declare their preferred mask function#44680
Conversation
`load_and_register_attn_kernel` hardcodes the mask function to `flash_attention_2` for all custom attention kernels. This is incorrect for kernels that need a different mask type (e.g., SDPA-style masks). Add support for a `MASK_FUNCTION` module-level attribute on kernel packages. If present, it specifies which mask type to use (e.g., "sdpa", "eager"). Falls back to "flash_attention_2" for backward compatibility when the attribute is absent. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds a way for hub-loaded custom attention kernels to choose which attention-mask factory function Transformers should use, instead of always defaulting to the flash_attention_2 mask behavior.
Changes:
- Update hub kernel registration to read an optional
MASK_FUNCTIONattribute from the kernel module, defaulting to"flash_attention_2". - Add tests covering the default fallback behavior and the custom
MASK_FUNCTION="sdpa"dispatch.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
src/transformers/integrations/hub_kernels.py |
Uses kernel-provided MASK_FUNCTION to select the mask interface at registration time. |
tests/kernels/test_kernels.py |
Adds unit tests ensuring the new dispatch behavior works and remains backward compatible. |
| ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS[mask_type]) | ||
|
|
There was a problem hiding this comment.
This looks overly cautious: the pull-request just replicates the way the kernel_function is registered.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
vasqu
left a comment
There was a problem hiding this comment.
Some initial comments from my side, unsure if we really want to save the mask function within the kernel - maybe an optional kwarg would do the job in a better way
| try: | ||
| ALL_ATTENTION_FUNCTIONS.pop(attn_impl, None) | ||
| ALL_MASK_ATTENTION_FUNCTIONS.pop(attn_impl, None) | ||
| except Exception as e: |
There was a problem hiding this comment.
Can we cleanup as in the other tests
transformers/tests/kernels/test_kernels.py
Lines 412 to 420 in 16a5b09
just for concistency sakes
| ALL_MASK_ATTENTION_FUNCTIONS[attn_impl], | ||
| ALL_MASK_ATTENTION_FUNCTIONS["sdpa"], | ||
| ) | ||
| try: |
|
|
||
| # Allow the kernel module to declare its preferred mask function (e.g., MASK_FUNCTION = "sdpa"). | ||
| # Falls back to "flash_attention_2" for backward compatibility with existing kernels. | ||
| mask_type = getattr(kernel, "MASK_FUNCTION", "flash_attention_2") |
There was a problem hiding this comment.
If the MASK_FUNCTION is not defined it will default to "flash_attention_2", so it is compatible with existing kernels. But maybe you meant something else ?
There was a problem hiding this comment.
The fallback is fine. The issue is how do we let kernels "register" their mask. The way this is currently done, it is kind of expected to be naturally integrated within the kernel or more explicit: In the __init__ of the kernel with the exact constant MASK_FUNCTION = "your_prefered_attn_mask_type"
This seems a bit extreme to me and maybe the proper way is to allow a kwarg (within this function) instead and register a new mask that way
|
@dacorvo Imo, the current solution is too reliant on the kernel to have the mask included as constant. Would it be possible to rewrite to have an optional kwarg that could be used instead? |
Fixes #44679
Summary
load_and_register_attn_kernelcurrently get hardcodedflash_attention_2mask dispatch, which produces 2D orNonemasksMASK_FUNCTIONmodule-level attribute on kernel packages — falls back to"flash_attention_2"for backward compatibilityChanges
src/transformers/integrations/hub_kernels.py: checkgetattr(kernel, "MASK_FUNCTION", "flash_attention_2")instead of hardcoding"flash_attention_2"tests/kernels/test_kernels.py: 2 new tests — default fallback and customMASK_FUNCTION="sdpa"dispatchTest plan
python -m pytest tests/kernels/test_kernels.py::TestAttentionKernelRegistration -xvs— all 5 tests pass (3 existing + 2 new)AI disclosure
This PR was developed with AI assistance (Claude). All changes reviewed and validated by a human contributor.
🤖 Generated with Claude Code