From a3bdbd9f69a2df6e24bf6094b742d2d0779a47c1 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 20 Apr 2025 00:10:20 +0800 Subject: [PATCH 1/7] Update base_fx_graph_translator.py --- .../frontend/torch/base_fx_graph_translator.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index ae4c918900ec..61453494919a 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1416,6 +1416,19 @@ def _empty_like(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit(relax.op.zeros_like(x)) + def _eye(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + if len(args) == 1: + n = args[0] + m = n + elif len(args) == 2: + n = args[0] + m = args[1] + else: + raise ValueError("Invalid number of arguments for eye") + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + return self.block_builder.emit(relax.op.eye(n, m, dtype=dtype)) + def _fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] From 3e194c886c208cf5e39270c9162e24b8e7dc50a7 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 20 Apr 2025 00:11:35 +0800 Subject: [PATCH 2/7] Update exported_program_translator.py --- python/tvm/relax/frontend/torch/exported_program_translator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 932607287571..6bd5b893b811 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -453,6 +453,7 @@ def create_convert_map( "clone.default": lambda node: self.env[node.args[0]], "empty.memory_format": self._empty, "empty_like.default": self._empty_like, + "eye.m": self._eye, "fill.Scalar": self._fill, "full.default": self._full, "full_like.default": self._full_like, From 7a43d6078efa19d1816dc0b8507442f4e8492464 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 20 Apr 2025 00:12:32 +0800 Subject: [PATCH 3/7] Update test_frontend_from_exported_program.py --- .../test_frontend_from_exported_program.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 80c0bd5fb4f5..3da550c7c08a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4377,5 +4377,26 @@ def main( verify_model(Narrow(), example_args, {}, Expected) +def test_eye(): + class Eye(Module): + def forward(self, input): + return torch.eye(3, 5, dtype=torch.float32) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((3, 5), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 5), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5, dtype="float32") + gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(3, 5, dtype=torch.float32),) + verify_model(Eye(), example_args, {}, Expected) + + if __name__ == "__main__": tvm.testing.main() From 16a68b88feb60b244ccd9a925009fe600b02935d Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 21 Apr 2025 17:25:36 +0800 Subject: [PATCH 4/7] Update exported_program_translator.py --- python/tvm/relax/frontend/torch/exported_program_translator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 6bd5b893b811..af1393329e1f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -453,6 +453,7 @@ def create_convert_map( "clone.default": lambda node: self.env[node.args[0]], "empty.memory_format": self._empty, "empty_like.default": self._empty_like, + "eye.default": self._eye, "eye.m": self._eye, "fill.Scalar": self._fill, "full.default": self._full, From f4ed22eea25d2e0087c6303ea8bb2f4b062a4235 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 21 Apr 2025 17:26:21 +0800 Subject: [PATCH 5/7] Update base_fx_graph_translator.py --- .../relax/frontend/torch/base_fx_graph_translator.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 61453494919a..733a5d6b1a87 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1418,14 +1418,8 @@ def _empty_like(self, node: fx.Node) -> relax.Var: def _eye(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - if len(args) == 1: - n = args[0] - m = n - elif len(args) == 2: - n = args[0] - m = args[1] - else: - raise ValueError("Invalid number of arguments for eye") + n = args[0] + m = args[1] if len(args) > 1 else n dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) return self.block_builder.emit(relax.op.eye(n, m, dtype=dtype)) From ec2dd85f55a00d92c1418c061aee8bfceca0a851 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 21 Apr 2025 17:27:32 +0800 Subject: [PATCH 6/7] add a testcase where only n is given --- .../test_frontend_from_exported_program.py | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 3da550c7c08a..346e11e31a3b 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4378,12 +4378,13 @@ def main( def test_eye(): - class Eye(Module): + class Eye1(Module): def forward(self, input): return torch.eye(3, 5, dtype=torch.float32) + @tvm.script.ir_module - class Expected: + class Expected1: @R.function def main( input: R.Tensor((3, 5), dtype="float32") @@ -4394,8 +4395,27 @@ def main( R.output(gv) return gv - example_args = (torch.randn(3, 5, dtype=torch.float32),) - verify_model(Eye(), example_args, {}, Expected) + class Eye2(Module): + def forward(self, input): + return torch.eye(5, dtype=torch.float32) + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + input: R.Tensor((5,), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5, 5), dtype="float32") = R.eye(5, dtype="float32") + gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args1 = (torch.randn(3, 5, dtype=torch.float32),) + verify_model(Eye1(), example_args1, {}, Expected1) + + example_args2 = (torch.randn(5, dtype=torch.float32),) + verify_model(Eye2(), example_args2, {}, Expected2) if __name__ == "__main__": From d4679f19e477260da95b73974e783e305efc0b31 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 21 Apr 2025 17:48:42 +0800 Subject: [PATCH 7/7] fix lint --- tests/python/relax/test_frontend_from_exported_program.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 346e11e31a3b..ce68089048a1 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4382,7 +4382,6 @@ class Eye1(Module): def forward(self, input): return torch.eye(3, 5, dtype=torch.float32) - @tvm.script.ir_module class Expected1: @R.function