diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 70e6263da67e..4866189ed872 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3480,6 +3480,7 @@ def _get_convert_map(opset): "IsNaN": Renamer("isnan"), "Sqrt": Renamer("sqrt"), "Relu": Renamer("relu"), + "Celu": Celu.get_converter(opset), "LeakyRelu": Renamer("leaky_relu"), "Selu": Selu.get_converter(opset), "Elu": Elu.get_converter(opset), @@ -3925,6 +3926,19 @@ def _fix_outputs(self, op_name, outputs): return outputs +class Celu(OnnxOpConverter): + """Operator convereter for celu""" + + @classmethod + def _impl_v12(cls, inputs, attr, params): + x = inputs[0] + dtype = infer_type(x).checked_type.dtype + alpha = _op.const(attr.get("alpha", 1.0), dtype) + zero = _op.const(0, dtype) + one = _op.const(1, dtype) + return _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one)) + + def from_onnx( model, shape=None, dtype="float32", opset=None, freeze_params=False, convert_config=None ): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 8422cda42afc..0b83b8de363d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4668,7 +4668,6 @@ def verify_eyelike(indata): "test_cast_FLOAT_to_BFLOAT16", "test_cast_FLOAT_to_STRING", "test_cast_STRING_to_FLOAT", - "test_celu", "test_compress_0", "test_compress_1", "test_compress_default_axis",