From 4002fd582d4ca4d0f4b7de5cd6348950d0311f2f Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 17 Aug 2021 15:33:26 -0700 Subject: [PATCH 1/4] add qlinearmatmul --- python/tvm/relay/frontend/onnx.py | 35 +++++++++++ tests/python/frontend/onnx/test_forward.py | 67 ++++++++++++++++++++-- 2 files changed, 97 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4866189ed872..faad4ea2748b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3279,6 +3279,40 @@ def get_scalar(x, dtype="float32"): return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype) +class QLinearMatMul(OnnxOpConverter): + """Operator converter for QLinearMatMul from Microsoft onnxruntime contrib opset.""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + def get_scalar(x, dtype="float32"): + if isinstance(x, _expr.Var) and x.name_hint in params: + return _op.const(params[x.name_hint].numpy(), dtype) + rank = len(infer_shape(x)) + assert rank <= 1, "QLinearMatMul scale and zero_point input must be scalars" + if rank == 1: + x = _op.squeeze(x, [0]) + return _op.cast(x, dtype) + + a = inputs[0] + a_scale = get_scalar(inputs[1]) + a_zero_point = get_scalar(inputs[2], "int32") + b = inputs[3] + b_scale = get_scalar(inputs[4]) + b_zero_point = get_scalar(inputs[5], "int32") + y_scale = fold_constant(get_scalar(inputs[6])) + y_zero_point = get_scalar(inputs[7], "int32") + + dtype = infer_type(a).checked_type.dtype + + ## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32 + ## and then requantize afer + ## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qlmul.cpp + a = _qnn.op.dequantize(inputs[0], a_scale, a_zero_point) + b = _qnn.op.dequantize(inputs[3], b_scale, b_zero_point) + out = _op.nn.matmul(a, b) + return _qnn.op.quantize(out, y_scale, y_zero_point, out_dtype=dtype) + + class ConvInteger(OnnxOpConverter): """Operator converter for ConvInteger.""" @@ -3605,6 +3639,7 @@ def _get_convert_map(opset): "ReverseSequence": ReverseSequence.get_converter(opset), "QLinearConv": QLinearConv.get_converter(opset), "QLinearAdd": QLinearAdd.get_converter(opset), + "QLinearMatMul": QLinearMatMul.get_converter(opset), "ConvInteger": ConvInteger.get_converter(opset), # Random number generation. "RandomUniform": RandomUniform.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 93b9cfa07464..b602ec7870b8 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4768,7 +4768,6 @@ def verify_eyelike(indata): "test_pow_types_int32_int32", "test_pow_types_int64_float32", "test_pow_types_int64_int64", - "test_qlinearmatmul_2D", "test_qlinearmatmul_3D", "test_range_float_type_positive_delta_expanded", "test_range_int32_type_negative_delta_expanded", @@ -5273,8 +5272,6 @@ def verify_qlinearadd(a_shape, b_shape, c_shape): ] input_values = [a_array, b_array] - node = helper.make_node("QLinearAdd", inputs=input_names, outputs=["C"]) - node = helper.make_node("Add", ["a", "b"], ["C"]) graph = helper.make_graph( [node], @@ -5282,7 +5279,7 @@ def verify_qlinearadd(a_shape, b_shape, c_shape): inputs=input_nodes, outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(c_shape))], ) - model = helper.make_model(graph, producer_name="qlinearconv_test") + model = helper.make_model(graph, producer_name="qlinearadd_test") from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType class RandomDataReader(CalibrationDataReader): @@ -5307,13 +5304,73 @@ def get_next(self): quantized_model = quantize_static(model_fp32, model_quant, RandomDataReader()) # opt_level=1 will cause error with qnn lowering model = onnx.load(model_quant) - verify_with_ort_with_inputs(model, input_values, opt_level=2, target=target, dev=dev) + verify_with_ort_with_inputs( + model, input_values, opt_level=2, target=target, dev=dev, use_vm=True + ) verify_qlinearadd([4, 2], [4, 2], [4, 2]) verify_qlinearadd([4, 2], [2], [4, 2]) verify_qlinearadd([5, 1, 7], [2, 7], [5, 2, 7]) +@tvm.testing.parametrize_targets +def test_qlinearmatmul(target, dev): + def verify_qlinearmatmul(a_shape, b_shape, c_shape): + + a_array = np.random.random(a_shape).astype("float32") + b_array = np.random.random(b_shape).astype("float32") + + input_nodes = [ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ] + input_names = [ + "a", + "b", + ] + input_values = [a_array, b_array] + + node = helper.make_node("MatMul", input_names, ["C"]) + graph = helper.make_graph( + [node], + "qlinearmatmul_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(c_shape))], + ) + model = helper.make_model(graph, producer_name="qlinearmatmul_test") + from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType + + class RandomDataReader(CalibrationDataReader): + def __init__(self, n=10): + self.data = iter( + [ + { + "a": np.random.random(a_shape).astype("float32"), + "b": np.random.random(b_shape).astype("float32"), + } + for _ in range(n) + ] + ) + + def get_next(self): + return next(self.data, None) + + d = tvm.contrib.utils.tempdir() + model_fp32 = os.path.join(d.temp_dir, "model.onnx") + onnx.save_model(model, model_fp32) + model_quant = os.path.join(d.temp_dir, "model.quant.onnx") + quantized_model = quantize_static(model_fp32, model_quant, RandomDataReader()) + # opt_level=1 will cause error with qnn lowering + model = onnx.load(model_quant) + verify_with_ort_with_inputs( + model, input_values, opt_level=2, target=target, dev=dev, use_vm=True + ) + + verify_qlinearmatmul([2, 4], [4, 3], [2, 3]) + verify_qlinearmatmul([2, 2], [2, 1], [2, 1]) + verify_qlinearmatmul([5, 3], [3, 4], [5, 4]) + + @tvm.testing.parametrize_targets def test_random_uniform(target, dev): def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None): From 925c730949223d193357e7934015b8ee98c6178e Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 17 Aug 2021 15:38:28 -0700 Subject: [PATCH 2/4] noop From cd2e7bcf0b8321390f5f74ca1452b09cd7e80438 Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 17 Aug 2021 16:17:26 -0700 Subject: [PATCH 3/4] mul not matmul --- python/tvm/relay/frontend/onnx.py | 10 +++++----- tests/python/frontend/onnx/test_forward.py | 17 +++++++++-------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index faad4ea2748b..695203772e57 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3279,8 +3279,8 @@ def get_scalar(x, dtype="float32"): return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype) -class QLinearMatMul(OnnxOpConverter): - """Operator converter for QLinearMatMul from Microsoft onnxruntime contrib opset.""" +class QLinearMul(OnnxOpConverter): + """Operator converter for QLinearMul from Microsoft onnxruntime contrib opset.""" @classmethod def _impl_v10(cls, inputs, attr, params): @@ -3288,7 +3288,7 @@ def get_scalar(x, dtype="float32"): if isinstance(x, _expr.Var) and x.name_hint in params: return _op.const(params[x.name_hint].numpy(), dtype) rank = len(infer_shape(x)) - assert rank <= 1, "QLinearMatMul scale and zero_point input must be scalars" + assert rank <= 1, "QLinearMul scale and zero_point input must be scalars" if rank == 1: x = _op.squeeze(x, [0]) return _op.cast(x, dtype) @@ -3309,7 +3309,7 @@ def get_scalar(x, dtype="float32"): ## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qlmul.cpp a = _qnn.op.dequantize(inputs[0], a_scale, a_zero_point) b = _qnn.op.dequantize(inputs[3], b_scale, b_zero_point) - out = _op.nn.matmul(a, b) + out = _op.multiply(a, b) return _qnn.op.quantize(out, y_scale, y_zero_point, out_dtype=dtype) @@ -3639,7 +3639,7 @@ def _get_convert_map(opset): "ReverseSequence": ReverseSequence.get_converter(opset), "QLinearConv": QLinearConv.get_converter(opset), "QLinearAdd": QLinearAdd.get_converter(opset), - "QLinearMatMul": QLinearMatMul.get_converter(opset), + "QLinearMul": QLinearMul.get_converter(opset), "ConvInteger": ConvInteger.get_converter(opset), # Random number generation. "RandomUniform": RandomUniform.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index b602ec7870b8..83ef32dd1066 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4768,6 +4768,7 @@ def verify_eyelike(indata): "test_pow_types_int32_int32", "test_pow_types_int64_float32", "test_pow_types_int64_int64", + "test_qlinearmatmul_2D", "test_qlinearmatmul_3D", "test_range_float_type_positive_delta_expanded", "test_range_int32_type_negative_delta_expanded", @@ -5314,8 +5315,8 @@ def get_next(self): @tvm.testing.parametrize_targets -def test_qlinearmatmul(target, dev): - def verify_qlinearmatmul(a_shape, b_shape, c_shape): +def test_qlinearmul(target, dev): + def verify_qlinearmul(a_shape, b_shape, c_shape): a_array = np.random.random(a_shape).astype("float32") b_array = np.random.random(b_shape).astype("float32") @@ -5330,14 +5331,14 @@ def verify_qlinearmatmul(a_shape, b_shape, c_shape): ] input_values = [a_array, b_array] - node = helper.make_node("MatMul", input_names, ["C"]) + node = helper.make_node("Mul", input_names, ["C"]) graph = helper.make_graph( [node], - "qlinearmatmul_test", + "qlinearmul_test", inputs=input_nodes, outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(c_shape))], ) - model = helper.make_model(graph, producer_name="qlinearmatmul_test") + model = helper.make_model(graph, producer_name="qlinearmul_test") from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType class RandomDataReader(CalibrationDataReader): @@ -5366,9 +5367,9 @@ def get_next(self): model, input_values, opt_level=2, target=target, dev=dev, use_vm=True ) - verify_qlinearmatmul([2, 4], [4, 3], [2, 3]) - verify_qlinearmatmul([2, 2], [2, 1], [2, 1]) - verify_qlinearmatmul([5, 3], [3, 4], [5, 4]) + verify_qlinearmul([4, 2], [4, 2], [4, 2]) + verify_qlinearmul([4, 2], [2], [4, 2]) + verify_qlinearmul([5, 1, 7], [2, 7], [5, 2, 7]) @tvm.testing.parametrize_targets From 520d62f6c056b8e04de9ee9c1e5d38f3eea4a4d7 Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 17 Aug 2021 16:41:27 -0700 Subject: [PATCH 4/4] refactor some common qlinear op test code --- tests/python/frontend/onnx/test_forward.py | 89 +++++++++------------- 1 file changed, 35 insertions(+), 54 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 83ef32dd1066..446b2473ec59 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -234,6 +234,39 @@ def verify_with_ort( ) +def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): + from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType + + input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] + + class RandomDataReader(CalibrationDataReader): + def __init__(self, n=10): + input_dict = dict(zip(input_names, input_shapes)) + self.data = iter( + [ + { + name: np.random.random(shape).astype("float32") + for name, shape in input_dict.items() + } + for _ in range(n) + ] + ) + + def get_next(self): + return next(self.data, None) + + d = tvm.contrib.utils.tempdir() + model_fp32 = os.path.join(d.temp_dir, "model.onnx") + onnx.save_model(onnx_model, model_fp32) + model_quant = os.path.join(d.temp_dir, "model.quant.onnx") + quantized_model = quantize_static(model_fp32, model_quant, RandomDataReader()) + # opt_level=1 will cause error with qnn lowering + model = onnx.load(model_quant) + verify_with_ort_with_inputs( + model, input_arrays, opt_level=2, target=target, dev=dev, use_vm=True + ) + + def make_constant_node(name, data_type, dims, vals): return helper.make_node( "Constant", @@ -5281,33 +5314,7 @@ def verify_qlinearadd(a_shape, b_shape, c_shape): outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(c_shape))], ) model = helper.make_model(graph, producer_name="qlinearadd_test") - from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType - - class RandomDataReader(CalibrationDataReader): - def __init__(self, n=10): - self.data = iter( - [ - { - "a": np.random.random(a_shape).astype("float32"), - "b": np.random.random(b_shape).astype("float32"), - } - for _ in range(n) - ] - ) - - def get_next(self): - return next(self.data, None) - - d = tvm.contrib.utils.tempdir() - model_fp32 = os.path.join(d.temp_dir, "model.onnx") - onnx.save_model(model, model_fp32) - model_quant = os.path.join(d.temp_dir, "model.quant.onnx") - quantized_model = quantize_static(model_fp32, model_quant, RandomDataReader()) - # opt_level=1 will cause error with qnn lowering - model = onnx.load(model_quant) - verify_with_ort_with_inputs( - model, input_values, opt_level=2, target=target, dev=dev, use_vm=True - ) + quantize_and_verify_with_ort(model, input_names, [a_shape, b_shape], target, dev) verify_qlinearadd([4, 2], [4, 2], [4, 2]) verify_qlinearadd([4, 2], [2], [4, 2]) @@ -5339,33 +5346,7 @@ def verify_qlinearmul(a_shape, b_shape, c_shape): outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(c_shape))], ) model = helper.make_model(graph, producer_name="qlinearmul_test") - from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType - - class RandomDataReader(CalibrationDataReader): - def __init__(self, n=10): - self.data = iter( - [ - { - "a": np.random.random(a_shape).astype("float32"), - "b": np.random.random(b_shape).astype("float32"), - } - for _ in range(n) - ] - ) - - def get_next(self): - return next(self.data, None) - - d = tvm.contrib.utils.tempdir() - model_fp32 = os.path.join(d.temp_dir, "model.onnx") - onnx.save_model(model, model_fp32) - model_quant = os.path.join(d.temp_dir, "model.quant.onnx") - quantized_model = quantize_static(model_fp32, model_quant, RandomDataReader()) - # opt_level=1 will cause error with qnn lowering - model = onnx.load(model_quant) - verify_with_ort_with_inputs( - model, input_values, opt_level=2, target=target, dev=dev, use_vm=True - ) + quantize_and_verify_with_ort(model, input_names, [a_shape, b_shape], target, dev) verify_qlinearmul([4, 2], [4, 2], [4, 2]) verify_qlinearmul([4, 2], [2], [4, 2])