From 77685df498cbdbf0c342c0d88ca31eeef8ec27ea Mon Sep 17 00:00:00 2001 From: "Guan Ming(Wesley) Chiu" Date: Mon, 10 Nov 2025 21:51:56 +0800 Subject: [PATCH 1/5] Support basic range constraints --- .../torch/exported_program_translator.py | 51 +++++++++++++++---- .../test_frontend_from_exported_program.py | 24 +++++++++ 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 4f3132b8d8f2..3cf36552d469 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1097,11 +1097,28 @@ def create_convert_map( def create_input_vars( self, exported_program: torch.export.ExportedProgram - ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]: + ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, Tuple[int, int]]]: """Create relax input vars.""" parameters_buffers_constants = OrderedDict() user_inputs = OrderedDict() torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {} + range_constraints = {} + + if hasattr(exported_program, "range_constraints"): + for symbol, value_range in exported_program.range_constraints.items(): + symbol_name = str(symbol) + if hasattr(value_range, "lower") and hasattr(value_range, "upper"): + try: + lower = int(value_range.lower) + except (OverflowError, AttributeError, TypeError): + continue + + try: + upper = int(value_range.upper) + except (OverflowError, AttributeError, TypeError): + continue + + range_constraints[symbol_name] = (lower, upper) for spec in exported_program.graph_signature.input_specs: name_hint = spec.arg.name @@ -1119,13 +1136,19 @@ def create_input_vars( torch_shape = exported_program.state_dict[spec.target].shape torch_dtype = exported_program.state_dict[spec.target].dtype - # TODO(mshr-h): Support range constraints - relax_shape = [ - torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64")) - if isinstance(s, torch.SymInt) - else s - for s in torch_shape - ] + # Create TIR variables for symbolic dimensions + relax_shape = [] + for s in torch_shape: + if isinstance(s, torch.SymInt): + symbol_name = str(s) + if symbol_name not in torch_symbol_to_relax_var: + torch_symbol_to_relax_var[symbol_name] = tvm.tir.SizeVar( + symbol_name, "int64" + ) + relax_shape.append(torch_symbol_to_relax_var[symbol_name]) + else: + relax_shape.append(s) + dtype = self._convert_data_type(torch_dtype) relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype)) @@ -1134,7 +1157,7 @@ def create_input_vars( else: parameters_buffers_constants[name_hint] = relax_var - return parameters_buffers_constants, user_inputs + return parameters_buffers_constants, user_inputs, range_constraints def from_exported_program( self, @@ -1147,7 +1170,11 @@ def from_exported_program( from torch import fx # type: ignore # Create input variables. - parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program) + ( + parameter_buffer_constant_vars, + user_input_vars, + range_constraints, + ) = self.create_input_vars(exported_program) inputs_vars = user_input_vars.copy() inputs_vars.update(parameter_buffer_constant_vars) @@ -1155,6 +1182,10 @@ def from_exported_program( self.block_builder = relax.BlockBuilder() func_name = "main" func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else None + if range_constraints: + if func_attrs is None: + func_attrs = {} + func_attrs["shape_var_constraints"] = range_constraints nodes: List[fx.Node] = exported_program.graph.nodes diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 44248c1c59f4..02e2602c4976 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6527,5 +6527,29 @@ def forward(self, x): np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) +def test_dynamic_shape_with_range_constraints(): + class DynamicModel(torch.nn.Module): + def forward(self, x1, x2): + return torch.ops.aten.add.Tensor(x1, x2) + + example_args = (torch.randn(8, 4), torch.randn(8, 4)) + batch = torch.export.Dim("batch", min=1, max=64) + dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} + exported_program = export(DynamicModel(), args=example_args, dynamic_shapes=dynamic_shapes) + + mod = from_exported_program(exported_program) + + main_func = mod["main"] + assert hasattr(main_func, "attrs"), "Function should have attributes" + + if "shape_var_constraints" in main_func.attrs: + constraints = main_func.attrs["shape_var_constraints"] + assert len(constraints) > 0, "Should have at least one constraint" + + for symbol_name, (min_val, max_val) in constraints.items(): + assert min_val == 1, f"Expected min=1 for {symbol_name}, got {min_val}" + assert max_val == 64, f"Expected max=64 for {symbol_name}, got {max_val}" + + if __name__ == "__main__": tvm.testing.main() From 1ad828fe909b4dd4f1b7b5a2f1d6022aadf732bd Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Tue, 11 Nov 2025 23:39:21 +0800 Subject: [PATCH 2/5] Apply gemini-code-assist suggestions Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../relax/frontend/torch/exported_program_translator.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 3cf36552d469..aad0903fb351 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1110,16 +1110,11 @@ def create_input_vars( if hasattr(value_range, "lower") and hasattr(value_range, "upper"): try: lower = int(value_range.lower) - except (OverflowError, AttributeError, TypeError): - continue - - try: upper = int(value_range.upper) + range_constraints[symbol_name] = (lower, upper) except (OverflowError, AttributeError, TypeError): continue - range_constraints[symbol_name] = (lower, upper) - for spec in exported_program.graph_signature.input_specs: name_hint = spec.arg.name if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: From 34476a47451b9ead82b7f6cf2bfab2c2b8dca1d3 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" Date: Wed, 12 Nov 2025 00:32:06 +0800 Subject: [PATCH 3/5] Apply reviewer comments --- .../torch/exported_program_translator.py | 24 +++++++-------- .../test_frontend_from_exported_program.py | 29 ++++++++++--------- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index aad0903fb351..2f02f8dfd0dc 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1131,19 +1131,12 @@ def create_input_vars( torch_shape = exported_program.state_dict[spec.target].shape torch_dtype = exported_program.state_dict[spec.target].dtype - # Create TIR variables for symbolic dimensions - relax_shape = [] - for s in torch_shape: - if isinstance(s, torch.SymInt): - symbol_name = str(s) - if symbol_name not in torch_symbol_to_relax_var: - torch_symbol_to_relax_var[symbol_name] = tvm.tir.SizeVar( - symbol_name, "int64" - ) - relax_shape.append(torch_symbol_to_relax_var[symbol_name]) - else: - relax_shape.append(s) - + relax_shape = [ + torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64")) + if isinstance(s, torch.SymInt) + else s + for s in torch_shape + ] dtype = self._convert_data_type(torch_dtype) relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype)) @@ -1180,7 +1173,10 @@ def from_exported_program( if range_constraints: if func_attrs is None: func_attrs = {} - func_attrs["shape_var_constraints"] = range_constraints + tir_var_upper_bound = { + var_name: upper for var_name, (_, upper) in range_constraints.items() + } + func_attrs["tir_var_upper_bound"] = tir_var_upper_bound nodes: List[fx.Node] = exported_program.graph.nodes diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 02e2602c4976..3a2252476321 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6532,24 +6532,27 @@ class DynamicModel(torch.nn.Module): def forward(self, x1, x2): return torch.ops.aten.add.Tensor(x1, x2) + @I.ir_module + class Expected: + @R.function + def main( + x1: R.Tensor(("s24", 4), dtype="float32"), x2: R.Tensor(("s24", 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s24", 4), dtype="float32")): + s24 = T.int64(is_size_var=True) + R.func_attr({"tir_var_upper_bound": {"s24": 64}}) + with R.dataflow(): + lv: R.Tensor((s24, 4), dtype="float32") = R.add(x1, x2) + gv: R.Tuple(R.Tensor((s24, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + example_args = (torch.randn(8, 4), torch.randn(8, 4)) batch = torch.export.Dim("batch", min=1, max=64) dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} exported_program = export(DynamicModel(), args=example_args, dynamic_shapes=dynamic_shapes) - mod = from_exported_program(exported_program) - - main_func = mod["main"] - assert hasattr(main_func, "attrs"), "Function should have attributes" - - if "shape_var_constraints" in main_func.attrs: - constraints = main_func.attrs["shape_var_constraints"] - assert len(constraints) > 0, "Should have at least one constraint" - - for symbol_name, (min_val, max_val) in constraints.items(): - assert min_val == 1, f"Expected min=1 for {symbol_name}, got {min_val}" - assert max_val == 64, f"Expected max=64 for {symbol_name}, got {max_val}" - + mod = from_exported_program(exported_program, run_ep_decomposition=True) + tvm.ir.assert_structural_equal(mod, Expected) if __name__ == "__main__": tvm.testing.main() From 56e779224a18ba7607d543c326a1f1a745f7c66b Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" Date: Wed, 12 Nov 2025 00:51:58 +0800 Subject: [PATCH 4/5] Fix lint error --- tests/python/relax/test_frontend_from_exported_program.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 3a2252476321..0d6d6c64faac 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6554,5 +6554,6 @@ def main( mod = from_exported_program(exported_program, run_ep_decomposition=True) tvm.ir.assert_structural_equal(mod, Expected) + if __name__ == "__main__": tvm.testing.main() From 2921bf0ef18d9d7bd6015a9a6b8d3d22af91bb45 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" Date: Wed, 12 Nov 2025 10:54:30 +0800 Subject: [PATCH 5/5] Refactor frontend test to use consistent size variable --- .../relax/test_frontend_from_exported_program.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 0d6d6c64faac..b6df02c132fd 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6536,13 +6536,13 @@ def forward(self, x1, x2): class Expected: @R.function def main( - x1: R.Tensor(("s24", 4), dtype="float32"), x2: R.Tensor(("s24", 4), dtype="float32") - ) -> R.Tuple(R.Tensor(("s24", 4), dtype="float32")): - s24 = T.int64(is_size_var=True) - R.func_attr({"tir_var_upper_bound": {"s24": 64}}) + x1: R.Tensor(("s0", 4), dtype="float32"), x2: R.Tensor(("s0", 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")): + s0 = T.int64(is_size_var=True) + R.func_attr({"tir_var_upper_bound": {"s0": 64}}) with R.dataflow(): - lv: R.Tensor((s24, 4), dtype="float32") = R.add(x1, x2) - gv: R.Tuple(R.Tensor((s24, 4), dtype="float32")) = (lv,) + lv: R.Tensor((s0, 4), dtype="float32") = R.add(x1, x2) + gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,) R.output(gv) return gv