From 018e7058c6770627a8bc3c40a60ea61d82b39744 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Wed, 5 Mar 2025 22:48:25 +0800 Subject: [PATCH 1/6] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 134 ++++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index e9fa7965315a..72c2a2081815 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3903,5 +3903,139 @@ def main(inp_0: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((), dtype="bool") verify_model(IsFloatingPoint(), [([2, 3], "float32")], {}, Expected) +def test_gather(): + class Gather0(Module): + def forward(self, data, indices): + return torch.gather(data, 0, indices) + + class Gather1(Module): + def forward(self, data, indices): + return torch.gather(data, 1, indices) + + class Gather2(Module): + def forward(self, data, indices): + return torch.gather(data, -1, indices) + + class Gather3(Module): + def forward(self, data, indices): + return torch.gather(data, -2, indices) + + @tvm.script.ir_module + class Expected0: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="int32"), + ) -> R.Tensor((2, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=0) + gv: R.Tensor((2, 3), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="int32"), + ) -> R.Tensor((2, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=1) + gv: R.Tensor((2, 3), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="int32"), + ) -> R.Tensor((2, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=-1) + gv: R.Tensor((2, 3), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="int32"), + ) -> R.Tensor((2, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=-2) + gv: R.Tensor((2, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Gather0(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected0) + verify_model(Gather1(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected1) + verify_model(Gather2(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected2) + verify_model(Gather3(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected3) + + +def test_flip(): + class Flip0(Module): + def forward(self, data): + return torch.flip(data, [0]) + + class Flip1(Module): + def forward(self, data): + return torch.flip(data, [1]) + + @tvm.script.ir_module + class Expected0: + @R.function + def main( + inp_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=0) + gv: R.Tensor((2, 2), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=1) + gv: R.Tensor((2, 2), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Flip0(), [([2, 2], "float32")], {}, Expected0) + verify_model(Flip1(), [([2, 2], "float32")], {}, Expected1) + + +def test_take(): + class Take(Module): + def forward(self, data, indices): + return torch.take(data, indices) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), + inp_1: R.Tensor((3,), dtype="int32"), + ) -> R.Tensor((3,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3,), dtype="int32") = R.astype(inp_1, "int32") + lv1: R.Tensor((3,), dtype="float32") = R.take(inp_0, lv, axis=None) + gv: R.Tensor((3,), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Take(), [([5], "float32"), ([3], "int32")], {}, Expected) + + if __name__ == "__main__": tvm.testing.main() From f0508f31c95808f2ddcb565d93e65a7641ffbd52 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Wed, 5 Mar 2025 23:08:38 +0800 Subject: [PATCH 2/6] Update fx_translator.py --- python/tvm/relax/frontend/torch/fx_translator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index af84f71bbf1e..ef98d3c02501 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -733,6 +733,8 @@ def create_convert_map( "cumsum": self._cumsum, "expand": self._expand, "flatten": self._flatten, + "flip": self._flip, + "gather": self._gather, "permute": self._permute, "repeat": self._repeat, "reshape": self._reshape, @@ -741,6 +743,7 @@ def create_convert_map( "split": self._split, "squeeze": self._squeeze, "stack": self._stack, + "take": self._take, "tile": self._tile, "transpose": self._transpose, "unsqueeze": lambda node: self.block_builder.emit( From cc004b8ab0835b0f1f14066c67438cfd3f64b325 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Wed, 5 Mar 2025 23:09:46 +0800 Subject: [PATCH 3/6] Update base_fx_graph_translator.py --- .../torch/base_fx_graph_translator.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4ce899685a7e..7e9c2f1a1292 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -847,6 +847,21 @@ def _expand(self, node: fx.Node) -> relax.Var: broadcast_shape.append(i) return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) + def _flip(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims", None) + if isinstance(dims, (list, tuple)) and len(dims) > 0: + dims = dims[0] + elif not isinstance(dims, int): + raise TypeError(f"flip expects an integer axis, but got {type(dims)}: {dims}") + return self.block_builder.emit(relax.op.flip(x, dims)) + + def _gather(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + index = self.env[node.args[2]] + return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) + def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore @@ -921,6 +936,14 @@ def _stack(self, node: fx.Node) -> relax.Var: s_shape.append(s) return self.block_builder.emit(relax.op.reshape(cat, s_shape)) + def _take(self, node: fx.Node, axis: Optional[int] = None) -> relax.Var: + x = self.env[node.args[0]] + indices = self.env[node.args[1]] + indices = self.block_builder.emit(relax.op.astype(indices, "int32")) + if axis is not None: + raise NotImplementedError("Relax's relax.op.take() does not fully support PyTorch's torch.take().") + return self.block_builder.emit(relax.op.take(x, indices, axis=axis)) + def _tile(self, node: fx.Node) -> relax.Var: import torch # type: ignore From b2ef6ba4168998ce46fd34d4c895e4d508fee6f8 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Wed, 5 Mar 2025 23:27:00 +0800 Subject: [PATCH 4/6] Update base_fx_graph_translator.py --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 7e9c2f1a1292..426bf52f9bda 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -941,7 +941,9 @@ def _take(self, node: fx.Node, axis: Optional[int] = None) -> relax.Var: indices = self.env[node.args[1]] indices = self.block_builder.emit(relax.op.astype(indices, "int32")) if axis is not None: - raise NotImplementedError("Relax's relax.op.take() does not fully support PyTorch's torch.take().") + raise NotImplementedError( + "Relax's relax.op.take() does not fully support PyTorch's torch.take()." + ) return self.block_builder.emit(relax.op.take(x, indices, axis=axis)) def _tile(self, node: fx.Node) -> relax.Var: From 233510dd0eb08c379d040a2020f43ba262ec0291 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 7 Mar 2025 01:13:47 +0800 Subject: [PATCH 5/6] Update base_fx_graph_translator.py --- .../tvm/relax/frontend/torch/base_fx_graph_translator.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 426bf52f9bda..003ceebec6ff 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -936,15 +936,11 @@ def _stack(self, node: fx.Node) -> relax.Var: s_shape.append(s) return self.block_builder.emit(relax.op.reshape(cat, s_shape)) - def _take(self, node: fx.Node, axis: Optional[int] = None) -> relax.Var: + def _take(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] indices = self.env[node.args[1]] indices = self.block_builder.emit(relax.op.astype(indices, "int32")) - if axis is not None: - raise NotImplementedError( - "Relax's relax.op.take() does not fully support PyTorch's torch.take()." - ) - return self.block_builder.emit(relax.op.take(x, indices, axis=axis)) + return self.block_builder.emit(relax.op.take(x, indices)) def _tile(self, node: fx.Node) -> relax.Var: import torch # type: ignore From cf5298638d2a8e175ca8bd5cf6c4540a137af82b Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 7 Mar 2025 01:16:56 +0800 Subject: [PATCH 6/6] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 72c2a2081815..0b4b34e0c9bb 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4029,7 +4029,7 @@ def main( ) -> R.Tensor((3,), dtype="float32"): with R.dataflow(): lv: R.Tensor((3,), dtype="int32") = R.astype(inp_1, "int32") - lv1: R.Tensor((3,), dtype="float32") = R.take(inp_0, lv, axis=None) + lv1: R.Tensor((3,), dtype="float32") = R.take(inp_0, lv) gv: R.Tensor((3,), dtype="float32") = lv1 R.output(gv) return gv