diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a695e0002b34..63e3a8eb9fd7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -209,6 +209,86 @@ def _dim_check(attrs): return _dim_check, "Only 1d, 2d and 3d kernel supported." +def collect_quant_params(graph, quant_ops): + """Collect quant params for the quant ops that between + QuantizeLinear and DequantizeLinear nodes + """ + # init the params from initializer and Constant node + params = {} + for init in graph.initializer: + params[init.name] = get_numpy(init) + for node in graph.node: + if node.op_type == "Constant": + for attr in node.attribute: + if attr.name == "value" and attr.HasField("t"): + params[node.output[0]] = get_numpy(attr.t) + + # map the output to nodes, input to nodes, prepare for dfs search + output_to_nodes = {} + input_to_nodes = {} + for node in graph.node: + for k in node.output: + output_to_nodes[k] = node + for idx, k in enumerate(node.input): + if k not in input_to_nodes: + input_to_nodes[k] = [] + input_to_nodes[k].append([idx, node]) + + quant_params = {} + # set the quant params for the input of quant ops before DequantizeLinear + def forward_dfs(node, input_idx, scale, zero_point): + if node.op_type == "DequantizeLinear": + return + if node.op_type in quant_ops: + quant_params[node.input[input_idx]] = (scale, zero_point) + return + for out in node.output: + if out in input_to_nodes: + for next_input_idx, next_node in input_to_nodes[out]: + forward_dfs(next_node, next_input_idx, scale, zero_point) + + # set the quant params for the output of quant ops after QuantizeLinear + def backward_dfs(node, scale, zero_point): + if node.op_type == "QuantizeLinear": + return + if node.op_type in quant_ops: + for out in node.output: + quant_params[out] = (scale, zero_point) + return + for inp in node.input: + if inp in output_to_nodes: + previous_node = output_to_nodes[inp] + backward_dfs(previous_node, scale, zero_point) + + for node in graph.node: + # pass quant params forward to the input of quant ops before DequantizeLinear + if node.op_type == "QuantizeLinear": + if node.input[1] in params and node.input[2] in params: + forward_dfs(node, 0, params[node.input[1]], params[node.input[2]]) + + # pass quant params backward to the output of quant ops after QuantizeLinear + if node.op_type == "DequantizeLinear": + if node.input[1] in params and node.input[2] in params: + backward_dfs(node, params[node.input[1]], params[node.input[2]]) + + return quant_params + + +def is_quantized_expr(x): + ttype = infer_type(x).checked_type + return ttype.dtype in ["int8", "uint8"] + + +def check_quant_params(attr, input_ids, output_ids): + for input_id in input_ids: + if attr["tvm_custom"]["input_quant_params"][input_id] is None: + return False + for output_id in output_ids: + if attr["tvm_custom"]["input_quant_params"][output_id] is None: + return False + return True + + class OnnxOpConverter(object): """A helper class for holding onnx op converters.""" @@ -264,6 +344,36 @@ def _impl_v1(cls, inputs, attr, params): # TODO(zhreshold): remove hard coded infershape axis = int(attr.get("axis", 0)) inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2) + + if ( + is_quantized_expr(inputs[0]) + and is_quantized_expr(inputs[1]) + and check_quant_params(attr, [0, 1], [0]) + ): + + lhs_scale, lhs_zero_point = attr["tvm_custom"]["input_quant_params"][0] + rhs_scale, rhs_zero_point = attr["tvm_custom"]["input_quant_params"][1] + out_scale, out_zero_point = attr["tvm_custom"]["output_quant_params"][0] + + args = [ + inputs[0], + inputs[1], + _expr.const(lhs_scale), + _expr.const(int(lhs_zero_point)), + _expr.const(rhs_scale), + _expr.const(int(rhs_zero_point)), + _expr.const(out_scale), + _expr.const(int(out_zero_point)), + ] + + if op_name == "add": + return _qnn.op.add(*args) + if op_name == "multiply": + return _qnn.op.mul(*args) + if op_name == "subtract": + return _qnn.op.subtract(*args) + raise Exception("Unsupported op: {} for quantized input!".format(op_name)) + return get_relay_op(op_name)(*inputs) @@ -406,6 +516,7 @@ class Conv(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + # Use shape of input to determine convolution type. data = inputs[0] input_shape = infer_shape(data) @@ -436,16 +547,51 @@ def _impl_v1(cls, inputs, attr, params): msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) attr.pop("auto_pad") - out = AttrCvt( - op_name=dimension_picker("conv"), - transforms={ - "kernel_shape": "kernel_size", - "dilations": ("dilation", 1), - "pads": ("padding", 0), - "group": ("groups", 1), - }, - custom_check=dimension_constraint(), - )([data, inputs[1]], attr, params) + + if ( + ndim == 4 + and is_quantized_expr(data) + and is_quantized_expr(inputs[1]) + and check_quant_params(attr, [0, 1], [0]) + ): + + input_scale, input_zero_point = attr["tvm_custom"]["input_quant_params"][0] + kernel_scale, kernel_zero_point = attr["tvm_custom"]["input_quant_params"][1] + + weight_shape = infer_shape(inputs[1]) + kernel_size = (weight_shape[2], weight_shape[3]) + out_channels = weight_shape[0] + + padding = attr["pads"] if "pads" in attr else 0 + dilation = attr["dilations"] if "dilations" in attr else 1 + groups = attr["group"] if "group" in attr else 1 + strides = attr["strides"] + + out = _qnn.op.conv2d( + data, + inputs[1], + _expr.const(int(input_zero_point)), + _expr.const(int(kernel_zero_point)), + _expr.const(input_scale), + _expr.const(kernel_scale), + kernel_size=kernel_size, + dilation=dilation, + strides=strides, + padding=padding, + groups=groups, + channels=out_channels, + ) + else: + out = AttrCvt( + op_name=dimension_picker("conv"), + transforms={ + "kernel_shape": "kernel_size", + "dilations": ("dilation", 1), + "pads": ("padding", 0), + "group": ("groups", 1), + }, + custom_check=dimension_constraint(), + )([data, inputs[1]], attr, params) use_bias = len(inputs) == 3 if use_bias: @@ -587,9 +733,43 @@ def _impl_v1(cls, inputs, attr, params): inputs[0] = _op.nn.batch_flatten(inputs[0]) if alpha != 1.0: inputs[0] *= _expr.const(alpha) - out = _op.nn.dense(inputs[0], inputs[1], units=channels) - if len(inputs) == 3: - out = out + _expr.const(beta) * inputs[2] + + if ( + is_quantized_expr(inputs[0]) + and is_quantized_expr(inputs[1]) + and check_quant_params(attr, [0, 1], [0]) + ): + + input_scale, input_zero_point = attr["tvm_custom"]["input_quant_params"][0] + kernel_scale, kernel_zero_point = attr["tvm_custom"]["input_quant_params"][1] + + out = _qnn.op.dense( + inputs[0], + inputs[1], + _expr.const(int(input_zero_point)), + _expr.const(int(kernel_zero_point)), + _expr.const(input_scale), + _expr.const(kernel_scale), + units=channels, + ) + if len(inputs) == 3: + output_scale, output_zero_point = attr["tvm_custom"]["output_quant_params"][0] + bias = _qnn.op.dequantize( + inputs[2], _expr.const(output_scale), _expr.const(output_zero_point), axis=0 + ) + bias = _expr.const(beta) * bias + bias = _qnn.op.quantize( + bias, + _expr.const(output_scale), + _expr.const(0, "int32"), + out_dtype="int32", + axis=0, + ) + out = out + bias + else: + out = _op.nn.dense(inputs[0], inputs[1], units=channels) + if len(inputs) == 3: + out = out + _expr.const(beta) * inputs[2] return out @@ -2937,6 +3117,7 @@ def __init__(self, shape, dtype, freeze_params=False): self._dtype = dtype self.opset = None self._freeze_params = freeze_params + self._quant_params = {} def __enter__(self): self._old_manager = GraphProto.current @@ -3054,6 +3235,7 @@ def from_onnx(self, graph, opset, get_output_expr=False): msg = "The following operators are not supported for frontend ONNX: " msg += ", ".join(unsupported_ops) raise tvm.error.OpNotImplemented(msg) + self._quant_params = collect_quant_params(graph, ["Conv", "Gemm", "Add", "Mul", "Sub"]) # construct nodes, nodes are stored as directed acyclic graph for node in graph.node: op_name = node.op_type @@ -3067,9 +3249,16 @@ def from_onnx(self, graph, opset, get_output_expr=False): inputs[i] = None i_name = self._parse_value_proto(node) node_output = self._fix_outputs(op_name, node.output) + attr["tvm_custom"] = {} attr["tvm_custom"]["name"] = i_name attr["tvm_custom"]["num_outputs"] = len(node_output) + attr["tvm_custom"]["input_quant_params"] = [ + self._quant_params[x] if x in self._quant_params else None for x in node.input + ] + attr["tvm_custom"]["output_quant_params"] = [ + self._quant_params[x] if x in self._quant_params else None for x in node.output + ] op = self._convert_operator(op_name, inputs, attr, opset) if not isinstance(op, _expr.TupleWrapper): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1a3d0d4ac6e0..5addb9d1cf62 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4351,6 +4351,181 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None verify_embedding_bag(32, 2, [3, 3]) +def check_contain_qnn_op(graph_def, input_data, qnn_op_name): + class QnnOpChecker(relay.expr_functor.ExprVisitor): + def __init__(self, qnn_op_name): + relay.expr_functor.ExprVisitor.__init__(self) + self.valid = False + self.qnn_op_name = qnn_op_name + + def visit_call(self, call): + if hasattr(call.op, "name") and call.op.name == self.qnn_op_name: + self.valid = True + super().visit_call(call) + + input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) + mod, params = relay.frontend.from_onnx(graph_def, shape_dict) + + checker = QnnOpChecker(qnn_op_name) + checker.visit(mod["main"]) + return checker.valid + + +def get_quantize_node(input, output, zp_dtype, dequantize=False): + scale_name = output + "_scale" + zp_name = output + "_zp" + + scale_node = make_constant_node(scale_name, TensorProto.FLOAT, (), [np.random.rand()]) + if zp_dtype == TensorProto.INT8: + zp_node = make_constant_node(zp_name, TensorProto.INT8, (), [np.random.randint(-128, 127)]) + else: + zp_node = make_constant_node(zp_name, zp_dtype, (), [np.random.randint(0, 255)]) + + quant_node = helper.make_node( + "DequantizeLinear" if dequantize is True else "QuantizeLinear", + inputs=[input, scale_name, zp_name], + outputs=[output], + axis=0, + ) + return [scale_node, zp_node, quant_node] + + +def test_quantized_conv2d(): + def verify_quantized_conv2d( + x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilations, zp_dtype + ): + nodes = [] + nodes.extend(get_quantize_node("x", "quant_x", zp_dtype)) + nodes.extend(get_quantize_node("W", "quant_W", zp_dtype)) + nodes.append( + helper.make_node( + "Conv", + inputs=["quant_x", "quant_W"], + outputs=["conv_out"], + kernel_shape=kernel_shape, + strides=strides, + dilations=dilations, + pads=padding, + ) + ) + nodes.extend(get_quantize_node("conv_out", "y", TensorProto.INT32, dequantize=True)) + + graph = helper.make_graph( + nodes, + "quantized_conv2d_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))], + ) + model = helper.make_model(graph, producer_name="quantized_conv2d_test") + + x_array = np.random.uniform(size=x_shape).astype("float32") + w_array = np.random.uniform(size=w_shape).astype("float32") + + # onnxruntime can not run conv with int8 as input, + # so we only check if the expr contain qnn op, and run the mod in tvm + assert check_contain_qnn_op(model, [x_array, w_array], "qnn.conv2d") == True + get_tvm_output_with_vm(model, [x_array, w_array], target="llvm", device=tvm.cpu(0)) + + verify_quantized_conv2d( + (1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), (2, 2), (3, 3), (1, 1), (1, 1), TensorProto.INT8 + ) + verify_quantized_conv2d( + (1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), (0, 0), (3, 3), (1, 1), (1, 1), TensorProto.UINT8 + ) + + +def test_quantized_gemm(): + def verify_quantized_gemm(a_shape, b_shape, c_shape, zp_dtype): + out_shape = [a_shape[0], b_shape[1]] + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") + + nodes = [] + input_names = ["quant_a", "quant_b"] + input_values = [a_array, b_array] + input_nodes = [ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ] + nodes.extend(get_quantize_node("a", "quant_a", zp_dtype)) + nodes.extend(get_quantize_node("b", "quant_b", zp_dtype)) + + if c_shape is not None: + c_array = np.random.uniform(size=c_shape).astype("float32") + input_names.append("quant_c") + input_nodes.append(helper.make_tensor_value_info("c", TensorProto.FLOAT, list(c_shape))) + input_values.append(c_array) + nodes.extend(get_quantize_node("c", "quant_c", TensorProto.INT32)) + + nodes.append(helper.make_node("Gemm", input_names, ["gemm_out"])) + nodes.extend(get_quantize_node("gemm_out", "out", TensorProto.INT32, dequantize=True)) + + graph = helper.make_graph( + nodes, + "quantized_gemm_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="quanttized_gemm_test") + + # onnxruntime can not run gemm with int8 as input, + # so we only check if the expr contain qnn op, and run the mod in tvm + assert check_contain_qnn_op(model, input_values, "qnn.dense") == True + get_tvm_output_with_vm(model, input_values, target="llvm", device=tvm.cpu(0)) + + verify_quantized_gemm(a_shape=(4, 8), b_shape=(8, 4), c_shape=None, zp_dtype=TensorProto.INT8) + verify_quantized_gemm(a_shape=(4, 8), b_shape=(8, 4), c_shape=(4,), zp_dtype=TensorProto.UINT8) + + +def test_quantized_binary_ops(): + in_shape = (1, 2, 3, 3) + out_shape = in_shape + dtype = "float32" + + def verify_binary_op(op, x, y, zp_dtype): + + nodes = [] + nodes.extend(get_quantize_node("in1", "quant_in1", zp_dtype)) + nodes.extend(get_quantize_node("in2", "quant_in2", zp_dtype)) + nodes.append(helper.make_node(op, ["quant_in1", "quant_in2"], ["binary_out"])) + nodes.extend(get_quantize_node("binary_out", "out", TensorProto.INT32, dequantize=True)) + + graph = helper.make_graph( + nodes, + "_test", + inputs=[ + helper.make_tensor_value_info("in1", TensorProto.FLOAT, x.shape), + helper.make_tensor_value_info("in2", TensorProto.FLOAT, y.shape), + ], + outputs=[ + helper.make_tensor_value_info( + "out", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], list(out_shape) + ) + ], + ) + model = helper.make_model(graph, producer_name="_test") + + # onnxruntime can not run add, mul, substract with int8 as input, + # so we only check if the expr contain qnn op, and run the mod in tvm + qnn_op_maps = {"Add": "qnn.add", "Mul": "qnn.mul", "Sub": "qnn.subtract"} + assert check_contain_qnn_op(model, [x, y], qnn_op_maps[op]) == True + get_tvm_output_with_vm(model, [x, y], target="llvm", device=tvm.cpu(0)) + + x = np.random.uniform(size=in_shape).astype(dtype) + y = np.random.uniform(size=in_shape).astype(dtype) + z = np.random.uniform(size=(3,)).astype(dtype) + + verify_binary_op("Add", x, y, TensorProto.INT8) + verify_binary_op("Add", x, z, TensorProto.UINT8) + verify_binary_op("Sub", x, y, TensorProto.UINT8) + verify_binary_op("Sub", x, z, TensorProto.INT8) + verify_binary_op("Mul", x, y, TensorProto.INT8) + verify_binary_op("Mul", x, z, TensorProto.UINT8) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -4431,3 +4606,6 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None test_cumsum() test_wrong_input() test_aten() + test_quantized_conv2d() + test_quantized_gemm() + test_quantized_binary_ops()