Fix executorch export with dynamic shapes#41559
Fix executorch export with dynamic shapes#41559justinchuby wants to merge 12 commits intohuggingface:mainfrom
Conversation
Add shape compatibility check for attention mask to ensure torch.export can reason about the logic without failing.
|
Is fx tracing still supported? It doesn’t seem to be compatible with torch._check here? |
|
Hey @justinchuby! Sorry but I don't get why this is needed. |
|
@Cyrilvallez thanks for the suggestion - I can test with the executorch integration. For context, we are looking to enable onnx export via torch.export. Could you share guidance on a potential integration « onnx » directory next to the current executorch directory, for correct usages of onnx export calls? I assume we can follow a similar model with executorch where we have an integration and a separate optimum-onnx project. torch.export recently resolved the vmap issues, and the torch._check was useful for the dynamic shapes engine of the tracer to understand equivalence. I agree that it should be part of the torch.export integration. |
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
|
I updated the PR to fix the executorch integration. Please take another look. Thanks! |
Ouuhhh that's very very nice, I wasn't aware of it! |
| def sdpa_attention_forward_for_export( | ||
| module: torch.nn.Module, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attention_mask: Optional[torch.Tensor], | ||
| dropout: float = 0.0, | ||
| scaling: Optional[float] = None, | ||
| is_causal: Optional[bool] = None, | ||
| **kwargs, | ||
| ) -> tuple[torch.Tensor, None]: | ||
| # This is same as sdpa_attention_forward but simplified and added torch._check | ||
| # torch.export dynamic shapes support | ||
| if kwargs.get("output_attentions", False): | ||
| logger.warning_once( | ||
| "`sdpa` attention does not support `output_attentions=True`." | ||
| " Please set your attention to `eager` if you want any of these features." | ||
| ) | ||
| sdpa_kwargs = {} | ||
| if hasattr(module, "num_key_value_groups"): | ||
| # Always use enable_gqa for grouped query attention which is supported by torch.export | ||
| sdpa_kwargs = {"enable_gqa": True} | ||
|
|
||
| if attention_mask is not None and attention_mask.ndim == 4: | ||
| attention_mask = attention_mask[:, :, :, : key.shape[-2]] | ||
| # torch._check used to inform torch.export of the shape relationship | ||
| torch._check(attention_mask.shape[-1] == key.shape[-2]) | ||
|
|
||
| # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment | ||
| # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. | ||
| # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` | ||
| if is_causal is None: | ||
| # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag | ||
| # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns | ||
| is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True) | ||
|
|
||
| attn_output = torch.nn.functional.scaled_dot_product_attention( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask=attention_mask, | ||
| dropout_p=dropout, | ||
| scale=scaling, | ||
| is_causal=is_causal, | ||
| **sdpa_kwargs, | ||
| ) | ||
| attn_output = attn_output.transpose(1, 2).contiguous() | ||
|
|
||
| return attn_output, None |
There was a problem hiding this comment.
Can we instead simply do the torch._check and then call sdpa_attention_forward? Would be much simpler!
There was a problem hiding this comment.
The check is needed because of this line attention_mask = attention_mask[:, :, :, : key.shape[-2]] Do you have suggestion on how it can be avoided or moved out maybe? The check needs to follow this slice operation. The reason is that the when slicing and when key.shape[-2] is dynamic, torch.export wouldn't know if it is actually getting the exact full slice.
I also simplified the gqa logic to always use enable_gqa instead of repeat_interleave because the option is now supported.
There was a problem hiding this comment.
Ohhh, I see - actually this slicing ops should not be here at all, i.e. the mask should be correctly prepared upstream, which is the case for all recent models using the nice mask primitives.
I wanted to check if we still have older models for which it's still necessary at some point, here and in the eager_attention_forward attentions.
And if you start by slicing the mask, then _check, and then call sdpa_attention_forward, would it work? It would get re-sliced in the call, but to the exact same length, not sure if the _check would be lost then?
There was a problem hiding this comment.
And if you start by slicing the mask, then _check, and then call sdpa_attention_forward, would it work?
I can check that - but it is going to create a duplicated slice op which is not ideal. I also want to make sure enable_gqa is used. Maybe for that we can update sdpa_attention_forward to always use enable_gqa instead?
There was a problem hiding this comment.
If we can get slicing out from the forward function, and ensure enable_gqa when exporting, then this patch can be simplified to _check then call sdpa_attention_forward
There was a problem hiding this comment.
What condition would you suggest I change to? Thanks
I don't have a good suggestion tbh, is there a way to tell we are tracing with export only? The issue is that we really can't expect people to only use export and falling back to the math kernel is likely more expensive than the manual repeats.
Maybe I can instead remove the slice and see what breaks?
Responding here but saw the other PR. I think that's the right way but let's wait on @Cyrilvallez to come back (next week).
There was a problem hiding this comment.
Yes, we can use torch.compiler.is_exporting() to check that.
There was a problem hiding this comment.
Then let's add a condition on the gqa function (i.e. here
) to allow GQA in any case when we detect exporting. Is it limited to a specific version of export? Might need double check the torch version.Side note and semi relevant: the masks are refactored a bit in #41852 so we won't need a workaround for export in the future.
There was a problem hiding this comment.
Great! I can create a separate update for use_gqa_in_sdpa. The most important fix has been removing the slice on masks.
|
BTW, we will probably add a non-vmap path to the sdpa mask creation due to vmap being much slower, and the impact being noticeable for small models (#41639). This PR is still welcome though, as its focus is on the sdpa attention, not sdpa mask, but I thought I'd mention it just in case! |
Concerning this, the answer is no: #41683! |
|
I've tried this PR and found it didn't solve the performance regression in #41639 . Will it be fixed in the next PR? |
I believe this PR is unrelated to the issue you linked |
|
I think we can merge #41900, then I will create a separate PR to guard torch.export on dynamic shapes. |
What does this PR do?
This PR addd shape compatibility check for attention mask to ensure torch.export can reason about the logic without failing when exporting with dynamic shapes.
It additionally simplified the sdpa forward function for export only usage.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@Cyrilvallez @jackzhxng @guangy10