diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 952fb6f97111..9b835b0eeea7 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -376,6 +376,16 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + ########## Linear Interpolation ########## + + def _lerp(self, node: fx.Node) -> relax.Var: + start = self.env[node.args[0]] + end = self.env[node.args[1]] + weight = self.env[node.args[2]] + return self.block_builder.emit( + relax.op.add(start, relax.op.multiply(weight, relax.op.subtract(end, start))) + ) + ########## Manipulation ########## def _chunk(self, node: fx.Node) -> relax.Var: @@ -414,6 +424,12 @@ def _numel(self, node: fx.Node) -> relax.Var: shape = self.shape_of(x) return relax.const(reduce(lambda x, y: x * y, [s.value for s in shape]), "int32") + def _select(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = relax.const(node.args[2], "int64") + return self.block_builder.emit(relax.op.take(x, index, dim)) + def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -737,6 +753,8 @@ def create_convert_map( "scaled_dot_product_attention": self._scaled_dot_product_attention, "stochastic_depth": lambda node: self.env[node.args[0]], "unbind": self._unbind, + # linear interpolation + "lerp": self._lerp, # statistical "mean": self._mean, "sum": self._sum, @@ -759,6 +777,7 @@ def create_convert_map( "repeat": self._repeat, "reshape": self._reshape, "scatter": self._scatter, + "select": self._select, "size": self._size, "split": self._split, "squeeze": self._squeeze, @@ -772,6 +791,7 @@ def create_convert_map( "view": self._reshape, # tensor creation "arange": self._arange, + "clone": lambda node: self.env[node.args[0]], "empty": self._empty, "empty_like": self._empty_like, "fill_": self._inplace_fill, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index fbea8b7388ed..f06ce7a75373 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4125,5 +4125,70 @@ def main( verify_model(Numel(), [([5, 3], "float32")], {}, Expected) +def test_select(): + class Select(Module): + def forward(self, data): + return torch.select(data, 0, 1) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((3,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3,), dtype="float32") = R.take(inp_0, R.const(1, "int64"), axis=0) + gv: R.Tensor((3,), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Select(), [([5, 3], "float32")], {}, Expected) + + +def test_clone(): + class Clone(Module): + def forward(self, x): + return x.clone() + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((5, 3), dtype="float32") = inp_0 + R.output(gv) + return gv + + verify_model(Clone(), [([5, 3], "float32")], {}, Expected) + + +def test_lerp(): + class Lerp(Module): + def forward(self, start, end, weight): + return torch.lerp(start, end, weight) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + inp_1: R.Tensor((5, 3), dtype="float32"), + inp_2: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="float32") = R.add( + inp_0, R.multiply(inp_2, R.subtract(inp_1, inp_0)) + ) + gv: R.Tensor((5, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + Lerp(), [([5, 3], "float32"), ([5, 3], "float32"), ([5, 3], "float32")], {}, Expected + ) + + if __name__ == "__main__": tvm.testing.main()