[AMD CI] Fix torch.compile/export failures on AMD CI due to untraceable set.__contains__ #45282
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
tarekziade
left a comment
There was a problem hiding this comment.
Thanks a lot, do you think there's a way to isolate the behavior in a small test?
b21cdba to
4a61d02
Compare
|
I added a test for that. Let me know what you think :) |
| # Regression test: on AMD CI (PyTorch 2.8.0+rocm), `set.__contains__` is not | ||
| # traceable by TorchDynamo. `_register_model_output_pytree_node` must return | ||
| # early when called inside a compiled context, before touching the set. | ||
| from dataclasses import dataclass |
There was a problem hiding this comment.
any reason to do late imports for those?
There was a problem hiding this comment.
No good reason actually. I'll move them to the top.
There was a problem hiding this comment.
i will keep _register_model_output_pytree_node as the only late import, it's a private internal symbol so keeping it local is reasonable i think
tarekziade
left a comment
There was a problem hiding this comment.
thanks for the test, LGTM just the nit
|
Who should we tag here for a force merge? |
…le set.__contains__ (huggingface#45282) * fix torch.compile/export failures on amd * test * move imports
_register_model_output_pytree_nodewas calling set.contains during TorchDynamo tracing, which is unsupported in PyTorch 2.8.0 (ROCm). Added an early return whentorch.compiler.is_compiling()is True, since pytree nodes are already registered from the preceding eager run.This should fix a bunch of new failures in AMD CI.