From 4994209de500e5c2766fd6e91add522552fb8665 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 12 Jul 2021 22:56:00 +0000 Subject: [PATCH 1/4] Add ConvInteger support and fix some ConvTranspose padding bugs. --- python/tvm/relay/frontend/onnx.py | 107 ++++++++++++-- tests/python/frontend/onnx/test_forward.py | 160 ++++++++++++++++++++- 2 files changed, 249 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f876b1d14fa1..6e7ec54e139c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -391,7 +391,6 @@ def autopad( dilations, ndim, pad_type="constant", - deconv=False, mode="SAME_UPPER", pad_value=0.0, ): @@ -421,8 +420,6 @@ def autopad( right = _op.maximum(dilated_kernel_shape - mod, zero) total_pad = _op.where(_op.equal(mod, zero), left, right) - if deconv: - total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad # split total padding into before and after pad_before = _op.floor_divide(total_pad, two) @@ -441,7 +438,10 @@ def autopad( # pad N and C with zeros pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) - return _op.nn.pad(data, fold_constant(pad), _op.const(pad_value), pad_type) + if not isinstance(pad_value, _expr.Var): + pad_value = _op.const(pad_value) + + return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) class Conv(OnnxOpConverter): @@ -545,17 +545,20 @@ def _impl_v1(cls, inputs, attr, params): if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - # Warning: Convolution does not yet support dynamic shapes, - # one will need to run dynamic_to_static on this model after import - data = autopad( - data, - attr.get("strides", [1] * (ndim - 2)), - attr["kernel_shape"], - attr.get("dilations", [1] * (ndim - 2)), - ndim, - deconv=True, - mode=attr["auto_pad"], - ) + strides = attr.get("strides", [1] * (ndim - 2)) + kernel_shape = attr["kernel_shape"] + dilations = attr.get("dilations", [1] * (ndim - 2)) + dilated_kernel_shape = [ + (kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations) + ] + pads = [k - s for k, s in zip(dilated_kernel_shape, strides)] + # Convert total padding into separate padding. + left = np.floor_divide(pads, 2) + right = pads - left + if attr["auto_pad"] == "SAME_UPPER": + attr["pads"] = list(right) + list(left) + else: + attr["pads"] = list(left) + list(right) elif attr["auto_pad"] == "VALID": attr["pads"] = tuple([0 for i in range(ndim - 2)]) elif attr["auto_pad"] == "NOTSET": @@ -3193,6 +3196,79 @@ def get_scalar(x, dtype="float32"): return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype) +class ConvInteger(OnnxOpConverter): + """Operator converter for ConvInteger.""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + data = inputs[0] + weight = inputs[1] + data_zp = inputs[2] + weight_zp = inputs[3] + if data_zp is None: + data_zp = _expr.const(0, "int32") + if weight_zp is None: + weight_zp = _expr.const(0, "int32") + + input_type = infer_type(data) + input_shape = get_const_tuple(input_type.checked_type.shape) + + ndim = len(input_shape) + kernel_type = infer_type(weight) + kernel_shape = get_const_tuple(kernel_type.checked_type.shape) + if "kernel_shape" not in attr: + attr["kernel_shape"] = kernel_shape[2:] + + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + data = autopad( + data, + attr.get("strides", [1] * (ndim - 2)), + attr["kernel_shape"], + attr.get("dilations", [1] * (ndim - 2)), + ndim, + pad_value=data_zp, + mode=attr["auto_pad"], + ) + elif attr["auto_pad"] == "VALID": + attr["pads"] = tuple([0 for i in range(ndim - 2)]) + elif attr["auto_pad"] == "NOTSET": + pass + else: + msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) + attr.pop("auto_pad") + + out_channels = kernel_shape[0] + dilation = attr.get("dilations", [1] * (ndim - 2)) + strides = attr.get("strides", [1] * (ndim - 2)) + padding = attr["pads"] if "pads" in attr else 0 + groups = attr["group"] if "group" in attr else 1 + + if ndim != 4: + raise tvm.error.OpAttributeInvalid( + "Only 2D kernels are supported for operator ConvInteger." + ) + + return _qnn.op.conv2d( + data, + weight, + _op.cast(data_zp, "int32"), + _op.cast(weight_zp, "int32"), + _expr.const(1.0, "float32"), + _expr.const(1.0, "float32"), + kernel_size=attr["kernel_shape"], + channels=out_channels, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + ) + + class BitShift(OnnxOpConverter): """Operator converter for NonZero""" @@ -3421,6 +3497,7 @@ def _get_convert_map(opset): "ReverseSequence": ReverseSequence.get_converter(opset), "QLinearConv": QLinearConv.get_converter(opset), "QLinearAdd": QLinearAdd.get_converter(opset), + "ConvInteger": ConvInteger.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c5407697de46..e9fa6201695c 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2729,7 +2729,7 @@ def repeat(N, D): verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(9, D), + (1, 1) + repeat(10, D), None, repeat(3, D), repeat(2, D), @@ -4420,7 +4420,6 @@ def verify_eyelike(indata): onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) unsupported_onnx_tests = [ - "test_basic_convinteger/", "test_cast_DOUBLE_to_FLOAT16/", "test_cast_FLOAT_to_STRING/", "test_cast_STRING_to_FLOAT/", @@ -4428,7 +4427,6 @@ def verify_eyelike(indata): "test_compress_1/", "test_compress_default_axis/", "test_compress_negative_axis/", - "test_convinteger_with_padding/", "test_convtranspose_dilations/", "test_convtranspose_output_shape/", "test_cumsum_1d/", @@ -4872,6 +4870,161 @@ def test_qlinearadd(): verify_qlinearadd([5, 1, 7], [2, 7], [5, 2, 7]) +def verify_convinteger( + x_shape, + w_shape, + y_shape, + padding, + kernel_shape, + strides, + dilations, + auto_pad="NOTSET", + dtype="uint8", +): + + x_array = np.random.randint(low=0, high=255, size=x_shape).astype(dtype) + w_array = np.random.uniform(low=0, high=255, size=w_shape).astype(dtype) + x_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype) + w_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype) + + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + input_nodes = [ + helper.make_tensor_value_info("x", ONNX_DTYPE, list(x_shape)), + helper.make_tensor_value_info("w", ONNX_DTYPE, list(w_shape)), + helper.make_tensor_value_info("x_zero_point", ONNX_DTYPE, []), + helper.make_tensor_value_info("w_zero_point", ONNX_DTYPE, []), + ] + input_names = [ + "x", + "w", + "x_zero_point", + "w_zero_point", + ] + input_values = [x_array, w_array, x_zero_point_array, w_zero_point_array] + + if padding is None: + ## autopadding with unset default attributes + kwargs = {} + if not all([s == 1 for s in strides]): + kwargs["strides"] = strides + if not all([d == 1 for d in dilations]): + kwargs["dilations"] = dilations + + node = helper.make_node( + "ConvInteger", + inputs=input_names, + outputs=["y"], + # Default values for other attributes: + auto_pad=auto_pad, + **kwargs, + ) + else: + node = helper.make_node( + "ConvInteger", + inputs=input_names, + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + # groups=1 + pads=padding, + ) + + graph = helper.make_graph( + [node], + "convinteger_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, list(y_shape))], + ) + model = helper.make_model(graph, producer_name="convinteger_test") + # opt_level=1 will cause error + verify_with_ort_with_inputs(model, input_values, opt_level=2) + + +def test_convinteger(): + def repeat(N, D): + return tuple([N for _ in range(D)]) + + # only support 2D ConvInteger because we only support qnn.conv2d for now. + D = 2 + + # Convolution with padding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + + # Convolution with asymmetric padding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(4, D), + repeat(0, D) + repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution without padding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution with autopadding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with valid autopadding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="VALID", + ) + # Convolution with non uniform stride + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(2, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with dilation + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(2, D), + repeat(3, D), + repeat(1, D), + repeat(2, D), + ) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -4955,3 +5108,4 @@ def test_qlinearadd(): test_reverse_sequence() test_eyelike() test_qlinearconv() + test_convinteger() From e696bf78c8ad98d85c4a62a5e15763698906663d Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 12 Jul 2021 23:11:37 +0000 Subject: [PATCH 2/4] Simplify pads check. --- python/tvm/relay/frontend/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6e7ec54e139c..692b3e1f3478 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -438,7 +438,7 @@ def autopad( # pad N and C with zeros pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) - if not isinstance(pad_value, _expr.Var): + if isinstance(pad_value, int) or isinstance(pad_value, float): pad_value = _op.const(pad_value) return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) From 8176642fe0f55894f0ec2f118b9e8d556b094115 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 12 Jul 2021 23:31:23 +0000 Subject: [PATCH 3/4] Fix style. --- python/tvm/relay/frontend/onnx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 692b3e1f3478..6a15a25135c2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -410,7 +410,6 @@ def autopad( # set up integer constants zero = _op.const(0, dtype="int64") - one = _op.const(1, dtype="int64") two = _op.const(2, dtype="int64") # Calculate total padding @@ -438,7 +437,7 @@ def autopad( # pad N and C with zeros pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) - if isinstance(pad_value, int) or isinstance(pad_value, float): + if isinstance(pad_value, (float, int)): pad_value = _op.const(pad_value) return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) From a0871ce4b05991b6514bf43386cf6b2b197ae986 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 13 Jul 2021 16:42:59 +0000 Subject: [PATCH 4/4] Remove changes to conv_transpose. --- python/tvm/relay/frontend/onnx.py | 29 +++++++++++----------- tests/python/frontend/onnx/test_forward.py | 2 +- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6a15a25135c2..68b51819662c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -391,6 +391,7 @@ def autopad( dilations, ndim, pad_type="constant", + deconv=False, mode="SAME_UPPER", pad_value=0.0, ): @@ -410,6 +411,7 @@ def autopad( # set up integer constants zero = _op.const(0, dtype="int64") + one = _op.const(1, dtype="int64") two = _op.const(2, dtype="int64") # Calculate total padding @@ -419,6 +421,8 @@ def autopad( right = _op.maximum(dilated_kernel_shape - mod, zero) total_pad = _op.where(_op.equal(mod, zero), left, right) + if deconv: + total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad # split total padding into before and after pad_before = _op.floor_divide(total_pad, two) @@ -544,20 +548,17 @@ def _impl_v1(cls, inputs, attr, params): if "auto_pad" in attr: attr["auto_pad"] = attr["auto_pad"].decode("utf-8") if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): - strides = attr.get("strides", [1] * (ndim - 2)) - kernel_shape = attr["kernel_shape"] - dilations = attr.get("dilations", [1] * (ndim - 2)) - dilated_kernel_shape = [ - (kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations) - ] - pads = [k - s for k, s in zip(dilated_kernel_shape, strides)] - # Convert total padding into separate padding. - left = np.floor_divide(pads, 2) - right = pads - left - if attr["auto_pad"] == "SAME_UPPER": - attr["pads"] = list(right) + list(left) - else: - attr["pads"] = list(left) + list(right) + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + data = autopad( + data, + attr.get("strides", [1] * (ndim - 2)), + attr["kernel_shape"], + attr.get("dilations", [1] * (ndim - 2)), + ndim, + deconv=True, + mode=attr["auto_pad"], + ) elif attr["auto_pad"] == "VALID": attr["pads"] = tuple([0 for i in range(ndim - 2)]) elif attr["auto_pad"] == "NOTSET": diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index e9fa6201695c..3ca9f0006a4e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2729,7 +2729,7 @@ def repeat(N, D): verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(10, D), + (1, 1) + repeat(9, D), None, repeat(3, D), repeat(2, D),