Fix gemma4 has flash-attention incompatbile head-dim=512#45202
Fix gemma4 has flash-attention incompatbile head-dim=512#45202Qubitium wants to merge 10 commits intohuggingface:mainfrom
Conversation
| "sliding_attention", | ||
| "full_attention", | ||
| ] # similarly we want to test sharing on both types | ||
| self.global_head_dim = self.head_dim # gemma4 use a different head_dim for full and sliding layers |
There was a problem hiding this comment.
Why was this overriden? This caused tests to pass but does not reflect reality.
My fix broke more gemma4 tests which assumed single head-dim for all the layers when all the gemm4 models switch between head-dim sizes in different layers. Fixing these failed tests now. |
| @unittest.skip("Float8 quantization + TP numerical noise exceeds match threshold") | ||
| def test_tp_generation_quantized(self): | ||
| pass |
There was a problem hiding this comment.
FP8 quantized is failing ci due to very low accuracy of output post-quantization. The fix is outside the scope of this minimal (disable FA2 support in Gemma4) pr.
| tensor_parallel_atol = 2e-4 | ||
| tensor_parallel_rtol = 2e-4 |
There was a problem hiding this comment.
Fixes ci test failing due to too strict of tolerance. Verified fix on A100.
|
All relevant Gemm4 ci tests passing. The failed tests is some tokenizer issue appears to be unrelated. Ready for full review and merge. The actual fix is 2 lines. The rest are fixing bad unit test assumptions. |
Gemma4 cannot use flash_attention_2 (global layers have head_dim=512, FA2 max is 256 — huggingface/transformers#45202). SDPA has a known NaN bug with sliding window + padding (huggingface/transformers#32390) that was fixed for Gemma2 but not ported to Gemma3/4. Eager attention handles all padding correctly. Also add padding_side: right to the 4 original configs that were missing it (the mock configs already had it). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Gemma4 cannot use flash_attention_2 (global layers have head_dim=512, FA2 max is 256 — huggingface/transformers#45202). SDPA has a known NaN bug with sliding window + padding (huggingface/transformers#32390) that was fixed for Gemma2 but not ported to Gemma3/4. Eager attention handles all padding correctly. Also add padding_side: right to the 4 original configs that were missing it (the mock configs already had it). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
|
A lot of has changed with the recent changes on gemma4 @Qubitium, I think the best way would be to maybe really have this layer-wise solution where we set granularly what attention type we want. It is non-trivial tho because it needs to touch some core code Dao-AILab/flash-attention#2427 suggests that FA3 and FA4 will at least soon ish have support for these head dims. So it might be smarter to wait for these instead and then only allow those FA versions to be used |
|
Tl;dr: Not sure about the timeline, if FA3/4 is faster I'd rather change the scope of allowed FAs then have the layer wise solution (because it has a few nontrivial things imo) |
Agree on the per layer control now that gemma has dual attention type layers. Another issue is that the current on/off FA toggle has no FA version scope so it is no longer whether it supports FA but if what versions are allowed for this model/layer config. |
|
We have transformers/src/transformers/modeling_utils.py Lines 1143 to 1144 in c585eea which allows to specify the compatible FAs to be used (in the future). |
Nice! I see you are the one who added the new toggle 2 months. =) In that case, we can just remove the boolean state and just only use the full control map? |
|
It's more of an accompanying flag: If it is set to None, we assume general FA support for all flavors, otherwise only for those listed there (and we raise an error if it is not compatible). |
Forward thinking wise, perhaps we should just make it generic and use enum mask/list to store all the compatible attention? FA is only one of many that will exist so instead of |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Hey @Qubitium! Setting both head_dim to the same value in tests is simply an easy way to have usual tests pass as they expect the same shapes all the time. There's basically no way that the code can break or see a regression for different head_dim per layer, so that's why. Those tests run on a small version of the model, and FA technically works perfectly for gemma4 (only on released checkpoints, the head_dim is too large on full layers), so even if we updated the tests, they would still pass as we don't want a head_dim of 512 on the small models. It's a bit of an awkward situation, where the model supports FA, but FA itself does not support the values they used in real checkpoints.
Of course I thought about setting the flag to False, but I decided not to in the end (at least for now), as it's not a Transformers problem, it's a FA/checkpoints problem.
Also, I'm not in favor of adding a complicated per-layer attention system, as this will likely only ever be used for gemma4, and FA kernels will soon be updated to reflect gemma4's 512 head_dim, making the effort useless. It will only bloat attention dispatch IMO. Sdpa should not be really slower than FA in most cases.
| @unittest.skip("Gemma4 multimodal tiny test config exceeds the 1M common-test size cap") | ||
| def test_model_is_small(self): | ||
| pass |
There was a problem hiding this comment.
See this, we really cannot afford it on our CI, that's why we use smaller head dims
I have reverted all test related changes as you have shown the tests are valid in their ctx. My remaining point is only about the runtime contract exposed to users today. Released Gemma4 checkpoints use full-attention layers with global_head_dim=512, and currently released fa kernels do not support this. In this present tense situation, advertising Gemma4 as fa compatible is misconfiguration, even if future fa releases will fix it. I agree a blanket global disable is not ideal and FA may release a fix tomorrow for all we know. But until a released flash kernel supports Gemma4’s published checkpoint, the current |
* feat: add mock VLM dataset and Gemma4 pretokenize support Add build_mock_vlm_dataset for VLM benchmarking and testing without real data downloads. Generates random PIL images + dummy text in the standard Automodel conversation format. Key changes: - New mock VLM dataset with max_length-driven response generation - PreTokenizedDatasetWrapper: add truncate mode (labels built before truncation) and Gemma4 tensor support (image_position_ids, mm_token_type_ids) - pad_collate_fn: handle mm_token_type_ids padding and image_position_ids concatenation for Gemma4 - Recipe: auto-enable pretokenize/truncate when max_length is set on the dataset config - Gemma4 4B mock config with right-padding (fixes SDPA + left-padding NaN gradient bug) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: rename test_mock.py to test_mock_vlm.py to avoid pytest collection conflict pytest uses flat module names when __init__.py is absent, causing a name collision between tests/unit_tests/datasets/llm/test_mock.py and tests/unit_tests/datasets/vlm/test_mock.py. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: handle 1D tensor truncation and add truncate mode test Address Claude review comments: - Truncation dict comprehension now also slices 1D tensors whose length matches seq_len (e.g. mm_token_type_ids returned as (seq_len,)) - Add test_pretokenized_wrapper_truncate_mode to verify truncation clips input_ids, labels, attention_mask, and 1D mm_token_type_ids Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: correct mock patch paths for local imports in truncate test Local imports inside __getitem__ bind from source modules at call time, so patches must target the source (collate_fns, fake_image) not the datasets module. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: remove unreliable label-content assertion from truncate test The build_labels_from_template mock cannot reliably override the local import inside __getitem__. Drop the label-content check and keep the shape-based assertions which verify truncation works correctly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: remove PreTokenizedDatasetWrapper truncate test The test requires mocking too many internal imports (local imports inside __getitem__) to work reliably in CI. The 1D tensor truncation fix is a straightforward change covered by GPU integration tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: switch Gemma4 configs to eager attention and right-padding Gemma4 cannot use flash_attention_2 (global layers have head_dim=512, FA2 max is 256 — huggingface/transformers#45202). SDPA has a known NaN bug with sliding window + padding (huggingface/transformers#32390) that was fixed for Gemma2 but not ported to Gemma3/4. Eager attention handles all padding correctly. Also add padding_side: right to the 4 original configs that were missing it (the mock configs already had it). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: remove activation_checkpointing from mock configs to match originals The original gemma4_4b.yaml and gemma4_26b_a4b_moe.yaml configs do not use activation checkpointing, so the mock variants should not either. Also bump 4B mock nproc-per-node from 2 to 8 (needed without AC at seq_len=2048). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * test: add unit tests for PreTokenizedDatasetWrapper truncation and pad_collate_fn mm_token_type_ids - Test that truncate=True produces exact max_length shapes for input_ids, attention_mask, and labels - Test that labels are not all -100 after truncation - Test that mm_token_type_ids (1D) is truncated correctly - Test that pad_collate_fn pads and trims mm_token_type_ids with autoregressive shift for both 1D and 2D inputs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: assert mm_token_type_ids presence instead of conditional guard Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: mock internal helpers in PreTokenizedDatasetWrapper tests The wrapper's __getitem__ calls _preload_media, _conversation_has_media, build_labels_from_template, etc. which need a full processor. Mock these internals so the tests can run without a real HF processor/model. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: place mock labels in first quarter so they survive truncation Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: compare inputs_embeds generate against input_ids generate, not uncached decode The Mamba mixer uses different CUDA kernels for cached (chunk_scan + selective_state_update) vs uncached (split_conv1d_scan_combined) paths. These are mathematically equivalent but not bit-identical in bf16, causing token divergence after the first step. Compare against generate(input_ids=...) instead, which uses the same cached kernel path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * revert Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> --------- Signed-off-by: HuiyingLi <willwin.lee@gmail.com> Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Co-authored-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma4 |
* feat: add mock VLM dataset and Gemma4 pretokenize support Add build_mock_vlm_dataset for VLM benchmarking and testing without real data downloads. Generates random PIL images + dummy text in the standard Automodel conversation format. Key changes: - New mock VLM dataset with max_length-driven response generation - PreTokenizedDatasetWrapper: add truncate mode (labels built before truncation) and Gemma4 tensor support (image_position_ids, mm_token_type_ids) - pad_collate_fn: handle mm_token_type_ids padding and image_position_ids concatenation for Gemma4 - Recipe: auto-enable pretokenize/truncate when max_length is set on the dataset config - Gemma4 4B mock config with right-padding (fixes SDPA + left-padding NaN gradient bug) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: rename test_mock.py to test_mock_vlm.py to avoid pytest collection conflict pytest uses flat module names when __init__.py is absent, causing a name collision between tests/unit_tests/datasets/llm/test_mock.py and tests/unit_tests/datasets/vlm/test_mock.py. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: handle 1D tensor truncation and add truncate mode test Address Claude review comments: - Truncation dict comprehension now also slices 1D tensors whose length matches seq_len (e.g. mm_token_type_ids returned as (seq_len,)) - Add test_pretokenized_wrapper_truncate_mode to verify truncation clips input_ids, labels, attention_mask, and 1D mm_token_type_ids Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: correct mock patch paths for local imports in truncate test Local imports inside __getitem__ bind from source modules at call time, so patches must target the source (collate_fns, fake_image) not the datasets module. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: remove unreliable label-content assertion from truncate test The build_labels_from_template mock cannot reliably override the local import inside __getitem__. Drop the label-content check and keep the shape-based assertions which verify truncation works correctly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: remove PreTokenizedDatasetWrapper truncate test The test requires mocking too many internal imports (local imports inside __getitem__) to work reliably in CI. The 1D tensor truncation fix is a straightforward change covered by GPU integration tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: switch Gemma4 configs to eager attention and right-padding Gemma4 cannot use flash_attention_2 (global layers have head_dim=512, FA2 max is 256 — huggingface/transformers#45202). SDPA has a known NaN bug with sliding window + padding (huggingface/transformers#32390) that was fixed for Gemma2 but not ported to Gemma3/4. Eager attention handles all padding correctly. Also add padding_side: right to the 4 original configs that were missing it (the mock configs already had it). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: remove activation_checkpointing from mock configs to match originals The original gemma4_4b.yaml and gemma4_26b_a4b_moe.yaml configs do not use activation checkpointing, so the mock variants should not either. Also bump 4B mock nproc-per-node from 2 to 8 (needed without AC at seq_len=2048). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * test: add unit tests for PreTokenizedDatasetWrapper truncation and pad_collate_fn mm_token_type_ids - Test that truncate=True produces exact max_length shapes for input_ids, attention_mask, and labels - Test that labels are not all -100 after truncation - Test that mm_token_type_ids (1D) is truncated correctly - Test that pad_collate_fn pads and trims mm_token_type_ids with autoregressive shift for both 1D and 2D inputs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: assert mm_token_type_ids presence instead of conditional guard Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: mock internal helpers in PreTokenizedDatasetWrapper tests The wrapper's __getitem__ calls _preload_media, _conversation_has_media, build_labels_from_template, etc. which need a full processor. Mock these internals so the tests can run without a real HF processor/model. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: place mock labels in first quarter so they survive truncation Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: compare inputs_embeds generate against input_ids generate, not uncached decode The Mamba mixer uses different CUDA kernels for cached (chunk_scan + selective_state_update) vs uncached (split_conv1d_scan_combined) paths. These are mathematically equivalent but not bit-identical in bf16, causing token divergence after the first step. Compare against generate(input_ids=...) instead, which uses the same cached kernel path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * revert Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> --------- Signed-off-by: HuiyingLi <willwin.lee@gmail.com> Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Co-authored-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
What does this PR do?
Disable FlashAttention support for Gemm4 which FA cannot suport due to global.head-dim=512.
I am very confused at the current code/test for Gemma4. I ran real inference using transformer
mainandfathrows head-dim errors yet I see there aretestsin transformers with comments noting thatFA2generatinggibberishand forcing changing head-dim value so the tests may pass. Is the Gemma4 models uploaded by Google have brokenconfig.jsonwhen it comes tohead-dimor did Transformers was testing on an unreleased/alpha version of Gemma4 or the test code was just bad? The tests should have never passed.Bad testing code force the the global_head_dim to match self.head_dim which allowed FA to run. Again, even the notes mentioned garbage output with FA2 enabled.
# TODO: raushan FA2 generates gibberish for no reason, check laterBefore submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @Cyrilvallez @vasqu
@douglas-reid