diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 65e308a257e4..13609704ccb7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1161,6 +1161,24 @@ def _impl_v1(cls, inputs, attr, params): return Gelu._impl_v1([inp], attr, params) +class Mish(OnnxOpConverter): + """Operator converter for Mish from Microsoft onnxruntime contrib opset. + + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x})) + """ + + @classmethod + def _impl_v18(cls, inputs, attr, params): + x = inputs[0] + # Declare const + const_dtype = infer_type(x).checked_type.dtype + one = _expr.const(1.0, dtype=const_dtype) + + # Compute Mish + term1 = _op.log(one + _op.exp(x)) + return _op.multiply(x, _op.tanh(term1)) + + class LayerNormalization(OnnxOpConverter): """Operator converter for LayerNormalization from Microsoft onnxruntime contrib opset.""" @@ -6536,6 +6554,7 @@ def _get_convert_map(opset): "Gelu": Gelu.get_converter(opset), "FastGelu": FastGelu.get_converter(opset), "BiasGelu": BiasGelu.get_converter(opset), + "Mish": Mish.get_converter(opset), "LayerNormalization": LayerNormalization.get_converter(opset), # TODO: We need a better way to handle different domains, in case # of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 72265e49818c..216732343028 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2489,6 +2489,15 @@ def selu_x(x, alpha, gamma): ) +@pytest.mark.skip("Currently ONNX Runtime in CI does not support domain version of 18") +@tvm.testing.parametrize_targets +def test_mish(target, dev): + def mish_x(x): + return x * np.tanh(np.log1p(np.exp(x))) + + _test_onnx_op_elementwise(target, dev, (2, 4, 5, 6), mish_x, {}, "float64", "Mish", {}) + + @tvm.testing.parametrize_targets def test_prelu(target, dev): """test_prelu"""