Fix Mllama torch.compile failure caused by new attention mask logic#44845
Fix Mllama torch.compile failure caused by new attention mask logic#44845jiqing-feng wants to merge 6 commits intohuggingface:mainfrom
Conversation
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
| # Apply padding mask separately using compile-friendly slice-based indexing | ||
| if _apply_padding_separately: | ||
| attention_mask = attention_mask & padding_mask[:, kv_offset : kv_offset + kv_length].unsqueeze( | ||
| 1 | ||
| ).unsqueeze(1) |
There was a problem hiding this comment.
This is super annoying tbh 😅 cuda seems less strict
I tried to take a look to avoid this (#44850) but mllama seems to be one of the weird models where batch sizes don't match across masks / inputs which might cause this. I am not really a fan of this because it essentially hides the problem of the model: Masks and inputs with different base shapes.
The way you used compile is also not the intended way to use compile with our generate logic --> you should pass a compile config and let us handle compilation.
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
|
Thanks for the review @vasqu Re: compile usage — Re: masking_utils.py change — This is a workaround for a torch inductor C++ codegen bug, not hiding a model-level problem. The issue is that I also noticed your #44850 includes the same slice indexing fix for |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: mllama |
|
Sorry but adding this workaround for just one model is not the way to go @jiqing-feng
Sorry about being strict here but adding workaround for just one model in something as core as the mask function is too much. |
|
Hi @vasqu . Thanks for your clarification. I agree that we should fix it in pytorch. I've opened a pytorch issue pytorch/pytorch#178244 to track it. Feel free to close this PR. |
|
We can keep it open for now to have a reference for the torch team. But moving it to draft instead, closing my PR |
What does this PR do?
Fixes
torch.compilefailure for Mllama after #42848 introduced a new unified attention mask creation path.The root cause is a torch inductor C++ codegen bug: when
padding_mask_functionuses advanced tensor indexing (padding_mask[batch_idx, kv_idx]), the generated C++ boundary-check code references an undeclared variable (tmp2), causingg++compilation to fail withCppCompileError.This PR applies two changes:
masking_utils.py: In the non-vmapsdpa_maskpath, apply the padding mask separately using slice-based indexing (padding_mask[:, kv_offset : kv_offset + kv_length]) instead of merging it into themask_functionwith advanced tensor indexing. This avoids the inductor codegen bug while producing identical results.modeling_mllama.py: Replacetorch.arange-based fancy indexing with simple slice indexing when extractingcross_attention_maskandfull_text_row_masked_out_maskfor the current sequence position. This is semantically equivalent but moretorch.compile-friendly.Fixes #44458