From 74e31e1ef28164a9a600751e12283f05bbe9a9fe Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 1 Apr 2025 07:04:39 +0000 Subject: [PATCH 1/6] Integrate SELU into core ops for native R.nn.selu support --- .../torch/base_fx_graph_translator.py | 31 ------------------- .../torch/exported_program_translator.py | 2 +- python/tvm/relax/op/nn/__init__.py | 1 + python/tvm/relax/op/nn/nn.py | 24 ++++++++++++++ python/tvm/relax/transform/legalize_ops/nn.py | 16 ++++++++++ src/relax/op/nn/nn.cc | 3 ++ src/relax/op/nn/nn.h | 3 ++ .../test_frontend_from_exported_program.py | 20 +++--------- 8 files changed, 52 insertions(+), 48 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 74c620a33ddb..890f925079e0 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -340,37 +340,6 @@ def _softshrink(self, node: fx.Node) -> relax.Var: # Combine the positive and negative shrink results return self.block_builder.emit(relax.op.add(shrink_pos, shrink_neg)) - def _selu(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("alpha", 1.6732631921768188) - gamma = node.args[2] if len(node.args) > 2 else node.kwargs.get("gamma", 1.0507009873554805) - dtype = x.struct_info.dtype - - if isinstance(alpha, (int, float)): - alpha = relax.const(alpha, dtype) - else: - if not isinstance(alpha, relax.Var): - alpha = self.block_builder.emit(relax.const(alpha, dtype)) - - if isinstance(gamma, (int, float)): - gamma = relax.const(gamma, dtype) - else: - if not isinstance(gamma, relax.Var): - gamma = self.block_builder.emit(relax.const(gamma, dtype)) - - # gamma * (ReLU(x) + alpha * (exp(x) - 1)) - return self.block_builder.emit( - relax.op.multiply( - gamma, - relax.op.add( - relax.op.nn.relu(x), - relax.op.multiply( - alpha, relax.op.subtract(relax.op.exp(x), relax.const(1, dtype)) - ), - ), - ) - ) - def _tril_triu(self, op: Callable) -> Callable: from torch import fx diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 97ccc6393cbb..003c333cffe5 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -275,7 +275,7 @@ def create_convert_map( "relu.default": self._unary_op(relax.op.nn.relu), "round.default": self._round, "rsqrt.default": self._unary_op(relax.op.rsqrt), - "selu.default": self._selu, + "selu.default": self._unary_op(relax.op.nn.selu), "sigmoid.default": self._unary_op(relax.op.sigmoid), "sign.default": self._unary_op(relax.op.sign), "silu.default": self._unary_op(relax.op.nn.silu), diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 61212f33d882..e45982a0fed2 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -45,6 +45,7 @@ pad, relu, rms_norm, + selu, silu, softmax, ) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 09a7df5149f9..5232eea047cf 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1304,6 +1304,30 @@ def gelu_tanh(data: Expr) -> Expr: return _ffi_api.gelu_tanh(data) # type: ignore +def selu(data: Expr) -> Expr: + r"""Scaled Exponential Linear Unit (SELU). + + .. math:: + \text{SELU}(x) = \lambda \begin{cases} + x & \text{if } x > 0 \\ + \alpha (e^x - 1) & \text{if } x \leq 0 + \end{cases} + + where :math:`\lambda \approx 1.0507` and :math:`\alpha \approx 1.6733`. + + Parameters + ---------- + data : relax.Expr + The input data. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.selu(data) + + def silu(data: Expr) -> Expr: r"""Sigmoid Linear Unit function diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 4c8bdbc6615c..65afb01a487d 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -505,6 +505,22 @@ def te_gelu_tanh(x: te.Tensor): return bb.call_te(te_gelu_tanh, call.args[0], primfunc_name_hint="gelu_tanh") +@register_legalize("relax.nn.selu") +def _nn_selu(bb: BlockBuilder, call: Call) -> Expr: + def te_selu(x: te.Tensor): + dtype = x.dtype + alpha = tir.const(1.6732632423543772848170429916717, dtype) + scale = tir.const(1.0507009873554804934193349852946, dtype) + + # Compute SELU + # SELU(x)=scale∗(max(0,x)+min(0,α∗(exp(x)−1))) + positive_part = topi.maximum(x, tir.const(0, dtype)) + negative_part = topi.minimum(tir.const(0, dtype), alpha * (topi.exp(x) - tir.const(1, dtype))) + return scale * (positive_part + negative_part) + + return bb.call_te(te_selu, call.args[0], primfunc_name_hint="selu") + + @register_legalize("relax.nn.silu") def _nn_silu(bb: BlockBuilder, call: Call) -> Expr: def te_silu(x: te.Tensor): diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index c768ea19af7d..4a5a9a701612 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -34,6 +34,9 @@ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu, "nn.gelu", /*require_float_dtype=*/tru /* relax.nn.gelu_tanh */ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu_tanh, "nn.gelu_tanh", /*require_float_dtype=*/true); +/* relax.nn.selu */ +RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(selu, "nn.selu", /*require_float_dtype=*/true); + /* relax.nn.silu */ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(silu, "nn.silu", /*require_float_dtype=*/true); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 28c14139b97b..d6db36aba50a 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -57,6 +57,9 @@ Expr gelu(Expr data); /*! \brief Gaussian Error Linear Units function approximated by tanh. */ Expr gelu_tanh(Expr data); +/*! \brief Scaled Exponential Linear Unit function. */ +Expr selu(Expr data); + /*! \brief Sigmoid Linear Unit function. */ Expr silu(Expr data); diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 98f0f1d9cac6..e272560c363d 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -518,29 +518,17 @@ def forward(self, input): class Selu2(Module): def forward(self, input): return torch.nn.functional.selu(input) - + @tvm.script.ir_module class expected_selu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) - lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( - lv_exp, R.const(1.0, "float32") - ) - lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(1.6732631921768188, "float32"), lv_sub - ) - lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_relu, lv_scaled) - lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(1.0507010221481323, "float32"), lv_add - ) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_selu,) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.selu(input) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) R.output(gv) - return gv verify_model(Selu1(), example_args, {}, expected_selu) From 938d6b8213a485b8a5d668d2809f47a57078eae1 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 1 Apr 2025 07:16:23 +0000 Subject: [PATCH 2/6] fix trailing whitespace issue --- tests/python/relax/test_frontend_from_exported_program.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e272560c363d..2f05a7e208a1 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -518,7 +518,7 @@ def forward(self, input): class Selu2(Module): def forward(self, input): return torch.nn.functional.selu(input) - + @tvm.script.ir_module class expected_selu: @R.function From 59ba9115d474c1ceae3d653eab970a6e3ec97926 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 1 Apr 2025 11:31:29 +0000 Subject: [PATCH 3/6] fixing selu mapping issue in fx_graph and lint issue --- python/tvm/relax/frontend/torch/fx_translator.py | 4 ++-- python/tvm/relax/transform/legalize_ops/nn.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index c3d605a329b1..3ddf919c2ed1 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -650,7 +650,7 @@ def create_convert_map( relax.op.clip(self.env[node.args[0]], 0, 6) ), nn.Sigmoid: self._unary_op(relax.op.sigmoid), - nn.SELU: self._selu, + nn.SELU: self._unary_op(relax.op.nn.selu), nn.SiLU: self._unary_op(relax.op.nn.silu), nn.Softmax: self._softmax_module, nn.Tanh: self._unary_op(relax.op.tanh), @@ -710,7 +710,7 @@ def create_convert_map( "relu": self._unary_op(relax.op.nn.relu), "round": self._round, "rsqrt": self._unary_op(relax.op.rsqrt), - "selu": self._selu, + "selu": self._unary_op(relax.op.nn.selu), "sigmoid": self._unary_op(relax.op.sigmoid), "sign": self._unary_op(relax.op.sign), "silu": self._unary_op(relax.op.nn.silu), diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 65afb01a487d..c0b02df39576 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -515,7 +515,9 @@ def te_selu(x: te.Tensor): # Compute SELU # SELU(x)=scale∗(max(0,x)+min(0,α∗(exp(x)−1))) positive_part = topi.maximum(x, tir.const(0, dtype)) - negative_part = topi.minimum(tir.const(0, dtype), alpha * (topi.exp(x) - tir.const(1, dtype))) + negative_part = topi.minimum( + tir.const(0, dtype), alpha * (topi.exp(x) - tir.const(1, dtype)) + ) return scale * (positive_part + negative_part) return bb.call_te(te_selu, call.args[0], primfunc_name_hint="selu") From b3e4ec988865ada18c334dbd8666560dd2840801 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 1 Apr 2025 17:14:21 +0000 Subject: [PATCH 4/6] update the test script of selu in fx graph --- tests/python/relax/test_frontend_from_fx.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index b8d7f0b14e5b..8cfabd03a425 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -2429,23 +2429,11 @@ def forward(self, input): class expected_selu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): - # block 0 + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) - lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( - lv_exp, R.const(1.0, "float32") - ) - lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(1.6732631921768188, "float32"), lv_sub - ) - lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_relu, lv_scaled) - lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(1.0507009873554805, "float32"), lv_add - ) - gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_selu + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.selu(input) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) R.output(gv) return gv From 34ad7330fb9da0ee7ce974136df7f01d28673fba Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Wed, 2 Apr 2025 01:41:06 +0000 Subject: [PATCH 5/6] modify test script to fix selu module check --- tests/python/relax/test_frontend_from_fx.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 8cfabd03a425..d913baf13a0d 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -2429,11 +2429,11 @@ def forward(self, input): class expected_selu: @R.function def main( - input: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.selu(input) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.selu(inp_0) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv R.output(gv) return gv From 7c23a0541d7c8f2e83f45566114d945834c36657 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Wed, 2 Apr 2025 09:37:57 +0000 Subject: [PATCH 6/6] format documentations --- python/tvm/relax/transform/legalize_ops/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index c0b02df39576..fd3db841e646 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -513,7 +513,7 @@ def te_selu(x: te.Tensor): scale = tir.const(1.0507009873554804934193349852946, dtype) # Compute SELU - # SELU(x)=scale∗(max(0,x)+min(0,α∗(exp(x)−1))) + # SELU(x) = scale∗(max(0,x)+min(0,α∗(exp(x)−1))) positive_part = topi.maximum(x, tir.const(0, dtype)) negative_part = topi.minimum( tir.const(0, dtype), alpha * (topi.exp(x) - tir.const(1, dtype))