From 2d34aa773192d713ebf0c34db953f1c905b5dae6 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 11 Mar 2025 16:01:34 +0800 Subject: [PATCH 1/5] Update exported_program_translator.py --- python/tvm/relax/frontend/torch/exported_program_translator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 4ff31ea1d772..5e633f7928bd 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -192,10 +192,12 @@ def create_convert_map( "atanh.default": self._unary_op(relax.op.atanh), "bitwise_not.default": self._unary_op(relax.op.bitwise_not), "ceil.default": self._unary_op(relax.op.ceil), + "celu.default": self._celu, "clamp.default": self._clamp, "cos.default": self._unary_op(relax.op.cos), "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], + "elu.default": self._elu, "erf.default": self._unary_op(relax.op.erf), "exp.default": self._unary_op(relax.op.exp), "floor.default": self._unary_op(relax.op.floor), @@ -213,6 +215,7 @@ def create_convert_map( "relu.default": self._unary_op(relax.op.nn.relu), "round.default": self._round, "rsqrt.default": self._unary_op(relax.op.rsqrt), + "selu.default": self._selu, "sigmoid.default": self._unary_op(relax.op.sigmoid), "sign.default": self._unary_op(relax.op.sign), "silu.default": self._unary_op(relax.op.nn.silu), From 8487e966c287b90f5912a882b1b67678882f0cf9 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 11 Mar 2025 16:04:00 +0800 Subject: [PATCH 2/5] Update test_frontend_from_exported_program.py --- .../test_frontend_from_exported_program.py | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 6406610bf53e..621ddc7fa3b9 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -126,6 +126,43 @@ def main( def test_extended_unary_ops(): example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + # celu + class Celu1(Module): + def __init__(self): + super().__init__() + self.celu = torch.nn.CELU() + + def forward(self, input): + return self.celu(input) + + class Celu2(Module): + def forward(self, input): + return torch.nn.functional.celu(input) + + # alpha * min(0, exp(x / alpha) - 1) + max(0, x) + @tvm.script.ir_module + class expected_celu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + lv_div: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(lv, R.const(1.0, "float32")) + lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lv_div, R.const(1.0, "float32")) + lv_min: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(R.const(0.0, "float32"), lv_sub) + lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(R.const(1.0, "float32"), lv_min) + lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) + lv_celu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) + + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_celu,) + R.output(gv) + + return gv + + verify_model(Celu1(), example_args, {}, expected_celu) + verify_model(Celu2(), example_args, {}, expected_celu) + # clamp class Clamp(Module): def forward(self, input): @@ -174,6 +211,48 @@ def main( verify_model(Dropout1(), example_args, {}, expected_dropout) verify_model(Dropout2(), example_args, {}, expected_dropout) + # elu + class Elu(Module): + def __init__(self): + super().__init__() + self.elu = torch.nn.ELU() + + def forward(self, input): + return self.elu(input) + + class Elu2(Module): + def forward(self, input): + return torch.nn.functional.elu(input) + + @tvm.script.ir_module + class expected_elu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + R.const(1.0, dtype="float32"), lv_exp + ) + lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(lv_one_minus_exp) + + lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp + ) + + lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) + lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) + + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_elu,) + R.output(gv) + + return gv + + verify_model(Elu(), example_args, {}, expected_elu) + verify_model(Elu2(), example_args, {}, expected_elu) + # gelu class Gelu(Module): def __init__(self): @@ -306,6 +385,43 @@ def main( verify_model(ReLU0(), example_args, {}, expected_relu) verify_model(ReLU1(), example_args, {}, expected_relu) + # selu + class Selu1(Module): + def __init__(self): + super().__init__() + self.selu = torch.nn.SELU() + + def forward(self, input): + return self.selu(input) + + class Selu2(Module): + def forward(self, input): + return torch.nn.functional.selu(input) + + @tvm.script.ir_module + class expected_selu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) + lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lv_exp, R.const(1.0, "float32")) + lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + R.const(1.6732631921768188, "float32"), lv_sub) + lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_relu, lv_scaled) + lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(R.const(1.0507010221481323, "float32"), + lv_add) + + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_selu,) + R.output(gv) + + return gv + + verify_model(Selu1(), example_args, {}, expected_selu) + verify_model(Selu2(), example_args, {}, expected_selu) + # sigmoid class Sigmoid(Module): def __init__(self): From 6d876a658c6d5f8da9850f4b6a7dc9aa9af07583 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 11 Mar 2025 23:31:37 +0800 Subject: [PATCH 3/5] Update test_frontend_from_exported_program.py --- .../test_frontend_from_exported_program.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 621ddc7fa3b9..77fee3243b99 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -144,7 +144,7 @@ def forward(self, input): class expected_celu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) @@ -154,10 +154,8 @@ def main( lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(R.const(1.0, "float32"), lv_min) lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) lv_celu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_celu,) R.output(gv) - return gv verify_model(Celu1(), example_args, {}, expected_celu) @@ -228,26 +226,18 @@ def forward(self, input): class expected_elu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( - R.const(1.0, dtype="float32"), lv_exp - ) + lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(R.const(1.0, dtype="float32"), lv_exp) lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(lv_one_minus_exp) - - lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp - ) - + lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp) lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_elu,) R.output(gv) - return gv verify_model(Elu(), example_args, {}, expected_elu) From 659aed5ffff103c37ecfe41a4f5aa066bd81a33c Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 11 Mar 2025 23:43:15 +0800 Subject: [PATCH 4/5] Update test_frontend_from_exported_program.py --- .../test_frontend_from_exported_program.py | 43 +++++++++++++------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 77fee3243b99..26fa0033ade3 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -148,10 +148,18 @@ def main( ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - lv_div: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(lv, R.const(1.0, "float32")) - lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lv_div, R.const(1.0, "float32")) - lv_min: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(R.const(0.0, "float32"), lv_sub) - lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(R.const(1.0, "float32"), lv_min) + lv_div: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv, R.const(1.0, "float32") + ) + lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + lv_div, R.const(1.0, "float32") + ) + lv_min: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum( + R.const(0.0, "float32"), lv_sub + ) + lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + R.const(1.0, "float32"), lv_min + ) lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) lv_celu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_celu,) @@ -231,9 +239,15 @@ def main( # block 0 with R.dataflow(): lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(R.const(1.0, dtype="float32"), lv_exp) - lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(lv_one_minus_exp) - lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp) + lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + R.const(1.0, dtype="float32"), lv_exp + ) + lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu( + lv_one_minus_exp + ) + lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp + ) lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_elu,) @@ -392,18 +406,21 @@ def forward(self, input): class expected_selu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lv_exp, R.const(1.0, "float32")) + lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + lv_exp, R.const(1.0, "float32") + ) lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(1.6732631921768188, "float32"), lv_sub) + R.const(1.6732631921768188, "float32"), lv_sub + ) lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_relu, lv_scaled) - lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(R.const(1.0507010221481323, "float32"), - lv_add) - + lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + R.const(1.0507010221481323, "float32"),lv_add + ) gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_selu,) R.output(gv) From 8278604186fe049a1e91d68b35613eb7258b599a Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 11 Mar 2025 23:50:55 +0800 Subject: [PATCH 5/5] Update test_frontend_from_exported_program.py --- tests/python/relax/test_frontend_from_exported_program.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 26fa0033ade3..2e9f1fbd1c76 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -419,7 +419,7 @@ def main( ) lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_relu, lv_scaled) lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(1.0507010221481323, "float32"),lv_add + R.const(1.0507010221481323, "float32"), lv_add ) gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_selu,) R.output(gv)