diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index e8f0fbffc0dc..7135fccdf43b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2973,6 +2973,8 @@ def _impl_v13(cls, inputs, attr, params): data, scale, zp = inputs out_dtype = infer_type(zp).checked_type.dtype axis = attr.get("axis", 1) + if len(infer_shape(data)) < 2: + axis = 0 return _qnn.op.quantize(data, scale, _op.cast(zp, "int32"), axis, out_dtype) @@ -3033,10 +3035,11 @@ def get_scalar(x, dtype="float32"): weight = inputs[3] w_scale = get_scalar(inputs[4]) w_zero_point = get_scalar(inputs[5], "int32") - y_scale = get_scalar(inputs[6]) + y_scale = fold_constant(get_scalar(inputs[6])) y_zero_point = get_scalar(inputs[7], "int32") input_shape = infer_shape(data) + ndim = len(input_shape) kernel_type = infer_type(weight) kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] @@ -3116,6 +3119,44 @@ def get_scalar(x, dtype="float32"): return out +class QLinearAdd(OnnxOpConverter): + """Operator converter for QLinearAdd from Microsoft onnxruntime contrib opset.""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + def get_scalar(x, dtype="float32"): + if isinstance(x, _expr.Var) and x.name_hint in params: + return _op.const(params[x.name_hint].numpy(), dtype) + rank = len(infer_shape(x)) + assert rank <= 1, "QLinearConv scale and zero_point input must be scalars" + if rank == 1: + x = _op.squeeze(x, [0]) + return _op.cast(x, dtype) + + a = inputs[0] + a_scale = get_scalar(inputs[1]) + a_zero_point = get_scalar(inputs[2], "int32") + b = inputs[3] + b_scale = get_scalar(inputs[4]) + b_zero_point = get_scalar(inputs[5], "int32") + c_scale = get_scalar(inputs[6]) + c_zero_point = get_scalar(inputs[7], "int32") + + dtype = infer_type(a).checked_type.dtype + + ## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32 + ## and then requantize afer + ## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qladd.cpp + a = _qnn.op.dequantize( + inputs[0], a_scale, a_zero_point + ) # , c_scale, c_zero_point, out_dtype = dtype) + b = _qnn.op.dequantize( + inputs[3], b_scale, b_zero_point + ) # , c_scale, c_zero_point, out_dtype = dtype) + out = _op.add(a, b) + return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype) + + class BitShift(OnnxOpConverter): """Operator converter for NonZero""" @@ -3343,6 +3384,7 @@ def _get_convert_map(opset): "DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset), "ReverseSequence": ReverseSequence.get_converter(opset), "QLinearConv": QLinearConv.get_converter(opset), + "QLinearAdd": QLinearAdd.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index db71855fd80f..52c3346e5807 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os import re import numpy as np @@ -120,7 +121,7 @@ def get_tvm_output( def get_onnxruntime_output(model, inputs): import onnxruntime.backend - rep = onnxruntime.backend.prepare(model, "CPU") + rep = onnxruntime.backend.prepare(model.SerializeToString(), "CPU") if isinstance(inputs, list) and len(inputs) == 1: inp = inputs[0] else: @@ -149,6 +150,7 @@ def verify_with_ort_with_inputs( ): if opset is not None: model.opset_import[0].version = opset + ort_out = get_onnxruntime_output(model, inputs) if targets is None: @@ -4755,6 +4757,64 @@ def repeat(N, D): ) +def verify_qlinearadd(a_shape, b_shape, c_shape): + + a_array = np.random.random(a_shape).astype("float32") + b_array = np.random.random(b_shape).astype("float32") + + input_nodes = [ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ] + input_names = [ + "a", + "b", + ] + input_values = [a_array, b_array] + + node = helper.make_node("QLinearAdd", inputs=input_names, outputs=["C"]) + + node = helper.make_node("Add", ["a", "b"], ["C"]) + graph = helper.make_graph( + [node], + "qlinearadd_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(c_shape))], + ) + model = helper.make_model(graph, producer_name="qlinearconv_test") + from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType + + class RandomDataReader(CalibrationDataReader): + def __init__(self, n=10): + self.data = iter( + [ + { + "a": np.random.random(a_shape).astype("float32"), + "b": np.random.random(b_shape).astype("float32"), + } + for _ in range(n) + ] + ) + + def get_next(self): + return next(self.data, None) + + d = tvm.contrib.utils.tempdir() + model_fp32 = os.path.join(d.temp_dir, "model.onnx") + onnx.save_model(model, model_fp32) + model_quant = os.path.join(d.temp_dir, "model.quant.onnx") + quantized_model = quantize_static(model_fp32, model_quant, RandomDataReader()) + # opt_level=1 will cause error with qnn lowering + model = onnx.load(model_quant) + verify_with_ort_with_inputs(model, input_values, opt_level=2) + + +def test_qlinearadd(): + verify_qlinearadd([4, 2], [4, 2], [4, 2]) + verify_qlinearadd([4, 2], [2], [4, 2]) + verify_qlinearadd([5, 1, 7], [2, 7], [5, 2, 7]) + + if __name__ == "__main__": test_flatten() test_reshape()