Skip to content

Fix Mllama torch.compile failure caused by new attention mask logic#44845

Closed
jiqing-feng wants to merge 6 commits intohuggingface:mainfrom
jiqing-feng:compile
Closed

Fix Mllama torch.compile failure caused by new attention mask logic#44845
jiqing-feng wants to merge 6 commits intohuggingface:mainfrom
jiqing-feng:compile

Conversation

@jiqing-feng
Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes torch.compile failure for Mllama after #42848 introduced a new unified attention mask creation path.

The root cause is a torch inductor C++ codegen bug: when padding_mask_function uses advanced tensor indexing (padding_mask[batch_idx, kv_idx]), the generated C++ boundary-check code references an undeclared variable (tmp2), causing g++ compilation to fail with CppCompileError.

This PR applies two changes:

  1. masking_utils.py: In the non-vmap sdpa_mask path, apply the padding mask separately using slice-based indexing (padding_mask[:, kv_offset : kv_offset + kv_length]) instead of merging it into the mask_function with advanced tensor indexing. This avoids the inductor codegen bug while producing identical results.

  2. modeling_mllama.py: Replace torch.arange-based fancy indexing with simple slice indexing when extracting cross_attention_mask and full_text_row_masked_out_mask for the current sequence position. This is semantically equivalent but more torch.compile-friendly.

Fixes #44458

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng jiqing-feng marked this pull request as ready for review March 19, 2026 06:15
Comment on lines +526 to +530
# Apply padding mask separately using compile-friendly slice-based indexing
if _apply_padding_separately:
attention_mask = attention_mask & padding_mask[:, kv_offset : kv_offset + kv_length].unsqueeze(
1
).unsqueeze(1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super annoying tbh 😅 cuda seems less strict

I tried to take a look to avoid this (#44850) but mllama seems to be one of the weird models where batch sizes don't match across masks / inputs which might cause this. I am not really a fan of this because it essentially hides the problem of the model: Masks and inputs with different base shapes.

The way you used compile is also not the intended way to use compile with our generate logic --> you should pass a compile config and let us handle compilation.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Thanks for the review @vasqu

Re: compile usageCompileConfig currently only supports CUDA and XPU devices (self.device.type in ["cuda", "xpu"]), so it skips compilation on CPU with a warning. Since the repro script runs on CPU, model.forward = torch.compile(model.forward) is the only way to trigger compilation. The underlying inductor codegen bug is the same regardless of how compile is invoked. Happy to update the repro script to use CompileConfig once CPU is supported or if a GPU repro is preferred.

Re: masking_utils.py change — This is a workaround for a torch inductor C++ codegen bug, not hiding a model-level problem. The issue is that padding_mask[batch_idx, kv_idx] (advanced indexing with two broadcast tensors) in padding_mask_function causes inductor to generate C++ boundary-check code referencing an undeclared variable tmp2, resulting in a g++ compilation failure. This happens regardless of whether batch sizes match — it's the advanced indexing pattern itself that triggers the codegen bug. The fix applies padding mask separately using slice-based indexing, which is semantically identical but generates simpler C++ code that doesn't trigger the bug. Only the non-vmap path is affected; the vmap path is unchanged.

I also noticed your #44850 includes the same slice indexing fix for cross_attention_mask / full_text_row_masked_out_mask that I have in modeling_mllama.py. Happy to incorporate your reshape fix as well if needed.

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: mllama

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 25, 2026

Sorry but adding this workaround for just one model is not the way to go @jiqing-feng

  1. You can set other devices over here cpu is just not guaranteed by default
  2. I'm more for having this fixed in torch directly than adding a workaround for just 1 model. (or we find a different fix within the modeling file instead).
  3. My PR was a quick experiment on checking if I could quickly fix it but to no avail (first steps seemed to work but then crashed again).

Sorry about being strict here but adding workaround for just one model in something as core as the mask function is too much.

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @vasqu . Thanks for your clarification. I agree that we should fix it in pytorch. I've opened a pytorch issue pytorch/pytorch#178244 to track it. Feel free to close this PR.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 26, 2026

We can keep it open for now to have a reference for the torch team. But moving it to draft instead, closing my PR

@vasqu vasqu marked this pull request as draft March 26, 2026 13:01
@jiqing-feng jiqing-feng deleted the compile branch April 20, 2026 02:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mllama compile failed after new attn mask

2 participants