Skip to content

Conversation

@guan404ming
Copy link
Member

@guan404ming guan404ming commented Nov 14, 2025

Related Issue

Why

  • Add support for addition expressions (e.g., s0 + 1) in PyTorch dynamic shape constraints

How

  • Parse SymPy addition expressions from PyTorch's range_constraints

@guan404ming guan404ming changed the title [Relax][PyTorch] Support advanced range constraints (addition)simpl [Relax][PyTorch] Support advanced range constraints (addition) Nov 14, 2025
@guan404ming guan404ming force-pushed the support-advanced-range-addition branch 4 times, most recently from 9e283e1 to aed2500 Compare November 15, 2025 18:13
@guan404ming guan404ming marked this pull request as ready for review November 15, 2025 22:31
@guan404ming
Copy link
Member Author

cc @mshr-h

@mshr-h mshr-h self-requested a review November 16, 2025 03:17
@mshr-h
Copy link
Contributor

mshr-h commented Nov 16, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for advanced range constraints with addition expressions from PyTorch's dynamic shapes. The changes involve parsing sympy expressions into TIR PrimExprs and storing them as a new function attribute. A new test case is added to verify this functionality. My review includes a high-severity fix for a potential TypeError in the sympy expression parser and a medium-severity comment on a potential inconsistency in the new test case's expected output.

Comment on lines 6929 to 6985
class Expected:
@R.function
def main(
x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0___1", 4), dtype="float32")
) -> R.Tuple(R.Tensor(("s0 + s0___1", 4), dtype="float32")):
s0 = T.int64(is_size_var=True)
s0___1 = T.int64(is_size_var=True)
R.func_attr(
{
"tir_var_expr": {"s0 + 1": 1 + s0},
"tir_var_lower_bound": {"s0": 1, "s0 + 1": 2},
"tir_var_upper_bound": {"s0": 64, "s0 + 1": 65},
}
)
with R.dataflow():
lv: R.Tensor((s0 + s0___1, 4), dtype="float32") = R.concat((x, y), axis=0)
gv: R.Tuple(R.Tensor((s0 + s0___1, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Expected IRModule seems to have some inconsistencies with what the translator is expected to generate. Given dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}}, the translator should generate a SizeVar named "s0 + 1" for the dynamic dimension of y. Therefore, the y parameter in main should probably have the shape R.Tensor(("s0 + 1", 4), ...) instead of R.Tensor(("s0___1", 4), ...). Consequently, the output tensor shape would be R.Tensor(("s0 + s0 + 1", 4), ...) and the free variable s0___1 would not be needed. Could you please double-check the Expected module definition?

@guan404ming guan404ming force-pushed the support-advanced-range-addition branch 2 times, most recently from 3b3e36a to a061f94 Compare November 16, 2025 15:36
@guan404ming guan404ming force-pushed the support-advanced-range-addition branch from a061f94 to a232be6 Compare November 16, 2025 16:40
@mshr-h mshr-h merged commit ea89f21 into apache:main Nov 17, 2025
13 checks passed
@guan404ming guan404ming deleted the support-advanced-range-addition branch November 17, 2025 09:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants