From e2e305776c6c0b9356cd5f752a06e6897caaf65a Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 11:19:08 +0900 Subject: [PATCH 01/12] support more unary ops --- .../torch/exported_program_translator.py | 18 + .../test_frontend_from_exported_program.py | 389 +++++++++++++++++- 2 files changed, 401 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 9af422d1c3ca..83a2f0ae9729 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -69,8 +69,26 @@ def create_convert_map( ) -> Dict[str, Callable[[fx.Node], relax.Var]]: return { # unary + "acos.default": self._unary_op(relax.op.acos), + "acosh.default": self._unary_op(relax.op.acosh), + "asin.default": self._unary_op(relax.op.asin), + "asinh.default": self._unary_op(relax.op.asinh), + "atan.default": self._unary_op(relax.op.atan), + "atanh.default": self._unary_op(relax.op.atanh), + "cos.default": self._unary_op(relax.op.cos), + "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], + "exp.default": self._unary_op(relax.op.exp), + "neg.default": self._unary_op(relax.op.negative), "relu.default": self._unary_op(relax.op.nn.relu), + "rsqrt.default": self._unary_op(relax.op.rsqrt), + "sigmoid.default": self._unary_op(relax.op.sigmoid), + "silu.default": self._unary_op(relax.op.nn.silu), + "sin.default": self._unary_op(relax.op.sin), + "sinh.default": self._unary_op(relax.op.sinh), + "sqrt.default": self._unary_op(relax.op.sqrt), + "tan.default": self._unary_op(relax.op.tan), + "tanh.default": self._unary_op(relax.op.tanh), # neural network "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "conv2d.default": self._conv2d, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 112390fe6094..2d39208a4c1d 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -39,6 +39,166 @@ def verify_model(torch_model, example_args, binding, expected): def test_unary(): example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + # acos + class Acos(Module): + def forward(self, input): + return torch.acos(input) + + @tvm.script.ir_module + class expected_acos: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Acos(), example_args, {}, expected_acos) + + # acosh + class Acosh(Module): + def forward(self, input): + return torch.acosh(input) + + @tvm.script.ir_module + class expected_acosh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acosh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Acosh(), example_args, {}, expected_acosh) + + # asin + class Asin(Module): + def forward(self, input): + return torch.asin(input) + + @tvm.script.ir_module + class expected_asin: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Asin(), example_args, {}, expected_asin) + + # asinh + class Asinh(Module): + def forward(self, input): + return torch.asinh(input) + + @tvm.script.ir_module + class expected_asinh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asinh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Asinh(), example_args, {}, expected_asinh) + + # atan + class Atan(Module): + def forward(self, input): + return torch.atan(input) + + @tvm.script.ir_module + class expected_atan: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Atan(), example_args, {}, expected_atan) + + # atanh + class Atanh(Module): + def forward(self, input): + return torch.atanh(input) + + @tvm.script.ir_module + class expected_atanh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atanh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Atanh(), example_args, {}, expected_atanh) + + # cos + class Cos(Module): + def forward(self, input): + return torch.cos(input) + + @tvm.script.ir_module + class expected_cos: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Cos(), example_args, {}, expected_cos) + + # cosh + class Cosh(Module): + def forward(self, input): + return torch.cosh(input) + + @tvm.script.ir_module + class expected_cosh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Cosh(), example_args, {}, expected_cosh) + # dropout class Dropout1(Module): def __init__(self): @@ -53,7 +213,7 @@ def forward(self, input): return torch.dropout(input, 0.5, train=True) @tvm.script.ir_module - class expected1: + class expected_dropout: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") @@ -64,8 +224,47 @@ def main( R.output(gv) return gv - verify_model(Dropout1(), example_args, {}, expected1) - verify_model(Dropout2(), example_args, {}, expected1) + verify_model(Dropout1(), example_args, {}, expected_dropout) + verify_model(Dropout2(), example_args, {}, expected_dropout) + + # exp + class Exp(Module): + def forward(self, input): + return torch.exp(input) + + @tvm.script.ir_module + class expected_exp: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Exp(), example_args, {}, expected_exp) + + # neg + class Neg(Module): + def forward(self, input): + return -input + + @I.ir_module + class expected_neg: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.negative(inp_0) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Neg(), example_args, {}, expected_neg) # relu class ReLU0(Module): @@ -81,7 +280,7 @@ def forward(self, input): return torch.nn.functional.relu(input) @tvm.script.ir_module - class expected: + class expected_relu: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") @@ -93,8 +292,186 @@ def main( R.output(gv) return gv - verify_model(ReLU0(), example_args, {}, expected) - verify_model(ReLU1(), example_args, {}, expected) + verify_model(ReLU0(), example_args, {}, expected_relu) + verify_model(ReLU1(), example_args, {}, expected_relu) + + # rsqrt + class Rsqrt(Module): + def forward(self, input): + return torch.rsqrt(input) + + @I.ir_module + class expected_rsqrt: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.rsqrt(inp_0) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Rsqrt(), example_args, {}, expected_rsqrt) + + # sigmoid + class Sigmoid(Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, input): + return self.sigmoid(input) + + class Sigmoid2(Module): + def forward(self, input): + return torch.sigmoid(input) + + @tvm.script.ir_module + class expected_sigmoid: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sigmoid(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sigmoid(), example_args, {}, expected_sigmoid) + verify_model(Sigmoid2(), example_args, {}, expected_sigmoid) + + # silu + class SiLU(Module): + def __init__(self): + super().__init__() + self.silu = torch.nn.SiLU() + + def forward(self, input): + return self.silu(input) + + class SiLU2(Module): + def forward(self, input): + return torch.nn.functional.silu(input) + + @tvm.script.ir_module + class expected_silu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.silu(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(SiLU(), example_args, {}, expected_silu) + verify_model(SiLU2(), example_args, {}, expected_silu) + + # sin + class Sin(Module): + def forward(self, input: torch.Tensor): + return torch.sin(input) + + @tvm.script.ir_module + class expected_sin: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sin(), example_args, {}, expected_sin) + + # sinh + class Sinh(Module): + def forward(self, input): + return torch.sinh(input) + + @tvm.script.ir_module + class expected_sinh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sinh(), example_args, {}, expected_sinh) + + # sqrt + class Sqrt(Module): + def forward(self, input): + return torch.sqrt(input) + + @tvm.script.ir_module + class expected_sqrt: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sqrt(), example_args, {}, expected_sqrt) + + # tan + class Tan(Module): + def forward(self, input): + return torch.tan(input) + + @tvm.script.ir_module + class expected_tan: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Tan(), example_args, {}, expected_tan) + + # tanh + class Tanh(Module): + def forward(self, input): + return torch.tanh(input) + + @tvm.script.ir_module + class expected_tanh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Tanh(), example_args, {}, expected_tanh) def test_adaptive_avgpool2d(): From bd57aaa53c4b7cde0ffa6bc4f5ef67708928d31a Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 11:30:15 +0900 Subject: [PATCH 02/12] support clamp --- .../torch/base_fx_graph_translator.py | 16 +++++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 16 ------------- .../test_frontend_from_exported_program.py | 23 +++++++++++++++++++ 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6a001b5a047c..2298122d955d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -111,6 +111,22 @@ def convert(node: fx.Node) -> relax.Var: return convert + def _clamp(self, node: fx.Node) -> relax.Expr: + args = self.retrieve_args(node) + a_min = args[1] if len(args) > 1 else node.kwargs["min"] + a_max = args[2] if len(args) > 2 else node.kwargs["max"] + if not isinstance(a_min, (int, float)): + raise ValueError( + f"TVM only supports constant min value for torch.clamp/clip, " + f"but got {a_min} with type {type(a_min)}" + ) + if not isinstance(a_max, (int, float)): + raise ValueError( + f"TVM only supports constant max value for torch.clamp/clip, " + f"but got {a_max} with type {type(a_max)}" + ) + return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 83a2f0ae9729..364148824e45 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -75,6 +75,7 @@ def create_convert_map( "asinh.default": self._unary_op(relax.op.asinh), "atan.default": self._unary_op(relax.op.atan), "atanh.default": self._unary_op(relax.op.atanh), + "clamp.default": self._clamp, "cos.default": self._unary_op(relax.op.cos), "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index ec53cf23edc5..6b84035ca3d1 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -62,22 +62,6 @@ def _fetch_attr(self, model, target: str): ########## Unary Ops ########## - def _clamp(self, node: fx.Node) -> relax.Expr: - args = self.retrieve_args(node) - a_min = args[1] if len(args) > 1 else node.kwargs["min"] - a_max = args[2] if len(args) > 2 else node.kwargs["max"] - if not isinstance(a_min, (int, float)): - raise ValueError( - f"TVM only supports constant min value for torch.clamp/clip, " - f"but got {a_min} with type {type(a_min)}" - ) - if not isinstance(a_max, (int, float)): - raise ValueError( - f"TVM only supports constant max value for torch.clamp/clip, " - f"but got {a_max} with type {type(a_max)}" - ) - return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) - def _gelu(self, node: fx.Node) -> relax.Expr: approximate = node.kwargs.get("approximate", "none") if approximate == "none": diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 2d39208a4c1d..da9e696ff7ed 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -474,6 +474,29 @@ def main( verify_model(Tanh(), example_args, {}, expected_tanh) +def test_clamp(): + class Clamp(Module): + def forward(self, input): + return torch.clamp(input, min=0.1, max=0.5) + + @tvm.script.ir_module + class expected_clamp: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.1, 0.5) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Clamp(), example_args, {}, expected_clamp) + + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): From 1dbb44ed86a7014b587b3075b9b2c4b752fa765c Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 11:32:41 +0900 Subject: [PATCH 03/12] support gelu --- .../torch/base_fx_graph_translator.py | 9 ++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 9 ------ .../test_frontend_from_exported_program.py | 30 +++++++++++++++++++ 4 files changed, 40 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 2298122d955d..d253f5084548 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -127,6 +127,15 @@ def _clamp(self, node: fx.Node) -> relax.Expr: ) return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + def _gelu(self, node: fx.Node) -> relax.Expr: + approximate = node.kwargs.get("approximate", "none") + if approximate == "none": + return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])) + elif approximate == "tanh": + return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]])) + else: + raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 364148824e45..8435441f7b9d 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -80,6 +80,7 @@ def create_convert_map( "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], "exp.default": self._unary_op(relax.op.exp), + "gelu.default": self._gelu, "neg.default": self._unary_op(relax.op.negative), "relu.default": self._unary_op(relax.op.nn.relu), "rsqrt.default": self._unary_op(relax.op.rsqrt), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 6b84035ca3d1..7d850d6024e1 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -62,15 +62,6 @@ def _fetch_attr(self, model, target: str): ########## Unary Ops ########## - def _gelu(self, node: fx.Node) -> relax.Expr: - approximate = node.kwargs.get("approximate", "none") - if approximate == "none": - return self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])) - elif approximate == "tanh": - return self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]])) - else: - raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) - def _hardsigmoid(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index da9e696ff7ed..00014a36d095 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -496,6 +496,36 @@ def main( verify_model(Clamp(), example_args, {}, expected_clamp) +def test_gelu(): + class Gelu(Module): + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU() + + def forward(self, input): + return self.gelu(input) + + class Gelu2(Module): + def forward(self, input): + return torch.nn.functional.gelu(input) + + @tvm.script.ir_module + class expected_gelu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.gelu(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Gelu(), example_args, {}, expected_gelu) + verify_model(Gelu2(), example_args, {}, expected_gelu) + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): From 565b44197c0e4a4182490f4eca295596e39d563d Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 11:33:18 +0900 Subject: [PATCH 04/12] support hardsigmoid --- .../torch/base_fx_graph_translator.py | 8 +++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 8 ----- .../test_frontend_from_exported_program.py | 34 +++++++++++++++++++ 4 files changed, 43 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index d253f5084548..62f67e6f1942 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -136,6 +136,14 @@ def _gelu(self, node: fx.Node) -> relax.Expr: else: raise KeyError("Unregonized approximate algorithm for gelu: {}.".format(approximate)) + def _hardsigmoid(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 8435441f7b9d..936dee20a6de 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -81,6 +81,7 @@ def create_convert_map( "dropout.default": lambda node: self.env[node.args[0]], "exp.default": self._unary_op(relax.op.exp), "gelu.default": self._gelu, + "hardsigmoid.default": self._hardsigmoid, "neg.default": self._unary_op(relax.op.negative), "relu.default": self._unary_op(relax.op.nn.relu), "rsqrt.default": self._unary_op(relax.op.rsqrt), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 7d850d6024e1..d737a853006b 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -62,14 +62,6 @@ def _fetch_attr(self, model, target: str): ########## Unary Ops ########## - def _hardsigmoid(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - x0 = relax.op.add(x, relax.const(3, dtype)) - x1 = relax.op.clip(x0, 0, 6) - return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) - def _hardswish(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 00014a36d095..5d9823b383aa 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -527,6 +527,40 @@ def main( verify_model(Gelu2(), example_args, {}, expected_gelu) +def test_hardsigmoid(): + class Hardsigmoid(torch.nn.Module): + def __init__(self): + super().__init__() + self.hs = torch.nn.Hardsigmoid() + + def forward(self, input): + return self.hs(input) + + class Hardsigmoid2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardsigmoid(input) + + @tvm.script.ir_module + class expected_hardsigmoid: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1, R.const(6, "float32") + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid) + verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): From 2e96b759dc6cae5742641fa059d72cd94ba7d213 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 11:34:34 +0900 Subject: [PATCH 05/12] support hardswish --- .../torch/base_fx_graph_translator.py | 9 +++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 9 ----- .../test_frontend_from_exported_program.py | 35 +++++++++++++++++++ 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 62f67e6f1942..832c90545b3d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -144,6 +144,15 @@ def _hardsigmoid(self, node: fx.Node) -> relax.Var: x1 = relax.op.clip(x0, 0, 6) return self.block_builder.emit(relax.op.divide(x1, relax.const(6, dtype))) + def _hardswish(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + x0 = relax.op.add(x, relax.const(3, dtype)) + x1 = relax.op.clip(x0, 0, 6) + x2 = relax.op.divide(x1, relax.const(6, dtype)) + return self.block_builder.emit(relax.op.multiply(x, x2)) + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 936dee20a6de..7e9eba8f56a5 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -82,6 +82,7 @@ def create_convert_map( "exp.default": self._unary_op(relax.op.exp), "gelu.default": self._gelu, "hardsigmoid.default": self._hardsigmoid, + "hardswish.default": self._hardswish, "neg.default": self._unary_op(relax.op.negative), "relu.default": self._unary_op(relax.op.nn.relu), "rsqrt.default": self._unary_op(relax.op.rsqrt), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index d737a853006b..500db245a25f 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -62,15 +62,6 @@ def _fetch_attr(self, model, target: str): ########## Unary Ops ########## - def _hardswish(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - x0 = relax.op.add(x, relax.const(3, dtype)) - x1 = relax.op.clip(x0, 0, 6) - x2 = relax.op.divide(x1, relax.const(6, dtype)) - return self.block_builder.emit(relax.op.multiply(x, x2)) - def _leakyrelu(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 5d9823b383aa..bf485dc7e2ff 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -561,6 +561,41 @@ def main( verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid) +def test_hardswish(): + class Hardswish(torch.nn.Module): + def __init__(self): + super().__init__() + self.hs = torch.nn.Hardswish() + + def forward(self, input): + return self.hs(input) + + class Hardswish2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardswish(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1, R.const(6, "float32") + ) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Hardswish(), example_args, {}, expected1) + verify_model(Hardswish2(), example_args, {}, expected1) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): From 5eaa75a399b2e548e3f13e79da44d8ad9f48b330 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 11:46:45 +0900 Subject: [PATCH 06/12] support hardtanh --- .../torch/exported_program_translator.py | 10 ++++++ .../test_frontend_from_exported_program.py | 32 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7e9eba8f56a5..41714c6b76b3 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -64,6 +64,15 @@ def create_input_vars( return parameters_buffers_constants, user_inputs + ########## Unary Ops ########## + + def _hardtanh(self, node: fx.Node) -> relax.Expr: + args = self.retrieve_args(node) + x = args[0] + min_val = node.args[1] if len(args) > 1 else node.kwargs("min_val", -1.0) + max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 1.0) + return self.block_builder.emit(relax.op.clip(x, min_val, max_val)) + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: @@ -83,6 +92,7 @@ def create_convert_map( "gelu.default": self._gelu, "hardsigmoid.default": self._hardsigmoid, "hardswish.default": self._hardswish, + "hardtanh.default": self._hardtanh, "neg.default": self._unary_op(relax.op.negative), "relu.default": self._unary_op(relax.op.nn.relu), "rsqrt.default": self._unary_op(relax.op.rsqrt), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index bf485dc7e2ff..f9a1c0883c85 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -596,6 +596,38 @@ def main( verify_model(Hardswish2(), example_args, {}, expected1) +def test_hardtanh(): + class Hardtanh(torch.nn.Module): + def __init__(self): + super().__init__() + self.ht = torch.nn.Hardtanh() + + def forward(self, input): + return self.ht(input) + + class Hardtanh2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardtanh(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + inp_0, R.prim_value(T.float64(-1.0)), R.prim_value(T.float64(1.0)) + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Hardtanh(), example_args, {}, expected1) + verify_model(Hardtanh2(), example_args, {}, expected1) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): From a74b89bac6111aab6694b7fbf47af2cfbab954b5 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 11:49:50 +0900 Subject: [PATCH 07/12] support leaky_relu --- .../torch/base_fx_graph_translator.py | 5 +++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 5 --- .../test_frontend_from_exported_program.py | 36 +++++++++++++++++++ 4 files changed, 42 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 832c90545b3d..d2ced05fe50f 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -153,6 +153,11 @@ def _hardswish(self, node: fx.Node) -> relax.Var: x2 = relax.op.divide(x1, relax.const(6, dtype)) return self.block_builder.emit(relax.op.multiply(x, x2)) + def _leakyrelu(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01) + return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 41714c6b76b3..96dd0e1aed3f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -93,6 +93,7 @@ def create_convert_map( "hardsigmoid.default": self._hardsigmoid, "hardswish.default": self._hardswish, "hardtanh.default": self._hardtanh, + "leaky_relu.default": self._leakyrelu, "neg.default": self._unary_op(relax.op.negative), "relu.default": self._unary_op(relax.op.nn.relu), "rsqrt.default": self._unary_op(relax.op.rsqrt), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 500db245a25f..06413bfdad4a 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -62,11 +62,6 @@ def _fetch_attr(self, model, target: str): ########## Unary Ops ########## - def _leakyrelu(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01) - return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) - def _leakyrelu_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f9a1c0883c85..d315ce11bfa7 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -628,6 +628,42 @@ def main( verify_model(Hardtanh2(), example_args, {}, expected1) +def test_leakyrelu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class LeakyReLU0(Module): + def __init__(self): + super().__init__() + self.leakyrelu = torch.nn.LeakyReLU(0.02) + + def forward(self, input): + return self.leakyrelu(input) + + class LeakyReLU1(Module): + def forward(self, input): + return torch.nn.functional.leaky_relu(input, 0.02) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.leakyrelu(input_1, 0.02) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(LeakyReLU0(), example_args, {}, expected) + verify_model(LeakyReLU1(), example_args, {}, expected) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): From 0a7addb2fec7e6fe157d635b8c112c337d8cf3b6 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 11:52:28 +0900 Subject: [PATCH 08/12] support log_softmax --- .../torch/base_fx_graph_translator.py | 5 +++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 5 --- .../test_frontend_from_exported_program.py | 31 +++++++++++++++++++ 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index d2ced05fe50f..5a44c31406f6 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -158,6 +158,11 @@ def _leakyrelu(self, node: fx.Node) -> relax.Var: alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("negative_slope", 0.01) return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) + def _log_softmax(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 96dd0e1aed3f..199dc6d890d1 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -94,6 +94,7 @@ def create_convert_map( "hardswish.default": self._hardswish, "hardtanh.default": self._hardtanh, "leaky_relu.default": self._leakyrelu, + "log_softmax.int": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), "relu.default": self._unary_op(relax.op.nn.relu), "rsqrt.default": self._unary_op(relax.op.rsqrt), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 06413bfdad4a..07f5c94d85e6 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -68,11 +68,6 @@ def _leakyrelu_module(self, node: fx.Node) -> relax.Var: alpha = module.negative_slope return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) - def _log_softmax(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) - return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) - def _log_softmax_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index d315ce11bfa7..90b4b3b01c34 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -664,6 +664,37 @@ def main( verify_model(LeakyReLU1(), example_args, {}, expected) +def test_logsoftmax(): + class LogSoftmax(Module): + def __init__(self): + super().__init__() + self.lsm = torch.nn.LogSoftmax(dim=1) + + def forward(self, input): + return self.lsm(input) + + class LogSoftmax2(Module): + def forward(self, input): + return torch.nn.functional.log_softmax(input, dim=1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.log_softmax(input_1, axis=1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(LogSoftmax(), example_args, {}, expected1) + verify_model(LogSoftmax2(), example_args, {}, expected1) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): From 9871b1a8c7ca927898ded9ea5873629eb68ebbfc Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 11:54:59 +0900 Subject: [PATCH 09/12] support round --- .../torch/base_fx_graph_translator.py | 6 +++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 6 ----- .../test_frontend_from_exported_program.py | 22 +++++++++++++++++++ 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 5a44c31406f6..ebbbff466da3 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -163,6 +163,12 @@ def _log_softmax(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + def _round(self, node: fx.Node) -> relax.Expr: + if node.kwargs.get("decimals", 0) != 0: + raise ValueError("specifying decimals for round is not supported yet") + arg = self.env[node.args[0]] + return self.block_builder.emit(relax.op.round(arg)) + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 199dc6d890d1..4f867d3ccfcd 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -97,6 +97,7 @@ def create_convert_map( "log_softmax.int": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), "relu.default": self._unary_op(relax.op.nn.relu), + "round.default": self._round, "rsqrt.default": self._unary_op(relax.op.rsqrt), "sigmoid.default": self._unary_op(relax.op.sigmoid), "silu.default": self._unary_op(relax.op.nn.silu), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 07f5c94d85e6..f904d8eb77c8 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -75,12 +75,6 @@ def _log_softmax_module(self, node: fx.Node) -> relax.Var: assert dim is not None return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) - def _round(self, node: fx.Node) -> relax.Expr: - if node.kwargs.get("decimals", 0) != 0: - raise ValueError("specifying decimals for round is not supported yet") - arg = self.env[node.args[0]] - return self.block_builder.emit(relax.op.round(arg)) - def _softmax(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 90b4b3b01c34..7f4247397654 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -695,6 +695,28 @@ def main( verify_model(LogSoftmax2(), example_args, {}, expected1) +def test_round(): + class Round(Module): + def forward(self, input): + return torch.round(input) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.round(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Round(), example_args, {}, expected) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): From 2cce028a6304e86dda6365a39eae5fed492f24f4 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 11:56:33 +0900 Subject: [PATCH 10/12] support softmax --- .../torch/base_fx_graph_translator.py | 5 +++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 5 --- .../test_frontend_from_exported_program.py | 31 +++++++++++++++++++ 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index ebbbff466da3..9385cc6fda69 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -169,6 +169,11 @@ def _round(self, node: fx.Node) -> relax.Expr: arg = self.env[node.args[0]] return self.block_builder.emit(relax.op.round(arg)) + def _softmax(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 4f867d3ccfcd..c70583442230 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -103,6 +103,7 @@ def create_convert_map( "silu.default": self._unary_op(relax.op.nn.silu), "sin.default": self._unary_op(relax.op.sin), "sinh.default": self._unary_op(relax.op.sinh), + "softmax.int": self._softmax, "sqrt.default": self._unary_op(relax.op.sqrt), "tan.default": self._unary_op(relax.op.tan), "tanh.default": self._unary_op(relax.op.tanh), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index f904d8eb77c8..8d6ad9cfc814 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -75,11 +75,6 @@ def _log_softmax_module(self, node: fx.Node) -> relax.Var: assert dim is not None return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) - def _softmax(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) - return self.block_builder.emit(relax.op.nn.softmax(x, dim)) - def _softmax_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 7f4247397654..d9ca87ad89ed 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -717,6 +717,37 @@ def main( verify_model(Round(), example_args, {}, expected) +def test_softmax(): + class Softmax(Module): + def __init__(self): + super().__init__() + self.sm = torch.nn.Softmax(dim=1) + + def forward(self, input): + return self.sm(input) + + class Softmax2(Module): + def forward(self, input): + return torch.nn.functional.softmax(input, dim=1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softmax(input_1, axis=1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Softmax(), example_args, {}, expected1) + verify_model(Softmax2(), example_args, {}, expected1) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): From b6c1a8f7ef4a06836472c021d45a607316ddf6d4 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 12:00:06 +0900 Subject: [PATCH 11/12] support tril and triu --- .../torch/base_fx_graph_translator.py | 11 +++++ .../torch/exported_program_translator.py | 2 + .../tvm/relax/frontend/torch/fx_translator.py | 11 ----- .../test_frontend_from_exported_program.py | 42 +++++++++++++++++++ 4 files changed, 55 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 9385cc6fda69..d52b3d598f89 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -174,6 +174,17 @@ def _softmax(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + def _tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else node.kwargs.get("diagonal", 0) + assert isinstance(k, int) + return self.block_builder.emit(op(x, k)) + + return convert + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c70583442230..1ceddad7d79f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -107,6 +107,8 @@ def create_convert_map( "sqrt.default": self._unary_op(relax.op.sqrt), "tan.default": self._unary_op(relax.op.tan), "tanh.default": self._unary_op(relax.op.tanh), + "tril.default": self._tril_triu(relax.op.tril), + "triu.default": self._tril_triu(relax.op.triu), # neural network "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "conv2d.default": self._conv2d, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 8d6ad9cfc814..6f7c6fa2c575 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -96,17 +96,6 @@ def convert(node: fx.Node) -> relax.Var: return convert - def _tril_triu(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - k = node.args[1] if len(node.args) > 1 else node.kwargs.get("diagonal", 0) - assert isinstance(k, int) - return self.block_builder.emit(op(x, k)) - - return convert - ########## Binary Ops ########## def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index d9ca87ad89ed..6c17d96004b6 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -748,6 +748,48 @@ def main( verify_model(Softmax2(), example_args, {}, expected1) +def test_tril_triu(): + example_args = (torch.randn(10, 10, dtype=torch.float32),) + + class Tril(Module): + def forward(self, input): + return torch.tril(input, 1) + + @tvm.script.ir_module + class expected_tril: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Tril(), example_args, {}, expected_tril) + + class Triu(Module): + def forward(self, input): + return torch.triu(input, 1) + + @tvm.script.ir_module + class expected_triu: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Triu(), example_args, {}, expected_triu) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): From 933065a86acf37a8188758a86894d29b124ce42c Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 27 Sep 2024 15:44:02 +0900 Subject: [PATCH 12/12] skip flaky test --- tests/python/relay/test_to_mixed_precision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index ae5172f6caf0..a8032ce0d26d 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -98,6 +98,7 @@ def test_lstm(target_precision): ) +@pytest.mark.skip(reason="Flaky test") def test_lstm_float64(): """Tests if can handle other mixed precision types.