Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 0 additions & 31 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,37 +340,6 @@ def _softshrink(self, node: fx.Node) -> relax.Var:
# Combine the positive and negative shrink results
return self.block_builder.emit(relax.op.add(shrink_pos, shrink_neg))

def _selu(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("alpha", 1.6732631921768188)
gamma = node.args[2] if len(node.args) > 2 else node.kwargs.get("gamma", 1.0507009873554805)
dtype = x.struct_info.dtype

if isinstance(alpha, (int, float)):
alpha = relax.const(alpha, dtype)
else:
if not isinstance(alpha, relax.Var):
alpha = self.block_builder.emit(relax.const(alpha, dtype))

if isinstance(gamma, (int, float)):
gamma = relax.const(gamma, dtype)
else:
if not isinstance(gamma, relax.Var):
gamma = self.block_builder.emit(relax.const(gamma, dtype))

# gamma * (ReLU(x) + alpha * (exp(x) - 1))
return self.block_builder.emit(
relax.op.multiply(
gamma,
relax.op.add(
relax.op.nn.relu(x),
relax.op.multiply(
alpha, relax.op.subtract(relax.op.exp(x), relax.const(1, dtype))
),
),
)
)

def _tril_triu(self, op: Callable) -> Callable:
from torch import fx

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def create_convert_map(
"relu.default": self._unary_op(relax.op.nn.relu),
"round.default": self._round,
"rsqrt.default": self._unary_op(relax.op.rsqrt),
"selu.default": self._selu,
"selu.default": self._unary_op(relax.op.nn.selu),
"sigmoid.default": self._unary_op(relax.op.sigmoid),
"sign.default": self._unary_op(relax.op.sign),
"silu.default": self._unary_op(relax.op.nn.silu),
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def create_convert_map(
relax.op.clip(self.env[node.args[0]], 0, 6)
),
nn.Sigmoid: self._unary_op(relax.op.sigmoid),
nn.SELU: self._selu,
nn.SELU: self._unary_op(relax.op.nn.selu),
nn.SiLU: self._unary_op(relax.op.nn.silu),
nn.Softmax: self._softmax_module,
nn.Tanh: self._unary_op(relax.op.tanh),
Expand Down Expand Up @@ -710,7 +710,7 @@ def create_convert_map(
"relu": self._unary_op(relax.op.nn.relu),
"round": self._round,
"rsqrt": self._unary_op(relax.op.rsqrt),
"selu": self._selu,
"selu": self._unary_op(relax.op.nn.selu),
"sigmoid": self._unary_op(relax.op.sigmoid),
"sign": self._unary_op(relax.op.sign),
"silu": self._unary_op(relax.op.nn.silu),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
pad,
relu,
rms_norm,
selu,
silu,
softmax,
)
24 changes: 24 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,30 @@ def gelu_tanh(data: Expr) -> Expr:
return _ffi_api.gelu_tanh(data) # type: ignore


def selu(data: Expr) -> Expr:
r"""Scaled Exponential Linear Unit (SELU).

.. math::
\text{SELU}(x) = \lambda \begin{cases}
x & \text{if } x > 0 \\
\alpha (e^x - 1) & \text{if } x \leq 0
\end{cases}

where :math:`\lambda \approx 1.0507` and :math:`\alpha \approx 1.6733`.

Parameters
----------
data : relax.Expr
The input data.

Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.selu(data)


def silu(data: Expr) -> Expr:
r"""Sigmoid Linear Unit function

Expand Down
18 changes: 18 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,24 @@ def te_gelu_tanh(x: te.Tensor):
return bb.call_te(te_gelu_tanh, call.args[0], primfunc_name_hint="gelu_tanh")


@register_legalize("relax.nn.selu")
def _nn_selu(bb: BlockBuilder, call: Call) -> Expr:
def te_selu(x: te.Tensor):
dtype = x.dtype
alpha = tir.const(1.6732632423543772848170429916717, dtype)
scale = tir.const(1.0507009873554804934193349852946, dtype)

# Compute SELU
# SELU(x) = scale∗(max(0,x)+min(0,α∗(exp(x)−1)))
positive_part = topi.maximum(x, tir.const(0, dtype))
negative_part = topi.minimum(
tir.const(0, dtype), alpha * (topi.exp(x) - tir.const(1, dtype))
)
return scale * (positive_part + negative_part)

return bb.call_te(te_selu, call.args[0], primfunc_name_hint="selu")


@register_legalize("relax.nn.silu")
def _nn_silu(bb: BlockBuilder, call: Call) -> Expr:
def te_silu(x: te.Tensor):
Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu, "nn.gelu", /*require_float_dtype=*/tru
/* relax.nn.gelu_tanh */
RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu_tanh, "nn.gelu_tanh", /*require_float_dtype=*/true);

/* relax.nn.selu */
RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(selu, "nn.selu", /*require_float_dtype=*/true);

/* relax.nn.silu */
RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(silu, "nn.silu", /*require_float_dtype=*/true);

Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ Expr gelu(Expr data);
/*! \brief Gaussian Error Linear Units function approximated by tanh. */
Expr gelu_tanh(Expr data);

/*! \brief Scaled Exponential Linear Unit function. */
Expr selu(Expr data);

/*! \brief Sigmoid Linear Unit function. */
Expr silu(Expr data);

Expand Down
18 changes: 3 additions & 15 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,24 +523,12 @@ def forward(self, input):
class expected_selu:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
input: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1)
lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
lv_exp, R.const(1.0, "float32")
)
lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
R.const(1.6732631921768188, "float32"), lv_sub
)
lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_relu, lv_scaled)
lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
R.const(1.0507010221481323, "float32"), lv_add
)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_selu,)
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.selu(input)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
R.output(gv)

return gv

verify_model(Selu1(), example_args, {}, expected_selu)
Expand Down
18 changes: 3 additions & 15 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2429,23 +2429,11 @@ def forward(self, input):
class expected_selu:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1)
lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
lv_exp, R.const(1.0, "float32")
)
lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
R.const(1.6732631921768188, "float32"), lv_sub
)
lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_relu, lv_scaled)
lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
R.const(1.0507009873554805, "float32"), lv_add
)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_selu
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.selu(inp_0)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

Expand Down