-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relax][PyTorch] Support advanced range constraints (addition) #18452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relax][PyTorch] Support advanced range constraints (addition) #18452
Conversation
9e283e1 to
aed2500
Compare
|
cc @mshr-h |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for advanced range constraints with addition expressions from PyTorch's dynamic shapes. The changes involve parsing sympy expressions into TIR PrimExprs and storing them as a new function attribute. A new test case is added to verify this functionality. My review includes a high-severity fix for a potential TypeError in the sympy expression parser and a medium-severity comment on a potential inconsistency in the new test case's expected output.
| class Expected: | ||
| @R.function | ||
| def main( | ||
| x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0___1", 4), dtype="float32") | ||
| ) -> R.Tuple(R.Tensor(("s0 + s0___1", 4), dtype="float32")): | ||
| s0 = T.int64(is_size_var=True) | ||
| s0___1 = T.int64(is_size_var=True) | ||
| R.func_attr( | ||
| { | ||
| "tir_var_expr": {"s0 + 1": 1 + s0}, | ||
| "tir_var_lower_bound": {"s0": 1, "s0 + 1": 2}, | ||
| "tir_var_upper_bound": {"s0": 64, "s0 + 1": 65}, | ||
| } | ||
| ) | ||
| with R.dataflow(): | ||
| lv: R.Tensor((s0 + s0___1, 4), dtype="float32") = R.concat((x, y), axis=0) | ||
| gv: R.Tuple(R.Tensor((s0 + s0___1, 4), dtype="float32")) = (lv,) | ||
| R.output(gv) | ||
| return gv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Expected IRModule seems to have some inconsistencies with what the translator is expected to generate. Given dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}}, the translator should generate a SizeVar named "s0 + 1" for the dynamic dimension of y. Therefore, the y parameter in main should probably have the shape R.Tensor(("s0 + 1", 4), ...) instead of R.Tensor(("s0___1", 4), ...). Consequently, the output tensor shape would be R.Tensor(("s0 + s0 + 1", 4), ...) and the free variable s0___1 would not be needed. Could you please double-check the Expected module definition?
3b3e36a to
a061f94
Compare
a061f94 to
a232be6
Compare
Related Issue
Why
How
SymPyaddition expressions from PyTorch's range_constraints