diff --git a/python/tvm/relax/frontend/common.py b/python/tvm/relax/frontend/common.py index bbd0c55aac2e..ba2960c159fc 100644 --- a/python/tvm/relax/frontend/common.py +++ b/python/tvm/relax/frontend/common.py @@ -17,8 +17,10 @@ # pylint: disable=invalid-name """Commons for Relax frontend.""" from typing import Dict, List, Tuple +import numpy as _np import tvm +from tvm import topi def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.nd.NDArray]]]: @@ -53,3 +55,75 @@ def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.n else: detached_mod[gv] = func return detached_mod, params_dict + + +def autopad( + bb, + data, + strides, + kernel_shape, + dilations=(1, 1), + pad_type="constant", + deconv=False, + mode="SAME_UPPER", + pad_value=0.0, +): + """ + Perform autopadding with dynamic input shapes + """ + # get attributes as constants + strides = _np.array(strides) + dilated_kernel_shape = _np.array( + [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] + ) + # get input shape + ndim = data.struct_info.ndim + data_shape = list(data.struct_info.shape) + shape = data_shape[2:ndim] + + # set up integer constants + zero = 0 + one = 1 + two = 2 + + # Calculate total padding + mod = shape % strides + + left = _np.maximum(dilated_kernel_shape - strides, zero) + right = _np.maximum(dilated_kernel_shape - mod, zero) + + total_pad = _np.where(_np.equal(mod, zero), left, right) + if deconv: + total_pad = _np.array(kernel_shape) - one - total_pad + + # split total padding into before and after + pad_before = _np.floor_divide(total_pad, two) + pad_after = total_pad - pad_before + + # combine + if "LOWER" in mode: + pad = _np.concatenate( + [_np.reshape(pad_after, [-1, 1]), _np.reshape(pad_before, [-1, 1])], axis=1 + ) + else: + pad = _np.concatenate( + [_np.reshape(pad_before, [-1, 1]), _np.reshape(pad_after, [-1, 1])], axis=1 + ) + + # pad N and C with zeros + pad = _np.concatenate([_np.zeros([2, 2], dtype="int64"), pad], axis=0) + + if pad_type not in ["constant", "edge", "reflect"]: + raise tvm.error.OpAttributeInvalid( + "Value " + pad_type + ' in attribute "mode" is invalid for operator Pad.' + ) + + if pad_type == "constant": + return bb.emit_te(topi.nn.pad, data, pad[:, 0].tolist(), pad[:, 1].tolist(), pad_value) + elif pad_type == "reflect": + return bb.emit_te( + topi.nn.mirror_pad, data, pad[:, 0].tolist(), pad[:, 1].tolist(), "REFLECT" + ) + else: + # TODO(gigiblender) Support edge mode. + raise NotImplementedError("Pad mode {} not implemented".format(pad_type)) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index b64e87822a0a..9e0f5a060c6f 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -49,6 +49,8 @@ from tvm.ir.supply import NameSupply from tvm.tir.generic import cast +from ..common import autopad + def get_type(elem_type: Union[str, int]) -> str: """Converts onnx integer datatype to numpy datatype""" @@ -1208,11 +1210,15 @@ class Conv(OnnxOpConverter): @classmethod def _impl_v11(cls, bb, inputs, attr, params): + data = inputs[0] if hasattr(inputs[0].struct_info, "ndim"): ndim = inputs[0].struct_info.ndim else: ndim = len(inputs[0].struct_info.shape) + if "kernel_shape" not in attr: + attr["kernel_shape"] = inputs[1].struct_info.shape.values[2:] + if ndim == 3: op = relax.op.nn.conv1d data_layout = "NCW" @@ -1228,9 +1234,33 @@ def _impl_v11(cls, bb, inputs, attr, params): else: raise NotImplementedError("Ndim > 5 not supported for convolution.") + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): + data = autopad( + bb, + inputs[0], + attr.get("strides", [1] * (ndim - 2)), + attr["kernel_shape"], + attr.get("dilations", [1] * (ndim - 2)), + mode=attr["auto_pad"], + deconv=False, + ) + elif attr["auto_pad"] == "VALID": + attr["pads"] = [0 for _ in range(ndim - 2)] + elif attr["auto_pad"] == "NOTSET": + pass + else: + msg = ( + f'Value {attr["auto_pad"]} in attribute "auto_pad" of operator Conv ' + f"is invalid." + ) + raise tvm.error.OpAttributeInvalid(msg) + attr.pop("auto_pad") + conv_out = bb.normalize( op( - data=inputs[0], + data=data, weight=inputs[1], strides=attr.get("strides", 1), padding=attr.get("pads", 0), diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 89f08e5af91f..4cd4704ac0be 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -980,23 +980,57 @@ def test_shrink(): @pytest.mark.parametrize("dilation", [1, 2]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("pad", [0, 2]) -def test_conv(stride: int, dilation: int, pad: int, bias: bool): +@pytest.mark.parametrize("auto_pad", ["SAME_UPPER", "SAME_LOWER", "VALID"]) +def test_conv(stride: int, dilation: int, pad: int, bias: bool, auto_pad: str): def _verify_conv(input_shape, weight_shape): nd = len(weight_shape) - 2 - output_shape = [input_shape[0], weight_shape[0]] + [ - (input_shape[i] + 2 * pad - dilation * (weight_shape[i] - 1) - 1) // stride + 1 - for i in range(2, len(input_shape)) - ] - bias_shape = [output_shape[1]] - conv_node = helper.make_node( - "Conv", - inputs=["x", "w"] + (["b"] if bias else []), - outputs=["y"], - strides=[stride] * nd, - dilations=[dilation] * nd, - pads=[pad] * nd * 2, - group=input_shape[1] // weight_shape[1], - ) + if auto_pad == "VALID": + output_shape = [input_shape[0], weight_shape[0]] + [ + (input_shape[i] - dilation * (weight_shape[i] - 1) - 1) // stride + 1 + for i in range(2, len(input_shape)) + ] + bias_shape = [output_shape[1]] + conv_node = helper.make_node( + "Conv", + inputs=["x", "w"] + (["b"] if bias else []), + outputs=["y"], + strides=[stride] * nd, + dilations=[dilation] * nd, + auto_pad=auto_pad, + group=input_shape[1] // weight_shape[1], + ) + elif auto_pad in ("SAME_UPPER", "SAME_LOWER"): + if dilation == 2: + # auto_pad = "SAME" and dilation = 2 is not supported in ONNX + return + output_shape = [input_shape[0], weight_shape[0]] + [ + (input_shape[i] + stride - 1) // stride for i in range(2, len(input_shape)) + ] + bias_shape = [output_shape[1]] + conv_node = helper.make_node( + "Conv", + inputs=["x", "w"] + (["b"] if bias else []), + outputs=["y"], + strides=[stride] * nd, + dilations=[dilation] * nd, + auto_pad=auto_pad, + group=input_shape[1] // weight_shape[1], + ) + else: + output_shape = [input_shape[0], weight_shape[0]] + [ + (input_shape[i] + 2 * pad - dilation * (weight_shape[i] - 1) - 1) // stride + 1 + for i in range(2, len(input_shape)) + ] + bias_shape = [output_shape[1]] + conv_node = helper.make_node( + "Conv", + inputs=["x", "w"] + (["b"] if bias else []), + outputs=["y"], + strides=[stride] * nd, + dilations=[dilation] * nd, + pads=[pad] * nd * 2, + group=input_shape[1] // weight_shape[1], + ) graph = helper.make_graph( [conv_node], "conv_test",