diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2103365c6c60..fbeb3045ff1a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -192,10 +192,12 @@ def create_convert_map( "atanh.default": self._unary_op(relax.op.atanh), "bitwise_not.default": self._unary_op(relax.op.bitwise_not), "ceil.default": self._unary_op(relax.op.ceil), + "celu.default": self._celu, "clamp.default": self._clamp, "cos.default": self._unary_op(relax.op.cos), "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], + "elu.default": self._elu, "erf.default": self._unary_op(relax.op.erf), "exp.default": self._unary_op(relax.op.exp), "floor.default": self._unary_op(relax.op.floor), @@ -213,6 +215,7 @@ def create_convert_map( "relu.default": self._unary_op(relax.op.nn.relu), "round.default": self._round, "rsqrt.default": self._unary_op(relax.op.rsqrt), + "selu.default": self._selu, "sigmoid.default": self._unary_op(relax.op.sigmoid), "sign.default": self._unary_op(relax.op.sign), "silu.default": self._unary_op(relax.op.nn.silu), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 6406610bf53e..2e9f1fbd1c76 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -126,6 +126,49 @@ def main( def test_extended_unary_ops(): example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + # celu + class Celu1(Module): + def __init__(self): + super().__init__() + self.celu = torch.nn.CELU() + + def forward(self, input): + return self.celu(input) + + class Celu2(Module): + def forward(self, input): + return torch.nn.functional.celu(input) + + # alpha * min(0, exp(x / alpha) - 1) + max(0, x) + @tvm.script.ir_module + class expected_celu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + lv_div: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv, R.const(1.0, "float32") + ) + lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + lv_div, R.const(1.0, "float32") + ) + lv_min: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum( + R.const(0.0, "float32"), lv_sub + ) + lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + R.const(1.0, "float32"), lv_min + ) + lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) + lv_celu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_celu,) + R.output(gv) + return gv + + verify_model(Celu1(), example_args, {}, expected_celu) + verify_model(Celu2(), example_args, {}, expected_celu) + # clamp class Clamp(Module): def forward(self, input): @@ -174,6 +217,46 @@ def main( verify_model(Dropout1(), example_args, {}, expected_dropout) verify_model(Dropout2(), example_args, {}, expected_dropout) + # elu + class Elu(Module): + def __init__(self): + super().__init__() + self.elu = torch.nn.ELU() + + def forward(self, input): + return self.elu(input) + + class Elu2(Module): + def forward(self, input): + return torch.nn.functional.elu(input) + + @tvm.script.ir_module + class expected_elu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + R.const(1.0, dtype="float32"), lv_exp + ) + lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu( + lv_one_minus_exp + ) + lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp + ) + lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) + lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_elu,) + R.output(gv) + return gv + + verify_model(Elu(), example_args, {}, expected_elu) + verify_model(Elu2(), example_args, {}, expected_elu) + # gelu class Gelu(Module): def __init__(self): @@ -306,6 +389,46 @@ def main( verify_model(ReLU0(), example_args, {}, expected_relu) verify_model(ReLU1(), example_args, {}, expected_relu) + # selu + class Selu1(Module): + def __init__(self): + super().__init__() + self.selu = torch.nn.SELU() + + def forward(self, input): + return self.selu(input) + + class Selu2(Module): + def forward(self, input): + return torch.nn.functional.selu(input) + + @tvm.script.ir_module + class expected_selu: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) + lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) + lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + lv_exp, R.const(1.0, "float32") + ) + lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + R.const(1.6732631921768188, "float32"), lv_sub + ) + lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_relu, lv_scaled) + lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( + R.const(1.0507010221481323, "float32"), lv_add + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_selu,) + R.output(gv) + + return gv + + verify_model(Selu1(), example_args, {}, expected_selu) + verify_model(Selu2(), example_args, {}, expected_selu) + # sigmoid class Sigmoid(Module): def __init__(self):