Skip to content

Fix executorch export with dynamic shapes#41559

Closed
justinchuby wants to merge 12 commits intohuggingface:mainfrom
justinchuby:patch-2
Closed

Fix executorch export with dynamic shapes#41559
justinchuby wants to merge 12 commits intohuggingface:mainfrom
justinchuby:patch-2

Conversation

@justinchuby
Copy link
Copy Markdown
Contributor

@justinchuby justinchuby commented Oct 14, 2025

What does this PR do?

This PR addd shape compatibility check for attention mask to ensure torch.export can reason about the logic without failing when exporting with dynamic shapes.

It additionally simplified the sdpa forward function for export only usage.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@Cyrilvallez @jackzhxng @guangy10

Add shape compatibility check for attention mask to ensure torch.export can reason about the logic without failing.
@justinchuby
Copy link
Copy Markdown
Contributor Author

Is fx tracing still supported? It doesn’t seem to be compatible with torch._check here?

@Cyrilvallez
Copy link
Copy Markdown
Member

Hey @justinchuby! Sorry but I don't get why this is needed. torch._check are very internal dynamo checks, and I don't see why we would have this here. Moreover, torch.export is only supported with a mask through the executorch integration, as otherwise the mask creation breaks due to using vmap

@justinchuby
Copy link
Copy Markdown
Contributor Author

@Cyrilvallez thanks for the suggestion - I can test with the executorch integration. For context, we are looking to enable onnx export via torch.export.

Could you share guidance on a potential integration « onnx » directory next to the current executorch directory, for correct usages of onnx export calls? I assume we can follow a similar model with executorch where we have an integration and a separate optimum-onnx project.

torch.export recently resolved the vmap issues, and the torch._check was useful for the dynamic shapes engine of the tracer to understand equivalence. I agree that it should be part of the torch.export integration.

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@justinchuby justinchuby changed the title Ensure attention mask shape matches key tensor Fix executorch export with dynamic shapes Oct 16, 2025
@justinchuby
Copy link
Copy Markdown
Contributor Author

I updated the PR to fix the executorch integration. Please take another look. Thanks!

@Cyrilvallez
Copy link
Copy Markdown
Member

torch.export recently resolved the vmap issues

Ouuhhh that's very very nice, I wasn't aware of it!

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I like this solution much more, as it's contained to export-specific functionalities only! However, I believe that we can simplify a lot!

Comment on lines +1208 to +1256
def sdpa_attention_forward_for_export(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
# This is same as sdpa_attention_forward but simplified and added torch._check
# torch.export dynamic shapes support
if kwargs.get("output_attentions", False):
logger.warning_once(
"`sdpa` attention does not support `output_attentions=True`."
" Please set your attention to `eager` if you want any of these features."
)
sdpa_kwargs = {}
if hasattr(module, "num_key_value_groups"):
# Always use enable_gqa for grouped query attention which is supported by torch.export
sdpa_kwargs = {"enable_gqa": True}

if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
# torch._check used to inform torch.export of the shape relationship
torch._check(attention_mask.shape[-1] == key.shape[-2])

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
if is_causal is None:
# The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
# This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
**sdpa_kwargs,
)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead simply do the torch._check and then call sdpa_attention_forward? Would be much simpler!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check is needed because of this line attention_mask = attention_mask[:, :, :, : key.shape[-2]] Do you have suggestion on how it can be avoided or moved out maybe? The check needs to follow this slice operation. The reason is that the when slicing and when key.shape[-2] is dynamic, torch.export wouldn't know if it is actually getting the exact full slice.

I also simplified the gqa logic to always use enable_gqa instead of repeat_interleave because the option is now supported.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhh, I see - actually this slicing ops should not be here at all, i.e. the mask should be correctly prepared upstream, which is the case for all recent models using the nice mask primitives.
I wanted to check if we still have older models for which it's still necessary at some point, here and in the eager_attention_forward attentions.
And if you start by slicing the mask, then _check, and then call sdpa_attention_forward, would it work? It would get re-sliced in the call, but to the exact same length, not sure if the _check would be lost then?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if you start by slicing the mask, then _check, and then call sdpa_attention_forward, would it work?

I can check that - but it is going to create a duplicated slice op which is not ideal. I also want to make sure enable_gqa is used. Maybe for that we can update sdpa_attention_forward to always use enable_gqa instead?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can get slicing out from the forward function, and ensure enable_gqa when exporting, then this patch can be simplified to _check then call sdpa_attention_forward

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What condition would you suggest I change to? Thanks

I don't have a good suggestion tbh, is there a way to tell we are tracing with export only? The issue is that we really can't expect people to only use export and falling back to the math kernel is likely more expensive than the manual repeats.

Maybe I can instead remove the slice and see what breaks?

Responding here but saw the other PR. I think that's the right way but let's wait on @Cyrilvallez to come back (next week).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can use torch.compiler.is_exporting() to check that.

Copy link
Copy Markdown
Contributor

@vasqu vasqu Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then let's add a condition on the gqa function (i.e. here

def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -> bool:
) to allow GQA in any case when we detect exporting. Is it limited to a specific version of export? Might need double check the torch version.

Side note and semi relevant: the masks are refactored a bit in #41852 so we won't need a workaround for export in the future.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! I can create a separate update for use_gqa_in_sdpa. The most important fix has been removing the slice on masks.

@Cyrilvallez
Copy link
Copy Markdown
Member

BTW, we will probably add a non-vmap path to the sdpa mask creation due to vmap being much slower, and the impact being noticeable for small models (#41639). This PR is still welcome though, as its focus is on the sdpa attention, not sdpa mask, but I thought I'd mention it just in case!

@Cyrilvallez
Copy link
Copy Markdown
Member

Is fx tracing still supported? It doesn’t seem to be compatible with torch._check here?

Concerning this, the answer is no: #41683!

@jiqing-feng
Copy link
Copy Markdown
Contributor

jiqing-feng commented Oct 23, 2025

I've tried this PR and found it didn't solve the performance regression in #41639 . Will it be fixed in the next PR?

cc @justinchuby @Cyrilvallez

@justinchuby
Copy link
Copy Markdown
Contributor Author

I've tried this PR and found it didn't solve the performance regression in #41639 . Will it be fixed in the next PR?

cc @justinchuby @Cyrilvallez

I believe this PR is unrelated to the issue you linked

@justinchuby justinchuby marked this pull request as draft October 28, 2025 21:25
@justinchuby
Copy link
Copy Markdown
Contributor Author

I think we can merge #41900, then I will create a separate PR to guard torch.export on dynamic shapes.

Comment thread tests/test_executorch.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants