Skip to content

Allow kernel modules to declare their preferred mask function#44680

Open
dacorvo wants to merge 4 commits intomainfrom
kernel-mask-function
Open

Allow kernel modules to declare their preferred mask function#44680
dacorvo wants to merge 4 commits intomainfrom
kernel-mask-function

Conversation

@dacorvo
Copy link
Copy Markdown
Contributor

@dacorvo dacorvo commented Mar 13, 2026

Fixes #44679

Summary

  • Custom attention kernels registered via load_and_register_attn_kernel currently get hardcoded flash_attention_2 mask dispatch, which produces 2D or None masks
  • Kernels that need SDPA-style 4D boolean masks (e.g., device-specific SDPA implementations) have no way to declare this
  • Add support for a MASK_FUNCTION module-level attribute on kernel packages — falls back to "flash_attention_2" for backward compatibility

Changes

  • src/transformers/integrations/hub_kernels.py: check getattr(kernel, "MASK_FUNCTION", "flash_attention_2") instead of hardcoding "flash_attention_2"
  • tests/kernels/test_kernels.py: 2 new tests — default fallback and custom MASK_FUNCTION="sdpa" dispatch

Test plan

  • python -m pytest tests/kernels/test_kernels.py::TestAttentionKernelRegistration -xvs — all 5 tests pass (3 existing + 2 new)

AI disclosure

This PR was developed with AI assistance (Claude). All changes reviewed and validated by a human contributor.

🤖 Generated with Claude Code

`load_and_register_attn_kernel` hardcodes the mask function to
`flash_attention_2` for all custom attention kernels. This is incorrect
for kernels that need a different mask type (e.g., SDPA-style masks).

Add support for a `MASK_FUNCTION` module-level attribute on kernel
packages. If present, it specifies which mask type to use (e.g.,
"sdpa", "eager"). Falls back to "flash_attention_2" for backward
compatibility when the attribute is absent.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings March 13, 2026 17:55
@dacorvo dacorvo requested a review from danieldk March 13, 2026 17:57
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a way for hub-loaded custom attention kernels to choose which attention-mask factory function Transformers should use, instead of always defaulting to the flash_attention_2 mask behavior.

Changes:

  • Update hub kernel registration to read an optional MASK_FUNCTION attribute from the kernel module, defaulting to "flash_attention_2".
  • Add tests covering the default fallback behavior and the custom MASK_FUNCTION="sdpa" dispatch.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
src/transformers/integrations/hub_kernels.py Uses kernel-provided MASK_FUNCTION to select the mask interface at registration time.
tests/kernels/test_kernels.py Adds unit tests ensuring the new dispatch behavior works and remains backward compatible.

Comment thread tests/kernels/test_kernels.py Outdated
Comment on lines +360 to 361
ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS[mask_type])

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks overly cautious: the pull-request just replicates the way the kernel_function is registered.

Comment thread tests/kernels/test_kernels.py Outdated
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments from my side, unsure if we really want to save the mask function within the kernel - maybe an optional kwarg would do the job in a better way

Comment on lines +433 to +436
try:
ALL_ATTENTION_FUNCTIONS.pop(attn_impl, None)
ALL_MASK_ATTENTION_FUNCTIONS.pop(attn_impl, None)
except Exception as e:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we cleanup as in the other tests

# Cleanup registration to avoid leaking functions across tests
try:
ALL_ATTENTION_FUNCTIONS.pop(attn_impl, None)
except Exception as e:
print(f"Could not clean up `ALL_ATTENTION_FUNCTIONS`: {e}")
try:
ALL_MASK_ATTENTION_FUNCTIONS.pop(attn_impl, None)
except Exception as e:
print(f"Could not clean up `ALL_MASK_ATTENTION_FUNCTIONS`: {e}")

just for concistency sakes

ALL_MASK_ATTENTION_FUNCTIONS[attn_impl],
ALL_MASK_ATTENTION_FUNCTIONS["sdpa"],
)
try:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Comment thread tests/kernels/test_kernels.py Outdated

# Allow the kernel module to declare its preferred mask function (e.g., MASK_FUNCTION = "sdpa").
# Falls back to "flash_attention_2" for backward compatibility with existing kernels.
mask_type = getattr(kernel, "MASK_FUNCTION", "flash_attention_2")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very heavy restriction that relies on the kernels to have this defined in there init as constant. Wouldn't it make more sense that we can pass this as optional kwarg?

Is this the way we want to go with this @danieldk @drbh

Copy link
Copy Markdown
Contributor Author

@dacorvo dacorvo Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the MASK_FUNCTION is not defined it will default to "flash_attention_2", so it is compatible with existing kernels. But maybe you meant something else ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback is fine. The issue is how do we let kernels "register" their mask. The way this is currently done, it is kind of expected to be naturally integrated within the kernel or more explicit: In the __init__ of the kernel with the exact constant MASK_FUNCTION = "your_prefered_attn_mask_type"

This seems a bit extreme to me and maybe the proper way is to allow a kwarg (within this function) instead and register a new mask that way

@dacorvo
Copy link
Copy Markdown
Contributor Author

dacorvo commented Apr 14, 2026

@vasqu @danieldk how can we make progress on this one ?

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 14, 2026

@dacorvo Imo, the current solution is too reliant on the kernel to have the mask included as constant. Would it be possible to rewrite to have an optional kwarg that could be used instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Allow kernel modules to declare their preferred mask function

4 participants