diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 44e967ec0e42..b1f8c1667c3f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1190,7 +1190,7 @@ def _process_derived_symbol( if isinstance(symbol, sympy.Symbol): return str(symbol), None - if not isinstance(symbol, sympy.Add): + if not isinstance(symbol, (sympy.Add, sympy.Mul)): return str(symbol), None tir_expr = None @@ -1206,13 +1206,24 @@ def _process_derived_symbol( if term is None: return str(symbol), None - tir_expr = term if tir_expr is None else tir_expr + term + + if tir_expr is None: + tir_expr = term + elif isinstance(symbol, sympy.Mul): + tir_expr = tir_expr * term + elif isinstance(symbol, sympy.Add): + tir_expr = tir_expr + term if isinstance(tir_expr, tvm.tir.Add): for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, tir_expr.a)]: if isinstance(const, tvm.tir.IntImm) and isinstance(var, tvm.tir.Var): return f"{var.name}___{const.value}", tir_expr + if isinstance(tir_expr, tvm.tir.Mul): + for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, tir_expr.a)]: + if isinstance(const, tvm.tir.IntImm) and isinstance(var, tvm.tir.Var): + return f"{var.name}_{const.value}", tir_expr + return str(symbol), tir_expr def create_input_vars( @@ -1255,12 +1266,20 @@ def create_input_vars( torch_shape = exported_program.state_dict[spec.target].shape torch_dtype = exported_program.state_dict[spec.target].dtype - relax_shape = [ - torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64")) - if isinstance(s, torch.SymInt) - else s - for s in torch_shape - ] + relax_shape = [] + for s in torch_shape: + if isinstance(s, torch.SymInt): + sympy_node = s.node.expr if hasattr(s.node, "expr") else s.node + symbol_name, _ = self._process_derived_symbol( + sympy_node, torch_symbol_to_relax_var + ) + + size_var = torch_symbol_to_relax_var.setdefault( + symbol_name, tvm.tir.SizeVar(symbol_name, "int64") + ) + relax_shape.append(size_var) + else: + relax_shape.append(s) dtype = self._convert_data_type(torch_dtype) relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype)) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ef2736778f54..7607e8f58175 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7000,7 +7000,7 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) -def test_dynamic_shape_with_derived_range_constraints(): +def test_dynamic_shape_with_addition_constraints(): class ConcatModel(torch.nn.Module): def forward(self, x, y): return torch.cat([x, y], dim=0) @@ -7034,5 +7034,73 @@ def main( tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True) +def test_dynamic_shape_with_subtraction_constraints(): + class ConcatModel(torch.nn.Module): + def forward(self, x, y): + return torch.cat([x, y], dim=0) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(("s1___1", 4), dtype="float32"), y: R.Tensor(("s1", 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s1___1 + s1", 4), dtype="float32")): + s1___1 = T.int64(is_size_var=True) + s1 = T.int64(is_size_var=True) + R.func_attr( + { + "tir_var_lower_bound": {"s1": 0, "s1___1": 1}, + "tir_var_upper_bound": {"s1": 63, "s1___1": 64}, + } + ) + with R.dataflow(): + lv: R.Tensor((s1___1 + s1, 4), dtype="float32") = R.concat((x, y), axis=0) + gv: R.Tuple(R.Tensor((s1___1 + s1, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + batch = torch.export.Dim("batch", min=1, max=64) + example_args = (torch.randn(8, 4), torch.randn(7, 4)) + dynamic_shapes = {"x": {0: batch}, "y": {0: batch - 1}} + exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes) + + mod = from_exported_program(exported_program, run_ep_decomposition=True) + tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True) + + +def test_dynamic_shape_with_multiplication_constraints(): + class ConcatModel(torch.nn.Module): + def forward(self, x, y): + return torch.cat([x, y], dim=0) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0_2", 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s0 + s0_2", 4), dtype="float32")): + s0 = T.int64(is_size_var=True) + s0_2 = T.int64(is_size_var=True) + R.func_attr( + { + "tir_var_lower_bound": {"s0": 1, "s0_2": 2}, + "tir_var_upper_bound": {"s0": 64, "s0_2": 128}, + } + ) + with R.dataflow(): + lv: R.Tensor((s0 + s0_2, 4), dtype="float32") = R.concat((x, y), axis=0) + gv: R.Tuple(R.Tensor((s0 + s0_2, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + batch = torch.export.Dim("batch", min=1, max=64) + example_args = (torch.randn(8, 4), torch.randn(16, 4)) + dynamic_shapes = {"x": {0: batch}, "y": {0: batch * 2}} + exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes) + + mod = from_exported_program(exported_program, run_ep_decomposition=True) + tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True) + + if __name__ == "__main__": tvm.testing.main()