diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 21a0b2d5642a..162c81dd7e0b 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1465,6 +1465,12 @@ def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var: ########## Others ########## + def _sym_size_int(self, node: fx.node.Node) -> relax.Expr: + x = self.env[node.args[0]] + shape = self.shape_of(x) + idx = node.args[1] + return self.block_builder.emit(relax.const(shape[idx].value, "int32")) + def _size(self, node: fx.node.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -1681,6 +1687,7 @@ def create_convert_map(self): "hardsigmoid": self._hardsigmoid, "hardswish": self._hardswish, "interpolate": self._interpolate, + "sym_size.int": self._sym_size_int, "size": self._size, "getattr": self._getattr, "getitem": self._getitem, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 35a9bc71bf98..78fc7abdf748 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3929,5 +3929,30 @@ def main( ) +def test_sym_size_int(): + class SymSizeInt1(Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.ops.aten.sym_size.int(x, self.dim) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 4), dtype="float32"), + ) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + lv: R.Tensor((), dtype="int32") = R.const(3, "int32") + gv: R.Tensor((), dtype="int32") = lv + R.output(gv) + return gv + + verify_model(SymSizeInt1(dim=1), [([1, 3, 4], "float32")], {}, Expected1) + verify_model(SymSizeInt1(dim=-2), [([1, 3, 4], "float32")], {}, Expected1) + + if __name__ == "__main__": tvm.testing.main()