From 0ddf7c81c8fec9ff583fc9f74093b71a9901087b Mon Sep 17 00:00:00 2001 From: An Wang Date: Wed, 1 Sep 2021 13:30:22 -0700 Subject: [PATCH 01/10] add qlinearconcat op --- python/tvm/relay/frontend/onnx.py | 55 +++++++++++++++------- tests/python/frontend/onnx/test_forward.py | 34 +++++++++++++ 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f9b49204b85e..6490aebf22b6 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -195,6 +195,17 @@ def _dim_check(attrs): return _dim_check, "Only 1d, 2d and 3d kernel supported." +def get_scalar(x, dtype="float32"): + """Helper to get a scalar value for Quantized operators.""" + if isinstance(x, _expr.Var) and x.name_hint in params: + return _op.const(params[x.name_hint].numpy(), dtype) + rank = len(infer_shape(x)) + assert rank <= 1, "scale and zero_point input must be scalars" + if rank == 1: + x = _op.squeeze(x, [0]) + return _op.cast(x, dtype) + + class OnnxOpConverter(object): """A helper class for holding onnx op converters.""" @@ -3135,15 +3146,6 @@ class QLinearConv(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): - def get_scalar(x, dtype="float32"): - if isinstance(x, _expr.Var) and x.name_hint in params: - return _op.const(params[x.name_hint].numpy(), dtype) - rank = len(infer_shape(x)) - assert rank <= 1, "QLinearConv scale and zero_point input must be scalars" - if rank == 1: - x = _op.squeeze(x, [0]) - return _op.cast(x, dtype) - data = inputs[0] x_scale = get_scalar(inputs[1]) x_zero_point = get_scalar(inputs[2], "int32") @@ -3239,15 +3241,6 @@ class QLinearAdd(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): - def get_scalar(x, dtype="float32"): - if isinstance(x, _expr.Var) and x.name_hint in params: - return _op.const(params[x.name_hint].numpy(), dtype) - rank = len(infer_shape(x)) - assert rank <= 1, "QLinearConv scale and zero_point input must be scalars" - if rank == 1: - x = _op.squeeze(x, [0]) - return _op.cast(x, dtype) - a = inputs[0] a_scale = get_scalar(inputs[1]) a_zero_point = get_scalar(inputs[2], "int32") @@ -3306,6 +3299,32 @@ def get_scalar(x, dtype="float32"): return _qnn.op.quantize(out, y_scale, y_zero_point, out_dtype=dtype) +class QLinearConcat(OnnxOpConverter): + """Operator converter for QLinearConcat from Microsoft onnxruntime contrib opset.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + # which axis to concat on + axis = attr["axis"] + + y_scale = inputs[0] + y_zero_point = get_scalar(inputs[1], "int8") + + # input tensors, scales, zero_points + assert ( + len(inputs) % 3 == 2 + ), "Additional number of inputs beyond y_scale, y_zero_point for QLinearConcat must be a multiple of 3, indicating a complete input tensor/scale/zero_point trifecta" + tensors = [] + scales = [] + zero_points = [] + for i in range(2, len(inputs), 3): + tensors.append(inputs[i]) + scales.append(get_scalar(inputs[i + 1])) + zero_points.append(get_scalar(inputs[i + 2], "int8")) + + return _qnn.op.concatenate(tensors, scales, zero_points, y_scale, y_zero_point, axis) + + class ConvInteger(OnnxOpConverter): """Operator converter for ConvInteger.""" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a1d821686ed5..84e4e3907ea3 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5266,6 +5266,39 @@ def repeat(N, D): ) +@tvm.testing.parametrize_targets +def test_qlinearconcat(target, dev): + def verify_qlinearconcat(shapes, out_shape, axis=None): + input_names = [] + input_values = [] + input_nodes = [] + for i in range(len(shapes)): + tensor_name = str(chr(ord("a") + i)) + shape = shapes[i] + node = helper.make_tensor_value_info(tensor_name, TensorProto.FLOAT, list(shape)) + + input_names.append(tensor_name) + input_values.append(np.random.random(shape).astype("float32")) + input_nodes.append(node) + + node = helper.make_node("Concat", input_names, ["OUT"]) + if axis is not None: + axis_attr = helper.make_attribute("axis", axis) + node.attribute.append(axis_attr) + graph = helper.make_graph( + [node], + "qlinearconcat_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="qlinearconcat_test") + quantize_and_verify_with_ort(model, input_names, shapes, target, dev) + + verify_qlinearconcat([[2, 1], [2, 1]], [4, 1], 0) + verify_qlinearconcat([[2, 1], [2, 1]], [2, 2], 1) + verify_qlinearconcat([[1, 2, 2], [1, 2, 3]], [1, 2, 5], 2) + + @tvm.testing.parametrize_targets def test_qlinearadd(target, dev): def verify_qlinearadd(a_shape, b_shape, c_shape): @@ -5623,6 +5656,7 @@ def repeat(N, D): test_aten() test_reverse_sequence() test_eyelike() + test_qlinearconcat() test_qlinearconv() test_random_uniform() test_convinteger() From db51c1c37fca9bedbe8a9157c6317363572434b3 Mon Sep 17 00:00:00 2001 From: An Wang Date: Wed, 1 Sep 2021 15:09:48 -0700 Subject: [PATCH 02/10] fix tests --- python/tvm/relay/frontend/onnx.py | 58 ++++++++++------------ tests/python/frontend/onnx/test_forward.py | 7 +-- 2 files changed, 29 insertions(+), 36 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6490aebf22b6..c46804eb7dd5 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -195,7 +195,7 @@ def _dim_check(attrs): return _dim_check, "Only 1d, 2d and 3d kernel supported." -def get_scalar(x, dtype="float32"): +def get_scalar(x, params, dtype="float32"): """Helper to get a scalar value for Quantized operators.""" if isinstance(x, _expr.Var) and x.name_hint in params: return _op.const(params[x.name_hint].numpy(), dtype) @@ -3147,13 +3147,13 @@ class QLinearConv(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): data = inputs[0] - x_scale = get_scalar(inputs[1]) - x_zero_point = get_scalar(inputs[2], "int32") + x_scale = get_scalar(inputs[1], params) + x_zero_point = get_scalar(inputs[2], params, "int32") weight = inputs[3] - w_scale = get_scalar(inputs[4]) - w_zero_point = get_scalar(inputs[5], "int32") - y_scale = fold_constant(get_scalar(inputs[6])) - y_zero_point = get_scalar(inputs[7], "int32") + w_scale = get_scalar(inputs[4], params) + w_zero_point = get_scalar(inputs[5], params, "int32") + y_scale = fold_constant(get_scalar(inputs[6], params)) + y_zero_point = get_scalar(inputs[7], params, "int32") input_shape = infer_shape(data) @@ -3242,13 +3242,13 @@ class QLinearAdd(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): a = inputs[0] - a_scale = get_scalar(inputs[1]) - a_zero_point = get_scalar(inputs[2], "int32") + a_scale = get_scalar(inputs[1], params) + a_zero_point = get_scalar(inputs[2], params, "int32") b = inputs[3] - b_scale = get_scalar(inputs[4]) - b_zero_point = get_scalar(inputs[5], "int32") - c_scale = get_scalar(inputs[6]) - c_zero_point = get_scalar(inputs[7], "int32") + b_scale = get_scalar(inputs[4], params) + b_zero_point = get_scalar(inputs[5], params, "int32") + c_scale = get_scalar(inputs[6], params) + c_zero_point = get_scalar(inputs[7], params, "int32") dtype = infer_type(a).checked_type.dtype @@ -3270,23 +3270,14 @@ class QLinearMul(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): - def get_scalar(x, dtype="float32"): - if isinstance(x, _expr.Var) and x.name_hint in params: - return _op.const(params[x.name_hint].numpy(), dtype) - rank = len(infer_shape(x)) - assert rank <= 1, "QLinearMul scale and zero_point input must be scalars" - if rank == 1: - x = _op.squeeze(x, [0]) - return _op.cast(x, dtype) - a = inputs[0] - a_scale = get_scalar(inputs[1]) - a_zero_point = get_scalar(inputs[2], "int32") + a_scale = get_scalar(inputs[1], params) + a_zero_point = get_scalar(inputs[2], params, "int32") b = inputs[3] - b_scale = get_scalar(inputs[4]) - b_zero_point = get_scalar(inputs[5], "int32") - y_scale = fold_constant(get_scalar(inputs[6])) - y_zero_point = get_scalar(inputs[7], "int32") + b_scale = get_scalar(inputs[4], params) + b_zero_point = get_scalar(inputs[5], params, "int32") + y_scale = fold_constant(get_scalar(inputs[6], params)) + y_zero_point = get_scalar(inputs[7], params, "int32") dtype = infer_type(a).checked_type.dtype @@ -3307,20 +3298,20 @@ def _impl_v1(cls, inputs, attr, params): # which axis to concat on axis = attr["axis"] - y_scale = inputs[0] - y_zero_point = get_scalar(inputs[1], "int8") + y_scale = fold_constant(get_scalar(inputs[0], params)) + y_zero_point = get_scalar(inputs[1], params, "int32") # input tensors, scales, zero_points assert ( len(inputs) % 3 == 2 - ), "Additional number of inputs beyond y_scale, y_zero_point for QLinearConcat must be a multiple of 3, indicating a complete input tensor/scale/zero_point trifecta" + ), "Additional number of inputs beyond y_scale, y_zero_point for QLinearConcat must be a multiple of 3, indicating complete input tensor/scale/zero_point tubles" tensors = [] scales = [] zero_points = [] for i in range(2, len(inputs), 3): tensors.append(inputs[i]) - scales.append(get_scalar(inputs[i + 1])) - zero_points.append(get_scalar(inputs[i + 2], "int8")) + scales.append(get_scalar(inputs[i + 1], params)) + zero_points.append(get_scalar(inputs[i + 2], params, "int32")) return _qnn.op.concatenate(tensors, scales, zero_points, y_scale, y_zero_point, axis) @@ -3650,6 +3641,7 @@ def _get_convert_map(opset): "DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset), "ReverseSequence": ReverseSequence.get_converter(opset), "QLinearConv": QLinearConv.get_converter(opset), + "QLinearConcat": QLinearConcat.get_converter(opset), "QLinearAdd": QLinearAdd.get_converter(opset), "QLinearMul": QLinearMul.get_converter(opset), "ConvInteger": ConvInteger.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 84e4e3907ea3..316a8045ffb2 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5281,7 +5281,7 @@ def verify_qlinearconcat(shapes, out_shape, axis=None): input_values.append(np.random.random(shape).astype("float32")) input_nodes.append(node) - node = helper.make_node("Concat", input_names, ["OUT"]) + node = helper.make_node("Concat", input_names, ["C"]) if axis is not None: axis_attr = helper.make_attribute("axis", axis) node.attribute.append(axis_attr) @@ -5289,14 +5289,15 @@ def verify_qlinearconcat(shapes, out_shape, axis=None): [node], "qlinearconcat_test", inputs=input_nodes, - outputs=[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, list(out_shape))], + outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(out_shape))], ) + breakpoint() model = helper.make_model(graph, producer_name="qlinearconcat_test") quantize_and_verify_with_ort(model, input_names, shapes, target, dev) verify_qlinearconcat([[2, 1], [2, 1]], [4, 1], 0) verify_qlinearconcat([[2, 1], [2, 1]], [2, 2], 1) - verify_qlinearconcat([[1, 2, 2], [1, 2, 3]], [1, 2, 5], 2) + verify_qlinearconcat([[1, 2], [2, 2], [3,2]], [6, 2], 0) @tvm.testing.parametrize_targets From 12d63335c816ecf28db57e16f2b2682fff0f7993 Mon Sep 17 00:00:00 2001 From: An Wang Date: Wed, 1 Sep 2021 16:52:08 -0700 Subject: [PATCH 03/10] Fix --- python/tvm/relay/frontend/onnx.py | 4 +++- tests/python/frontend/onnx/test_forward.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c46804eb7dd5..f2f413bf5f21 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3301,10 +3301,12 @@ def _impl_v1(cls, inputs, attr, params): y_scale = fold_constant(get_scalar(inputs[0], params)) y_zero_point = get_scalar(inputs[1], params, "int32") + out_dtype = infer_type(inputs[1]).checked_type.dtype + # input tensors, scales, zero_points assert ( len(inputs) % 3 == 2 - ), "Additional number of inputs beyond y_scale, y_zero_point for QLinearConcat must be a multiple of 3, indicating complete input tensor/scale/zero_point tubles" + ), "Additional number of inputs beyond y_scale, y_zero_point for QLinearConcat must be a multiple of 3, indicating complete input tensor/scale/zero_point tuples" tensors = [] scales = [] zero_points = [] diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 316a8045ffb2..1ce17a34b6e0 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5291,7 +5291,6 @@ def verify_qlinearconcat(shapes, out_shape, axis=None): inputs=input_nodes, outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(out_shape))], ) - breakpoint() model = helper.make_model(graph, producer_name="qlinearconcat_test") quantize_and_verify_with_ort(model, input_names, shapes, target, dev) From d030da1a2ef745d7c4b49c1ca96c40b513d7b2c5 Mon Sep 17 00:00:00 2001 From: An Wang Date: Thu, 2 Sep 2021 10:40:54 -0700 Subject: [PATCH 04/10] lint --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1ce17a34b6e0..0abda3474f0b 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5296,7 +5296,7 @@ def verify_qlinearconcat(shapes, out_shape, axis=None): verify_qlinearconcat([[2, 1], [2, 1]], [4, 1], 0) verify_qlinearconcat([[2, 1], [2, 1]], [2, 2], 1) - verify_qlinearconcat([[1, 2], [2, 2], [3,2]], [6, 2], 0) + verify_qlinearconcat([[1, 2], [2, 2], [3, 2]], [6, 2], 0) @tvm.testing.parametrize_targets From 01f3b2e112199d25a7cec65cd54ca942297b1655 Mon Sep 17 00:00:00 2001 From: An Wang Date: Thu, 2 Sep 2021 10:53:58 -0700 Subject: [PATCH 05/10] lint --- python/tvm/relay/frontend/onnx.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f2f413bf5f21..dc0411f918ef 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3301,12 +3301,10 @@ def _impl_v1(cls, inputs, attr, params): y_scale = fold_constant(get_scalar(inputs[0], params)) y_zero_point = get_scalar(inputs[1], params, "int32") - out_dtype = infer_type(inputs[1]).checked_type.dtype - # input tensors, scales, zero_points assert ( len(inputs) % 3 == 2 - ), "Additional number of inputs beyond y_scale, y_zero_point for QLinearConcat must be a multiple of 3, indicating complete input tensor/scale/zero_point tuples" + ), "Additional input count must be a multiple of 3 -- tensor/scale/zero_point tuples" tensors = [] scales = [] zero_points = [] From 77084dcff5138168a78b1e266db809c3dabd5e3c Mon Sep 17 00:00:00 2001 From: An Wang Date: Thu, 2 Sep 2021 11:35:44 -0700 Subject: [PATCH 06/10] review --- python/tvm/relay/frontend/onnx.py | 3 ++- tests/python/frontend/onnx/test_forward.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index dc0411f918ef..f1c69ec36af5 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -197,7 +197,8 @@ def _dim_check(attrs): def get_scalar(x, params, dtype="float32"): """Helper to get a scalar value for Quantized operators.""" - if isinstance(x, _expr.Var) and x.name_hint in params: + if isinstance(x, _expr.Var): + assert x.name_hint in params, "Var should be found in params lookup" return _op.const(params[x.name_hint].numpy(), dtype) rank = len(infer_shape(x)) assert rank <= 1, "scale and zero_point input must be scalars" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0abda3474f0b..01f02543407f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5273,7 +5273,7 @@ def verify_qlinearconcat(shapes, out_shape, axis=None): input_values = [] input_nodes = [] for i in range(len(shapes)): - tensor_name = str(chr(ord("a") + i)) + tensor_name = chr(ord("a") + i) shape = shapes[i] node = helper.make_tensor_value_info(tensor_name, TensorProto.FLOAT, list(shape)) From 83ebf69fefc2387ce8af2ff122c0137d9fa5f2ea Mon Sep 17 00:00:00 2001 From: An Wang Date: Thu, 2 Sep 2021 13:57:34 -0700 Subject: [PATCH 07/10] boop ci From c75b2e7ac8eac8a076a56099032cc5f22d738385 Mon Sep 17 00:00:00 2001 From: An Wang Date: Fri, 3 Sep 2021 13:43:46 -0700 Subject: [PATCH 08/10] fix regression --- python/tvm/relay/frontend/onnx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f1c69ec36af5..dc0411f918ef 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -197,8 +197,7 @@ def _dim_check(attrs): def get_scalar(x, params, dtype="float32"): """Helper to get a scalar value for Quantized operators.""" - if isinstance(x, _expr.Var): - assert x.name_hint in params, "Var should be found in params lookup" + if isinstance(x, _expr.Var) and x.name_hint in params: return _op.const(params[x.name_hint].numpy(), dtype) rank = len(infer_shape(x)) assert rank <= 1, "scale and zero_point input must be scalars" From 540d388b6dfdbbfdbab5f7cfc6a1ed0305dff224 Mon Sep 17 00:00:00 2001 From: An Wang Date: Fri, 3 Sep 2021 16:11:56 -0700 Subject: [PATCH 09/10] noop From 660a0fb90727d43b779bb8044d40d352a095ff5f Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 7 Sep 2021 12:07:46 -0700 Subject: [PATCH 10/10] jostle ci