fix(utils): Make torch_compilable_check compatible with torch.export strict mode#44266
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: deformable_detr |
ArthurZucker
left a comment
There was a problem hiding this comment.
thanks a lot for fixing!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Thank you so much for your time @ArthurZucker! I was wondering if I may ask you to check in on this PR as it has been pending review for some time now; at your convenience and sorry if this is a bother :) |
|
Hi ! the test you mention in the PR uses strict mode which we don't wanna support in general because it's too constraining, non-strict mode is generally enough for model export (even onnx export doesn't require strict mode). |
|
Good day @IlyasMoutawwakil! |
|
unfortunately the change actually does break a couple models 😅, we have a more general test for torch.export (it's marked as slow so it doesn't appear on fast ci) and some models started failing after this change because we also use torch._check to give dynamo compiler hints, basically telling dynamo that it should expect a condition to be true at compile time instead of evaluating it. I document that in the function's docstring but it might be easy to miss there. the failing tests can be seen in here for example https://github.com/huggingface/transformers/actions/runs/22426245229/job/64935308903 |
|
to sum it up:
I would argue that this requires fixing in pytorch because the torch._check api is what torch.export suggests using when you get data dependency failures, so it's weird for it to not be compatible with strict mode. see this traceback from a failing export due to datat dep: |
|
That makes sense to me; thanks for the explanation of the revert and the context on the |
What does this PR do?
The following issue was identified and fixed in this PR:
→ Reasoning: The impact of this fix goes beyond
Mask2FormerandDeformableDetrand should fix any model that usestorch_compilable_check. Most users run models in eager where the fn works just fine; onlytorch._check_withonly breaks during export tracing.Mask2Formeralready has test coverage fortorch.export.export()functionality and it currently fails (check CI for failing test_modeling_mask2former.py::Mask2FormerModelIntegrationTest::test_export), this should fix that.→ I decided to pick one other pattern to prove that this fix generalizes to other models without regressions using
torch_compilable_checkand add test coverage for it. PickedDeformableDetrbecause it uses the same deformable attn pattern as inMask2Former(same torch_compilable_check() call in DeformableDetrMultiscaleDeformableAttention as in Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention) and didn't have coverage for an export test yet. Again for the test as well, I followed the canonical pattern from Mask2FormerModelIntegrationTest::test_export.→ For more details on reproducing the bug and the output screenshots, please visit the linked issue!
Fixes #44265.
Added DeformableDetr test before the fix (feel free to cross-check; this error is reproducible):
Added DeformableDetr test after the fix (feel free to cross-check):
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.