From 7d74f1a1a744d34503fafdfea05c165953987c9c Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 5 Sep 2024 23:17:48 +0900 Subject: [PATCH 1/3] add a test for `torch.ops.aten.sym_size.int` --- tests/python/relax/test_frontend_from_fx.py | 25 +++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 35a9bc71bf98..ca623c3c5762 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3929,5 +3929,30 @@ def main( ) +def test_sym_size_int(): + class SymSizeInt1(Module): + def forward(self, x): + return torch.ops.aten.sym_size.int(x, 1) + + class SymSizeInt2(Module): + def forward(self, x): + return torch.ops.aten.sym_size.int(x, -2) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 4), dtype="float32"), + ) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + lv: R.Tensor((), dtype="int32") = R.const(3, "int32") + gv: R.Tensor((), dtype="int32") = lv + R.output(gv) + return gv + + verify_model(SymSizeInt1(), [([1, 3, 4], "float32")], {}, Expected1) + verify_model(SymSizeInt2(), [([1, 3, 4], "float32")], {}, Expected1) + + if __name__ == "__main__": tvm.testing.main() From 3c775fbd0f9f2566f92baebe7e12dc8fa31ea624 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 5 Sep 2024 23:18:14 +0900 Subject: [PATCH 2/3] add support for `torch.ops.aten.sym_size.int` --- python/tvm/relax/frontend/torch/fx_translator.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 21a0b2d5642a..162c81dd7e0b 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1465,6 +1465,12 @@ def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var: ########## Others ########## + def _sym_size_int(self, node: fx.node.Node) -> relax.Expr: + x = self.env[node.args[0]] + shape = self.shape_of(x) + idx = node.args[1] + return self.block_builder.emit(relax.const(shape[idx].value, "int32")) + def _size(self, node: fx.node.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -1681,6 +1687,7 @@ def create_convert_map(self): "hardsigmoid": self._hardsigmoid, "hardswish": self._hardswish, "interpolate": self._interpolate, + "sym_size.int": self._sym_size_int, "size": self._size, "getattr": self._getattr, "getitem": self._getitem, From cb88841b6e02eecfca3b51b5c5278fdafccc0ea7 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 5 Sep 2024 23:36:27 +0900 Subject: [PATCH 3/3] cleanup --- tests/python/relax/test_frontend_from_fx.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index ca623c3c5762..78fc7abdf748 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3931,12 +3931,12 @@ def main( def test_sym_size_int(): class SymSizeInt1(Module): - def forward(self, x): - return torch.ops.aten.sym_size.int(x, 1) + def __init__(self, dim): + super().__init__() + self.dim = dim - class SymSizeInt2(Module): def forward(self, x): - return torch.ops.aten.sym_size.int(x, -2) + return torch.ops.aten.sym_size.int(x, self.dim) @I.ir_module class Expected1: @@ -3950,8 +3950,8 @@ def main( R.output(gv) return gv - verify_model(SymSizeInt1(), [([1, 3, 4], "float32")], {}, Expected1) - verify_model(SymSizeInt2(), [([1, 3, 4], "float32")], {}, Expected1) + verify_model(SymSizeInt1(dim=1), [([1, 3, 4], "float32")], {}, Expected1) + verify_model(SymSizeInt1(dim=-2), [([1, 3, 4], "float32")], {}, Expected1) if __name__ == "__main__":