From be54af45e6758a46f13b3990326ed24f705e4564 Mon Sep 17 00:00:00 2001 From: Aarsh2001 Date: Wed, 26 Jul 2023 22:22:36 +0530 Subject: [PATCH 1/7] added mish operator to onnx frontend --- python/tvm/relay/frontend/onnx.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 65e308a257e4..9681408a8fc5 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1160,6 +1160,24 @@ def _impl_v1(cls, inputs, attr, params): inp = _op.add(x, b) return Gelu._impl_v1([inp], attr, params) +class Mish(OnnxOpConverter): + """Operator converter for Mish from Microsoft onnxruntime contrib opset. + + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x})) + """ + + @classmethod + def _impl_v18(cls, inputs, attr, params): + x = inputs[0] + # Declare const + const_dtype = infer_type(x).checked_type.dtype + one = _expr.const(1.0, dtype=const_dtype) + + # Compute Mish + term1 = _op.log(one + _op.exp(x)) + return _op.multiply(x, _op.tanh(term1)) + + class LayerNormalization(OnnxOpConverter): """Operator converter for LayerNormalization from Microsoft onnxruntime contrib opset.""" @@ -2124,7 +2142,6 @@ def _impl_v13(cls, inputs, attr, params): class Prelu(OnnxOpConverter): """Operator converter for Prelu.""" - @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, f"Prelu need 2 inputs, {len(inputs)} given" @@ -6536,6 +6553,7 @@ def _get_convert_map(opset): "Gelu": Gelu.get_converter(opset), "FastGelu": FastGelu.get_converter(opset), "BiasGelu": BiasGelu.get_converter(opset), + "Mish": Mish.get_converter(opset), "LayerNormalization": LayerNormalization.get_converter(opset), # TODO: We need a better way to handle different domains, in case # of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention From baa58f785e5a32431da35777e48935f210bec725 Mon Sep 17 00:00:00 2001 From: Aarsh2001 Date: Wed, 26 Jul 2023 22:34:14 +0530 Subject: [PATCH 2/7] linter reformat --- python/tvm/relay/frontend/onnx.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9681408a8fc5..25fbd5088a9c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2142,6 +2142,7 @@ def _impl_v13(cls, inputs, attr, params): class Prelu(OnnxOpConverter): """Operator converter for Prelu.""" + @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, f"Prelu need 2 inputs, {len(inputs)} given" From 4388991a2ddc5bc526a164a7ab84e0ac31a52ebb Mon Sep 17 00:00:00 2001 From: Aarsh2001 Date: Thu, 27 Jul 2023 16:45:30 +0530 Subject: [PATCH 3/7] fixed lint issues as linter failed on the CI --- python/tvm/relay/frontend/onnx.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 25fbd5088a9c..13609704ccb7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1160,23 +1160,23 @@ def _impl_v1(cls, inputs, attr, params): inp = _op.add(x, b) return Gelu._impl_v1([inp], attr, params) + class Mish(OnnxOpConverter): """Operator converter for Mish from Microsoft onnxruntime contrib opset. - + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x})) """ @classmethod def _impl_v18(cls, inputs, attr, params): - x = inputs[0] - # Declare const - const_dtype = infer_type(x).checked_type.dtype - one = _expr.const(1.0, dtype=const_dtype) - - # Compute Mish - term1 = _op.log(one + _op.exp(x)) - return _op.multiply(x, _op.tanh(term1)) + x = inputs[0] + # Declare const + const_dtype = infer_type(x).checked_type.dtype + one = _expr.const(1.0, dtype=const_dtype) + # Compute Mish + term1 = _op.log(one + _op.exp(x)) + return _op.multiply(x, _op.tanh(term1)) class LayerNormalization(OnnxOpConverter): @@ -2142,7 +2142,7 @@ def _impl_v13(cls, inputs, attr, params): class Prelu(OnnxOpConverter): """Operator converter for Prelu.""" - + @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, f"Prelu need 2 inputs, {len(inputs)} given" From 006b6981e4e0cfd891dd38afd3c0f2c0b104733f Mon Sep 17 00:00:00 2001 From: Aarsh2001 Date: Fri, 28 Jul 2023 10:55:42 +0530 Subject: [PATCH 4/7] added test for mish operator --- tests/python/frontend/onnx/test_forward.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 72265e49818c..7a6e2790f904 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2489,6 +2489,23 @@ def selu_x(x, alpha, gamma): ) +@tvm.testing.parametrize_targets +def test_mish(target, dev): + def mish_x(x): + return x * np.tanh(np.log1p(np.exp(x))) + + _test_onnx_op_elementwise( + target, + dev, + (2, 4, 5, 6), + mish_x, + {}, + "float64", + "Mish", + {} + ) + + @tvm.testing.parametrize_targets def test_prelu(target, dev): """test_prelu""" From 8bd384995e044ac5e40bd213abbf32de5a6235da Mon Sep 17 00:00:00 2001 From: Aarsh2001 Date: Fri, 28 Jul 2023 11:26:36 +0530 Subject: [PATCH 5/7] added test for mish operator --- tests/python/frontend/onnx/test_forward.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7a6e2790f904..7c506278ea24 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2493,17 +2493,8 @@ def selu_x(x, alpha, gamma): def test_mish(target, dev): def mish_x(x): return x * np.tanh(np.log1p(np.exp(x))) - - _test_onnx_op_elementwise( - target, - dev, - (2, 4, 5, 6), - mish_x, - {}, - "float64", - "Mish", - {} - ) + + _test_onnx_op_elementwise(target, dev, (2, 4, 5, 6), mish_x, {}, "float64", "Mish", {}) @tvm.testing.parametrize_targets From a729263666b505aa9f0b2272fb76709f2e874941 Mon Sep 17 00:00:00 2001 From: Aarsh2001 Date: Fri, 28 Jul 2023 18:02:46 +0530 Subject: [PATCH 6/7] pytest skip since ONNX Runtime in CI does not support domain version 18 --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7c506278ea24..f1c963804326 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2488,7 +2488,7 @@ def selu_x(x, alpha, gamma): {"alpha": 0.25, "gamma": 0.3}, ) - +@pytest.mark.skip("Currently ONNX Runtime in CI does not support domain version of 18") @tvm.testing.parametrize_targets def test_mish(target, dev): def mish_x(x): From 821e08c0cb5f68ddd5592323b1c7012dd183147f Mon Sep 17 00:00:00 2001 From: Aarsh2001 Date: Fri, 28 Jul 2023 20:10:35 +0530 Subject: [PATCH 7/7] linter format --- tests/python/frontend/onnx/test_forward.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f1c963804326..216732343028 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2488,6 +2488,7 @@ def selu_x(x, alpha, gamma): {"alpha": 0.25, "gamma": 0.3}, ) + @pytest.mark.skip("Currently ONNX Runtime in CI does not support domain version of 18") @tvm.testing.parametrize_targets def test_mish(target, dev):