From 29e864429a176a2533ec15a60ecb961c5b47cb03 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 1 Mar 2020 21:00:45 +0900 Subject: [PATCH 01/14] qnn support initial import --- python/tvm/relay/frontend/pytorch.py | 75 +++- python/tvm/relay/frontend/qnn_torch.py | 566 +++++++++++++++++++++++++ 2 files changed, 626 insertions(+), 15 deletions(-) create mode 100644 python/tvm/relay/frontend/qnn_torch.py diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b256faa5d6f9..e38bed1385ad 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -32,6 +32,8 @@ from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value +import qnn_torch + __all__ = ["from_pytorch"] # operator implementation @@ -146,6 +148,10 @@ def _impl(inputs, input_types): def _relu(): def _impl(inputs, input_types): data = inputs[0] + if input_types[0] == "quint8": + assert len(inputs) == 3, "Input quant param not found in op inputs" + input_zero_point = _expr.const(inputs[2]) + return qnn_torch.quantized_relu(data, input_zero_point) return _op.nn.relu(data) return _impl @@ -154,9 +160,14 @@ def _impl(inputs, input_types): data = inputs[0] output_size = _infer_shape(inputs[1]) - return _op.nn.adaptive_avg_pool2d( - data, - output_size=output_size) + def func(x): + return _op.nn.adaptive_avg_pool2d(x, output_size=output_size) + + if input_types[0] == "quint8": + return qnn_torch.quantized_adaptive_avg_2d(data, func) + + return func(data) + return _impl def _adaptive_max_2d(): @@ -503,7 +514,18 @@ def _impl(inputs, input_types): else: exclude = False - return _op.mean(data, axis, keepdims, exclude) + def func(x): + return _op.mean(x, axis, keepdims, exclude) + + if input_types[0] == "quint8": + assert len(inputs) == 6, "Input quant param not found in op inputs" + input_scale = _expr.const(inputs[4]) + input_zero_point = _expr.const(inputs[5]) + return qnn_torch.quantized_mean(data, input_scale, + input_zero_point, func) + + return func(data) + return _impl def _chunk(): @@ -665,7 +687,16 @@ def _impl(inputs, input_types): else: coord_trans = "half_pixel" - return _op.image.resize(data, out_size, "NCHW", method, coord_trans) + def func(x): + return _op.image.resize(x, out_size, "NCHW", "bilinear", coord_trans) + + if input_types[0] == "quint8": + assert len(inputs) == 7, "Input quant param not found in op inputs" + input_scale = _expr.const(inputs[-2]) + input_zero_point = _expr.const(inputs[-1]) + return qnn_torch.quantized_upsample(data, input_scale, + input_zero_point, func) + return func(data) return _impl @@ -932,7 +963,7 @@ def _get_operator_nodes(nodes): return ops -def parse_inputs(graph_inputs, input_shapes): +def convert_inputs(graph_inputs, input_shapes): """ Return Relay vars from torch input vars """ ir_inputs = list(graph_inputs) input_vars = {} @@ -983,7 +1014,7 @@ def terminate(users): return get_use_chains(root_getattr_node, terminate) -def parse_params(graph, state_dict): +def convert_params(graph, state_dict): """ Return Relay vars and TVM NDArrays for input parameters A chain of prim::GetAttr nodes is processed one at a time @@ -991,6 +1022,7 @@ def parse_params(graph, state_dict): getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True) params = {} param_tensors = {} + packed_param_map = {} seen = set() for node in getattr_nodes: @@ -1003,17 +1035,20 @@ def parse_params(graph, state_dict): full_attr = _getattr_full_name(getattrs) full_attr_node_name = _get_output_name(getattrs[-1]) - if full_attr in state_dict: + if full_attr.endswith("_packed_params"): # for quantized models + assert full_attr in state_dict + packed_param_map[full_attr_node_name] = full_attr + elif full_attr in state_dict: torch_tensor = state_dict[full_attr] tensor, var = _get_tensor_and_var(torch_tensor, full_attr_node_name) param_tensors[full_attr_node_name] = tensor params[full_attr_node_name] = var - return params, param_tensors + return params, param_tensors, packed_param_map -def parse_operators(operators, outputs, output_index_map, ret_name): +def convert_operators(operators, outputs, output_index_map, ret_name): """ Convert each Torch IR operators to Relay equivalent """ for node_name, op_node in operators.items(): operator = op_node.kind() @@ -1089,17 +1124,27 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): _report_missing_conversion(op_names) params = script_module.state_dict() - input_vars = parse_inputs(graph.inputs(), input_shapes) - param_vars, tensors = parse_params(graph, params) + input_vars = convert_inputs(graph.inputs(), input_shapes) + param_vars, tensors, packed_param_map = convert_params(graph, params) + tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} input_vars.update(param_vars) outputs = list(input_vars.values()) output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) ret_name = _get_input_names(graph.return_node())[0] - body = parse_operators(_get_operator_nodes(graph.nodes()), outputs, - output_index_map, ret_name) + # For quantized models + if "aten::quantize_per_tensor" in op_names: + weight_quant_params = qnn_torch.get_weight_quant_params(script_module) + qnn_torch.add_input_quant_params_to_op_inputs(graph) + qnn_torch.add_quant_params_to_outputs(outputs, output_index_map, + packed_param_map, + weight_quant_params) + qnn_torch.add_quant_params(tvm_params, weight_quant_params) + _convert_map.update(qnn_torch.convert_map) + + body = convert_operators(_get_operator_nodes(graph.nodes()), outputs, + output_index_map, ret_name) func = tvm.relay.Function(_analysis.free_vars(body), body) - tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} return _module.IRModule.from_expr(func), tvm_params diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py new file mode 100644 index 000000000000..09d0e3841bb6 --- /dev/null +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -0,0 +1,566 @@ +import torch +import tvm +import numpy as np +from tvm import relay +from tvm.relay import expr as _expr +from tvm.relay import op as _op +from tvm.relay.frontend.common import infer_shape + + +class QuantParam: + def __init__(self, weight, bias, scale, zero_point, param_key): + param_prefix = param_key[:-len("._packed_params")] + self.weight_var = _expr.var(param_prefix + "_weight", + shape=weight.shape) + self.weight = weight + + if bias is not None: + self.bias_var = _expr.var(param_prefix + "_bias", + shape=bias.shape) + self.bias = bias.detach().numpy() + else: + self.bias_var = None + self.bias = None + + self.scale = _expr.const(scale) + self.zero_point = _expr.const(zero_point, dtype="int32") + + +def unpack_quant_params(param_name, packed_params, unpack_func): + qweight, bias = unpack_func(packed_params) + weight_np = qweight.dequantize().numpy() + + if qweight.qscheme() == torch.per_tensor_affine: + param = QuantParam(weight_np, bias, qweight.q_scale(), + int(qweight.q_zero_point()), param_name) + else: + scales = qweight.q_per_channel_scales().numpy() + zero_points = qweight.q_per_channel_zero_points().numpy() + assert np.all(zero_points == 0) + param = QuantParam(weight_np, bias, scales, 0, param_name) + + return param + + +def get_weight_quant_params(script_module): + conv_packed_params = [] + linear_packed_params = [] + + for name, m in script_module.named_modules(): + if isinstance(m, torch.jit.RecursiveScriptModule): + if "Conv" in m.original_name: + conv_packed_params.append((name, m.state_dict())) + elif m.original_name == "LinearPackedParams": + linear_packed_params.append((name, m.state_dict())) + + pairs = [(torch.ops.quantized.conv2d_unpack, conv_packed_params), + (torch.ops.quantized.linear_unpack, linear_packed_params)] + + quant_params = {} + param_name = "_packed_params" + for unpack_func, params in pairs: + for name, state_dict in params: + assert len(state_dict) == 1 + assert param_name in state_dict + key = name + "." + param_name + packed_param = state_dict[param_name] + quant_params[key] = unpack_quant_params(key, packed_param, + unpack_func) + + return quant_params + + +def add_quant_params_to_outputs(outputs, output_index_map, + packed_param_map, quant_params): + for node_name, packed_param_name in packed_param_map.items(): + qparam = quant_params[packed_param_name] + output_index_map[node_name] = len(outputs) + qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale, + qparam.zero_point, out_dtype="int8", + axis=0) + param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var) + outputs.append(param_tup) + + +def get_quant_param_for_input(input_value): + output_quant_param_indices = { + "aten::quantize_per_tensor": (1, 2), + "quantized::conv2d": (6, 7), + "quantized::conv2d_relu": (6, 7), + "quantized::linear": (2, 3), + "quantized::linear_relu": (2, 3), + "quantized::add_relu": (2, 3), + "quantized::add": (2, 3), + "quantized::mul_relu": (2, 3), + "quantized::mul": (2, 3), + "quantized::cat": (2, 3), + "quantized::mul_scalar": (2, 3), + "quantized::add_scalar": (2, 3) + } + + def dfs(current_node): + # trace back to find the producer of this input value + current_op = current_node.kind() + if current_op in output_quant_param_indices: + indices = output_quant_param_indices[current_op] + scale = current_node.inputsAt(indices[0]) + zp = current_node.inputsAt(indices[1]) + return scale, zp + else: + # Assume quantized tensor comes earlier in the args + for arg in current_node.inputs(): + return dfs(arg.node()) + + assert False, "No producer for %s" % (str(current_node)) + + return dfs(input_value.node()) + + +def get_add_scalar_output_quant_param(input_scale, input_zero_point, + scalar): + # refer to aten/src/ATen/native/quantized/cpu/qadd.cpp + q_min = 0 + q_max = 255 + s = input_scale + z = input_zero_point + c = scalar + c_q = round(c / s) + + if q_min > z - c_q: + s_prime = (float(q_max) - (z - c_q)) / (float(q_max) - q_min) * s + z_prime = q_min + elif q_max < z - c_q: + s_prime = (float(z - c_q) - q_min) / (float(q_max) - q_min) * s + z_prime = q_max + else: + s_prime = s + z_prime = z - c_q + + return s_prime, z_prime + + +def get_mul_scalar_output_quant_param(input_scale, input_zero_point, + scalar): + # refer to aten/src/ATen/native/quantized/cpu/qmul.cpp + q_min = 0 + q_max = 255 + self_scale = input_scale + self_zero_point = input_zero_point + other_val = scalar + + if other_val > 0.0: + s_prime = other_val * self_scale + z_prime = self_zero_point + elif other_val == 0.0: + s_prime = 1.0 + z_prime = 0 + else: + s_prime = abs(other_val) * self_scale + z_prime = q_max - (self_zero_point - q_min) + + return s_prime, z_prime + + +def add_output_quant_params_to_scalar_op(node, graph, + input_scale, input_zero_point, + scalar): + operator = node.kind() + + if operator == "quantized::mul_scalar": + out_scale, out_zero_point = \ + get_mul_scalar_output_quant_param(input_scale, input_zero_point, + scalar) + elif operator == "quantized::add_scalar": + out_scale, out_zero_point = \ + get_add_scalar_output_quant_param(input_scale, input_zero_point, + scalar) + else: + assert False, "unsupported scalar op: %s" % operator + + # create new constant nodes and add them to graph + out_scale_node = graph.create("prim::Constant") + out_zero_point_node = graph.create("prim::Constant") + out_scale_node.insertBefore(node) + out_zero_point_node.insertBefore(node) + out_scale_node.f_("value", out_scale) + out_zero_point_node.i_("value", out_zero_point) + out_scale_node.output().setType(torch._C.FloatType.get()) + out_zero_point_node.output().setType(torch._C.IntType.get()) + node.addInput(out_scale_node.output()) + node.addInput(out_zero_point_node.output()) + + +def add_input_quant_params_to_op_inputs(graph): + # Quantized operators in PyTorch do not take input quant params as + # arguments. But QNN expects them to be passed in as arguements. + # To simplify the translation of inputs, we add input quant params + # to inputs of PyTorch quantized operator nodes. See _impl in + # _quantized_conv2d() below for example of why this is helpful. + num_quantized_inputs = {"quantized::conv2d": 1, + "quantized::conv2d_relu": 1, + "quantized::linear": 1, + "quantized::linear_relu": 1, + "quantized::add_relu": 2, + "quantized::add": 2, + "quantized::mul_relu": 2, + "quantized::mul": 2, + "aten::dequantize": 1, + "aten::mean": 1, + "aten::upsample_bilinear2d": 1, + "aten::relu_": 1, + "aten::relu": 1, + "quantized::add_scalar": 1, + "quantized::mul_scalar": 1, + 'quantized::relu6': 1} + + need_input_quant_param = set(num_quantized_inputs.keys()) + need_input_quant_param.add("quantized::cat") + + for node in graph.nodes(): + operator = node.kind() + if operator not in need_input_quant_param: + continue + + input_scales = [] + input_zero_points = [] + + if operator == "quantized::cat": + inputs = node.inputsAt(0).node().inputs() + for inp in inputs: + scale, zp = get_quant_param_for_input(inp) + input_scales.append(scale) + input_zero_points.append(zp) + else: + for i in range(num_quantized_inputs[operator]): + scale, zp = get_quant_param_for_input(node.inputsAt(i)) + input_scales.append(scale) + input_zero_points.append(zp) + + if operator in ["quantized::add_scalar", "quantized::mul_scalar"]: + scalar = node.inputsAt(1).node().f("value") + inp_scale = input_scales[0].node().f("value") + inp_zero_point = input_zero_points[0].node().i("value") + + add_output_quant_params_to_scalar_op(node, graph, + inp_scale, inp_zero_point, + scalar) + + for scale, zp in zip(input_scales, input_zero_points): + node.addInput(scale) + node.addInput(zp) + + +def add_quant_params(params, quant_params): + for qparam in quant_params.values(): + params[qparam.weight_var.name_hint] = tvm.nd.array(qparam.weight) + if qparam.bias is not None: + params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias) + + +def quantized_adaptive_avg_2d(data, func): + inp = _op.cast(data, dtype="int32") + out = func(inp) + return _op.cast(out, "uint8") + + +def quantized_mean(data, input_scale, input_zero_point, func): + dequantized = relay.qnn.op.dequantize(data, input_scale, input_zero_point) + out = func(dequantized) + return relay.qnn.op.quantize(out, input_scale, input_zero_point, + out_dtype="uint8", axis=1) + + +def quantized_upsample(data, input_scale, input_zero_point, func): + data = relay.qnn.op.dequantize(data, input_scale, input_zero_point) + out = func(data) + return relay.qnn.op.quantize(out, input_scale, input_zero_point, + out_dtype="uint8", axis=1) + + +def quantized_relu(data, input_zero_point): + zp = _op.cast(input_zero_point, dtype="uint8") + return _op.tensor.maximum(data, zp) + + +def _quantize_per_tensor(): + def _impl(inputs, input_type): + return relay.qnn.op.quantize(inputs[0], _expr.const(inputs[1]), + _expr.const(inputs[2]), out_dtype="uint8", + axis=1) + return _impl + + +def _dequantize(): + def _impl(inputs, input_type): + inp_scale = _expr.const(inputs[1]) + inp_zero_point = _expr.const(inputs[2]) + return relay.qnn.op.dequantize(inputs[0], inp_scale, inp_zero_point) + return _impl + + +def get_numpy(relay_const_scalar): + return relay_const_scalar.data.asnumpy() + + +def get_scalar(relay_const_scalar): + return np.asscalar(get_numpy(relay_const_scalar)) + + +def _quantized_conv2d(with_relu=False): + def _impl(inputs, input_type): + # refer to src/ATen/native/quantized/cpu/qconv.cpp + # inputs[0]: input tensor + # inputs[1]: (weight, scale, zero_point, bias) + # inputs[2-5]: stride, padding, dilation, groups + # inputs[6]: output_scale + # inputs[7]: output_zero_point + # inputs[8]: input_scale (added manually by frontend) + # inputs[9]: input_zero_point (added manually by frontend) + weight = inputs[1][0] + weight_scale = inputs[1][1] + weight_zero_point = inputs[1][2] + + output_scale = _expr.const(inputs[6]) + output_zero_point = _expr.const(inputs[7]) + + assert len(inputs) == 10, "Input quant params not found in op inputs" + input_scale = _expr.const(inputs[8]) + input_zero_point = _expr.const(inputs[9]) + + strides, padding, dilation = inputs[2], inputs[3], inputs[4] + strides = infer_shape(inputs[2]) + padding = infer_shape(inputs[3]) + dilation = infer_shape(inputs[4]) + groups = inputs[5] + + weight_shape = infer_shape(weight) + kernel_size = (weight_shape[2], weight_shape[3]) + out_channels = weight_shape[0] + + if padding[0] != 0 or padding[1] != 0: + pad_val = get_scalar(input_zero_point) + inp = _op.nn.pad(inputs[0], pad_width=((0, 0), + (0, 0), + (padding[0], padding[0]), + (padding[1], padding[1])), + pad_value=float(pad_val)) + else: + inp = inputs[0] + + conv_out = relay.qnn.op.conv2d(inp, weight, + input_zero_point, weight_zero_point, + input_scale, weight_scale, + kernel_size=kernel_size, + dilation=dilation, strides=strides, + padding=(0, 0), groups=groups, + channels=out_channels) + + # input scale * weight scale + requant_input_scale = _expr.const(inputs[8] * get_numpy(weight_scale)) + bias_var = inputs[1][3] + + if bias_var is not None: + qbias = relay.qnn.op.quantize(bias_var, requant_input_scale, + _expr.const(0, "int32"), + out_dtype="int32", axis=0) + conv_res = _op.nn.bias_add(conv_out, qbias) + else: + conv_res = conv_out + + requantized = relay.qnn.op.requantize(conv_res, + requant_input_scale, + _expr.const(0, "int32"), + output_scale, output_zero_point, + out_dtype="int32", axis=1) + clip_min = 0 + if with_relu: + clip_min = get_scalar(output_zero_point) + + clip = _op.tensor.clip(requantized, clip_min, 255.) + return _op.cast(clip, dtype="uint8") + + return _impl + + +def _binop(relay_op, with_relu=False): + def _impl(inputs, input_type): + output_scale = _expr.const(inputs[2]) + output_zero_point = _expr.const(inputs[3]) + assert len(inputs) == 8, "Input quant params not found in op inputs" + input_scale_lhs = _expr.const(inputs[4]) + input_zero_point_lhs = _expr.const(inputs[5]) + input_scale_rhs = _expr.const(inputs[6]) + input_zero_point_rhs = _expr.const(inputs[7]) + lhs = inputs[0] + rhs = inputs[1] + + if isinstance(lhs, _expr.Call) and lhs.op.name == 'qnn.quantize': + lhs = lhs.args[0] + else: + lhs = relay.qnn.op.dequantize(lhs, + input_scale_lhs, + input_zero_point_lhs) + + if isinstance(rhs, _expr.Call) and rhs.op.name == 'qnn.quantize': + rhs = rhs.args[0] + else: + rhs = relay.qnn.op.dequantize(rhs, + input_scale_rhs, + input_zero_point_rhs) + fp32_out = relay_op(lhs, rhs) + + if with_relu: + fp32_out = _op.nn.relu(fp32_out) + + return relay.qnn.op.quantize(fp32_out, + output_scale, + output_zero_point, + axis=-1, + out_dtype="uint8") + return _impl + + +def _linear(with_relu=False): + def _impl(inputs, input_type): + weight = inputs[1][0] + weight_scale = inputs[1][1] + weight_zero_point = inputs[1][2] + output_scale = _expr.const(inputs[2]) + output_zero_point = _expr.const(inputs[3]) + assert len(inputs) == 6, "Input quant params not found in op inputs" + input_scale = _expr.const(inputs[4]) + input_zero_point = _expr.const(inputs[5]) + + weight_shape = infer_shape(weight) + dense = relay.qnn.op.dense(inputs[0], weight, + input_zero_point, weight_zero_point, + input_scale, weight_scale, + units=weight_shape[0]) + + requant_input_scale = _expr.const(inputs[4] * get_numpy(weight_scale)) + bias_var = inputs[1][3] + + if bias_var is not None: + qbias = relay.qnn.op.quantize(bias_var, requant_input_scale, + _expr.const(0, "int32"), + out_dtype="int32", axis=0) + dense_res = _op.nn.bias_add(dense, qbias) + else: + dense_res = dense + + requantized = relay.qnn.op.requantize(dense_res, + requant_input_scale, + relay.const(0, 'int32'), + output_scale, output_zero_point, + out_dtype="int32", axis=1) + clip_min = 0 + if with_relu: + clip_min = get_scalar(output_zero_point) + + clip = _op.tensor.clip(requantized, clip_min, 255.) + return _op.cast(clip, dtype="uint8") + + return _impl + + +def _cat(): + def _impl(inputs, input_type): + axis = inputs[1] + output_scale = _expr.const(inputs[2]) + output_zero_point = _expr.const(inputs[3]) + num_inputs = (len(inputs) - 4) // 2 + dequantized = [] + + for i in range(0, num_inputs): + inp_scale = _expr.const(inputs[4+i*2]) + inp_zp = _expr.const(inputs[4+i*2+1]) + dequantized.append(relay.qnn.op.dequantize(inputs[0][i], + inp_scale, inp_zp)) + + concat = _op.tensor.concatenate(dequantized, axis=axis) + return relay.qnn.op.quantize(concat, output_scale, output_zero_point, + axis=1, out_dtype="uint8") + + return _impl + + +def _add_scalar(): + def _impl(inputs, input_type): + # refer to aten/src/ATen/native/quantized/cpu/qadd.cpp + assert len(inputs) == 6, "Input quant params not found in op inputs" + s = inputs[4] + z = inputs[5] + c = inputs[1] + c_q = round(c / s) + q_min = 0 + q_max = 255 + + out_scale = _expr.const(inputs[2]) + out_zp = _expr.const(inputs[3]) + + if q_min > z - c_q or q_max < z - c_q: + dequant = relay.qnn.op.dequantize(inputs[0], + _expr.const(s), _expr.const(z)) + dequantized_add = _op.tensor.add(dequant, _expr.const(c_q * s)) + return relay.qnn.op.quantize(dequantized_add, out_scale, out_zp, + axis=1, out_dtype="uint8") + else: + # only scale change + return inputs[0] + + return _impl + + +def quantize_scalar(data, scale, zero_point): + transformed = zero_point + data / scale + return max(0, min(round(transformed), 255)) + + +def _relu6(): + def _impl(inputs, input_type): + assert len(inputs) == 4, "Input quant params not found in op inputs" + input_scale = inputs[2] + input_zero_point = inputs[3] + six = quantize_scalar(6., input_scale, input_zero_point) + return _op.tensor.clip(inputs[0], input_zero_point, six) + return _impl + + +def _mul_scalar(): + def _impl(inputs, input_type): + # refer to aten/src/ATen/native/quantized/cpu/qmul.cpp + assert len(inputs) == 6, "Input quant params not found in op inputs" + other_val = inputs[1] # scalar + + if other_val > 0.0: + # only scale change + return inputs[0] + elif other_val == 0.0: + shape = infer_shape(inputs[0]) + return _op.full(_expr.const(0), shape, dtype="uint8") + else: + q_min = 0 + q_max = 255 + bias = _expr.const(q_max + q_min, dtype="int8") + int8 = bias - _op.cast(inputs[0], "int8") + return _op.cast(int8, "uint8") + + return _impl + + +convert_map = { + 'aten::quantize_per_tensor': _quantize_per_tensor(), + 'quantized::conv2d_relu': _quantized_conv2d(True), + 'aten::dequantize': _dequantize(), + 'quantized::conv2d': _quantized_conv2d(), + 'quantized::add_relu': _binop(relay.add, True), + 'quantized::add': _binop(relay.add), + 'quantized::mul_relu': _binop(relay.multiply, True), + 'quantized::mul': _binop(relay.multiply), + 'quantized::linear': _linear(), + 'quantized::linear_relu': _linear(True), + 'quantized::cat': _cat(), + 'quantized::add_scalar': _add_scalar(), + 'quantized::mul_scalar': _mul_scalar(), + 'quantized::relu6': _relu6() +} From e7df8ec19dadad3b15ac0743980b072022a87967 Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 2 Mar 2020 12:09:02 +0900 Subject: [PATCH 02/14] fix upsampling num input --- python/tvm/relay/frontend/pytorch.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index e38bed1385ad..ba86ef797299 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -688,10 +688,20 @@ def _impl(inputs, input_types): coord_trans = "half_pixel" def func(x): - return _op.image.resize(x, out_size, "NCHW", "bilinear", coord_trans) + return _op.image.resize(x, out_size, "NCHW", method, coord_trans) if input_types[0] == "quint8": - assert len(inputs) == 7, "Input quant param not found in op inputs" + import torch + from packaging import version + + # Torch version > 1.4 changed upsampling API + if version.parse(torch.__version__) > version.parse("1.4.0"): + num_inputs = 7 + else: + num_inputs = 5 + + assert len(inputs) == num_inputs, "Input quant param not found in op inputs" + input_scale = _expr.const(inputs[-2]) input_zero_point = _expr.const(inputs[-1]) return qnn_torch.quantized_upsample(data, input_scale, From f5a319f7b4a235c492adb2d0cfc38d7f8d73275d Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 2 Mar 2020 12:39:51 +0900 Subject: [PATCH 03/14] imagenet tests added --- tests/python/frontend/pytorch/test_forward.py | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index c2ff94de546f..d3559367b4d6 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -803,6 +803,127 @@ def forward(self, inp): ctx_list=[("cuda", tvm.gpu(0))]) +def test_quantized_imagenet(): + import os + from tvm.contrib.download import download_testdata + from PIL import Image + + def get_transform(): + import torchvision.transforms as transforms + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + return transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ]) + + def get_real_image(im_height, im_width): + repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' + img_name = 'elephant-299.jpg' + image_url = os.path.join(repo_base, img_name) + img_path = download_testdata(image_url, img_name, module='data') + return Image.open(img_path).resize((im_height, im_width)) + + def get_imagenet_input(): + im = get_real_image(224, 224) + preprocess = get_transform() + pt_tensor = preprocess(im) + return np.expand_dims(pt_tensor.numpy(), 0) + + def get_tvm_runtime(script_module, input_name): + + input_shapes = {input_name: (1, 3, 224, 224)} + mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + + with relay.build_config(opt_level=3): + json, lib, params = relay.build(mod, target="llvm -mcpu=core-avx2", + params=params) + + runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.cpu(0)) + runtime.set_input(**params) + return runtime + + def get_qconfig(per_channel): + from torch.quantization.observer import MovingAverageMinMaxObserver + from torch.quantization.observer import default_weight_observer + + if per_channel: + return torch.quantization.get_default_qconfig('fbgemm') + else: + act = MovingAverageMinMaxObserver.with_args(reduce_range=False) + return torch.quantization.QConfig(activation=act, + weight=default_weight_observer) + + def quantize_model(model, inp, per_channel=False, dummy=True): + model.fuse_model() + model.qconfig = get_qconfig(per_channel) + torch.quantization.prepare(model, inplace=True) + model(inp) + torch.quantization.convert(model, inplace=True) + + from torchvision.models.quantization import resnet as qresnet + from torchvision.models.quantization import mobilenet as qmobilenet + from torchvision.models.quantization import inception as qinception + from torchvision.models.quantization import googlenet as qgooglenet + + qmodels = [] + + for per_channel in [False, True]: + qmodels += [ + ("resnet18", qresnet.resnet18(pretrained=True), per_channel), + ("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel), + ("inception_v3", qinception.inception_v3(pretrained=True), per_channel), + ("googlenet", qgooglenet(pretrained=True), per_channel), + ] + + results = [] + + for (model_name, raw_model, per_channel) in qmodels: + raw_model.eval() + + if per_channel: + model_name += ", per channel quantization" + else: + model_name += ", per tensor quantization" + + inp = get_imagenet_input() + pt_inp = torch.from_numpy(inp) + + quantize_model(raw_model, pt_inp, per_channel=per_channel, dummy=False) + script_module = torch.jit.trace(raw_model, pt_inp).eval() + + with torch.no_grad(): + pt_result = script_module(pt_inp).numpy() + + input_name = get_graph_input_names(script_module)[0] + runtime = get_tvm_runtime(script_module, input_name) + runtime.set_input(input_name, inp) + runtime.run() + + tvm_result = runtime.get_output(0).asnumpy() + + results.append((model_name, pt_result[0], tvm_result[0])) + + pt_top3_labels = np.argsort(pt_result)[::-1][:3] + tvm_top3_labels = np.argsort(pt_result)[::-1][:3] + + assert set(pt_top3_labels) == set(tvm_top3_labels) + + for (model_name, pt_result, tvm_result) in results: + max_abs_diff = np.max(np.abs(tvm_result - pt_result)) + mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) + num_correct = np.sum(tvm_result == pt_result) + + print("\nModel name: %s" % model_name) + print("PyTorch top5 label:", np.argsort(pt_result)[::-1][:5]) + print("TVM top5 label:", np.argsort(tvm_result)[::-1][:5]) + print("max abs diff:", max_abs_diff) + print("mean abs_diff:", mean_abs_diff) + print("%d in 1000 raw outputs identical." % num_correct) + + if __name__ == "__main__": # Single operator tests test_forward_add() @@ -849,3 +970,6 @@ def forward(self, inp): test_custom_conversion_map() test_segmentaton_models() + + # Quantization test + test_quantized_imagenet() From 93f374ab296519c1bc7c4bfb0078fc19d10aebc7 Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 2 Mar 2020 13:03:34 +0900 Subject: [PATCH 04/14] add qunatized module tests --- python/tvm/relay/frontend/pytorch.py | 3 +- tests/python/frontend/pytorch/qnn_test.py | 362 ++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 124 +----- 3 files changed, 367 insertions(+), 122 deletions(-) create mode 100644 tests/python/frontend/pytorch/qnn_test.py diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ba86ef797299..53113dc27547 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -32,7 +32,7 @@ from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value -import qnn_torch +from . import qnn_torch __all__ = ["from_pytorch"] @@ -880,6 +880,7 @@ def _report_missing_conversion(op_names): "prim::ListConstruct", "prim::ListUnpack", "prim::TupleConstruct", "prim::TupleUnpack"] known_ops += list(_convert_map.keys()) + known_ops += list(qnn_torch.convert_map.keys()) missing = [op_name for op_name in op_names if op_name not in known_ops] diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py new file mode 100644 index 000000000000..496b40a11627 --- /dev/null +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -0,0 +1,362 @@ +import os + +from PIL import Image + +import numpy as np + +import torch +from torch import nn +from torch.quantization import QuantStub, DeQuantStub +from torch.quantization import fuse_modules, QuantWrapper + +import tvm +from tvm import relay +from tvm.relay.frontend.pytorch import get_graph_input_names +from tvm.contrib.download import download_testdata + + +def torch_version_check(): + from packaging import version + return version.parse(torch.__version__) > version.parse("1.4.0") + + +def get_tvm_runtime(script_module, input_name): + + input_shapes = {input_name: (1, 3, 224, 224)} + mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + + with relay.build_config(opt_level=3): + # test on only cpu for now, torch cannot run quant models on cuda + json, lib, params = relay.build(mod, target="llvm", params=params) + + runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.cpu(0)) + runtime.set_input(**params) + return runtime + + +def get_qconfig(per_channel): + from torch.quantization.observer import MovingAverageMinMaxObserver + from torch.quantization.observer import default_weight_observer + + if per_channel: + return torch.quantization.get_default_qconfig('fbgemm') + else: + act = MovingAverageMinMaxObserver.with_args(reduce_range=False) + return torch.quantization.QConfig(activation=act, + weight=default_weight_observer) + + +def quantize_model(model, inp, per_channel=False, dummy=True): + model.fuse_model() + model.qconfig = get_qconfig(per_channel) + torch.quantization.prepare(model, inplace=True) + model(inp) + torch.quantization.convert(model, inplace=True) + + +class ConvBn(nn.Module): + def __init__(self, with_relu=False): + super().__init__() + layers = [nn.Conv2d(3, 32, 3, bias=True), + nn.BatchNorm2d(32)] + if with_relu: + layers.append(nn.ReLU()) + self.conv = nn.Sequential(*layers) + self.quant_wrap = QuantWrapper(self.conv) + self.with_relu = with_relu + + def forward(self, x): + return self.quant_wrap(x) + + def fuse_model(self): + indices = ["0", "1"] + if self.with_relu: + indices.append("2") + fuse_modules(self.conv, indices, inplace=True) + + +class Linear(nn.Module): + def __init__(self, with_relu=False): + super().__init__() + layers = [nn.Linear(16, 32)] + if with_relu: + layers.append(nn.ReLU()) + self.fc = nn.Sequential(*layers) + self.quant_wrap = QuantWrapper(self.fc) + self.with_relu = with_relu + + def forward(self, x): + return self.quant_wrap(x) + + def fuse_model(self): + if self.with_relu: + fuse_modules(self.fc, ["0", "1"], inplace=True) + + +class ReLU(nn.Module): + def __init__(self): + super().__init__() + self.relu = QuantWrapper(nn.ReLU()) + + def forward(self, x): + return self.relu(x) + + def fuse_model(self): + pass + + +# Mobilenet V3 related modules +class Hsigmoid(nn.Module): + def __init__(self, inplace=True, add_stub=False): + super().__init__() + self.float_op = nn.quantized.FloatFunctional() + self.relu6 = nn.ReLU6(inplace=inplace) + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.add_stub = add_stub + + def forward(self, x): + if self.add_stub: + x = self.quant(x) + relu6 = self.relu6(self.float_op.add_scalar(x, 3.)) + mul = self.float_op.mul_scalar(relu6, 1/6.) + if self.add_stub: + mul = self.dequant(mul) + return mul + + def fuse_model(self): + pass + + +class Hswish(nn.Module): + def __init__(self, inplace=True, add_stub=False): + super(Hswish, self).__init__() + self.float_op = nn.quantized.FloatFunctional() + self.hsigmoid = Hsigmoid(inplace, add_stub=False) + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.add_stub = add_stub + + def forward(self, x): + if self.add_stub: + x = self.quant(x) + mul = self.float_op.mul(x, self.hsigmoid(x)) + if self.add_stub: + mul = self.dequant(mul) + return mul + + def fuse_model(self): + pass + + +class SqueezeExcite(nn.Module): + def __init__(self, channel, reduction=4, add_stub=False): + super(SqueezeExcite, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + Hsigmoid(add_stub=False) + ) + self.fmul = nn.quantized.FloatFunctional() + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.add_stub = add_stub + + def forward(self, x): + b, c, _, _ = x.size() + if self.add_stub: + x = self.quant(x) + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + out = self.fmul.mul(x, y.expand_as(x)) + if self.add_stub: + return self.dequant(out) + else: + return out + + def fuse_model(self): + fuse_modules(self.fc, ["0", "1"], inplace=True) + + +class MulScalarNegative(nn.Module): + def __init__(self, ): + super().__init__() + self.float_op = nn.quantized.FloatFunctional() + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + mul = self.float_op.mul_scalar(x, -0.3) + return self.dequant(mul) + + def fuse_model(self): + pass + + +class UpsamplingBilinear(nn.Module): + def __init__(self): + super().__init__() + self.relu = QuantWrapper(nn.ReLU()) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + upsample = nn.functional.interpolate(x, scale_factor=2, + mode='bilinear', + align_corners=True) + return self.dequant(upsample) + + def fuse_model(self): + pass + + +def test_quantized_modules(): + imagenet_ishape = (1, 3, 224, 224) + + qmodules = [ + ("conv_bn", imagenet_ishape, ConvBn(), False), + ("conv_bn_relu", imagenet_ishape, ConvBn(with_relu=True), False), + ("relu", imagenet_ishape, ReLU(), False), + ("linear", (16, 16), Linear(), False), + ("linear_relu", (16, 16), Linear(with_relu=True), False), + ("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False), + ] + + qmodules += [ + ("conv_bn, per_channel", imagenet_ishape, ConvBn(), True), + ("conv_bn_relu, per_channel", imagenet_ishape, ConvBn(with_relu=True), True), + ("linear, per_channel", (16, 16), Linear(), False), + ("linear_relu, per_channel", (16, 16), Linear(with_relu=True), True) + ] + + if torch_version_check(): + qmodules += [ + ("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False), + ("hswish", imagenet_ishape, Hswish(add_stub=True), False), + ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False), + ("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True), + ("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False) + ] + else: + print("Skipping tests that requires nightly torch build (newer than 1.4)") + + for (module_name, ishape, raw_module, per_channel) in qmodules: + raw_module.eval() + inp = torch.rand(ishape) + + quantize_model(raw_module, inp, per_channel=per_channel, dummy=True) + script_module = torch.jit.trace(raw_module, inp).eval() + + with torch.no_grad(): + pt_result = script_module(inp.clone()).numpy() + + input_name = get_graph_input_names(script_module)[0] + + runtime = get_tvm_runtime(script_module, input_name) + runtime.set_input(input_name, inp.numpy().copy()) + runtime.run() + tvm_result = runtime.get_output(0).asnumpy() + + # tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1) + + max_abs_diff = np.max(np.abs(tvm_result - pt_result)) + mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) + num_identical = np.sum(tvm_result == pt_result) + correct_ratio = num_identical / float(np.prod(tvm_result.shape)) + + print(module_name, max_abs_diff, mean_abs_diff, correct_ratio) + + +def test_quantized_imagenet(): + + def get_transform(): + import torchvision.transforms as transforms + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + return transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ]) + + def get_real_image(im_height, im_width): + repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' + img_name = 'elephant-299.jpg' + image_url = os.path.join(repo_base, img_name) + img_path = download_testdata(image_url, img_name, module='data') + return Image.open(img_path).resize((im_height, im_width)) + + def get_imagenet_input(): + im = get_real_image(224, 224) + preprocess = get_transform() + pt_tensor = preprocess(im) + return np.expand_dims(pt_tensor.numpy(), 0) + + from torchvision.models.quantization import resnet as qresnet + from torchvision.models.quantization import mobilenet as qmobilenet + from torchvision.models.quantization import inception as qinception + from torchvision.models.quantization import googlenet as qgooglenet + + qmodels = [] + + for per_channel in [False, True]: + qmodels += [ + ("resnet18", qresnet.resnet18(pretrained=True), per_channel), + ("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel), + ("inception_v3", qinception.inception_v3(pretrained=True), per_channel), + ("googlenet", qgooglenet(pretrained=True), per_channel), + ] + + results = [] + + for (model_name, raw_model, per_channel) in qmodels: + raw_model.eval() + + if per_channel: + model_name += ", per channel quantization" + else: + model_name += ", per tensor quantization" + + inp = get_imagenet_input() + pt_inp = torch.from_numpy(inp) + + quantize_model(raw_model, pt_inp, per_channel=per_channel, dummy=False) + script_module = torch.jit.trace(raw_model, pt_inp).eval() + + with torch.no_grad(): + pt_result = script_module(pt_inp).numpy() + + input_name = get_graph_input_names(script_module)[0] + runtime = get_tvm_runtime(script_module, input_name) + runtime.set_input(input_name, inp) + runtime.run() + + tvm_result = runtime.get_output(0).asnumpy() + + results.append((model_name, pt_result[0], tvm_result[0])) + + pt_top3_labels = np.argsort(pt_result)[::-1][:3] + tvm_top3_labels = np.argsort(pt_result)[::-1][:3] + + assert set(pt_top3_labels) == set(tvm_top3_labels) + + for (model_name, pt_result, tvm_result) in results: + max_abs_diff = np.max(np.abs(tvm_result - pt_result)) + mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) + num_identical = np.sum(tvm_result == pt_result) + + print("\nModel name: %s" % model_name) + print("PyTorch top5 label:", np.argsort(pt_result)[::-1][:5]) + print("TVM top5 label:", np.argsort(tvm_result)[::-1][:5]) + print("max abs diff:", max_abs_diff) + print("mean abs_diff:", mean_abs_diff) + print("%d in 1000 raw outputs identical." % num_identical) + + +test_quantized_modules() +test_quantized_imagenet() diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d3559367b4d6..641f5c9f99dd 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -803,127 +803,6 @@ def forward(self, inp): ctx_list=[("cuda", tvm.gpu(0))]) -def test_quantized_imagenet(): - import os - from tvm.contrib.download import download_testdata - from PIL import Image - - def get_transform(): - import torchvision.transforms as transforms - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - return transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ]) - - def get_real_image(im_height, im_width): - repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' - img_name = 'elephant-299.jpg' - image_url = os.path.join(repo_base, img_name) - img_path = download_testdata(image_url, img_name, module='data') - return Image.open(img_path).resize((im_height, im_width)) - - def get_imagenet_input(): - im = get_real_image(224, 224) - preprocess = get_transform() - pt_tensor = preprocess(im) - return np.expand_dims(pt_tensor.numpy(), 0) - - def get_tvm_runtime(script_module, input_name): - - input_shapes = {input_name: (1, 3, 224, 224)} - mod, params = relay.frontend.from_pytorch(script_module, input_shapes) - - with relay.build_config(opt_level=3): - json, lib, params = relay.build(mod, target="llvm -mcpu=core-avx2", - params=params) - - runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.cpu(0)) - runtime.set_input(**params) - return runtime - - def get_qconfig(per_channel): - from torch.quantization.observer import MovingAverageMinMaxObserver - from torch.quantization.observer import default_weight_observer - - if per_channel: - return torch.quantization.get_default_qconfig('fbgemm') - else: - act = MovingAverageMinMaxObserver.with_args(reduce_range=False) - return torch.quantization.QConfig(activation=act, - weight=default_weight_observer) - - def quantize_model(model, inp, per_channel=False, dummy=True): - model.fuse_model() - model.qconfig = get_qconfig(per_channel) - torch.quantization.prepare(model, inplace=True) - model(inp) - torch.quantization.convert(model, inplace=True) - - from torchvision.models.quantization import resnet as qresnet - from torchvision.models.quantization import mobilenet as qmobilenet - from torchvision.models.quantization import inception as qinception - from torchvision.models.quantization import googlenet as qgooglenet - - qmodels = [] - - for per_channel in [False, True]: - qmodels += [ - ("resnet18", qresnet.resnet18(pretrained=True), per_channel), - ("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel), - ("inception_v3", qinception.inception_v3(pretrained=True), per_channel), - ("googlenet", qgooglenet(pretrained=True), per_channel), - ] - - results = [] - - for (model_name, raw_model, per_channel) in qmodels: - raw_model.eval() - - if per_channel: - model_name += ", per channel quantization" - else: - model_name += ", per tensor quantization" - - inp = get_imagenet_input() - pt_inp = torch.from_numpy(inp) - - quantize_model(raw_model, pt_inp, per_channel=per_channel, dummy=False) - script_module = torch.jit.trace(raw_model, pt_inp).eval() - - with torch.no_grad(): - pt_result = script_module(pt_inp).numpy() - - input_name = get_graph_input_names(script_module)[0] - runtime = get_tvm_runtime(script_module, input_name) - runtime.set_input(input_name, inp) - runtime.run() - - tvm_result = runtime.get_output(0).asnumpy() - - results.append((model_name, pt_result[0], tvm_result[0])) - - pt_top3_labels = np.argsort(pt_result)[::-1][:3] - tvm_top3_labels = np.argsort(pt_result)[::-1][:3] - - assert set(pt_top3_labels) == set(tvm_top3_labels) - - for (model_name, pt_result, tvm_result) in results: - max_abs_diff = np.max(np.abs(tvm_result - pt_result)) - mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) - num_correct = np.sum(tvm_result == pt_result) - - print("\nModel name: %s" % model_name) - print("PyTorch top5 label:", np.argsort(pt_result)[::-1][:5]) - print("TVM top5 label:", np.argsort(tvm_result)[::-1][:5]) - print("max abs diff:", max_abs_diff) - print("mean abs_diff:", mean_abs_diff) - print("%d in 1000 raw outputs identical." % num_correct) - - if __name__ == "__main__": # Single operator tests test_forward_add() @@ -972,4 +851,7 @@ def quantize_model(model, inp, per_channel=False, dummy=True): test_segmentaton_models() # Quantization test + from qnn_test import test_quantized_imagenet, test_quantized_modules + + test_quantized_modules() test_quantized_imagenet() From 0615ca63be3ea6c5f6be8e6969be145e8ee86703 Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 2 Mar 2020 13:21:09 +0900 Subject: [PATCH 05/14] quantized module tests working --- python/tvm/relay/frontend/pytorch.py | 13 +++++++++ tests/python/frontend/pytorch/qnn_test.py | 35 ++++++++++++----------- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 53113dc27547..f496d51257d8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -19,6 +19,7 @@ # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension """PT: PyTorch frontend.""" import itertools +import logging import numpy as np @@ -710,6 +711,17 @@ def func(x): return _impl + +def _expand_as(): + def _impl(inputs, input_types): + # TODO: maybe fix this + # This assumes expand_as can be removed because TVM has broadcast op + msg = "aten::expand_as(...) found, assume it is part of broadcast op" + logging.warning(msg) + return inputs[0] + return _impl + + # Helper functions for operator implementation def _convert_data_type(input_type): @@ -830,6 +842,7 @@ def _convert_elemwise_input(data, input_type): "aten::detach" : _identity(), "aten::upsample_bilinear2d" : _upsample("bilinear"), "aten::upsample_nearest2d" : _upsample("nearest_neighbor"), + "aten::expand_as" : _expand_as() } diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 496b40a11627..0a8f60843ceb 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -20,9 +20,9 @@ def torch_version_check(): return version.parse(torch.__version__) > version.parse("1.4.0") -def get_tvm_runtime(script_module, input_name): +def get_tvm_runtime(script_module, input_name, ishape): - input_shapes = {input_name: (1, 3, 224, 224)} + input_shapes = {input_name: ishape} mod, params = relay.frontend.from_pytorch(script_module, input_shapes) with relay.build_config(opt_level=3): @@ -218,20 +218,22 @@ def test_quantized_modules(): imagenet_ishape = (1, 3, 224, 224) qmodules = [ - ("conv_bn", imagenet_ishape, ConvBn(), False), - ("conv_bn_relu", imagenet_ishape, ConvBn(with_relu=True), False), ("relu", imagenet_ishape, ReLU(), False), - ("linear", (16, 16), Linear(), False), - ("linear_relu", (16, 16), Linear(with_relu=True), False), ("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False), ] - qmodules += [ - ("conv_bn, per_channel", imagenet_ishape, ConvBn(), True), - ("conv_bn_relu, per_channel", imagenet_ishape, ConvBn(with_relu=True), True), - ("linear, per_channel", (16, 16), Linear(), False), - ("linear_relu, per_channel", (16, 16), Linear(with_relu=True), True) - ] + for per_channel in [False, True]: + if per_channel: + postfix = ", per_channel" + else: + postfix = "" + + qmodules += [ + ("conv_bn" + postfix, imagenet_ishape, ConvBn(), per_channel), + ("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel), + ("linear" + postfix, (16, 16), Linear(), per_channel), + ("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel) + ] if torch_version_check(): qmodules += [ @@ -242,7 +244,7 @@ def test_quantized_modules(): ("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False) ] else: - print("Skipping tests that requires nightly torch build (newer than 1.4)") + print("Skipping tests that require torch > 1.4") for (module_name, ishape, raw_module, per_channel) in qmodules: raw_module.eval() @@ -256,7 +258,7 @@ def test_quantized_modules(): input_name = get_graph_input_names(script_module)[0] - runtime = get_tvm_runtime(script_module, input_name) + runtime = get_tvm_runtime(script_module, input_name, ishape) runtime.set_input(input_name, inp.numpy().copy()) runtime.run() tvm_result = runtime.get_output(0).asnumpy() @@ -272,7 +274,6 @@ def test_quantized_modules(): def test_quantized_imagenet(): - def get_transform(): import torchvision.transforms as transforms normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], @@ -332,7 +333,7 @@ def get_imagenet_input(): pt_result = script_module(pt_inp).numpy() input_name = get_graph_input_names(script_module)[0] - runtime = get_tvm_runtime(script_module, input_name) + runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224)) runtime.set_input(input_name, inp) runtime.run() @@ -359,4 +360,4 @@ def get_imagenet_input(): test_quantized_modules() -test_quantized_imagenet() +#test_quantized_imagenet() From da89492d0feb2eec3e8b9752edbed4beb5b3db5a Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 2 Mar 2020 13:52:43 +0900 Subject: [PATCH 06/14] imagenet test working --- tests/python/frontend/pytorch/qnn_test.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 0a8f60843ceb..883f48b12c4d 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -27,6 +27,7 @@ def get_tvm_runtime(script_module, input_name, ishape): with relay.build_config(opt_level=3): # test on only cpu for now, torch cannot run quant models on cuda + # also not to make CI too slow json, lib, params = relay.build(mod, target="llvm", params=params) runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.cpu(0)) @@ -180,6 +181,7 @@ def fuse_model(self): fuse_modules(self.fc, ["0", "1"], inplace=True) +# test on quantized::mul_scalar with negative scale class MulScalarNegative(nn.Module): def __init__(self, ): super().__init__() @@ -263,6 +265,7 @@ def test_quantized_modules(): runtime.run() tvm_result = runtime.get_output(0).asnumpy() + # we cannot make any guarantee on how close the raw output is to torch # tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1) max_abs_diff = np.max(np.abs(tvm_result - pt_result)) @@ -341,23 +344,22 @@ def get_imagenet_input(): results.append((model_name, pt_result[0], tvm_result[0])) - pt_top3_labels = np.argsort(pt_result)[::-1][:3] - tvm_top3_labels = np.argsort(pt_result)[::-1][:3] + pt_top3_labels = np.argsort(pt_result[0])[::-1][:3] + tvm_top3_labels = np.argsort(pt_result[0])[::-1][:3] assert set(pt_top3_labels) == set(tvm_top3_labels) + print("Torch top3 label:", pt_top3_labels) + print("TVM top3 label:", tvm_top3_labels) + for (model_name, pt_result, tvm_result) in results: max_abs_diff = np.max(np.abs(tvm_result - pt_result)) mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) num_identical = np.sum(tvm_result == pt_result) print("\nModel name: %s" % model_name) - print("PyTorch top5 label:", np.argsort(pt_result)[::-1][:5]) - print("TVM top5 label:", np.argsort(tvm_result)[::-1][:5]) + print("PyTorch top3 label:", np.argsort(pt_result)[::-1][:3]) + print("TVM top3 label:", np.argsort(tvm_result)[::-1][:3]) print("max abs diff:", max_abs_diff) print("mean abs_diff:", mean_abs_diff) print("%d in 1000 raw outputs identical." % num_identical) - - -test_quantized_modules() -#test_quantized_imagenet() From 0daf4ca0da2e4410c5726efe43d5abd8ed94887f Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 2 Mar 2020 14:16:45 +0900 Subject: [PATCH 07/14] fix lint --- python/tvm/relay/frontend/qnn_torch.py | 141 +++++++++++++--------- tests/python/frontend/pytorch/qnn_test.py | 18 +++ 2 files changed, 102 insertions(+), 57 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 09d0e3841bb6..b48cad46f3c2 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -1,6 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +""" Functions to convert quantized torch models to QNN""" + +import numpy as np + import torch + import tvm -import numpy as np from tvm import relay from tvm.relay import expr as _expr from tvm.relay import op as _op @@ -8,6 +29,8 @@ class QuantParam: + """ A placeholder for weight quantization parameters """ + def __init__(self, weight, bias, scale, zero_point, param_key): param_prefix = param_key[:-len("._packed_params")] self.weight_var = _expr.var(param_prefix + "_weight", @@ -26,7 +49,7 @@ def __init__(self, weight, bias, scale, zero_point, param_key): self.zero_point = _expr.const(zero_point, dtype="int32") -def unpack_quant_params(param_name, packed_params, unpack_func): +def _unpack_quant_params(param_name, packed_params, unpack_func): qweight, bias = unpack_func(packed_params) weight_np = qweight.dequantize().numpy() @@ -43,6 +66,7 @@ def unpack_quant_params(param_name, packed_params, unpack_func): def get_weight_quant_params(script_module): + """ Retrive and unpack weight parameters from quantized modules """ conv_packed_params = [] linear_packed_params = [] @@ -64,14 +88,15 @@ def get_weight_quant_params(script_module): assert param_name in state_dict key = name + "." + param_name packed_param = state_dict[param_name] - quant_params[key] = unpack_quant_params(key, packed_param, - unpack_func) + quant_params[key] = _unpack_quant_params(key, packed_param, + unpack_func) return quant_params def add_quant_params_to_outputs(outputs, output_index_map, packed_param_map, quant_params): + """ Add quant params to outputs so that they can be referenced later """ for node_name, packed_param_name in packed_param_map.items(): qparam = quant_params[packed_param_name] output_index_map[node_name] = len(outputs) @@ -82,7 +107,7 @@ def add_quant_params_to_outputs(outputs, output_index_map, outputs.append(param_tup) -def get_quant_param_for_input(input_value): +def _get_quant_param_for_input(input_value): output_quant_param_indices = { "aten::quantize_per_tensor": (1, 2), "quantized::conv2d": (6, 7), @@ -106,18 +131,18 @@ def dfs(current_node): scale = current_node.inputsAt(indices[0]) zp = current_node.inputsAt(indices[1]) return scale, zp - else: - # Assume quantized tensor comes earlier in the args - for arg in current_node.inputs(): - return dfs(arg.node()) + + # Assume quantized tensor comes earlier in the args + for arg in current_node.inputs(): + return dfs(arg.node()) assert False, "No producer for %s" % (str(current_node)) return dfs(input_value.node()) -def get_add_scalar_output_quant_param(input_scale, input_zero_point, - scalar): +def _get_add_scalar_output_quant_param(input_scale, input_zero_point, + scalar): # refer to aten/src/ATen/native/quantized/cpu/qadd.cpp q_min = 0 q_max = 255 @@ -139,8 +164,8 @@ def get_add_scalar_output_quant_param(input_scale, input_zero_point, return s_prime, z_prime -def get_mul_scalar_output_quant_param(input_scale, input_zero_point, - scalar): +def _get_mul_scalar_output_quant_param(input_scale, input_zero_point, + scalar): # refer to aten/src/ATen/native/quantized/cpu/qmul.cpp q_min = 0 q_max = 255 @@ -161,19 +186,19 @@ def get_mul_scalar_output_quant_param(input_scale, input_zero_point, return s_prime, z_prime -def add_output_quant_params_to_scalar_op(node, graph, - input_scale, input_zero_point, - scalar): +def _add_output_quant_params_to_scalar_op(node, graph, + input_scale, input_zero_point, + scalar): operator = node.kind() if operator == "quantized::mul_scalar": out_scale, out_zero_point = \ - get_mul_scalar_output_quant_param(input_scale, input_zero_point, - scalar) + _get_mul_scalar_output_quant_param(input_scale, input_zero_point, + scalar) elif operator == "quantized::add_scalar": out_scale, out_zero_point = \ - get_add_scalar_output_quant_param(input_scale, input_zero_point, - scalar) + _get_add_scalar_output_quant_param(input_scale, input_zero_point, + scalar) else: assert False, "unsupported scalar op: %s" % operator @@ -191,11 +216,13 @@ def add_output_quant_params_to_scalar_op(node, graph, def add_input_quant_params_to_op_inputs(graph): - # Quantized operators in PyTorch do not take input quant params as - # arguments. But QNN expects them to be passed in as arguements. - # To simplify the translation of inputs, we add input quant params - # to inputs of PyTorch quantized operator nodes. See _impl in - # _quantized_conv2d() below for example of why this is helpful. + """ + Quantized operators in PyTorch do not take input quant params as + arguments. But QNN expects them to be passed in as arguements. + To simplify the translation of inputs, we add input quant params + to inputs of PyTorch quantized operator nodes. See _impl in + _quantized_conv2d() below for example of why this is helpful. + """ num_quantized_inputs = {"quantized::conv2d": 1, "quantized::conv2d_relu": 1, "quantized::linear": 1, @@ -227,12 +254,12 @@ def add_input_quant_params_to_op_inputs(graph): if operator == "quantized::cat": inputs = node.inputsAt(0).node().inputs() for inp in inputs: - scale, zp = get_quant_param_for_input(inp) + scale, zp = _get_quant_param_for_input(inp) input_scales.append(scale) input_zero_points.append(zp) else: for i in range(num_quantized_inputs[operator]): - scale, zp = get_quant_param_for_input(node.inputsAt(i)) + scale, zp = _get_quant_param_for_input(node.inputsAt(i)) input_scales.append(scale) input_zero_points.append(zp) @@ -241,9 +268,9 @@ def add_input_quant_params_to_op_inputs(graph): inp_scale = input_scales[0].node().f("value") inp_zero_point = input_zero_points[0].node().i("value") - add_output_quant_params_to_scalar_op(node, graph, - inp_scale, inp_zero_point, - scalar) + _add_output_quant_params_to_scalar_op(node, graph, + inp_scale, inp_zero_point, + scalar) for scale, zp in zip(input_scales, input_zero_points): node.addInput(scale) @@ -251,6 +278,7 @@ def add_input_quant_params_to_op_inputs(graph): def add_quant_params(params, quant_params): + """ Add quant parameters to TVM param map """ for qparam in quant_params.values(): params[qparam.weight_var.name_hint] = tvm.nd.array(qparam.weight) if qparam.bias is not None: @@ -283,7 +311,7 @@ def quantized_relu(data, input_zero_point): def _quantize_per_tensor(): - def _impl(inputs, input_type): + def _impl(inputs, _): return relay.qnn.op.quantize(inputs[0], _expr.const(inputs[1]), _expr.const(inputs[2]), out_dtype="uint8", axis=1) @@ -291,23 +319,23 @@ def _impl(inputs, input_type): def _dequantize(): - def _impl(inputs, input_type): + def _impl(inputs, _): inp_scale = _expr.const(inputs[1]) inp_zero_point = _expr.const(inputs[2]) return relay.qnn.op.dequantize(inputs[0], inp_scale, inp_zero_point) return _impl -def get_numpy(relay_const_scalar): +def _get_numpy(relay_const_scalar): return relay_const_scalar.data.asnumpy() -def get_scalar(relay_const_scalar): - return np.asscalar(get_numpy(relay_const_scalar)) +def _get_scalar(relay_const_scalar): + return np.asscalar(_get_numpy(relay_const_scalar)) def _quantized_conv2d(with_relu=False): - def _impl(inputs, input_type): + def _impl(inputs, _): # refer to src/ATen/native/quantized/cpu/qconv.cpp # inputs[0]: input tensor # inputs[1]: (weight, scale, zero_point, bias) @@ -338,7 +366,7 @@ def _impl(inputs, input_type): out_channels = weight_shape[0] if padding[0] != 0 or padding[1] != 0: - pad_val = get_scalar(input_zero_point) + pad_val = _get_scalar(input_zero_point) inp = _op.nn.pad(inputs[0], pad_width=((0, 0), (0, 0), (padding[0], padding[0]), @@ -356,7 +384,7 @@ def _impl(inputs, input_type): channels=out_channels) # input scale * weight scale - requant_input_scale = _expr.const(inputs[8] * get_numpy(weight_scale)) + requant_input_scale = _expr.const(inputs[8] * _get_numpy(weight_scale)) bias_var = inputs[1][3] if bias_var is not None: @@ -374,7 +402,7 @@ def _impl(inputs, input_type): out_dtype="int32", axis=1) clip_min = 0 if with_relu: - clip_min = get_scalar(output_zero_point) + clip_min = _get_scalar(output_zero_point) clip = _op.tensor.clip(requantized, clip_min, 255.) return _op.cast(clip, dtype="uint8") @@ -383,7 +411,7 @@ def _impl(inputs, input_type): def _binop(relay_op, with_relu=False): - def _impl(inputs, input_type): + def _impl(inputs, _): output_scale = _expr.const(inputs[2]) output_zero_point = _expr.const(inputs[3]) assert len(inputs) == 8, "Input quant params not found in op inputs" @@ -421,7 +449,7 @@ def _impl(inputs, input_type): def _linear(with_relu=False): - def _impl(inputs, input_type): + def _impl(inputs, _): weight = inputs[1][0] weight_scale = inputs[1][1] weight_zero_point = inputs[1][2] @@ -437,7 +465,7 @@ def _impl(inputs, input_type): input_scale, weight_scale, units=weight_shape[0]) - requant_input_scale = _expr.const(inputs[4] * get_numpy(weight_scale)) + requant_input_scale = _expr.const(inputs[4] * _get_numpy(weight_scale)) bias_var = inputs[1][3] if bias_var is not None: @@ -455,7 +483,7 @@ def _impl(inputs, input_type): out_dtype="int32", axis=1) clip_min = 0 if with_relu: - clip_min = get_scalar(output_zero_point) + clip_min = _get_scalar(output_zero_point) clip = _op.tensor.clip(requantized, clip_min, 255.) return _op.cast(clip, dtype="uint8") @@ -464,7 +492,7 @@ def _impl(inputs, input_type): def _cat(): - def _impl(inputs, input_type): + def _impl(inputs, _): axis = inputs[1] output_scale = _expr.const(inputs[2]) output_zero_point = _expr.const(inputs[3]) @@ -485,7 +513,7 @@ def _impl(inputs, input_type): def _add_scalar(): - def _impl(inputs, input_type): + def _impl(inputs, _): # refer to aten/src/ATen/native/quantized/cpu/qadd.cpp assert len(inputs) == 6, "Input quant params not found in op inputs" s = inputs[4] @@ -504,9 +532,8 @@ def _impl(inputs, input_type): dequantized_add = _op.tensor.add(dequant, _expr.const(c_q * s)) return relay.qnn.op.quantize(dequantized_add, out_scale, out_zp, axis=1, out_dtype="uint8") - else: - # only scale change - return inputs[0] + # only scale change + return inputs[0] return _impl @@ -517,7 +544,7 @@ def quantize_scalar(data, scale, zero_point): def _relu6(): - def _impl(inputs, input_type): + def _impl(inputs, _): assert len(inputs) == 4, "Input quant params not found in op inputs" input_scale = inputs[2] input_zero_point = inputs[3] @@ -527,7 +554,7 @@ def _impl(inputs, input_type): def _mul_scalar(): - def _impl(inputs, input_type): + def _impl(inputs, _): # refer to aten/src/ATen/native/quantized/cpu/qmul.cpp assert len(inputs) == 6, "Input quant params not found in op inputs" other_val = inputs[1] # scalar @@ -535,15 +562,15 @@ def _impl(inputs, input_type): if other_val > 0.0: # only scale change return inputs[0] - elif other_val == 0.0: + if other_val == 0.0: shape = infer_shape(inputs[0]) return _op.full(_expr.const(0), shape, dtype="uint8") - else: - q_min = 0 - q_max = 255 - bias = _expr.const(q_max + q_min, dtype="int8") - int8 = bias - _op.cast(inputs[0], "int8") - return _op.cast(int8, "uint8") + + q_min = 0 + q_max = 255 + bias = _expr.const(q_max + q_min, dtype="int8") + int8 = bias - _op.cast(inputs[0], "int8") + return _op.cast(int8, "uint8") return _impl diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 883f48b12c4d..a3da657ed64f 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -1,3 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument +""" Tests on quantized torch model conversion """ import os from PIL import Image From 6a99ec9baeb023e6518f82e123dfa8e66efa547d Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 2 Mar 2020 14:36:47 +0900 Subject: [PATCH 08/14] remove top level torch import to fix ci error --- python/tvm/relay/frontend/qnn_torch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index b48cad46f3c2..1189deba3161 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -15,12 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name -""" Functions to convert quantized torch models to QNN""" +""" Functions to convert quantized torch models to QNN """ import numpy as np -import torch - import tvm from tvm import relay from tvm.relay import expr as _expr @@ -53,6 +51,7 @@ def _unpack_quant_params(param_name, packed_params, unpack_func): qweight, bias = unpack_func(packed_params) weight_np = qweight.dequantize().numpy() + import torch if qweight.qscheme() == torch.per_tensor_affine: param = QuantParam(weight_np, bias, qweight.q_scale(), int(qweight.q_zero_point()), param_name) @@ -70,6 +69,7 @@ def get_weight_quant_params(script_module): conv_packed_params = [] linear_packed_params = [] + import torch for name, m in script_module.named_modules(): if isinstance(m, torch.jit.RecursiveScriptModule): if "Conv" in m.original_name: @@ -189,6 +189,7 @@ def _get_mul_scalar_output_quant_param(input_scale, input_zero_point, def _add_output_quant_params_to_scalar_op(node, graph, input_scale, input_zero_point, scalar): + import torch operator = node.kind() if operator == "quantized::mul_scalar": From 96ccb37cad690b49eb6397bb2fec23270013b913 Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 2 Mar 2020 14:40:45 +0900 Subject: [PATCH 09/14] disable lint warning on outside toplevel import --- python/tvm/relay/frontend/qnn_torch.py | 2 +- tests/python/frontend/pytorch/qnn_test.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 1189deba3161..4a632c97ca17 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name, import-outside-toplevel """ Functions to convert quantized torch models to QNN """ import numpy as np diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index a3da657ed64f..6607fea40e57 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=import-self, invalid-name, unused-argument """ Tests on quantized torch model conversion """ import os From 75c1200410f7ec9242d591aa367bb53cc82ca268 Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 2 Mar 2020 16:37:38 +0900 Subject: [PATCH 10/14] revert parse -> convert change --- python/tvm/relay/frontend/pytorch.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f496d51257d8..f0bdd806975e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -987,7 +987,7 @@ def _get_operator_nodes(nodes): return ops -def convert_inputs(graph_inputs, input_shapes): +def parse_inputs(graph_inputs, input_shapes): """ Return Relay vars from torch input vars """ ir_inputs = list(graph_inputs) input_vars = {} @@ -1038,7 +1038,7 @@ def terminate(users): return get_use_chains(root_getattr_node, terminate) -def convert_params(graph, state_dict): +def parse_params(graph, state_dict): """ Return Relay vars and TVM NDArrays for input parameters A chain of prim::GetAttr nodes is processed one at a time @@ -1072,7 +1072,7 @@ def convert_params(graph, state_dict): return params, param_tensors, packed_param_map -def convert_operators(operators, outputs, output_index_map, ret_name): +def parse_operators(operators, outputs, output_index_map, ret_name): """ Convert each Torch IR operators to Relay equivalent """ for node_name, op_node in operators.items(): operator = op_node.kind() @@ -1148,8 +1148,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): _report_missing_conversion(op_names) params = script_module.state_dict() - input_vars = convert_inputs(graph.inputs(), input_shapes) - param_vars, tensors, packed_param_map = convert_params(graph, params) + input_vars = parse_inputs(graph.inputs(), input_shapes) + param_vars, tensors, packed_param_map = parse_params(graph, params) tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} input_vars.update(param_vars) @@ -1167,8 +1167,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): qnn_torch.add_quant_params(tvm_params, weight_quant_params) _convert_map.update(qnn_torch.convert_map) - body = convert_operators(_get_operator_nodes(graph.nodes()), outputs, - output_index_map, ret_name) + body = parse_operators(_get_operator_nodes(graph.nodes()), outputs, + output_index_map, ret_name) func = tvm.relay.Function(_analysis.free_vars(body), body) return _module.IRModule.from_expr(func), tvm_params From 9c9556d2e634286aa3beed78a81079163d552a6e Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 2 Mar 2020 19:05:52 +0900 Subject: [PATCH 11/14] add comments to qnn translation --- python/tvm/relay/frontend/qnn_torch.py | 97 +++++++++++++++++++++++--- 1 file changed, 89 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 4a632c97ca17..2f0553cfc7fd 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -48,6 +48,8 @@ def __init__(self, weight, bias, scale, zero_point, param_key): def _unpack_quant_params(param_name, packed_params, unpack_func): + # Torch stores quantized params in a custom packed format, + # need to unpack and retrieve them as numpy arrays qweight, bias = unpack_func(packed_params) weight_np = qweight.dequantize().numpy() @@ -96,7 +98,10 @@ def get_weight_quant_params(script_module): def add_quant_params_to_outputs(outputs, output_index_map, packed_param_map, quant_params): - """ Add quant params to outputs so that they can be referenced later """ + """ + Add quant params to outputs so that they can be referenced by other + ops later. Weights are quantized here. + """ for node_name, packed_param_name in packed_param_map.items(): qparam = quant_params[packed_param_name] output_index_map[node_name] = len(outputs) @@ -108,6 +113,17 @@ def add_quant_params_to_outputs(outputs, output_index_map, def _get_quant_param_for_input(input_value): + """ + We want to know the input scale and zp of this input_value, since + input quant params are not explicitly passed around in torch (they + are embeded in a QTensor data structure, not visible statically). + We know that it is quantized using output scale and zp + of some previous quantized op. The purpose of this function + is to find that pair of paramters. + """ + # Indices for output scale and zp + # For example, in quantized::conv2d(%input, %1, %2, %3, %4, %5, %6, %7), + # 6th and 7th arg are output scale and zp respectively. output_quant_param_indices = { "aten::quantize_per_tensor": (1, 2), "quantized::conv2d": (6, 7), @@ -132,10 +148,12 @@ def dfs(current_node): zp = current_node.inputsAt(indices[1]) return scale, zp + # Trace back eariler nodes, dfs order # Assume quantized tensor comes earlier in the args for arg in current_node.inputs(): return dfs(arg.node()) + # shouldn't happen assert False, "No producer for %s" % (str(current_node)) return dfs(input_value.node()) @@ -143,7 +161,11 @@ def dfs(current_node): def _get_add_scalar_output_quant_param(input_scale, input_zero_point, scalar): - # refer to aten/src/ATen/native/quantized/cpu/qadd.cpp + """ + Determine the output scale and zp of quantized::add_scalar op + This is used for mobilenet v3 + Refer to aten/src/ATen/native/quantized/cpu/qadd.cpp + """ q_min = 0 q_max = 255 s = input_scale @@ -166,7 +188,11 @@ def _get_add_scalar_output_quant_param(input_scale, input_zero_point, def _get_mul_scalar_output_quant_param(input_scale, input_zero_point, scalar): - # refer to aten/src/ATen/native/quantized/cpu/qmul.cpp + """ + Determine the output scale and zp of quantized::mul_scalar op + This is used for mobilenet v3 + Refer to aten/src/ATen/native/quantized/cpu/qmul.cpp + """ q_min = 0 q_max = 255 self_scale = input_scale @@ -189,6 +215,24 @@ def _get_mul_scalar_output_quant_param(input_scale, input_zero_point, def _add_output_quant_params_to_scalar_op(node, graph, input_scale, input_zero_point, scalar): + """ + The output scale and zp of {add,mul}_scalar op are not explicit in the IR + They are required for _get_quant_param_for_input above to work correctly + So calculate these params using the same way torch does, and make new + constant nodes in the input IR. Also add these params to the inputs of + scalar op. + + For example, + %6 : float = prim::Constant[value=3.]() + %input : QUInt8(1, 3, 224, 224) = quantized::add_scalar(%x.1, %6) + becomes + %6 : float = prim::Constant[value=3.]() + %7 : float = prim::Constant[value=0.015686161816120148]() + %8 : int = prim::Constant[value=0]() + %input : UInt8(1, 3, 224, 224) = quantized::add_scalar(%x.1, %6, %7, %8) + + %7 and %8 are newly created output scale and zp constant nodes + """ import torch operator = node.kind() @@ -218,12 +262,31 @@ def _add_output_quant_params_to_scalar_op(node, graph, def add_input_quant_params_to_op_inputs(graph): """ - Quantized operators in PyTorch do not take input quant params as - arguments. But QNN expects them to be passed in as arguements. - To simplify the translation of inputs, we add input quant params - to inputs of PyTorch quantized operator nodes. See _impl in - _quantized_conv2d() below for example of why this is helpful. + In Torch, input quant params are not explicitly passed around + Instead, they are stored in QTensor data structure, and retrieved + at runtime by each quantized ops. + However, they need to be known statically for QNN translation. + To workaround and simplify the translation of inputs, we manually add + input quant params to inputs of Torch quantized operators listed below. + See _quantized_conv2d() below for example of why this is helpful. + + For example, + %input : QUInt8(1, 512, 7, 7) = quantized::add(%x.8, %x.9, %434, %435) + becomes + %395 : float = prim::Constant[value=0.036212071776390076]() + %396 : int = prim::Constant[value=0]() + %430 : float = prim::Constant[value=0.16080744564533234]() + %431 : int = prim::Constant[value=42]() + %input : QUInt8(1, 512, 7, 7) = quantized::add(%x.8, %x.9, %434, %435, + %430, %431, %395, %396) + + %434, %435 are output scale and zp of quantized::add op + %430, %431, %395, %396 are two pairs of input (scale, zp) for two tensors + added by this function """ + # How many quantized tensors each op takes as inputs? + # A pair of (scale, zp) for each input quantized tensor will be added + # to the input nodes num_quantized_inputs = {"quantized::conv2d": 1, "quantized::conv2d_relu": 1, "quantized::linear": 1, @@ -293,6 +356,7 @@ def quantized_adaptive_avg_2d(data, func): def quantized_mean(data, input_scale, input_zero_point, func): + # refer to aten/src/ATen/native/quantized/cpu/qreduction.cpp dequantized = relay.qnn.op.dequantize(data, input_scale, input_zero_point) out = func(dequantized) return relay.qnn.op.quantize(out, input_scale, input_zero_point, @@ -300,6 +364,7 @@ def quantized_mean(data, input_scale, input_zero_point, func): def quantized_upsample(data, input_scale, input_zero_point, func): + # currently piggy backs to fp32, it gets identical output as torch data = relay.qnn.op.dequantize(data, input_scale, input_zero_point) out = func(data) return relay.qnn.op.quantize(out, input_scale, input_zero_point, @@ -307,6 +372,7 @@ def quantized_upsample(data, input_scale, input_zero_point, func): def quantized_relu(data, input_zero_point): + # refer to aten/src/ATen/native/quantized/cpu/qrelu.cpp zp = _op.cast(input_zero_point, dtype="uint8") return _op.tensor.maximum(data, zp) @@ -353,6 +419,8 @@ def _impl(inputs, _): output_zero_point = _expr.const(inputs[7]) assert len(inputs) == 10, "Input quant params not found in op inputs" + # These are manually added by add_input_quant_params_to_op_inputs above + # In torch, they are retrieved from QTensor data structure at runtime input_scale = _expr.const(inputs[8]) input_zero_point = _expr.const(inputs[9]) @@ -412,10 +480,13 @@ def _impl(inputs, _): def _binop(relay_op, with_relu=False): + # refer to aten/src/ATen/native/quantized/cpu/{qadd, qmul}.cpp + # they piggy backs to fp32 math by dequantize -> fp32 math -> quantize def _impl(inputs, _): output_scale = _expr.const(inputs[2]) output_zero_point = _expr.const(inputs[3]) assert len(inputs) == 8, "Input quant params not found in op inputs" + # Manually added by add_input_quant_params_to_op_inputs above input_scale_lhs = _expr.const(inputs[4]) input_zero_point_lhs = _expr.const(inputs[5]) input_scale_rhs = _expr.const(inputs[6]) @@ -450,6 +521,7 @@ def _impl(inputs, _): def _linear(with_relu=False): + # similar to conv def _impl(inputs, _): weight = inputs[1][0] weight_scale = inputs[1][1] @@ -457,6 +529,7 @@ def _impl(inputs, _): output_scale = _expr.const(inputs[2]) output_zero_point = _expr.const(inputs[3]) assert len(inputs) == 6, "Input quant params not found in op inputs" + # Manually added by add_input_quant_params_to_op_inputs above input_scale = _expr.const(inputs[4]) input_zero_point = _expr.const(inputs[5]) @@ -493,6 +566,9 @@ def _impl(inputs, _): def _cat(): + # refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp + # for concat they also piggy backs to fp32(!) + # dequantize -> fp32 math -> quantize def _impl(inputs, _): axis = inputs[1] output_scale = _expr.const(inputs[2]) @@ -516,6 +592,8 @@ def _impl(inputs, _): def _add_scalar(): def _impl(inputs, _): # refer to aten/src/ATen/native/quantized/cpu/qadd.cpp + # math for calculating output scale and zp are already done + # during _add_output_quant_params_to_scalar_op above assert len(inputs) == 6, "Input quant params not found in op inputs" s = inputs[4] z = inputs[5] @@ -545,6 +623,7 @@ def quantize_scalar(data, scale, zero_point): def _relu6(): + # refer to src/ATen/native/quantized/cpu/qrelu.cpp def _impl(inputs, _): assert len(inputs) == 4, "Input quant params not found in op inputs" input_scale = inputs[2] @@ -557,6 +636,8 @@ def _impl(inputs, _): def _mul_scalar(): def _impl(inputs, _): # refer to aten/src/ATen/native/quantized/cpu/qmul.cpp + # math for calculating output scale and zp are already done + # during _add_output_quant_params_to_scalar_op above assert len(inputs) == 6, "Input quant params not found in op inputs" other_val = inputs[1] # scalar From a6239e19b430b1f069010e94dac939743dd36ff5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 3 Mar 2020 06:50:00 +0900 Subject: [PATCH 12/14] address comments, add sample outputs --- python/tvm/relay/frontend/pytorch.py | 2 +- python/tvm/relay/frontend/qnn_torch.py | 12 +-- tests/python/frontend/pytorch/qnn_test.py | 104 ++++++++++++++++++---- 3 files changed, 96 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f0bdd806975e..8249345ae4ee 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -151,7 +151,7 @@ def _impl(inputs, input_types): data = inputs[0] if input_types[0] == "quint8": assert len(inputs) == 3, "Input quant param not found in op inputs" - input_zero_point = _expr.const(inputs[2]) + input_zero_point = _expr.const(inputs[2], dtype="int32") return qnn_torch.quantized_relu(data, input_zero_point) return _op.nn.relu(data) return _impl diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 2f0553cfc7fd..502fea46a3b2 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -26,7 +26,7 @@ from tvm.relay.frontend.common import infer_shape -class QuantParam: +class QNNParam: """ A placeholder for weight quantization parameters """ def __init__(self, weight, bias, scale, zero_point, param_key): @@ -55,13 +55,13 @@ def _unpack_quant_params(param_name, packed_params, unpack_func): import torch if qweight.qscheme() == torch.per_tensor_affine: - param = QuantParam(weight_np, bias, qweight.q_scale(), - int(qweight.q_zero_point()), param_name) + param = QNNParam(weight_np, bias, qweight.q_scale(), + int(qweight.q_zero_point()), param_name) else: scales = qweight.q_per_channel_scales().numpy() zero_points = qweight.q_per_channel_zero_points().numpy() assert np.all(zero_points == 0) - param = QuantParam(weight_np, bias, scales, 0, param_name) + param = QNNParam(weight_np, bias, scales, 0, param_name) return param @@ -119,7 +119,7 @@ def _get_quant_param_for_input(input_value): are embeded in a QTensor data structure, not visible statically). We know that it is quantized using output scale and zp of some previous quantized op. The purpose of this function - is to find that pair of paramters. + is to find that pair of parameters. """ # Indices for output scale and zp # For example, in quantized::conv2d(%input, %1, %2, %3, %4, %5, %6, %7), @@ -245,7 +245,7 @@ def _add_output_quant_params_to_scalar_op(node, graph, _get_add_scalar_output_quant_param(input_scale, input_zero_point, scalar) else: - assert False, "unsupported scalar op: %s" % operator + raise NotImplementedError("unsupported scalar op: %s" % operator) # create new constant nodes and add them to graph out_scale_node = graph.create("prim::Constant") diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 6607fea40e57..846a9ba03bf1 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -282,15 +282,34 @@ def test_quantized_modules(): runtime.run() tvm_result = runtime.get_output(0).asnumpy() - # we cannot make any guarantee on how close the raw output is to torch - # tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1) - max_abs_diff = np.max(np.abs(tvm_result - pt_result)) mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) num_identical = np.sum(tvm_result == pt_result) - correct_ratio = num_identical / float(np.prod(tvm_result.shape)) + match_ratio = num_identical / float(np.prod(tvm_result.shape)) + + print(module_name, max_abs_diff, mean_abs_diff, match_ratio) + + # sample outputs + """ + relu 0.0039215684 2.6052087e-08 0.9999933567176871 + upsample bilinear 0.0 0.0 1.0 + conv_bn 0.22062653 0.011478779 0.6909348115006899 + conv_bn_relu 0.3700896 0.010921672 0.7489366477964451 + linear 0.15987062 0.009231662 0.794921875 + linear_relu 0.14180502 0.0053220326 0.8828125 + conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019 + conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732 + linear, per_channel 0.0 0.0 1.0 + linear_relu, per_channel 0.0 0.0 1.0 + hsigmoid 0.002614379 0.00020525524 0.9214896896258503 + hswish 0.0052286386 0.00063522335 0.7587359162414966 + semodule, per_channel 0.0039885044 0.0008620687 0.7838592529296875 + mul_scalar negative 0.0011764616 7.815566e-09 0.9999933567176871 + """ - print(module_name, max_abs_diff, mean_abs_diff, correct_ratio) + # we cannot make any guarantee on how close the raw output is to torch + # tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1) + assert match_ratio > 0.6 def test_quantized_imagenet(): @@ -361,22 +380,77 @@ def get_imagenet_input(): results.append((model_name, pt_result[0], tvm_result[0])) - pt_top3_labels = np.argsort(pt_result[0])[::-1][:3] - tvm_top3_labels = np.argsort(pt_result[0])[::-1][:3] - - assert set(pt_top3_labels) == set(tvm_top3_labels) - - print("Torch top3 label:", pt_top3_labels) - print("TVM top3 label:", tvm_top3_labels) - for (model_name, pt_result, tvm_result) in results: max_abs_diff = np.max(np.abs(tvm_result - pt_result)) mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) num_identical = np.sum(tvm_result == pt_result) + pt_top3_labels = np.argsort(pt_result)[::-1][:3] + tvm_top3_labels = np.argsort(pt_result)[::-1][:3] print("\nModel name: %s" % model_name) - print("PyTorch top3 label:", np.argsort(pt_result)[::-1][:3]) - print("TVM top3 label:", np.argsort(tvm_result)[::-1][:3]) + print("PyTorch top3 label:", pt_top3_labels) + print("TVM top3 label:", tvm_top3_labels) print("max abs diff:", max_abs_diff) print("mean abs_diff:", mean_abs_diff) print("%d in 1000 raw outputs identical." % num_identical) + + assert set(pt_top3_labels) == set(tvm_top3_labels) + + # sample outputs + """ + Model name: resnet18, per tensor quantization + PyTorch top3 label: [386 101 385] + TVM top3 label: [386 101 385] + max abs diff: 0.65681696 + mean abs_diff: 0.14055882 + 236 in 1000 raw outputs identical. + + Model name: mobilenet_v2, per tensor quantization + PyTorch top3 label: [101 386 385] + TVM top3 label: [101 386 385] + max abs diff: 2.1262953 + mean abs_diff: 0.41025686 + 101 in 1000 raw outputs identical. + + Model name: inception_v3, per tensor quantization + PyTorch top3 label: [101 386 385] + TVM top3 label: [101 386 385] + max abs diff: 0.9994669 + mean abs_diff: 0.098697364 + 272 in 1000 raw outputs identical. + + Model name: googlenet, per tensor quantization + PyTorch top3 label: [101 386 385] + TVM top3 label: [101 386 385] + max abs diff: 0.28248847 + mean abs_diff: 0.0634469 + 274 in 1000 raw outputs identical. + + Model name: resnet18, per channel quantization + PyTorch top3 label: [101 386 385] + TVM top3 label: [101 386 385] + max abs diff: 0.65908074 + mean abs_diff: 0.1274223 + 469 in 1000 raw outputs identical. + + Model name: mobilenet_v2, per channel quantization + PyTorch top3 label: [101 386 385] + TVM top3 label: [101 386 385] + max abs diff: 0.71120834 + mean abs_diff: 0.15883648 + 423 in 1000 raw outputs identical. + + Model name: inception_v3, per channel quantization + PyTorch top3 label: [386 101 385] + TVM top3 label: [386 101 385] + max abs diff: 1.3372154 + mean abs_diff: 0.1225224 + 401 in 1000 raw outputs identical. + + Model name: googlenet, per channel quantization + PyTorch top3 label: [101 386 385] + TVM top3 label: [101 386 385] + max abs diff: 0.34015465 + mean abs_diff: 0.054197952 + 558 in 1000 raw outputs identical. + """ From d33ac9941d35ddb92ae42d674154b453e4030d05 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 3 Mar 2020 07:56:50 +0900 Subject: [PATCH 13/14] add more comments --- python/tvm/relay/frontend/pytorch.py | 3 +- python/tvm/relay/frontend/qnn_torch.py | 44 ++++++++++++++++++++------ 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 8249345ae4ee..4cd24f401011 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1060,7 +1060,8 @@ def parse_params(graph, state_dict): full_attr_node_name = _get_output_name(getattrs[-1]) if full_attr.endswith("_packed_params"): # for quantized models - assert full_attr in state_dict + err_msg = "parameter %s not found in state dict" % full_attr + assert full_attr in state_dict, err_msg packed_param_map[full_attr_node_name] = full_attr elif full_attr in state_dict: torch_tensor = state_dict[full_attr] diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 502fea46a3b2..61848984a018 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -60,7 +60,9 @@ def _unpack_quant_params(param_name, packed_params, unpack_func): else: scales = qweight.q_per_channel_scales().numpy() zero_points = qweight.q_per_channel_zero_points().numpy() - assert np.all(zero_points == 0) + # This is an assumption posed by QNN + msg = "The values of zero points should be all zero for per channel" + assert np.all(zero_points == 0), msg param = QNNParam(weight_np, bias, scales, 0, param_name) return param @@ -72,6 +74,8 @@ def get_weight_quant_params(script_module): linear_packed_params = [] import torch + # conv and linear requires different unpacking function + # extract all conv and linear parameters separately to distinguish them for name, m in script_module.named_modules(): if isinstance(m, torch.jit.RecursiveScriptModule): if "Conv" in m.original_name: @@ -165,6 +169,7 @@ def _get_add_scalar_output_quant_param(input_scale, input_zero_point, Determine the output scale and zp of quantized::add_scalar op This is used for mobilenet v3 Refer to aten/src/ATen/native/quantized/cpu/qadd.cpp + The names of variables are the same as torch impl """ q_min = 0 q_max = 255 @@ -192,6 +197,7 @@ def _get_mul_scalar_output_quant_param(input_scale, input_zero_point, Determine the output scale and zp of quantized::mul_scalar op This is used for mobilenet v3 Refer to aten/src/ATen/native/quantized/cpu/qmul.cpp + The names of variables are the same as torch impl """ q_min = 0 q_max = 255 @@ -316,6 +322,8 @@ def add_input_quant_params_to_op_inputs(graph): input_zero_points = [] if operator == "quantized::cat": + # the number of inputs to concat is not constant + # so handle it separately inputs = node.inputsAt(0).node().inputs() for inp in inputs: scale, zp = _get_quant_param_for_input(inp) @@ -332,6 +340,7 @@ def add_input_quant_params_to_op_inputs(graph): inp_scale = input_scales[0].node().f("value") inp_zero_point = input_zero_points[0].node().i("value") + # see the comments in this function above _add_output_quant_params_to_scalar_op(node, graph, inp_scale, inp_zero_point, scalar) @@ -349,24 +358,25 @@ def add_quant_params(params, quant_params): params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias) -def quantized_adaptive_avg_2d(data, func): +def quantized_adaptive_avg_2d(data, func_fp32): + # this follows tflite impl inp = _op.cast(data, dtype="int32") - out = func(inp) + out = func_fp32(inp) return _op.cast(out, "uint8") -def quantized_mean(data, input_scale, input_zero_point, func): +def quantized_mean(data, input_scale, input_zero_point, func_fp32): # refer to aten/src/ATen/native/quantized/cpu/qreduction.cpp dequantized = relay.qnn.op.dequantize(data, input_scale, input_zero_point) - out = func(dequantized) + out = func_fp32(dequantized) return relay.qnn.op.quantize(out, input_scale, input_zero_point, out_dtype="uint8", axis=1) -def quantized_upsample(data, input_scale, input_zero_point, func): +def quantized_upsample(data, input_scale, input_zero_point, func_fp32): # currently piggy backs to fp32, it gets identical output as torch data = relay.qnn.op.dequantize(data, input_scale, input_zero_point) - out = func(data) + out = func_fp32(data) return relay.qnn.op.quantize(out, input_scale, input_zero_point, out_dtype="uint8", axis=1) @@ -387,6 +397,7 @@ def _impl(inputs, _): def _dequantize(): def _impl(inputs, _): + assert len(inputs) == 3, "Input quant params not found in op inputs" inp_scale = _expr.const(inputs[1]) inp_zero_point = _expr.const(inputs[2]) return relay.qnn.op.dequantize(inputs[0], inp_scale, inp_zero_point) @@ -444,6 +455,8 @@ def _impl(inputs, _): else: inp = inputs[0] + # padding is (0, 0) because we did explicit pad op with + # pad value being zero point above conv_out = relay.qnn.op.conv2d(inp, weight, input_zero_point, weight_zero_point, input_scale, weight_scale, @@ -456,6 +469,13 @@ def _impl(inputs, _): requant_input_scale = _expr.const(inputs[8] * _get_numpy(weight_scale)) bias_var = inputs[1][3] + # Torch does bias add and requanize scale in fp32 + # refer to third_party/fbgemm/include/fbgemm/OutputProcessing-inl.h + # Instead, we do bias add in int32 and use qnn requantize, which needs + # integer input. + # We observed no loss in accuracy in doing this way, and it is better + # for tvm because bias quantization can be done at compile time + # Instead, the torch way requires rounding of activation at runtime if bias_var is not None: qbias = relay.qnn.op.quantize(bias_var, requant_input_scale, _expr.const(0, "int32"), @@ -542,6 +562,7 @@ def _impl(inputs, _): requant_input_scale = _expr.const(inputs[4] * _get_numpy(weight_scale)) bias_var = inputs[1][3] + # See comments at quantized_conv above on the bias + requantize step if bias_var is not None: qbias = relay.qnn.op.quantize(bias_var, requant_input_scale, _expr.const(0, "int32"), @@ -569,6 +590,7 @@ def _cat(): # refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp # for concat they also piggy backs to fp32(!) # dequantize -> fp32 math -> quantize + # we can also use QNN concat op. we observed no change in accuracy def _impl(inputs, _): axis = inputs[1] output_scale = _expr.const(inputs[2]) @@ -590,10 +612,9 @@ def _impl(inputs, _): def _add_scalar(): + # this is used for mobilenet v3 def _impl(inputs, _): # refer to aten/src/ATen/native/quantized/cpu/qadd.cpp - # math for calculating output scale and zp are already done - # during _add_output_quant_params_to_scalar_op above assert len(inputs) == 6, "Input quant params not found in op inputs" s = inputs[4] z = inputs[5] @@ -602,6 +623,8 @@ def _impl(inputs, _): q_min = 0 q_max = 255 + # math for calculating output scale and zp are already done + # during _add_output_quant_params_to_scalar_op above out_scale = _expr.const(inputs[2]) out_zp = _expr.const(inputs[3]) @@ -618,6 +641,7 @@ def _impl(inputs, _): def quantize_scalar(data, scale, zero_point): + # used to quantize 6., in mobilenet v3 transformed = zero_point + data / scale return max(0, min(round(transformed), 255)) @@ -634,6 +658,7 @@ def _impl(inputs, _): def _mul_scalar(): + # this is used for mobilenet v3 def _impl(inputs, _): # refer to aten/src/ATen/native/quantized/cpu/qmul.cpp # math for calculating output scale and zp are already done @@ -648,6 +673,7 @@ def _impl(inputs, _): shape = infer_shape(inputs[0]) return _op.full(_expr.const(0), shape, dtype="uint8") + # negative scale case q_min = 0 q_max = 255 bias = _expr.const(q_max + q_min, dtype="int8") From 80b1a61125500fb3f60ecef01fb6918acc5ba860 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 3 Mar 2020 09:00:21 +0900 Subject: [PATCH 14/14] refactor bias add and requantize step --- python/tvm/relay/frontend/qnn_torch.py | 139 ++++++++++------------ tests/python/frontend/pytorch/qnn_test.py | 1 - 2 files changed, 65 insertions(+), 75 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 61848984a018..0704e34b77ef 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -412,6 +412,42 @@ def _get_scalar(relay_const_scalar): return np.asscalar(_get_numpy(relay_const_scalar)) +def _do_bias_and_requantize(output, bias, input_scale, weight_scale, + output_scale, output_zero_point, + with_relu): + """ Output processing for conv and linear """ + # this is a vector for per channel case + requant_input_scale = _expr.const(_get_numpy(input_scale) * + _get_numpy(weight_scale)) + # Torch does bias add and requanize scale in fp32 + # refer to third_party/fbgemm/include/fbgemm/OutputProcessing-inl.h + # Instead, we do bias add in int32 and use qnn requantize, which needs + # integer input. + # We observed no loss in accuracy in doing this way, and it is better + # for tvm because bias quantization can be done at compile time + # Instead, the torch way requires rounding of activation at runtime + + if bias is not None: + qbias = relay.qnn.op.quantize(bias, requant_input_scale, + _expr.const(0, "int32"), + out_dtype="int32", axis=0) + requantize_input = _op.nn.bias_add(output, qbias) + else: + requantize_input = output + + requantized = relay.qnn.op.requantize(requantize_input, + requant_input_scale, + relay.const(0, 'int32'), + output_scale, output_zero_point, + out_dtype="int32", axis=1) + clip_min = 0 + if with_relu: + clip_min = _get_scalar(output_zero_point) + + clip = _op.tensor.clip(requantized, clip_min, 255.) + return _op.cast(clip, dtype="uint8") + + def _quantized_conv2d(with_relu=False): def _impl(inputs, _): # refer to src/ATen/native/quantized/cpu/qconv.cpp @@ -464,37 +500,38 @@ def _impl(inputs, _): dilation=dilation, strides=strides, padding=(0, 0), groups=groups, channels=out_channels) - - # input scale * weight scale - requant_input_scale = _expr.const(inputs[8] * _get_numpy(weight_scale)) bias_var = inputs[1][3] - # Torch does bias add and requanize scale in fp32 - # refer to third_party/fbgemm/include/fbgemm/OutputProcessing-inl.h - # Instead, we do bias add in int32 and use qnn requantize, which needs - # integer input. - # We observed no loss in accuracy in doing this way, and it is better - # for tvm because bias quantization can be done at compile time - # Instead, the torch way requires rounding of activation at runtime - if bias_var is not None: - qbias = relay.qnn.op.quantize(bias_var, requant_input_scale, - _expr.const(0, "int32"), - out_dtype="int32", axis=0) - conv_res = _op.nn.bias_add(conv_out, qbias) - else: - conv_res = conv_out - - requantized = relay.qnn.op.requantize(conv_res, - requant_input_scale, - _expr.const(0, "int32"), - output_scale, output_zero_point, - out_dtype="int32", axis=1) - clip_min = 0 - if with_relu: - clip_min = _get_scalar(output_zero_point) + return _do_bias_and_requantize(conv_out, bias_var, input_scale, + weight_scale, output_scale, + output_zero_point, with_relu) + + return _impl + + +def _linear(with_relu=False): + # similar to conv + def _impl(inputs, _): + weight = inputs[1][0] + weight_scale = inputs[1][1] + weight_zero_point = inputs[1][2] + output_scale = _expr.const(inputs[2]) + output_zero_point = _expr.const(inputs[3]) + assert len(inputs) == 6, "Input quant params not found in op inputs" + # Manually added by add_input_quant_params_to_op_inputs above + input_scale = _expr.const(inputs[4]) + input_zero_point = _expr.const(inputs[5]) - clip = _op.tensor.clip(requantized, clip_min, 255.) - return _op.cast(clip, dtype="uint8") + weight_shape = infer_shape(weight) + dense = relay.qnn.op.dense(inputs[0], weight, + input_zero_point, weight_zero_point, + input_scale, weight_scale, + units=weight_shape[0]) + bias_var = inputs[1][3] + + return _do_bias_and_requantize(dense, bias_var, input_scale, + weight_scale, output_scale, + output_zero_point, with_relu) return _impl @@ -540,52 +577,6 @@ def _impl(inputs, _): return _impl -def _linear(with_relu=False): - # similar to conv - def _impl(inputs, _): - weight = inputs[1][0] - weight_scale = inputs[1][1] - weight_zero_point = inputs[1][2] - output_scale = _expr.const(inputs[2]) - output_zero_point = _expr.const(inputs[3]) - assert len(inputs) == 6, "Input quant params not found in op inputs" - # Manually added by add_input_quant_params_to_op_inputs above - input_scale = _expr.const(inputs[4]) - input_zero_point = _expr.const(inputs[5]) - - weight_shape = infer_shape(weight) - dense = relay.qnn.op.dense(inputs[0], weight, - input_zero_point, weight_zero_point, - input_scale, weight_scale, - units=weight_shape[0]) - - requant_input_scale = _expr.const(inputs[4] * _get_numpy(weight_scale)) - bias_var = inputs[1][3] - - # See comments at quantized_conv above on the bias + requantize step - if bias_var is not None: - qbias = relay.qnn.op.quantize(bias_var, requant_input_scale, - _expr.const(0, "int32"), - out_dtype="int32", axis=0) - dense_res = _op.nn.bias_add(dense, qbias) - else: - dense_res = dense - - requantized = relay.qnn.op.requantize(dense_res, - requant_input_scale, - relay.const(0, 'int32'), - output_scale, output_zero_point, - out_dtype="int32", axis=1) - clip_min = 0 - if with_relu: - clip_min = _get_scalar(output_zero_point) - - clip = _op.tensor.clip(requantized, clip_min, 255.) - return _op.cast(clip, dtype="uint8") - - return _impl - - def _cat(): # refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp # for concat they also piggy backs to fp32(!) diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 846a9ba03bf1..e3a876c79591 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -309,7 +309,6 @@ def test_quantized_modules(): # we cannot make any guarantee on how close the raw output is to torch # tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1) - assert match_ratio > 0.6 def test_quantized_imagenet():