diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index a38bcf5bcefa..b05265fa976a 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -617,6 +617,22 @@ def convert_attributes(cls, attrs): return {"value": 1} +class LRN(OpConverter): + """Operator converter for LRN.""" + + @classmethod + def convert_attributes(cls, attrs): + """axis attr is not supported as an argument in onnx. + Onnx only supports axis=1 (channels).""" + if attrs.get_int("axis") != 1: + raise RuntimeError( + "Unsupported axis %s in operator relay lrn operator. " + "Only axis = 1 is supported by Onnx." % (attrs.get_int("axis")) + ) + + return {"alpha": attrs.alpha, "beta": attrs.beta, "bias": attrs.bias, "size": attrs.size} + + relay_to_onnx_op_mapping = { "reshape": Reshape, "nn.conv2d": Conv, @@ -650,6 +666,7 @@ def convert_attributes(cls, attrs): "layout_transform": LayoutTransform, "clip": Clip, "expand_dims": Expand, + "nn.lrn": LRN, } diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index d99946d19d66..eef66f29e27a 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -515,6 +515,21 @@ def verify_expand_dims(dshape, axis, num_newaxis, dtype="float32"): verify_expand_dims((1, 1, 1001), 2, 2) +def test_lrn(): + def verify_lrn(xshape, size, dtype="float32"): + x = relay.var("x", relay.ty.TensorType(xshape, dtype)) + y = relay.nn.lrn(x, size=size, axis=1, alpha=1.0, beta=1.0, bias=1.0) + func = relay.Function([x], y) + x_data = np.random.uniform(size=xshape).astype(dtype) + verify_results(func, [x_data], "test_lrn", rtol=1e-5, atol=1e-5) + + isize = [(1, 1, 480, 640), (1, 3, 224, 224)] + sizes = [1, 3] + for i in isize: + for s in sizes: + verify_lrn(i, s) + + if __name__ == "__main__": test_add() test_bias_add() @@ -538,3 +553,4 @@ def verify_expand_dims(dshape, axis, num_newaxis, dtype="float32"): test_layout_transform() test_clip() test_expand_dims() + test_lrn()