Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,7 +1190,7 @@ def _process_derived_symbol(
if isinstance(symbol, sympy.Symbol):
return str(symbol), None

if not isinstance(symbol, sympy.Add):
if not isinstance(symbol, (sympy.Add, sympy.Mul)):
return str(symbol), None

tir_expr = None
Expand All @@ -1206,13 +1206,24 @@ def _process_derived_symbol(

if term is None:
return str(symbol), None
tir_expr = term if tir_expr is None else tir_expr + term

if tir_expr is None:
tir_expr = term
elif isinstance(symbol, sympy.Mul):
tir_expr = tir_expr * term
elif isinstance(symbol, sympy.Add):
tir_expr = tir_expr + term

if isinstance(tir_expr, tvm.tir.Add):
for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, tir_expr.a)]:
if isinstance(const, tvm.tir.IntImm) and isinstance(var, tvm.tir.Var):
return f"{var.name}___{const.value}", tir_expr

if isinstance(tir_expr, tvm.tir.Mul):
for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, tir_expr.a)]:
if isinstance(const, tvm.tir.IntImm) and isinstance(var, tvm.tir.Var):
return f"{var.name}_{const.value}", tir_expr

return str(symbol), tir_expr

def create_input_vars(
Expand Down Expand Up @@ -1255,12 +1266,20 @@ def create_input_vars(
torch_shape = exported_program.state_dict[spec.target].shape
torch_dtype = exported_program.state_dict[spec.target].dtype

relax_shape = [
torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64"))
if isinstance(s, torch.SymInt)
else s
for s in torch_shape
]
relax_shape = []
for s in torch_shape:
if isinstance(s, torch.SymInt):
sympy_node = s.node.expr if hasattr(s.node, "expr") else s.node
symbol_name, _ = self._process_derived_symbol(
sympy_node, torch_symbol_to_relax_var
)

size_var = torch_symbol_to_relax_var.setdefault(
symbol_name, tvm.tir.SizeVar(symbol_name, "int64")
)
relax_shape.append(size_var)
else:
relax_shape.append(s)
dtype = self._convert_data_type(torch_dtype)

relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype))
Expand Down
70 changes: 69 additions & 1 deletion tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -7000,7 +7000,7 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected)


def test_dynamic_shape_with_derived_range_constraints():
def test_dynamic_shape_with_addition_constraints():
class ConcatModel(torch.nn.Module):
def forward(self, x, y):
return torch.cat([x, y], dim=0)
Expand Down Expand Up @@ -7034,5 +7034,73 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)


def test_dynamic_shape_with_subtraction_constraints():
class ConcatModel(torch.nn.Module):
def forward(self, x, y):
return torch.cat([x, y], dim=0)

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor(("s1___1", 4), dtype="float32"), y: R.Tensor(("s1", 4), dtype="float32")
) -> R.Tuple(R.Tensor(("s1___1 + s1", 4), dtype="float32")):
s1___1 = T.int64(is_size_var=True)
s1 = T.int64(is_size_var=True)
R.func_attr(
{
"tir_var_lower_bound": {"s1": 0, "s1___1": 1},
"tir_var_upper_bound": {"s1": 63, "s1___1": 64},
}
)
with R.dataflow():
lv: R.Tensor((s1___1 + s1, 4), dtype="float32") = R.concat((x, y), axis=0)
gv: R.Tuple(R.Tensor((s1___1 + s1, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv

batch = torch.export.Dim("batch", min=1, max=64)
example_args = (torch.randn(8, 4), torch.randn(7, 4))
dynamic_shapes = {"x": {0: batch}, "y": {0: batch - 1}}
exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes)

mod = from_exported_program(exported_program, run_ep_decomposition=True)
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)


def test_dynamic_shape_with_multiplication_constraints():
class ConcatModel(torch.nn.Module):
def forward(self, x, y):
return torch.cat([x, y], dim=0)

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0_2", 4), dtype="float32")
) -> R.Tuple(R.Tensor(("s0 + s0_2", 4), dtype="float32")):
s0 = T.int64(is_size_var=True)
s0_2 = T.int64(is_size_var=True)
R.func_attr(
{
"tir_var_lower_bound": {"s0": 1, "s0_2": 2},
"tir_var_upper_bound": {"s0": 64, "s0_2": 128},
}
)
with R.dataflow():
lv: R.Tensor((s0 + s0_2, 4), dtype="float32") = R.concat((x, y), axis=0)
gv: R.Tuple(R.Tensor((s0 + s0_2, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv

batch = torch.export.Dim("batch", min=1, max=64)
example_args = (torch.randn(8, 4), torch.randn(16, 4))
dynamic_shapes = {"x": {0: batch}, "y": {0: batch * 2}}
exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes)

mod = from_exported_program(exported_program, run_ep_decomposition=True)
tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)


if __name__ == "__main__":
tvm.testing.main()
Loading