Enable fx tracing for Mistral#30209
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Some other modeling files are based on |
|
@michaelbenayoun Done! Fix copies added tracing for MoE models also, which was a bit unexpected. Anyway, I just removed a line with dynamic control flow from MoE models, and checked that it was not necessary (even if top-x is an empty tensor) |
michaelbenayoun
left a comment
There was a problem hiding this comment.
LGTM on my side.
Let's see what @ArthurZucker or @amyeroberts have to say about the top_x change.
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for working on this!
If you invoke the case when top_x.shape[0] == 0 e.g. by setting
idx, top_x = torch.where(torch.zeros_like(expert_mask[expert_idx]))in the lines above, do this still work in the tracing and non-tracing case?
|
@amyeroberts Yes, for me it is working fine for me when it's empty tensor for 'top_x' |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for adding and confirming top_x behaviour!
|
So, just confirming: we can merge with the removing of the |
ArthurZucker
left a comment
There was a problem hiding this comment.
Cool! yeah let's remove it if that still produces correct behaviour and supports fx!
|
Merging now, since removal of "top_x" is approved |
* tracing for mistral * typo * fix copies
|
This PR introduces a bug for Qwen2MoE GPTQ models, maybe revert it for the modeling_qwen2_moe.py file? @ArthurZucker |
|
? without a reproducer and a stack trace shows the error? |
The code to reproduce to error is here: The error information is here, and the model successfully generates after I revert the change for modeling_qwen2_moe.py. |
|
This seems to be using GPTQ and quantisation. Can you open a separate issue and ping @younesbelkada and @SunMarc |
What does this PR do?
Fixes #30083. As per title enables tracing for Mistral. Apparently Mistral was already traceable for "sdpa" attention, as it is similar to Llama which is already working. I enabled for "eager" attention also, which failed because mistral uses "sliding window" here
Tests passing: