From 44f8eeff46de013e34ccf3c0063faeb3096a67e8 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 8 Apr 2025 20:05:11 +0800 Subject: [PATCH 1/4] Update fx_translator.py --- python/tvm/relax/frontend/torch/fx_translator.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a151a57ae659..2f2126cc43ac 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -422,6 +422,13 @@ def _flatten_module(self, node: fx.Node) -> relax.Var: end_dim = module.end_dim return self._flatten_impl(x, start_dim, end_dim) + def _narrow(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + start = node.args[2] + length = node.args[3] + return self.block_builder.emit(relax.op.strided_slice(x, [dim], [start], [length])) + def _numel(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -755,6 +762,7 @@ def create_convert_map( "where": self._where, # tensor manipulation "argsort": self._argsort, + "broadcast_to": self._broadcast_to, "cat": self._cat, "chunk": self._chunk, "concat": self._cat, @@ -766,6 +774,7 @@ def create_convert_map( "flatten": self._flatten, "flip": self._flip, "gather": self._gather, + "narrow": self._narrow, "numel": self._numel, "permute": self._permute, "repeat": self._repeat, From eb55adf11305ba370bcb122c5c3c2c3f628b9804 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 8 Apr 2025 20:07:18 +0800 Subject: [PATCH 2/4] Update base_fx_graph_translator.py --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index affbd81e1c28..7660c1f5756c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -966,6 +966,12 @@ def _argsort(self, node: fx.Node) -> relax.Var: descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False) return self.block_builder.emit(relax.op.argsort(x, dim, descending)) + def _broadcast_to(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + shape = args[1] if len(args) > 1 else args[0] + return self.block_builder.emit(relax.op.broadcast_to(x, shape)) + def _cat(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) From 10bd1dcd2f283bbf8ae88634ade9a37ea8bc872e Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 8 Apr 2025 20:08:25 +0800 Subject: [PATCH 3/4] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 40 +++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 2c5560b577c4..fd5a157e40a3 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4430,5 +4430,45 @@ def main( verify_model(Topk(), [([5, 3], "float32")], {}, Expected) +def test_broadcast_to(): + class BroadcastTo(Module): + def forward(self, x): + return torch.broadcast_to(x, (5, 3)) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 1), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(inp_0, (5, 3)) + gv: R.Tensor((5, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(BroadcastTo(), [([5, 1], "float32")], {}, Expected) + + +def test_narrow(): + class Narrow(Module): + def forward(self, x): + return torch.narrow(x, 1, 0, 2) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 2), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice(inp_0, axes=[1], begin=[0], end=[2]) + gv: R.Tensor((5, 2), dtype="float32") = lv + R.output(gv) + + return gv + verify_model(Narrow(), [([5, 3], "float32")], {}, Expected) + + if __name__ == "__main__": tvm.testing.main() From 5e6cf962bd860e31a540a8b1ced0774cdcd219eb Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Wed, 9 Apr 2025 23:53:58 +0800 Subject: [PATCH 4/4] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index fd5a157e40a3..9505356fcefd 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4439,7 +4439,7 @@ def forward(self, x): class Expected: @R.function def main( - inp_0: R.Tensor((5, 1), dtype="float32"), + inp_0: R.Tensor((5, 1), dtype="float32"), ) -> R.Tensor((5, 3), dtype="float32"): with R.dataflow(): lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(inp_0, (5, 3)) @@ -4462,11 +4462,14 @@ def main( inp_0: R.Tensor((5, 3), dtype="float32"), ) -> R.Tensor((5, 2), dtype="float32"): with R.dataflow(): - lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice(inp_0, axes=[1], begin=[0], end=[2]) + lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice( + inp_0, axes=[1], begin=[0], end=[2] + ) gv: R.Tensor((5, 2), dtype="float32") = lv R.output(gv) return gv + verify_model(Narrow(), [([5, 3], "float32")], {}, Expected)