From 5e29f11989de2f889109c0b45e2a2bfa26a6e31e Mon Sep 17 00:00:00 2001 From: HongHongHongL Date: Tue, 19 Nov 2024 19:13:27 +0800 Subject: [PATCH 1/7] add auto_pad support for conv --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 32 +++++++++++- tests/python/relax/test_frontend_onnx.py | 52 +++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index b64e87822a0a..ff29bbc71750 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -49,6 +49,8 @@ from tvm.ir.supply import NameSupply from tvm.tir.generic import cast +from ..common import autopad + def get_type(elem_type: Union[str, int]) -> str: """Converts onnx integer datatype to numpy datatype""" @@ -1208,11 +1210,15 @@ class Conv(OnnxOpConverter): @classmethod def _impl_v11(cls, bb, inputs, attr, params): + data = inputs[0] if hasattr(inputs[0].struct_info, "ndim"): ndim = inputs[0].struct_info.ndim else: ndim = len(inputs[0].struct_info.shape) + if "kernel_shape" not in attr: + attr["kernel_shape"] = inputs[1].struct_info.shape.values[2:] + if ndim == 3: op = relax.op.nn.conv1d data_layout = "NCW" @@ -1227,10 +1233,34 @@ def _impl_v11(cls, bb, inputs, attr, params): kernel_layout = "OIDHW" else: raise NotImplementedError("Ndim > 5 not supported for convolution.") + + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): + data = autopad( + bb, + inputs[0], + attr.get("strides", [1] * (ndim - 2)), + attr["kernel_shape"], + attr.get("dilations", [1] * (ndim - 2)), + mode=attr["auto_pad"], + deconv=False, + ) + elif attr["auto_pad"] == "VALID": + attr["pads"] = [0 for _ in range(ndim - 2)] + elif attr["auto_pad"] == "NOTSET": + pass + else: + msg = ( + f'Value {attr["auto_pad"]} in attribute "auto_pad" of operator Conv ' + f"is invalid." + ) + raise tvm.error.OpAttributeInvalid(msg) + attr.pop("auto_pad") conv_out = bb.normalize( op( - data=inputs[0], + data=data, weight=inputs[1], strides=attr.get("strides", 1), padding=attr.get("pads", 0), diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 89f08e5af91f..991373b1d386 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1022,6 +1022,58 @@ def _verify_conv(input_shape, weight_shape): _verify_conv([3, 4, 32, 32, 32], [2, 4, 3, 3, 3]) # group=2 +@pytest.mark.parametrize("stride", [1, 2]) +@pytest.mark.parametrize("dilation", [1]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("auto_pad", ["SAME_UPPER", "SAME_LOWER", "VALID"]) +def test_conv_auto_pad(stride: int, dilation: int, bias: bool, auto_pad: str): + def _verify_conv(input_shape, weight_shape): + nd = len(weight_shape) - 2 + if auto_pad == "VALID": + output_shape = [input_shape[0], weight_shape[0]] + [ + (input_shape[i] - dilation * (weight_shape[i] - 1) - 1) // stride + 1 + for i in range(2, len(input_shape)) + ] + else: + output_shape = [input_shape[0], weight_shape[0]] + [ + (input_shape[i] + stride - 1) // stride + for i in range(2, len(input_shape)) + ] + bias_shape = [output_shape[1]] + conv_node = helper.make_node( + "Conv", + inputs=["x", "w"] + (["b"] if bias else []), + outputs=["y"], + strides=[stride] * nd, + dilations=[dilation] * nd, + auto_pad=auto_pad, + group=input_shape[1] // weight_shape[1], + ) + graph = helper.make_graph( + [conv_node], + "conv_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, input_shape), + helper.make_tensor_value_info("w", TensorProto.FLOAT, weight_shape), + ] + + ([helper.make_tensor_value_info("b", TensorProto.FLOAT, bias_shape)] if bias else []), + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)], + ) + + model = helper.make_model(graph, producer_name="conv_test") + check_correctness(model, atol=1e-4) + + # Conv1D + _verify_conv([3, 4, 32], [4, 4, 3]) + _verify_conv([3, 4, 32], [2, 4, 3]) # group=2 + # Conv2D + _verify_conv([3, 4, 32, 32], [4, 4, 3, 3]) + _verify_conv([3, 4, 32, 32], [2, 4, 3, 3]) # group=2 + # Conv3D + _verify_conv([3, 4, 32, 32, 32], [4, 4, 3, 3, 3]) + _verify_conv([3, 4, 32, 32, 32], [2, 4, 3, 3, 3]) # group=2 + + @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("dilation", [1]) @pytest.mark.parametrize("bias", [True, False]) From 72e24ad0a841b48b740dc1b24921082d4b0cb0cf Mon Sep 17 00:00:00 2001 From: Honglin Zhu Date: Wed, 20 Nov 2024 11:03:27 +0800 Subject: [PATCH 2/7] Update test_frontend_onnx.py --- tests/python/relax/test_frontend_onnx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 991373b1d386..a4696a01db77 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1036,8 +1036,7 @@ def _verify_conv(input_shape, weight_shape): ] else: output_shape = [input_shape[0], weight_shape[0]] + [ - (input_shape[i] + stride - 1) // stride - for i in range(2, len(input_shape)) + (input_shape[i] + stride - 1) // stride for i in range(2, len(input_shape)) ] bias_shape = [output_shape[1]] conv_node = helper.make_node( From 6ea1f685eabdbd57ad52d52c17503dd785d20bef Mon Sep 17 00:00:00 2001 From: Honglin Zhu Date: Wed, 20 Nov 2024 11:04:07 +0800 Subject: [PATCH 3/7] Update onnx_frontend.py --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index ff29bbc71750..9e0f5a060c6f 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1233,7 +1233,7 @@ def _impl_v11(cls, bb, inputs, attr, params): kernel_layout = "OIDHW" else: raise NotImplementedError("Ndim > 5 not supported for convolution.") - + if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): From cdee5e06c9136fb5c840ccbd95e8290a65947fc3 Mon Sep 17 00:00:00 2001 From: HongHongHongL Date: Wed, 20 Nov 2024 15:18:35 +0800 Subject: [PATCH 4/7] add common.py --- python/tvm/relax/frontend/common.py | 72 +++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/python/tvm/relax/frontend/common.py b/python/tvm/relax/frontend/common.py index bbd0c55aac2e..999da0893146 100644 --- a/python/tvm/relax/frontend/common.py +++ b/python/tvm/relax/frontend/common.py @@ -16,9 +16,11 @@ # under the License. # pylint: disable=invalid-name """Commons for Relax frontend.""" +import numpy as _np from typing import Dict, List, Tuple import tvm +from tvm import topi def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.nd.NDArray]]]: @@ -53,3 +55,73 @@ def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.n else: detached_mod[gv] = func return detached_mod, params_dict + + +def autopad( + bb, + data, + strides, + kernel_shape, + dilations=(1, 1), + pad_type="constant", + deconv=False, + mode="SAME_UPPER", + pad_value=0.0, +): + """ + Perform autopadding with dynamic input shapes + """ + # get attributes as constants + strides = _np.array(strides) + dilated_kernel_shape = _np.array( + [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] + ) + # get input shape + ndim = data.struct_info.ndim + data_shape = [s for s in data.struct_info.shape] + shape = data_shape[2:ndim] + + # set up integer constants + zero = 0 + one = 1 + two = 2 + + # Calculate total padding + mod = shape % strides + + left = _np.maximum(dilated_kernel_shape - strides, zero) + right = _np.maximum(dilated_kernel_shape - mod, zero) + + total_pad = _np.where(_np.equal(mod, zero), left, right) + if deconv: + total_pad = _np.array(kernel_shape) - one - total_pad + + # split total padding into before and after + pad_before = _np.floor_divide(total_pad, two) + pad_after = total_pad - pad_before + + # combine + if "LOWER" in mode: + pad = _np.concatenate( + [_np.reshape(pad_after, [-1, 1]), _np.reshape(pad_before, [-1, 1])], axis=1 + ) + else: + pad = _np.concatenate( + [_np.reshape(pad_before, [-1, 1]), _np.reshape(pad_after, [-1, 1])], axis=1 + ) + + # pad N and C with zeros + pad = _np.concatenate([_np.zeros([2, 2], dtype="int64"), pad], axis=0) + + if not pad_type in ["constant", "edge", "reflect"]: + raise tvm.error.OpAttributeInvalid( + "Value " + pad_type + ' in attribute "mode" is invalid for operator Pad.' + ) + + if pad_type == "constant": + return bb.emit_te(topi.nn.pad, data, pad[:,0].tolist(), pad[:,1].tolist(), pad_value) + elif pad_type == "reflect": + return bb.emit_te(topi.nn.mirror_pad, data, pad[:,0].tolist(), pad[:,1].tolist(), "REFLECT") + else: + # TODO(gigiblender) Support edge mode. + raise NotImplementedError("Pad mode {} not implemented".format(pad_type)) From a697619e1797d830c3f0688d4a1c9994c55a7277 Mon Sep 17 00:00:00 2001 From: HongHongHongL Date: Wed, 20 Nov 2024 15:32:26 +0800 Subject: [PATCH 5/7] reformat common.py --- python/tvm/relax/frontend/common.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/common.py b/python/tvm/relax/frontend/common.py index 999da0893146..649bb0dd168c 100644 --- a/python/tvm/relax/frontend/common.py +++ b/python/tvm/relax/frontend/common.py @@ -74,8 +74,8 @@ def autopad( # get attributes as constants strides = _np.array(strides) dilated_kernel_shape = _np.array( - [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] - ) + [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] + ) # get input shape ndim = data.struct_info.ndim data_shape = [s for s in data.struct_info.shape] @@ -112,16 +112,18 @@ def autopad( # pad N and C with zeros pad = _np.concatenate([_np.zeros([2, 2], dtype="int64"), pad], axis=0) - + if not pad_type in ["constant", "edge", "reflect"]: raise tvm.error.OpAttributeInvalid( "Value " + pad_type + ' in attribute "mode" is invalid for operator Pad.' ) if pad_type == "constant": - return bb.emit_te(topi.nn.pad, data, pad[:,0].tolist(), pad[:,1].tolist(), pad_value) + return bb.emit_te(topi.nn.pad, data, pad[:, 0].tolist(), pad[:, 1].tolist(), pad_value) elif pad_type == "reflect": - return bb.emit_te(topi.nn.mirror_pad, data, pad[:,0].tolist(), pad[:,1].tolist(), "REFLECT") + return bb.emit_te( + topi.nn.mirror_pad, data, pad[:, 0].tolist(), pad[:, 1].tolist(), "REFLECT" + ) else: # TODO(gigiblender) Support edge mode. - raise NotImplementedError("Pad mode {} not implemented".format(pad_type)) + raise NotImplementedError("Pad mode {} not implemented".format(pad_type)) From 686c9b949d2f6a4c32e264ae6a266f2f8a0e02eb Mon Sep 17 00:00:00 2001 From: HongHongHongL Date: Wed, 20 Nov 2024 15:50:01 +0800 Subject: [PATCH 6/7] reformat common.py --- python/tvm/relax/frontend/common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/common.py b/python/tvm/relax/frontend/common.py index 649bb0dd168c..ba2960c159fc 100644 --- a/python/tvm/relax/frontend/common.py +++ b/python/tvm/relax/frontend/common.py @@ -16,8 +16,8 @@ # under the License. # pylint: disable=invalid-name """Commons for Relax frontend.""" -import numpy as _np from typing import Dict, List, Tuple +import numpy as _np import tvm from tvm import topi @@ -78,7 +78,7 @@ def autopad( ) # get input shape ndim = data.struct_info.ndim - data_shape = [s for s in data.struct_info.shape] + data_shape = list(data.struct_info.shape) shape = data_shape[2:ndim] # set up integer constants @@ -113,7 +113,7 @@ def autopad( # pad N and C with zeros pad = _np.concatenate([_np.zeros([2, 2], dtype="int64"), pad], axis=0) - if not pad_type in ["constant", "edge", "reflect"]: + if pad_type not in ["constant", "edge", "reflect"]: raise tvm.error.OpAttributeInvalid( "Value " + pad_type + ' in attribute "mode" is invalid for operator Pad.' ) From e217694307ccf12a07102b2c9b429b7690e5f612 Mon Sep 17 00:00:00 2001 From: HongHongHongL Date: Wed, 20 Nov 2024 19:01:24 +0800 Subject: [PATCH 7/7] combine test into test_conv --- tests/python/relax/test_frontend_onnx.py | 97 ++++++++++-------------- 1 file changed, 40 insertions(+), 57 deletions(-) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index a4696a01db77..4cd4704ac0be 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -980,53 +980,8 @@ def test_shrink(): @pytest.mark.parametrize("dilation", [1, 2]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("pad", [0, 2]) -def test_conv(stride: int, dilation: int, pad: int, bias: bool): - def _verify_conv(input_shape, weight_shape): - nd = len(weight_shape) - 2 - output_shape = [input_shape[0], weight_shape[0]] + [ - (input_shape[i] + 2 * pad - dilation * (weight_shape[i] - 1) - 1) // stride + 1 - for i in range(2, len(input_shape)) - ] - bias_shape = [output_shape[1]] - conv_node = helper.make_node( - "Conv", - inputs=["x", "w"] + (["b"] if bias else []), - outputs=["y"], - strides=[stride] * nd, - dilations=[dilation] * nd, - pads=[pad] * nd * 2, - group=input_shape[1] // weight_shape[1], - ) - graph = helper.make_graph( - [conv_node], - "conv_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, input_shape), - helper.make_tensor_value_info("w", TensorProto.FLOAT, weight_shape), - ] - + ([helper.make_tensor_value_info("b", TensorProto.FLOAT, bias_shape)] if bias else []), - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)], - ) - - model = helper.make_model(graph, producer_name="conv_test") - check_correctness(model, atol=1e-4) - - # Conv1D - _verify_conv([3, 4, 32], [4, 4, 3]) - _verify_conv([3, 4, 32], [2, 4, 3]) # group=2 - # Conv2D - _verify_conv([3, 4, 32, 32], [4, 4, 3, 3]) - _verify_conv([3, 4, 32, 32], [2, 4, 3, 3]) # group=2 - # Conv3D - _verify_conv([3, 4, 32, 32, 32], [4, 4, 3, 3, 3]) - _verify_conv([3, 4, 32, 32, 32], [2, 4, 3, 3, 3]) # group=2 - - -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [1]) -@pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("auto_pad", ["SAME_UPPER", "SAME_LOWER", "VALID"]) -def test_conv_auto_pad(stride: int, dilation: int, bias: bool, auto_pad: str): +def test_conv(stride: int, dilation: int, pad: int, bias: bool, auto_pad: str): def _verify_conv(input_shape, weight_shape): nd = len(weight_shape) - 2 if auto_pad == "VALID": @@ -1034,20 +989,48 @@ def _verify_conv(input_shape, weight_shape): (input_shape[i] - dilation * (weight_shape[i] - 1) - 1) // stride + 1 for i in range(2, len(input_shape)) ] - else: + bias_shape = [output_shape[1]] + conv_node = helper.make_node( + "Conv", + inputs=["x", "w"] + (["b"] if bias else []), + outputs=["y"], + strides=[stride] * nd, + dilations=[dilation] * nd, + auto_pad=auto_pad, + group=input_shape[1] // weight_shape[1], + ) + elif auto_pad in ("SAME_UPPER", "SAME_LOWER"): + if dilation == 2: + # auto_pad = "SAME" and dilation = 2 is not supported in ONNX + return output_shape = [input_shape[0], weight_shape[0]] + [ (input_shape[i] + stride - 1) // stride for i in range(2, len(input_shape)) ] - bias_shape = [output_shape[1]] - conv_node = helper.make_node( - "Conv", - inputs=["x", "w"] + (["b"] if bias else []), - outputs=["y"], - strides=[stride] * nd, - dilations=[dilation] * nd, - auto_pad=auto_pad, - group=input_shape[1] // weight_shape[1], - ) + bias_shape = [output_shape[1]] + conv_node = helper.make_node( + "Conv", + inputs=["x", "w"] + (["b"] if bias else []), + outputs=["y"], + strides=[stride] * nd, + dilations=[dilation] * nd, + auto_pad=auto_pad, + group=input_shape[1] // weight_shape[1], + ) + else: + output_shape = [input_shape[0], weight_shape[0]] + [ + (input_shape[i] + 2 * pad - dilation * (weight_shape[i] - 1) - 1) // stride + 1 + for i in range(2, len(input_shape)) + ] + bias_shape = [output_shape[1]] + conv_node = helper.make_node( + "Conv", + inputs=["x", "w"] + (["b"] if bias else []), + outputs=["y"], + strides=[stride] * nd, + dilations=[dilation] * nd, + pads=[pad] * nd * 2, + group=input_shape[1] // weight_shape[1], + ) graph = helper.make_graph( [conv_node], "conv_test",