diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index f09cc56de372..3a70cd090a54 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1918,6 +1918,20 @@ def _impl_v1(cls, bb, inputs, attr, params): ) + relax.op.nn.relu(inputs[0]) +class HardSigmoid(OnnxOpConverter): + """Converts an onnx HardSigmoid node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + x = inputs[0] + dtype = x.struct_info.dtype + alpha = float(attr.get("alpha", 0.2)) + alpha = relax.const(alpha, dtype=dtype) + beta = float(attr.get("beta", 0.5)) + beta = relax.const(beta, dtype=dtype) + return relax.op.clip(relax.op.add(relax.op.multiply(alpha, x), beta), 0, 1) + + class HardSwish(OnnxOpConverter): """Converts an onnx HardSwish node into an equivalent Relax expression.""" @@ -2014,6 +2028,7 @@ def _get_convert_map(): "Reciprocal": Reciprocal, "OneHot": OneHot, "Elu": Elu, + "HardSigmoid": HardSigmoid, "HardSwish": HardSwish, } diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 0161534d17f7..0fc7ec064402 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -590,6 +590,12 @@ def test_elu(): verify_unary("Elu", [32, 32]) +def test_hardsigmoid(): + verify_unary("HardSigmoid", [32, 32]) + verify_unary("HardSigmoid", [32, 32], attrs={"alpha": 0.3, "beta": 0.4}) + verify_unary("HardSigmoid", [1, 3, 20, 20], attrs={"alpha": 0.5, "beta": 0.6}) + + def test_hardswish(): verify_unary("HardSwish", [32, 32])