From c813cc3977a5e3f876b298a815f324dee32e1cd1 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 13 Mar 2025 19:38:54 +0800 Subject: [PATCH 1/4] Update fx_translator.py --- .../tvm/relax/frontend/torch/fx_translator.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 952fb6f97111..4abc1dcac16e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -376,6 +376,14 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + ########## Linear Interpolation ########## + + def _lerp(self, node: fx.Node) -> relax.Var: + start = self.env[node.args[0]] + end = self.env[node.args[1]] + weight = self.env[node.args[2]] + return self.block_builder.emit(relax.op.add(start, relax.op.multiply(weight, relax.op.subtract(end, start)))) + ########## Manipulation ########## def _chunk(self, node: fx.Node) -> relax.Var: @@ -414,6 +422,12 @@ def _numel(self, node: fx.Node) -> relax.Var: shape = self.shape_of(x) return relax.const(reduce(lambda x, y: x * y, [s.value for s in shape]), "int32") + def _select(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = relax.const(node.args[2], "int64") + return self.block_builder.emit(relax.op.take(x, index, dim)) + def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -737,6 +751,8 @@ def create_convert_map( "scaled_dot_product_attention": self._scaled_dot_product_attention, "stochastic_depth": lambda node: self.env[node.args[0]], "unbind": self._unbind, + # linear interpolation + "lerp": self._lerp, # statistical "mean": self._mean, "sum": self._sum, @@ -759,6 +775,7 @@ def create_convert_map( "repeat": self._repeat, "reshape": self._reshape, "scatter": self._scatter, + "select": self._select, "size": self._size, "split": self._split, "squeeze": self._squeeze, @@ -772,6 +789,7 @@ def create_convert_map( "view": self._reshape, # tensor creation "arange": self._arange, + "clone": lambda node: self.env[node.args[0]], "empty": self._empty, "empty_like": self._empty_like, "fill_": self._inplace_fill, From 46024e9b434864fd6296d4ee7a6d066787862107 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 13 Mar 2025 19:41:01 +0800 Subject: [PATCH 2/4] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 61 +++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index fbea8b7388ed..063190bf39d7 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4125,5 +4125,66 @@ def main( verify_model(Numel(), [([5, 3], "float32")], {}, Expected) +def test_select(): + class Select(Module): + def forward(self, data): + return torch.select(data, 0, 1) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((3,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3,), dtype="float32") = R.take(inp_0, R.const(1, "int64"), axis=0) + gv: R.Tensor((3,), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Select(), [([5, 3], "float32")], {}, Expected) + + +def test_clone(): + class Clone(Module): + def forward(self, x): + return x.clone() + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((5, 3), dtype="float32") = inp_0 + R.output(gv) + return gv + + verify_model(Clone(), [([5, 3], "float32")], {}, Expected) + + +def test_lerp(): + class Lerp(Module): + def forward(self, start, end, weight): + return torch.lerp(start, end, weight) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + inp_1: R.Tensor((5, 3), dtype="float32"), + inp_2: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="float32") = R.add(inp_0, R.multiply(inp_2, R.subtract(inp_1, inp_0))) + gv: R.Tensor((5, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Lerp(), [([5, 3], "float32"), ([5, 3], "float32"), ([5, 3], "float32")], {}, Expected) + + if __name__ == "__main__": tvm.testing.main() From 51724a56c78b4189160f67e57729a5c559d03631 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 17 Mar 2025 22:58:18 +0800 Subject: [PATCH 3/4] Update fx_translator.py --- python/tvm/relax/frontend/torch/fx_translator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 4abc1dcac16e..9b835b0eeea7 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -377,12 +377,14 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) ########## Linear Interpolation ########## - + def _lerp(self, node: fx.Node) -> relax.Var: start = self.env[node.args[0]] end = self.env[node.args[1]] weight = self.env[node.args[2]] - return self.block_builder.emit(relax.op.add(start, relax.op.multiply(weight, relax.op.subtract(end, start)))) + return self.block_builder.emit( + relax.op.add(start, relax.op.multiply(weight, relax.op.subtract(end, start))) + ) ########## Manipulation ########## From 64521845c4d3fd40a3490b0cb5d77c984954a5f5 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 17 Mar 2025 22:59:25 +0800 Subject: [PATCH 4/4] lint --- tests/python/relax/test_frontend_from_fx.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 063190bf39d7..f06ce7a75373 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4178,12 +4178,16 @@ def main( inp_2: R.Tensor((5, 3), dtype="float32"), ) -> R.Tensor((5, 3), dtype="float32"): with R.dataflow(): - lv: R.Tensor((5, 3), dtype="float32") = R.add(inp_0, R.multiply(inp_2, R.subtract(inp_1, inp_0))) + lv: R.Tensor((5, 3), dtype="float32") = R.add( + inp_0, R.multiply(inp_2, R.subtract(inp_1, inp_0)) + ) gv: R.Tensor((5, 3), dtype="float32") = lv R.output(gv) return gv - verify_model(Lerp(), [([5, 3], "float32"), ([5, 3], "float32"), ([5, 3], "float32")], {}, Expected) + verify_model( + Lerp(), [([5, 3], "float32"), ([5, 3], "float32"), ([5, 3], "float32")], {}, Expected + ) if __name__ == "__main__":