Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,12 @@ def create_convert_map(
"atanh.default": self._unary_op(relax.op.atanh),
"bitwise_not.default": self._unary_op(relax.op.bitwise_not),
"ceil.default": self._unary_op(relax.op.ceil),
"celu.default": self._celu,
"clamp.default": self._clamp,
"cos.default": self._unary_op(relax.op.cos),
"cosh.default": self._unary_op(relax.op.cosh),
"dropout.default": lambda node: self.env[node.args[0]],
"elu.default": self._elu,
"erf.default": self._unary_op(relax.op.erf),
"exp.default": self._unary_op(relax.op.exp),
"floor.default": self._unary_op(relax.op.floor),
Expand All @@ -213,6 +215,7 @@ def create_convert_map(
"relu.default": self._unary_op(relax.op.nn.relu),
"round.default": self._round,
"rsqrt.default": self._unary_op(relax.op.rsqrt),
"selu.default": self._selu,
"sigmoid.default": self._unary_op(relax.op.sigmoid),
"sign.default": self._unary_op(relax.op.sign),
"silu.default": self._unary_op(relax.op.nn.silu),
Expand Down
123 changes: 123 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,49 @@ def main(
def test_extended_unary_ops():
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

# celu
class Celu1(Module):
def __init__(self):
super().__init__()
self.celu = torch.nn.CELU()

def forward(self, input):
return self.celu(input)

class Celu2(Module):
def forward(self, input):
return torch.nn.functional.celu(input)

# alpha * min(0, exp(x / alpha) - 1) + max(0, x)
@tvm.script.ir_module
class expected_celu:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
lv_div: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
lv, R.const(1.0, "float32")
)
lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
lv_div, R.const(1.0, "float32")
)
lv_min: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(
R.const(0.0, "float32"), lv_sub
)
lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
R.const(1.0, "float32"), lv_min
)
lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1)
lv_celu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_celu,)
R.output(gv)
return gv

verify_model(Celu1(), example_args, {}, expected_celu)
verify_model(Celu2(), example_args, {}, expected_celu)

# clamp
class Clamp(Module):
def forward(self, input):
Expand Down Expand Up @@ -174,6 +217,46 @@ def main(
verify_model(Dropout1(), example_args, {}, expected_dropout)
verify_model(Dropout2(), example_args, {}, expected_dropout)

# elu
class Elu(Module):
def __init__(self):
super().__init__()
self.elu = torch.nn.ELU()

def forward(self, input):
return self.elu(input)

class Elu2(Module):
def forward(self, input):
return torch.nn.functional.elu(input)

@tvm.script.ir_module
class expected_elu:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
# block 0
with R.dataflow():
lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
R.const(1.0, dtype="float32"), lv_exp
)
lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(
lv_one_minus_exp
)
lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp
)
lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1)
lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_elu,)
R.output(gv)
return gv

verify_model(Elu(), example_args, {}, expected_elu)
verify_model(Elu2(), example_args, {}, expected_elu)

# gelu
class Gelu(Module):
def __init__(self):
Expand Down Expand Up @@ -306,6 +389,46 @@ def main(
verify_model(ReLU0(), example_args, {}, expected_relu)
verify_model(ReLU1(), example_args, {}, expected_relu)

# selu
class Selu1(Module):
def __init__(self):
super().__init__()
self.selu = torch.nn.SELU()

def forward(self, input):
return self.selu(input)

class Selu2(Module):
def forward(self, input):
return torch.nn.functional.selu(input)

@tvm.script.ir_module
class expected_selu:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1)
lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
lv_exp, R.const(1.0, "float32")
)
lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
R.const(1.6732631921768188, "float32"), lv_sub
)
lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_relu, lv_scaled)
lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
R.const(1.0507010221481323, "float32"), lv_add
)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_selu,)
R.output(gv)

return gv

verify_model(Selu1(), example_args, {}, expected_selu)
verify_model(Selu2(), example_args, {}, expected_selu)

# sigmoid
class Sigmoid(Module):
def __init__(self):
Expand Down