Skip to content

Gemma4: fix failed test cases#45568

Open
kaixuanliu wants to merge 8 commits intohuggingface:mainfrom
kaixuanliu:gemma4-fix
Open

Gemma4: fix failed test cases#45568
kaixuanliu wants to merge 8 commits intohuggingface:mainfrom
kaixuanliu:gemma4-fix

Conversation

@kaixuanliu
Copy link
Copy Markdown
Contributor

@kaixuanliu kaixuanliu commented Apr 22, 2026

What does this PR do?

This PR did several things:

  1. Skip some test cases that are not suitbale for gemma4 model
  2. Fix bug when attention_mask is None(tests/models/gemma4/test_modeling_gemma4.py::Gemma4Audio2TextModelTest::test_eager_matches_fa2_generate)
  3. fix some failed test cases related to test_flash_attn_x_from_config
  4. Add XPU related Expectations

Fixes # (issue)

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Who can review?

@ydshieh pls help review

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu kaixuanliu changed the title Gemma4 fix Gemma4: fix failed test cases Apr 22, 2026
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
@kaixuanliu kaixuanliu marked this pull request as ready for review April 22, 2026 09:25
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma4

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Comment on lines +1944 to +1945
if attention_mask is not None:
attention_mask = self._convert_4d_mask_to_blocked_5d(attention_mask)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cyrilvallez any opinion.

From PR descriptioin

Fix bug when attention_mask is None(tests/models/gemma4/test_modeling_gemma4.py::Gemma4Audio2TextModelTest::test_eager_matches_fa2_generate)

Comment on lines +135 to +147
@unittest.skip(
"Under non-bf16 dtypes, MoE grouped_mm falls back to "
"_grouped_mm_fallback_backward which is incompatible with torch.compile."
)
def test_flash_attn_2_can_compile_with_attention_mask_None_without_graph_break(self):
pass

@unittest.skip(
"Under non-bf16 dtypes, MoE grouped_mm falls back to "
"_grouped_mm_fallback_backward which is incompatible with torch.compile."
)
def test_torch_compile_for_training(self):
pass
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK for me. Just let one of @Cyrilvallez or @vasqu to also valid or comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are indeed failing on our CI too

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, iirc the fallback should be compile compatible cc @IlyasMoutawwakil

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which torch version is CI now at btw?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is torch compileable, just not any mode that uses cuda graphs (like max-autotune), where the torch.grouped_mm also fails on <sm90 (it also uses the fallback path)

Comment on lines +464 to +478
@require_flash_attn
@require_torch_accelerator
@mark.flash_attn_test
@slow
def test_flash_attn_2_from_config(self):
# Gemma4 requires mm_token_type_ids in train mode, so we test in eval mode
self.flash_attn_from_config(attn_implementation="flash_attention_2", test_fwd_in_train=False)

@require_flash_attn_3
@require_torch_gpu
@mark.flash_attn_3_test
@slow
def test_flash_attn_3_from_config(self):
# Gemma4 requires mm_token_type_ids in train mode, so we test in eval mode
self.flash_attn_from_config(attn_implementation="flash_attention_3", test_fwd_in_train=False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaixuanliu I didn't see these 2 failing on our Flash Attn CI job.

Could you share more info / error logs ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our flash attn ci doesn have FA3 - I think it's hard to install because you need to compile from source and it's much longer than FA2 build from source

Maybe we could add a separate FA4 CI - not sure how stable it is tho since it's still in beta

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, for FA3 and FA4, on my env they are skipped as well. I can delete these two.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no, see my comment below #45568 (comment)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I mean for

test_flash_attn_2_from_config

our CI is [PASSED]. So I am not sure why we need this fix, at least for FA2.

Our CI runner don't have FA3 or FA4, so they are skipped. But the question may still valid: do we really this fix?

pass

@unittest.skip("The base test does not pass image_position_ids and mm_token_type_ids required by Gemma4")
def test_flash_attn_4_inference_equivalence_right_padding(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have something like

def skip_non_greedy_generate(self):
skippable_tests = [
"test_sample_generate_dict_output", # return sequences > 1
"test_beam",
"test_contrastive",
"test_assisted",
"test_prompt_lookup",
"test_model_parallel_beam_search",
"test_generate_without_input_ids",
]
for test in skippable_tests:
if self._testMethodName.startswith(test):
self.skipTest(reason="Dia only supports greedy search / sampling with one sequence.")

But for FA? Imo it will always be quite a lot to skip these manually like that

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants