diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 74c620a33ddb..890f925079e0 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -340,37 +340,6 @@ def _softshrink(self, node: fx.Node) -> relax.Var: # Combine the positive and negative shrink results return self.block_builder.emit(relax.op.add(shrink_pos, shrink_neg)) - def _selu(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("alpha", 1.6732631921768188) - gamma = node.args[2] if len(node.args) > 2 else node.kwargs.get("gamma", 1.0507009873554805) - dtype = x.struct_info.dtype - - if isinstance(alpha, (int, float)): - alpha = relax.const(alpha, dtype) - else: - if not isinstance(alpha, relax.Var): - alpha = self.block_builder.emit(relax.const(alpha, dtype)) - - if isinstance(gamma, (int, float)): - gamma = relax.const(gamma, dtype) - else: - if not isinstance(gamma, relax.Var): - gamma = self.block_builder.emit(relax.const(gamma, dtype)) - - # gamma * (ReLU(x) + alpha * (exp(x) - 1)) - return self.block_builder.emit( - relax.op.multiply( - gamma, - relax.op.add( - relax.op.nn.relu(x), - relax.op.multiply( - alpha, relax.op.subtract(relax.op.exp(x), relax.const(1, dtype)) - ), - ), - ) - ) - def _tril_triu(self, op: Callable) -> Callable: from torch import fx diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 97ccc6393cbb..003c333cffe5 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -275,7 +275,7 @@ def create_convert_map( "relu.default": self._unary_op(relax.op.nn.relu), "round.default": self._round, "rsqrt.default": self._unary_op(relax.op.rsqrt), - "selu.default": self._selu, + "selu.default": self._unary_op(relax.op.nn.selu), "sigmoid.default": self._unary_op(relax.op.sigmoid), "sign.default": self._unary_op(relax.op.sign), "silu.default": self._unary_op(relax.op.nn.silu), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index c3d605a329b1..3ddf919c2ed1 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -650,7 +650,7 @@ def create_convert_map( relax.op.clip(self.env[node.args[0]], 0, 6) ), nn.Sigmoid: self._unary_op(relax.op.sigmoid), - nn.SELU: self._selu, + nn.SELU: self._unary_op(relax.op.nn.selu), nn.SiLU: self._unary_op(relax.op.nn.silu), nn.Softmax: self._softmax_module, nn.Tanh: self._unary_op(relax.op.tanh), @@ -710,7 +710,7 @@ def create_convert_map( "relu": self._unary_op(relax.op.nn.relu), "round": self._round, "rsqrt": self._unary_op(relax.op.rsqrt), - "selu": self._selu, + "selu": self._unary_op(relax.op.nn.selu), "sigmoid": self._unary_op(relax.op.sigmoid), "sign": self._unary_op(relax.op.sign), "silu": self._unary_op(relax.op.nn.silu), diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 61212f33d882..e45982a0fed2 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -45,6 +45,7 @@ pad, relu, rms_norm, + selu, silu, softmax, ) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 09a7df5149f9..5232eea047cf 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1304,6 +1304,30 @@ def gelu_tanh(data: Expr) -> Expr: return _ffi_api.gelu_tanh(data) # type: ignore +def selu(data: Expr) -> Expr: + r"""Scaled Exponential Linear Unit (SELU). + + .. math:: + \text{SELU}(x) = \lambda \begin{cases} + x & \text{if } x > 0 \\ + \alpha (e^x - 1) & \text{if } x \leq 0 + \end{cases} + + where :math:`\lambda \approx 1.0507` and :math:`\alpha \approx 1.6733`. + + Parameters + ---------- + data : relax.Expr + The input data. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.selu(data) + + def silu(data: Expr) -> Expr: r"""Sigmoid Linear Unit function diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 4c8bdbc6615c..fd3db841e646 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -505,6 +505,24 @@ def te_gelu_tanh(x: te.Tensor): return bb.call_te(te_gelu_tanh, call.args[0], primfunc_name_hint="gelu_tanh") +@register_legalize("relax.nn.selu") +def _nn_selu(bb: BlockBuilder, call: Call) -> Expr: + def te_selu(x: te.Tensor): + dtype = x.dtype + alpha = tir.const(1.6732632423543772848170429916717, dtype) + scale = tir.const(1.0507009873554804934193349852946, dtype) + + # Compute SELU + # SELU(x) = scale∗(max(0,x)+min(0,α∗(exp(x)−1))) + positive_part = topi.maximum(x, tir.const(0, dtype)) + negative_part = topi.minimum( + tir.const(0, dtype), alpha * (topi.exp(x) - tir.const(1, dtype)) + ) + return scale * (positive_part + negative_part) + + return bb.call_te(te_selu, call.args[0], primfunc_name_hint="selu") + + @register_legalize("relax.nn.silu") def _nn_silu(bb: BlockBuilder, call: Call) -> Expr: def te_silu(x: te.Tensor): diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index c768ea19af7d..4a5a9a701612 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -34,6 +34,9 @@ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu, "nn.gelu", /*require_float_dtype=*/tru /* relax.nn.gelu_tanh */ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu_tanh, "nn.gelu_tanh", /*require_float_dtype=*/true); +/* relax.nn.selu */ +RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(selu, "nn.selu", /*require_float_dtype=*/true); + /* relax.nn.silu */ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(silu, "nn.silu", /*require_float_dtype=*/true); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 28c14139b97b..d6db36aba50a 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -57,6 +57,9 @@ Expr gelu(Expr data); /*! \brief Gaussian Error Linear Units function approximated by tanh. */ Expr gelu_tanh(Expr data); +/*! \brief Scaled Exponential Linear Unit function. */ +Expr selu(Expr data); + /*! \brief Sigmoid Linear Unit function. */ Expr silu(Expr data); diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 98f0f1d9cac6..2f05a7e208a1 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -523,24 +523,12 @@ def forward(self, input): class expected_selu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) - lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( - lv_exp, R.const(1.0, "float32") - ) - lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(1.6732631921768188, "float32"), lv_sub - ) - lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_relu, lv_scaled) - lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(1.0507010221481323, "float32"), lv_add - ) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_selu,) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.selu(input) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) R.output(gv) - return gv verify_model(Selu1(), example_args, {}, expected_selu) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index b8d7f0b14e5b..d913baf13a0d 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -2429,23 +2429,11 @@ def forward(self, input): class expected_selu: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): - # block 0 with R.dataflow(): - lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) - lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( - lv_exp, R.const(1.0, "float32") - ) - lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(1.6732631921768188, "float32"), lv_sub - ) - lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_relu, lv_scaled) - lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( - R.const(1.0507009873554805, "float32"), lv_add - ) - gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_selu + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.selu(inp_0) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv R.output(gv) return gv