From 243784d57fc6d71338e042b37fa8ce3ece2a7bb7 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Mon, 6 Jun 2022 13:00:12 +0300 Subject: [PATCH 01/25] [QNN] Disable QNN canonicalization pass. This commit enables work of TVM without QNN canonicalization pass. It adds new TOPI ops for QNN + simple compute/schedules. --- include/tvm/runtime/data_type.h | 10 + python/tvm/relay/qnn/op/_qnn.py | 27 +- python/tvm/relay/qnn/op/qnn.py | 7 - python/tvm/relay/qnn/strategy/__init__.py | 23 ++ python/tvm/relay/qnn/strategy/generic.py | 157 +++++++++++ python/tvm/relay/qnn/strategy/hexagon.py | 104 +++++++ python/tvm/te/__init__.py | 1 + python/tvm/tir/__init__.py | 1 + python/tvm/topi/hexagon/__init__.py | 1 + python/tvm/topi/hexagon/qnn.py | 314 ++++++++++++++++++++++ src/relay/backend/te_compiler_cache.cc | 109 +++++++- src/relay/backend/utils.cc | 2 +- src/relay/transforms/fuse_ops.cc | 4 +- 13 files changed, 747 insertions(+), 13 deletions(-) create mode 100644 python/tvm/relay/qnn/strategy/__init__.py create mode 100644 python/tvm/relay/qnn/strategy/generic.py create mode 100644 python/tvm/relay/qnn/strategy/hexagon.py create mode 100644 python/tvm/topi/hexagon/qnn.py diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index e0c3106e14fa..7f68ce2ad5bb 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -124,6 +124,16 @@ class DataType { * \return the result type. */ DataType element_of() const { return with_lanes(1); } + /*! + * \brief Assignment operator. + */ + DataType& operator=(const DataType& rhs) { + if (this == &rhs) { + return *this; + } + data_ = rhs.data_; + return *this; + } /*! * \brief Equal comparator. * \param other The data type to compare against. diff --git a/python/tvm/relay/qnn/op/_qnn.py b/python/tvm/relay/qnn/op/_qnn.py index a059c293a0f8..72232bdffc20 100644 --- a/python/tvm/relay/qnn/op/_qnn.py +++ b/python/tvm/relay/qnn/op/_qnn.py @@ -19,9 +19,10 @@ from tvm import topi +from .. import strategy from ...op.op import register_compute from ...op.op import register_injective_schedule -from ...op.op import register_pattern, OpPattern +from ...op.op import register_strategy, register_pattern, OpPattern @register_compute("qnn.simulated_quantize") @@ -50,3 +51,27 @@ def simulated_dequantize_compute(attrs, inputs, output_type): register_injective_schedule("qnn.simulated_dequantize") register_pattern("qnn.simulated_dequantize", OpPattern.ELEMWISE) + +# qnn.quantize +register_strategy("qnn.quantize", strategy.qnn_quantize_strategy) +register_pattern("qnn.quantize", OpPattern.ELEMWISE) + +# qnn.dequantize +register_strategy("qnn.dequantize", strategy.qnn_dequantize_strategy) +register_pattern("qnn.dequantize", OpPattern.ELEMWISE) + +# qnn.requantize +register_strategy("qnn.requantize", strategy.qnn_requantize_strategy) +register_pattern("qnn.requantize", OpPattern.ELEMWISE) + +# qnn.add +register_strategy("qnn.add", strategy.qnn_add_strategy) +register_pattern("qnn.add", OpPattern.BROADCAST) + +# qnn.conv2d +register_strategy("qnn.conv2d", strategy.qnn_conv2d_strategy) +register_pattern("qnn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) + +# qnn.dense +register_strategy("qnn.dense", strategy.qnn_dense_strategy) +register_pattern("qnn.dense", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 1f383851071b..78d6669413ca 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -29,8 +29,6 @@ from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE from tvm.topi.x86.utils import target_has_sse41 -from ... import op as reg -from ...op import OpPattern from . import _make, _requantize @@ -1212,11 +1210,6 @@ def batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype=" return _make.batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype) -# register fuse pattern for qnn ops -reg.register_pattern("qnn.quantize", OpPattern.OPAQUE) -reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE) - - def leaky_relu(x, alpha, input_scale, input_zero_point, output_scale, output_zero_point): """Quantized leaky relu. diff --git a/python/tvm/relay/qnn/strategy/__init__.py b/python/tvm/relay/qnn/strategy/__init__.py new file mode 100644 index 000000000000..05778c3e9f86 --- /dev/null +++ b/python/tvm/relay/qnn/strategy/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=wildcard-import +"""QNN op strategies.""" +from __future__ import absolute_import as _abs + +from .generic import * +from . import hexagon diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py new file mode 100644 index 000000000000..01ef6fcdf586 --- /dev/null +++ b/python/tvm/relay/qnn/strategy/generic.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Definition of generic operator strategy.""" + +from tvm.target import override_native_generic_func + + +def wrap_topi_schedule(topi_schedule): + """Wrap TOPI schedule which doesn't use attrs""" + + def wrapper(_attrs, outs, target): + with target: + return topi_schedule(outs) + + return wrapper + + +def wrap_topi_compute(topi_compute): + """Wrap TOPI compute which doesn't use attrs""" + + def wrapper(_attrs, inputs, _out_type): + return [topi_compute(*inputs)] + + return wrapper + + +def wrap_compute_quantize(topi_compute): + """Wrap TOPI compute which use out data type from attrs""" + + def wrapper(attrs, inputs, _out_type): + out_dtype = attrs.out_dtype + args = [*inputs, out_dtype] + return [topi_compute(*args)] + + return wrapper + + +def wrap_topi_qnn_conv2d(topi_compute): + """Wrap TOPI compute which use conv2d attrs and output data type""" + + def wrapper(attrs, inputs, out_type): + oshape = out_type.shape + out_dtype = attrs.out_dtype + strides = attrs.strides + padding = attrs.padding + dilation = attrs.dilation + if len([*inputs]) == 11: + args = [*inputs, strides, padding, dilation, oshape, out_dtype] + elif len([*inputs]) == 10: + args = [ # QNN Conv2d params: + inputs[0], + inputs[1], + inputs[2], + inputs[3], + inputs[4], + inputs[5], + # Bias argument + None, + # Requantization params: + inputs[6], + inputs[7], + inputs[8], + inputs[9], + strides, + padding, + dilation, + oshape, + out_dtype, + ] + else: + assert len([*inputs]) == 6 + args = [ # QNN Conv2d params: + *inputs, + # Bias argument: + None, + # Requantization params: + None, + None, + None, + None, + strides, + padding, + dilation, + oshape, + out_dtype, + ] + return [topi_compute(*args)] + + return wrapper + + +@override_native_generic_func("qnn_quantize_strategy") +def qnn_quantize_strategy(attrs, inputs, out_type, target): + """qnn.quantize generic strategy""" + raise RuntimeError( + "qnn.quantize is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_dequantize_strategy") +def qnn_dequantize_strategy(attrs, inputs, out_type, target): + """qnn.dequantize generic strategy""" + raise RuntimeError( + "qnn.dequantize is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_requantize_strategy") +def qnn_requantize_strategy(attrs, inputs, out_type, target): + """qnn.requantize generic strategy""" + raise RuntimeError( + "qnn.requantize is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_add_strategy") +def qnn_add_strategy(attrs, inputs, out_type, target): + """qnn.add generic strategy""" + raise RuntimeError( + "qnn.add is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_conv2d_strategy") +def qnn_conv2d_strategy(attrs, inputs, out_type, target): + """qnn.conv2d generic strategy""" + raise RuntimeError( + "qnn.conv2d is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + +@override_native_generic_func("qnn_dense_strategy") +def qnn_dense_strategy(attrs, inputs, out_type, target): + """qnn.dense generic strategy""" + raise RuntimeError( + "qnn.dense is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) diff --git a/python/tvm/relay/qnn/strategy/hexagon.py b/python/tvm/relay/qnn/strategy/hexagon.py new file mode 100644 index 000000000000..002ba2da39a1 --- /dev/null +++ b/python/tvm/relay/qnn/strategy/hexagon.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Definition of Hexagon operator strategy.""" +# pylint: disable=unused-argument,wildcard-import,unused-wildcard-import + +from tvm import topi +from .generic import * +from ... import op as _op + + +# TODO: This is POC code. Change it on "hexagon" instead of "cpu" +@qnn_quantize_strategy.register("cpu") +def qnn_quantize_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.quantize strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_quantize(topi.hexagon.qnn_quantize), + wrap_topi_schedule(topi.hexagon.schedule_qnn_quantize), + name="qnn_quantize.hexagon", + ) + return strategy + + +# TODO: This is POC code. Change it on "hexagon" instead of "cpu" +@qnn_dequantize_strategy.register("cpu") +def qnn_dequantize_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.dequantize strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_topi_compute(topi.hexagon.qnn_dequantize), + wrap_topi_schedule(topi.hexagon.schedule_qnn_dequantize), + name="qnn_dequantize.hexagon", + ) + return strategy + + +# TODO: This is POC code. Change it on "hexagon" instead of "cpu" +@qnn_requantize_strategy.register("cpu") +def qnn_requantize_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.requantize strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_quantize(topi.hexagon.qnn_requantize), + wrap_topi_schedule(topi.hexagon.schedule_qnn_requantize), + name="qnn_requantize.hexagon", + ) + return strategy + + +# TODO: This is POC code. Change it on "hexagon" instead of "cpu" +@qnn_add_strategy.register("cpu") +def qnn_add_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.add strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_topi_compute(topi.hexagon.qnn_add), + wrap_topi_schedule(topi.hexagon.schedule_qnn_add), + name="qnn_add.hexagon", + ) + return strategy + + +# TODO: This is POC code. Change it on "hexagon" instead of "cpu" +@qnn_conv2d_strategy.register("cpu") +def qnn_conv2d_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.conv2d strategy for Hexagon""" + data_layout = attrs.data_layout + groups = attrs.groups + strategy = _op.OpStrategy() + if groups == 1: + if data_layout == "NCHW": + strategy.add_implementation( + wrap_topi_qnn_conv2d(topi.hexagon.qnn_conv2d), + wrap_topi_schedule(topi.hexagon.schedule_qnn_conv2d), + name="qnn_conv2d.hexagon", + ) + return strategy + + +# TODO: This is POC code. Change it on "hexagon" instead of "cpu" +@qnn_dense_strategy.register("cpu") +def qnn_dense_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.dense strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_quantize(topi.hexagon.qnn_dense), + wrap_topi_schedule(topi.hexagon.schedule_qnn_dense), + name="qnn_dense.hexagon", + ) + return strategy diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index a52422f6c1d2..0907ea2ebf85 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -26,6 +26,7 @@ from tvm.tir import isnan, isfinite, isinf from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from tvm.tir import comm_reducer, min, max, sum +from tvm.tir import add, subtract, multiply from .schedule import ( Schedule, diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 8e637d2d6564..2767f2d5f779 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -74,6 +74,7 @@ from .op import comm_reducer, min, max, sum from .op import q_multiply_shift, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace +from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/topi/hexagon/__init__.py b/python/tvm/topi/hexagon/__init__.py index b94526e5b919..7172ddfd7af7 100644 --- a/python/tvm/topi/hexagon/__init__.py +++ b/python/tvm/topi/hexagon/__init__.py @@ -25,6 +25,7 @@ from .injective import * from .pad import * from .pooling import * +from .qnn import * from .reduce import * from .resize2d import * from .tensor_intrin import * diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py new file mode 100644 index 000000000000..99589cc12663 --- /dev/null +++ b/python/tvm/topi/hexagon/qnn.py @@ -0,0 +1,314 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hexagon QNN operators""" +# pylint: disable=invalid-name + +import tvm +from tvm import te +from ..generic.default import default_schedule as _default_schedule +from ..utils import get_const_tuple +from ..nn.utils import get_pad_tuple +from ..nn.pad import pad + + +def qnn_quantize(data, output_scale, output_zero_point, out_dtype): + """Compute for qnn.quantize + Q_output = clamp((round(input_tensor/output_scale) + output_zero_point), + out_dtype::min, + out_dtype::max) + TODO: Support 'axis' argument. + """ + + def _compute(*indices): + value = data(*indices) + const_min = tvm.tir.min_value(out_dtype) + const_max = tvm.tir.max_value(out_dtype) + val = te.add(te.round(te.div(value, output_scale)), output_zero_point) + return te.max(tvm.te.min(val, const_max), const_min).astype(out_dtype) + + return te.compute(data.shape, _compute) + + +def schedule_qnn_quantize(outs): + """Schedule for qnn.quantize + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.quantize + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + +def qnn_dequantize(data, input_scale, input_zero_point): + """Compute for qnn.dequantize + fp_output = input_scale * (Q_input - input_zero_point) + TODO: Support 'axis' argument. + """ + + def _compute(*indices): + value = data(*indices) + return te.multiply(input_scale, te.subtract(value, input_zero_point)) + + return te.compute(data.shape, _compute) + + +def schedule_qnn_dequantize(outs): + """Schedule for qnn.dequantize + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.dequantize + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + +def qnn_requantize(data, input_scale, input_zero_point, output_scale, output_zero_point, out_dtype): + """Compute for qnn.requantize + Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) + + TODO: support 'axis', 'rounding' and 'compute_dtype' arguments. + """ + + def _compute(*indices): + value = data(*indices) + sub = te.subtract(value, input_zero_point) + mul = te.div(input_scale, output_scale) + val = te.add(te.round(te.multiply(mul, sub)), output_zero_point) + + # clip + cast: + const_min = tvm.tir.min_value(out_dtype) + const_max = tvm.tir.max_value(out_dtype) + return te.max(tvm.te.min(val, const_max), const_min).astype(out_dtype) + + return te.compute(data.shape, _compute) + + +def schedule_qnn_requantize(outs): + """Schedule for qnn.requantize + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.requantize + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + +def qnn_add( + lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point +): + """Compute for qnn.add + Q_output = zp_output + round((lhs_scale)/(scale_output) * (lhs_input - lhs_zp_input)) + + round((rhs_scale)/(scale_output) * (rhs_input - rhs_zp_input)) + + TODO: support 'axis' argument. + """ + + assert lhs.dtype == rhs.dtype + dtype = lhs.dtype + + def _compute(*indices): + lvalue = lhs(*indices) + rvalue = rhs(*indices) + q_lv = te.round( + te.multiply(te.div(lhs_scale, output_scale), te.subtract(lvalue, lhs_zero_point)) + ).astype(dtype) + q_rv = te.round( + te.multiply(te.div(rhs_scale, output_scale), te.subtract(rvalue, rhs_zero_point)) + ).astype(dtype) + return te.add(te.add(q_lv, q_rv), output_zero_point).astype(dtype) + + return te.compute(lhs.shape, _compute) + + +def schedule_qnn_add(outs): + """Schedule for qnn.add + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.add + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + +def qnn_conv2d( # Conv2d inputs + data, + weight, + # Conv2d quantization params: + input_zero_point, + kernel_zero_point, + _input_scale, + _kernel_scale, + # bias + bias, + # Requantization params: + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + # Conv2d attributes: + strides, + padding, + dilation, + oshape, + odtype, +): + """Compute for qnn.conv2d with NCHW layout""" + in_channel = data.shape[1] # NCHW layout + kernel_height = weight.shape[2] # OIHW layout + kernel_width = weight.shape[3] # OIHW layout + + height_stride, width_stride = strides + dilation_h, dilation_w = dilation + + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + get_const_tuple(padding), (dilated_kernel_h, dilated_kernel_w) + ) + + # DOPAD + if pad_top != 0 or pad_down != 0 or pad_left != 0 or pad_right != 0: + pad_before = (0, 0, pad_top, pad_left) + pad_after = (0, 0, pad_down, pad_right) + data_pad = pad(data, pad_before, pad_after, name="data_pad") + else: + data_pad = data + + ic = te.reduce_axis((0, in_channel), name="ic") + kh = te.reduce_axis((0, kernel_height), name="kh") + kw = te.reduce_axis((0, kernel_width), name="kw") + + out = te.compute( + oshape, + lambda n, oc, oh, ow: te.sum( + te.subtract( + data_pad[ + n, + ic, + oh * height_stride + kh * dilation_h, + ow * width_stride + kw * dilation_w, + ], + input_zero_point, + ).astype("int32") + * te.subtract(weight[oc, ic, kh, kw], kernel_zero_point).astype("int32"), + axis=[ic, kh, kw], + ), + ) + + # Add bias + if bias is not None: + assert len(out.shape) == len(bias.shape) + assert bias.shape[2] == 1 and bias.shape[3] == 1 + out = te.compute(out.shape, lambda n, c, h, w: out[n, c, h, w] + bias[n, c, 1, 1]) + + def _rq_compute(*indices): + value = out(*indices) + sub = te.subtract(value, rq_input_zero_point) + mul = te.div(rq_input_scale, rq_output_scale) + val = te.add(te.round(te.multiply(mul, sub)), rq_output_zero_point) + + # clip + cast: + const_min = tvm.tir.min_value(odtype) + const_max = tvm.tir.max_value(odtype) + return te.max(tvm.te.min(val, const_max), const_min).astype(odtype) + + # Requantize output of convolution + # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) + if rq_input_scale is not None and rq_output_scale is not None: + return te.compute(out.shape, _rq_compute) + + return out + + +def schedule_qnn_conv2d(outs): + """Schedule for qnn.conv2d + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.conv2d + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + +def qnn_dense( + data, weight, input_zero_point, kernel_zero_point, _input_scale, _kernel_scale, out_dtype +): + """Compute for qnn.dense""" + M, K = get_const_tuple(data.shape) + N, _ = get_const_tuple(weight.shape) + k = te.reduce_axis((0, K), "k") + return te.compute( + (M, N), + lambda m, n: te.sum( + te.subtract(data[m, k], input_zero_point).astype(out_dtype) + * te.subtract(weight[n, k], kernel_zero_point).astype(out_dtype), + axis=k, + ), + ) + + +def schedule_qnn_dense(outs): + """Schedule for qnn.dense + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.dense + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 9a0a2bef9a47..e923cc38dccd 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -123,6 +124,86 @@ Array GetShape(const Array& shape) { return res; } +// Helper class that is used during lowering to TE. +// It matches sequence of Ops and lower them into single TOPI operation. +class PatternMatcher { + public: + PatternMatcher() + : qnn_conv2d_op_(Op::Get("qnn.conv2d")), + qnn_requantize_op_(Op::Get("qnn.requantize")), + bias_add_op_(Op::Get("add")) {} + + // Memoize visited operations + void Register(const CallNode* call_node) { + ICHECK(call_node->op.as()); + Op op = Downcast(call_node->op); + if (op == qnn_conv2d_op_) { + registered_ops_[QConv2d]++; + anchor_op_ = call_node; + } else if (op == qnn_requantize_op_) { + registered_ops_[QRequantize]++; + } else if (op == bias_add_op_) { + registered_ops_[BiasAdd]++; + } + } + + // Check whether given Op is part of matched pattern. + bool find(const Op& op) { + if (registered_ops_.empty()) return false; + + if (op == qnn_conv2d_op_ || op == qnn_requantize_op_ || op == bias_add_op_) { + // Patterns: qnn.conv2d -> qnn.requantize or qnn.conv2d -> bias_add -> qnn.requantize + if (registered_ops_[QConv2d] && registered_ops_[QRequantize]) { + return true; + } + } + return false; + } + + // returns whether given Op is last in the pattern qequence. + bool IsLeafOp(const Op& op) { return op == qnn_requantize_op_; } + + LoweredOutput LowerOps(const CallNode* a_op, const CallNode* leaf_op, + const Array& inputs, tvm::Target target) { + static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); + ICHECK(flower_call) << "relay.backend.lower_call is not registered."; + ICHECK(a_op == anchor_op_); + + // TODO(ibsidorenko): + // now this code changes output data type of anchor op on output data type of + // requantize op. After lowering it restore previous output data type. + // It will be better to pass new data type directly to lowering function. + if (auto* pattr = const_cast(a_op->attrs.as())) { + const auto* requantize_attrs = leaf_op->attrs.as(); + + DataType init_dtype = pattr->out_dtype; + pattr->out_dtype = requantize_attrs->out_dtype; + + LoweredOutput lowered_out = (*flower_call)(GetRef(a_op), inputs, target); + + pattr->out_dtype = init_dtype; + + return lowered_out; + } else { + LOG(FATAL) << "Unsupported op: " << PrettyPrint(a_op->op); + return LoweredOutput({}, OpImplementation()); + } + } + + const CallNode* GetAnchorOp() { return anchor_op_; } + + private: + const Op& qnn_conv2d_op_; + const Op& qnn_requantize_op_; + const Op& bias_add_op_; + + // Main (complicated) operation in the primitive. + const CallNode* anchor_op_ = nullptr; + + enum POper { QConv2d, BiasAdd, QRequantize }; + std::map registered_ops_; +}; + // Lowers Relay primitive Function to TE Compute class LowerToTECompute : public backend::MemoizedExprTranslator> { public: @@ -213,6 +294,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator inputs; int count_tuple = 0; for (Expr arg : call_node->args) { @@ -236,9 +319,27 @@ class LowerToTECompute : public backend::MemoizedExprTranslator(call_node), inputs, target_); - Array outputs = lowered_out->outputs; - op_implementations_[op.operator->()] = lowered_out->implementation; + Array outputs; + + if (pattern_matcher_.find(op)) { + if (pattern_matcher_.IsLeafOp(op)) { + // Lower anchor op when pattern leaf op was reached + auto anchor_op = pattern_matcher_.GetAnchorOp(); + LoweredOutput lowered_out = + pattern_matcher_.LowerOps(anchor_op, call_node, inputs, target_); + outputs = lowered_out->outputs; + Op a_op = Downcast(anchor_op->op); + op_implementations_[a_op.operator->()] = lowered_out->implementation; + } else { + // Forward inputs as "outputs" for successor. + readable_name_stream_ << '_' << op->name; + return inputs; + } + } else { + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); + outputs = lowered_out->outputs; + op_implementations_[op.operator->()] = lowered_out->implementation; + } if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as(); @@ -294,6 +395,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator GetPassPrefix(bool is_homogeneous, bool is_vm) { pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); pass_seqs.push_back(transform::ToBasicBlockNormalForm()); // Run all dialect legalization passes. - pass_seqs.push_back(relay::qnn::transform::Legalize()); + // pass_seqs.push_back(relay::qnn::transform::Legalize()); // Legalize pass is restricted to homogeneous execution for now. if (is_homogeneous) { diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index dac5dc69ead5..afa60f1bb4e5 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -885,8 +885,10 @@ class FuseMutator : private MixedModeMutator { Expr Rewrite_(const CallNode* call, const Expr& post) { if (call->op.as()) { static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); + static auto fqnncanonicalize = Op::GetAttrMap("FTVMQnnCanonicalize"); - if (fnoncomputational.get(Downcast(call->op), false)) { + Op op = Downcast(call->op); + if (fnoncomputational.get(op, false) && !fqnncanonicalize.count(op)) { return ExprMutator::VisitExpr_(call); } From 71c0cb0ae023cecd3985a4c5f56e4d74b225f00c Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Tue, 21 Jun 2022 16:57:49 +0300 Subject: [PATCH 02/25] added dependence of the qnn::transform::Legalize pass launch on target. --- src/relay/backend/build_module.cc | 2 +- src/relay/backend/task_extraction.cc | 2 +- src/relay/backend/te_compiler_cache.cc | 3 ++- src/relay/backend/utils.cc | 8 ++++++-- src/relay/backend/utils.h | 4 ++-- src/relay/backend/vm/compiler.cc | 2 +- 6 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index bca524794a20..1d1bd69b54c9 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -328,7 +328,7 @@ class RelayBuildModule : public runtime::ModuleNode { backend::BindParamsInModule(relay_module, params_); Array pass_seqs = - GetPassPrefix(/*is_homogenous=*/config_->primitive_targets.size() == 1, /*is_vm=*/false); + GetPassPrefix(/*homogeneous target=*/config_->optional_homogeneous_target, /*is_vm=*/false); transform::PassContext pass_ctx = PassContext::Current(); if (config_->optional_homogeneous_target.defined()) { diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 213841c621de..703f18587c6c 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -36,7 +36,7 @@ Array ExtractTask(IRModule mod, Target target, backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter(); backend::BindParamsInModule(mod, params); // is_vm=true for backward compatibility - Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); + Array pass_seqs = relay::backend::GetPassPrefix(target, /*is_vm=*/true); pass_seqs.push_back(transform::FuseOps()); mod = transform::Sequential(pass_seqs)(std::move(mod)); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index e923cc38dccd..f69eb765c74c 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -294,7 +294,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslatorkind->device_type == kDLCPU) pattern_matcher_.Register(call_node); Array inputs; int count_tuple = 0; diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 648a6904e77c..13e526195f92 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -219,17 +219,21 @@ ExecutorCodegenMetadata::ExecutorCodegenMetadata( TVM_REGISTER_NODE_TYPE(ExecutorCodegenMetadataNode); -Array GetPassPrefix(bool is_homogeneous, bool is_vm) { +Array GetPassPrefix(Target homogeneous_target, bool is_vm) { Array pass_seqs; // TODO(mbs): Would be nice to get spans on all diagnostics, but since they arg forgotton // by most passes there's little utility in including this now. Plus we'd need to only do // this if there's no existing spans to work from. // pass_seqs.push_back(parser::AnnotateSpans()); Array entry_functions{"main"}; + // Can be undefined in case of heterogeneous execution + bool is_homogeneous = homogeneous_target.defined(); pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); pass_seqs.push_back(transform::ToBasicBlockNormalForm()); // Run all dialect legalization passes. - // pass_seqs.push_back(relay::qnn::transform::Legalize()); + // Should be changed on kDLHexagon + if ((is_homogeneous && homogeneous_target->kind->device_type != kDLCPU) || !is_homogeneous) + pass_seqs.push_back(relay::qnn::transform::Legalize()); // Legalize pass is restricted to homogeneous execution for now. if (is_homogeneous) { diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 00c75921f2f2..91b569ad0cfc 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -676,11 +676,11 @@ inline IRModule PrimFuncToIRModule(tir::PrimFunc f) { * difference. This function unifies the shared optimization pass prefix between vm and graph * runtime, and returns the pass prefix given the backend type. * - * \param is_homogeneous True if all primitives are to be executed on the same device and target. + * \param homogeneous_target Execution target (can be undefined in case of heterogeneous execution). * \param is_vm True if passes are to be used for the vm executor. * \return An array of passes. */ -Array GetPassPrefix(bool is_homogeneous, bool is_vm); +Array GetPassPrefix(Target homogeneous_target, bool is_vm); /*! \brief Target hash function */ struct TargetStrHash { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b807f4195947..570d4b69e4b2 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1054,7 +1054,7 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig& IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { backend::BindParamsInModule(mod, params_); Array pass_seqs = relay::backend::GetPassPrefix( - /*is_homogeneous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true); + /*homogeneous target=*/config_->optional_homogeneous_target, /*is_vm=*/true); // Always plan devices so the remaining passes don't need to distinguish homogeneous vs // heterogeneous execution. From a442f70e408481d677ff5308bce8415da2ed8151 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Thu, 23 Jun 2022 20:21:11 +0300 Subject: [PATCH 03/25] Added new dense topi operator for the pattern qnn.dense+bias+requantize --- python/tvm/relay/qnn/strategy/generic.py | 42 ++++++++++++++++++++++ python/tvm/relay/qnn/strategy/hexagon.py | 2 +- python/tvm/topi/hexagon/qnn.py | 45 +++++++++++++++++++++--- src/relay/backend/te_compiler_cache.cc | 29 +++++++++++++-- 4 files changed, 110 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py index 01ef6fcdf586..ee56e8305a34 100644 --- a/python/tvm/relay/qnn/strategy/generic.py +++ b/python/tvm/relay/qnn/strategy/generic.py @@ -103,6 +103,48 @@ def wrapper(attrs, inputs, out_type): return wrapper +def wrap_topi_qnn_dense(topi_compute): + """Wrap TOPI compute which use qnn.dense attrs""" + + def wrapper(attrs, inputs, _out_type): + out_dtype = attrs.out_dtype + if len([*inputs]) == 11: + args = [*inputs, out_dtype] + elif len([*inputs]) == 10: + args = [ # QNN Dense params: + inputs[0], + inputs[1], + inputs[2], + inputs[3], + inputs[4], + inputs[5], + # Bias argument + None, + # Requantization params: + inputs[6], + inputs[7], + inputs[8], + inputs[9], + out_dtype, + ] + else: + assert len([*inputs]) == 6 + args = [ # QNN Dense params: + *inputs, + # Bias argument: + None, + # Requantization params: + None, + None, + None, + None, + out_dtype, + ] + return [topi_compute(*args)] + + return wrapper + + @override_native_generic_func("qnn_quantize_strategy") def qnn_quantize_strategy(attrs, inputs, out_type, target): """qnn.quantize generic strategy""" diff --git a/python/tvm/relay/qnn/strategy/hexagon.py b/python/tvm/relay/qnn/strategy/hexagon.py index 002ba2da39a1..fe158f05781d 100644 --- a/python/tvm/relay/qnn/strategy/hexagon.py +++ b/python/tvm/relay/qnn/strategy/hexagon.py @@ -97,7 +97,7 @@ def qnn_dense_strategy_hexagon(attrs, inputs, out_type, target): """qnn.dense strategy for Hexagon""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_quantize(topi.hexagon.qnn_dense), + wrap_topi_qnn_dense(topi.hexagon.qnn_dense), wrap_topi_schedule(topi.hexagon.schedule_qnn_dense), name="qnn_dense.hexagon", ) diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 99589cc12663..03494761562f 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -281,21 +281,58 @@ def schedule_qnn_conv2d(outs): def qnn_dense( - data, weight, input_zero_point, kernel_zero_point, _input_scale, _kernel_scale, out_dtype + data, + weight, + # Dense quantization params: + input_zero_point, + kernel_zero_point, + _input_scale, + _kernel_scale, + # bias + bias, + # Requantization params: + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + out_dtype, ): """Compute for qnn.dense""" M, K = get_const_tuple(data.shape) N, _ = get_const_tuple(weight.shape) k = te.reduce_axis((0, K), "k") - return te.compute( + # This implementation uses "int32" dense output data type. + out = te.compute( (M, N), lambda m, n: te.sum( - te.subtract(data[m, k], input_zero_point).astype(out_dtype) - * te.subtract(weight[n, k], kernel_zero_point).astype(out_dtype), + te.subtract(data[m, k], input_zero_point).astype("int32") + * te.subtract(weight[n, k], kernel_zero_point).astype("int32"), axis=k, ), ) + # Add bias + if bias is not None: + out = te.compute(out.shape, lambda n, c: out[n, c] + bias[c]) + + def _rq_compute(*indices): + value = out(*indices) + sub = te.subtract(value, rq_input_zero_point) + mul = te.div(rq_input_scale, rq_output_scale) + val = te.add(te.round(te.multiply(mul, sub)), rq_output_zero_point) + + # clip + cast: + const_min = tvm.tir.min_value(out_dtype) + const_max = tvm.tir.max_value(out_dtype) + return te.max(tvm.te.min(val, const_max), const_min).astype(out_dtype) + + # Requantize output of dense + # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) + if rq_input_scale is not None and rq_output_scale is not None: + return te.compute(out.shape, _rq_compute) + + return out + def schedule_qnn_dense(outs): """Schedule for qnn.dense diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index f69eb765c74c..741164e97f38 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -130,6 +130,7 @@ class PatternMatcher { public: PatternMatcher() : qnn_conv2d_op_(Op::Get("qnn.conv2d")), + qnn_dense_op_(Op::Get("qnn.dense")), qnn_requantize_op_(Op::Get("qnn.requantize")), bias_add_op_(Op::Get("add")) {} @@ -139,11 +140,16 @@ class PatternMatcher { Op op = Downcast(call_node->op); if (op == qnn_conv2d_op_) { registered_ops_[QConv2d]++; + ICHECK(anchor_op_ == nullptr); anchor_op_ = call_node; } else if (op == qnn_requantize_op_) { registered_ops_[QRequantize]++; } else if (op == bias_add_op_) { registered_ops_[BiasAdd]++; + } else if (op == qnn_dense_op_) { + registered_ops_[QDense]++; + ICHECK(anchor_op_ == nullptr); + anchor_op_ = call_node; } } @@ -151,16 +157,21 @@ class PatternMatcher { bool find(const Op& op) { if (registered_ops_.empty()) return false; - if (op == qnn_conv2d_op_ || op == qnn_requantize_op_ || op == bias_add_op_) { + if (op == qnn_conv2d_op_ || op == qnn_requantize_op_ || op == bias_add_op_ || + op == qnn_dense_op_) { // Patterns: qnn.conv2d -> qnn.requantize or qnn.conv2d -> bias_add -> qnn.requantize if (registered_ops_[QConv2d] && registered_ops_[QRequantize]) { return true; } + // Patterns: qnn.dense -> qnn.requantize or qnn.dense -> bias_add -> qnn.requantize + if (registered_ops_[QDense] && registered_ops_[QRequantize]) { + return true; + } } return false; } - // returns whether given Op is last in the pattern qequence. + // returns whether given Op is last in the pattern sequence. bool IsLeafOp(const Op& op) { return op == qnn_requantize_op_; } LoweredOutput LowerOps(const CallNode* a_op, const CallNode* leaf_op, @@ -183,6 +194,17 @@ class PatternMatcher { pattr->out_dtype = init_dtype; + return lowered_out; + } else if (auto* pattr = const_cast(a_op->attrs.as())) { + const auto* requantize_attrs = leaf_op->attrs.as(); + + DataType init_dtype = pattr->out_dtype; + pattr->out_dtype = requantize_attrs->out_dtype; + + LoweredOutput lowered_out = (*flower_call)(GetRef(a_op), inputs, target); + + pattr->out_dtype = init_dtype; + return lowered_out; } else { LOG(FATAL) << "Unsupported op: " << PrettyPrint(a_op->op); @@ -194,13 +216,14 @@ class PatternMatcher { private: const Op& qnn_conv2d_op_; + const Op& qnn_dense_op_; const Op& qnn_requantize_op_; const Op& bias_add_op_; // Main (complicated) operation in the primitive. const CallNode* anchor_op_ = nullptr; - enum POper { QConv2d, BiasAdd, QRequantize }; + enum POper { QConv2d, QDense, BiasAdd, QRequantize }; std::map registered_ops_; }; From 06bfeedf6919a1da704c8bd96618cd0a910c41f7 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Wed, 29 Jun 2022 18:13:59 +0300 Subject: [PATCH 04/25] Added support of axis attribute for QNN TOPI ops --- include/tvm/relay/qnn/attrs.h | 115 +++++++++++++++++++++++ python/tvm/relay/op/op_attrs.py | 10 ++ python/tvm/relay/qnn/strategy/generic.py | 20 ++-- python/tvm/topi/hexagon/qnn.py | 82 +++++++++------- src/relay/backend/te_compiler_cache.cc | 48 +++------- src/relay/qnn/op/convolution.cc | 57 +++++++---- src/relay/qnn/op/dense.cc | 21 +++-- 7 files changed, 252 insertions(+), 101 deletions(-) diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 64b2dc20981d..192ac28905e1 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -25,6 +25,7 @@ #define TVM_RELAY_QNN_ATTRS_H_ #include +#include #include @@ -125,6 +126,120 @@ struct BroadcastAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in QNN convolution operator */ +struct QConv2DAttrs : public tvm::AttrsNode { + Array strides; + Array padding; + Array dilation; + int groups; + IndexExpr channels; + Array kernel_size; + tvm::String data_layout; + tvm::String kernel_layout; + tvm::String out_layout; + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + DataType out_dtype; + + // Optional extra attributes for Hexagon target. Describes requantization parameters. + // Note, It is not set up explicitly through qnn._make.conv2d. + int axis; + DataType rq_out_dtype; + + TVM_DECLARE_ATTRS(QConv2DAttrs, "relay.attrs.QConv2DAttrs") { + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(channels) + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") + .set_default(NullValue()); + TVM_ATTR_FIELD(kernel_size) + .describe("Specifies the dimensions of the convolution window.") + .set_default(NullValue>()); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); + + // use 0 bits to indicate none. + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + + TVM_ATTR_FIELD(axis) + .describe( + "The channel axis for channel wise requantization. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); + TVM_ATTR_FIELD(rq_out_dtype) + .set_default(NullValue()) + .describe("Requantized output data type"); + } +}; + +/*! \brief Attributes for QNN dense operator */ +struct QDenseAttrs : public tvm::AttrsNode { + IndexExpr units; + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + DataType out_dtype; + + // Optional extra attributes for Hexagon target. Describes requantization parameters. + // Note, It is not set up explicitly through qnn._make.dense. + int axis; + DataType rq_out_dtype; + + TVM_DECLARE_ATTRS(QDenseAttrs, "relay.attrs.QDenseAttrs") { + TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); + + // use 0 bits to indicate none. + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + + TVM_ATTR_FIELD(axis) + .describe( + "The channel axis for channel wise requantization. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); + TVM_ATTR_FIELD(rq_out_dtype) + .set_default(NullValue()) + .describe("Requantized output data type"); + } +}; + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index b76097722c07..d6d9ec3d2365 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -529,6 +529,16 @@ class RequantizeAttrs(Attrs): """Attributes used in requantize operators""" +@tvm._ffi.register_object("relay.attrs.QConv2DAttrs") +class QConv2DAttrs(Attrs): + """Attributes used in QNN conv2d operators""" + + +@tvm._ffi.register_object("relay.attrs.QDenseAttrs") +class QDenseAttrs(Attrs): + """Attributes used in QNN dense operators""" + + @tvm._ffi.register_object("relay.attrs.ScatterAttrs") class ScatterAttrs(Attrs): """Attributes used in scatter operators""" diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py index ee56e8305a34..0a16377f41ef 100644 --- a/python/tvm/relay/qnn/strategy/generic.py +++ b/python/tvm/relay/qnn/strategy/generic.py @@ -39,11 +39,12 @@ def wrapper(_attrs, inputs, _out_type): def wrap_compute_quantize(topi_compute): - """Wrap TOPI compute which use out data type from attrs""" + """Wrap TOPI compute which use axis and out data type from attrs""" def wrapper(attrs, inputs, _out_type): + axis = attrs.axis out_dtype = attrs.out_dtype - args = [*inputs, out_dtype] + args = [*inputs, axis, out_dtype] return [topi_compute(*args)] return wrapper @@ -54,12 +55,13 @@ def wrap_topi_qnn_conv2d(topi_compute): def wrapper(attrs, inputs, out_type): oshape = out_type.shape - out_dtype = attrs.out_dtype + out_dtype = attrs.rq_out_dtype strides = attrs.strides padding = attrs.padding dilation = attrs.dilation + axis = attrs.axis if len([*inputs]) == 11: - args = [*inputs, strides, padding, dilation, oshape, out_dtype] + args = [*inputs, axis, strides, padding, dilation, oshape, out_dtype] elif len([*inputs]) == 10: args = [ # QNN Conv2d params: inputs[0], @@ -75,6 +77,8 @@ def wrapper(attrs, inputs, out_type): inputs[7], inputs[8], inputs[9], + axis, + # Conv2d attrs: strides, padding, dilation, @@ -92,6 +96,7 @@ def wrapper(attrs, inputs, out_type): None, None, None, + axis, strides, padding, dilation, @@ -107,9 +112,10 @@ def wrap_topi_qnn_dense(topi_compute): """Wrap TOPI compute which use qnn.dense attrs""" def wrapper(attrs, inputs, _out_type): - out_dtype = attrs.out_dtype + out_dtype = attrs.rq_out_dtype + axis = attrs.axis if len([*inputs]) == 11: - args = [*inputs, out_dtype] + args = [*inputs, axis, out_dtype] elif len([*inputs]) == 10: args = [ # QNN Dense params: inputs[0], @@ -125,6 +131,7 @@ def wrapper(attrs, inputs, _out_type): inputs[7], inputs[8], inputs[9], + axis, out_dtype, ] else: @@ -138,6 +145,7 @@ def wrapper(attrs, inputs, _out_type): None, None, None, + axis, out_dtype, ] return [topi_compute(*args)] diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 03494761562f..306d8fafdde5 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -18,27 +18,37 @@ # pylint: disable=invalid-name import tvm -from tvm import te +from tvm import te, topi from ..generic.default import default_schedule as _default_schedule from ..utils import get_const_tuple from ..nn.utils import get_pad_tuple from ..nn.pad import pad -def qnn_quantize(data, output_scale, output_zero_point, out_dtype): +def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype): """Compute for qnn.quantize Q_output = clamp((round(input_tensor/output_scale) + output_zero_point), out_dtype::min, out_dtype::max) - TODO: Support 'axis' argument. """ + assert len(output_scale.shape) == 0 or len(output_scale.shape) == 1 + assert len(output_zero_point.shape) == 0 or len(output_zero_point.shape) == 1 + def _compute(*indices): value = data(*indices) + + # Account scalar and 1D quantization parameters: + scale_idx = tvm.tir.indexmod(indices[axis], topi.shape(output_scale)[0]) + scale = output_scale if len(output_scale.shape) == 0 else output_scale[scale_idx] + + zp_idx = tvm.tir.indexmod(indices[axis], topi.shape(output_zero_point)[0]) + zp = output_zero_point if len(output_zero_point.shape) == 0 else output_zero_point[zp_idx] + const_min = tvm.tir.min_value(out_dtype) const_max = tvm.tir.max_value(out_dtype) - val = te.add(te.round(te.div(value, output_scale)), output_zero_point) - return te.max(tvm.te.min(val, const_max), const_min).astype(out_dtype) + val = te.add(te.round(te.div(value, scale)), zp) + return te.max(te.min(val, const_max), const_min).astype(out_dtype) return te.compute(data.shape, _compute) @@ -90,18 +100,26 @@ def schedule_qnn_dequantize(outs): return _default_schedule(outs, False) -def qnn_requantize(data, input_scale, input_zero_point, output_scale, output_zero_point, out_dtype): +def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis, out_dtype): """Compute for qnn.requantize Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) - TODO: support 'axis', 'rounding' and 'compute_dtype' arguments. + TODO: support 'rounding' and 'compute_dtype' arguments. """ def _compute(*indices): value = data(*indices) - sub = te.subtract(value, input_zero_point) - mul = te.div(input_scale, output_scale) - val = te.add(te.round(te.multiply(mul, sub)), output_zero_point) + + # Account scalar and 1D quantization parameters: + iscale_idx = tvm.tir.indexmod(indices[axis], topi.shape(input_scale)[0]) + iscale = input_scale if len(input_scale.shape) == 0 else input_scale[iscale_idx] + + oscale_idx = tvm.tir.indexmod(indices[axis], topi.shape(output_scale)[0]) + oscale = output_scale if len(output_scale.shape) == 0 else output_scale[oscale_idx] + + sub = te.subtract(value, input_zp) + mul = te.div(iscale, oscale) + val = te.add(te.round(te.multiply(mul, sub)), output_zp) # clip + cast: const_min = tvm.tir.min_value(out_dtype) @@ -187,6 +205,7 @@ def qnn_conv2d( # Conv2d inputs rq_input_zero_point, rq_output_scale, rq_output_zero_point, + axis, # Conv2d attributes: strides, padding, @@ -244,21 +263,18 @@ def qnn_conv2d( # Conv2d inputs assert bias.shape[2] == 1 and bias.shape[3] == 1 out = te.compute(out.shape, lambda n, c, h, w: out[n, c, h, w] + bias[n, c, 1, 1]) - def _rq_compute(*indices): - value = out(*indices) - sub = te.subtract(value, rq_input_zero_point) - mul = te.div(rq_input_scale, rq_output_scale) - val = te.add(te.round(te.multiply(mul, sub)), rq_output_zero_point) - - # clip + cast: - const_min = tvm.tir.min_value(odtype) - const_max = tvm.tir.max_value(odtype) - return te.max(tvm.te.min(val, const_max), const_min).astype(odtype) - # Requantize output of convolution # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) if rq_input_scale is not None and rq_output_scale is not None: - return te.compute(out.shape, _rq_compute) + return qnn_requantize( + out, + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + axis, + odtype, + ) return out @@ -295,6 +311,7 @@ def qnn_dense( rq_input_zero_point, rq_output_scale, rq_output_zero_point, + axis, out_dtype, ): """Compute for qnn.dense""" @@ -315,21 +332,18 @@ def qnn_dense( if bias is not None: out = te.compute(out.shape, lambda n, c: out[n, c] + bias[c]) - def _rq_compute(*indices): - value = out(*indices) - sub = te.subtract(value, rq_input_zero_point) - mul = te.div(rq_input_scale, rq_output_scale) - val = te.add(te.round(te.multiply(mul, sub)), rq_output_zero_point) - - # clip + cast: - const_min = tvm.tir.min_value(out_dtype) - const_max = tvm.tir.max_value(out_dtype) - return te.max(tvm.te.min(val, const_max), const_min).astype(out_dtype) - # Requantize output of dense # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) if rq_input_scale is not None and rq_output_scale is not None: - return te.compute(out.shape, _rq_compute) + return qnn_requantize( + out, + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + axis, + out_dtype, + ) return out diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 741164e97f38..808744588605 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -174,41 +174,17 @@ class PatternMatcher { // returns whether given Op is last in the pattern sequence. bool IsLeafOp(const Op& op) { return op == qnn_requantize_op_; } - LoweredOutput LowerOps(const CallNode* a_op, const CallNode* leaf_op, - const Array& inputs, tvm::Target target) { - static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); - ICHECK(flower_call) << "relay.backend.lower_call is not registered."; - ICHECK(a_op == anchor_op_); - - // TODO(ibsidorenko): - // now this code changes output data type of anchor op on output data type of - // requantize op. After lowering it restore previous output data type. - // It will be better to pass new data type directly to lowering function. - if (auto* pattr = const_cast(a_op->attrs.as())) { - const auto* requantize_attrs = leaf_op->attrs.as(); - - DataType init_dtype = pattr->out_dtype; - pattr->out_dtype = requantize_attrs->out_dtype; - - LoweredOutput lowered_out = (*flower_call)(GetRef(a_op), inputs, target); - - pattr->out_dtype = init_dtype; - - return lowered_out; - } else if (auto* pattr = const_cast(a_op->attrs.as())) { - const auto* requantize_attrs = leaf_op->attrs.as(); - - DataType init_dtype = pattr->out_dtype; - pattr->out_dtype = requantize_attrs->out_dtype; - - LoweredOutput lowered_out = (*flower_call)(GetRef(a_op), inputs, target); - - pattr->out_dtype = init_dtype; - - return lowered_out; + // Copy requantization attributes from one node to another. + void CopyAttrs(const CallNode* from, const CallNode* to) { + const auto* requantize_attrs = from->attrs.as(); + if (auto* pattr = const_cast(to->attrs.as())) { + pattr->axis = requantize_attrs->axis; + pattr->rq_out_dtype = requantize_attrs->out_dtype; + } else if (auto* pattr = const_cast(to->attrs.as())) { + pattr->axis = requantize_attrs->axis; + pattr->rq_out_dtype = requantize_attrs->out_dtype; } else { - LOG(FATAL) << "Unsupported op: " << PrettyPrint(a_op->op); - return LoweredOutput({}, OpImplementation()); + LOG(FATAL) << "Unsupported op: " << PrettyPrint(to->op); } } @@ -349,8 +325,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator(anchor_op), inputs, target_); outputs = lowered_out->outputs; Op a_op = Downcast(anchor_op->op); op_implementations_[a_op.operator->()] = lowered_out->implementation; diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 64a5a02e6e25..2dff57d7f658 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -38,6 +38,8 @@ namespace tvm { namespace relay { namespace qnn { +TVM_REGISTER_NODE_TYPE(QConv2DAttrs); + // relay.op.qnn.conv2d bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -48,8 +50,8 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* data = types[0].as(); const auto* weight = types[1].as(); if (data == nullptr || weight == nullptr) return false; - const auto* param = attrs.as(); - ICHECK(param != nullptr) << "Conv2DAttrs cannot be nullptr."; + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "QConv2DAttrs cannot be nullptr."; ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) || data->dtype == DataType::Int(16)) << "Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype; @@ -83,10 +85,24 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, reporter); // weight_scale } + // Create Conv2DAttrs from QConv2DAttrs + auto conv2d_attrs = make_object(); + conv2d_attrs->strides = param->strides; + conv2d_attrs->padding = param->padding; + conv2d_attrs->dilation = param->dilation; + conv2d_attrs->groups = param->groups; + conv2d_attrs->channels = param->channels; + conv2d_attrs->kernel_size = param->kernel_size; + conv2d_attrs->data_layout = param->data_layout; + conv2d_attrs->kernel_layout = param->kernel_layout; + conv2d_attrs->out_layout = param->out_layout; + conv2d_attrs->out_dtype = param->out_dtype; + conv2d_attrs->auto_scheduler_rewritten_layout = param->auto_scheduler_rewritten_layout; + // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Conv2D infer type function. Array tensor_types = {types[0], types[1], types[6]}; - return Conv2DRel(tensor_types, 3, attrs, reporter); + return Conv2DRel(tensor_types, 3, Attrs(conv2d_attrs), reporter); } InferCorrectLayoutOutput QnnConvInferCorrectLayout(const Attrs& attrs, @@ -95,7 +111,7 @@ InferCorrectLayoutOutput QnnConvInferCorrectLayout(const Attrs& attrs, const Array& old_in_types) { // Use Relay Conv2D Infer correct layout. auto conv_new_layouts = - ConvInferCorrectLayout(attrs, new_in_layouts, old_in_layouts, old_in_types); + ConvInferCorrectLayout(attrs, new_in_layouts, old_in_layouts, old_in_types); // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these // tensors can be treated as channel layout. @@ -110,7 +126,7 @@ InferCorrectLayoutOutput QnnConvInferCorrectLayout(const Attrs& attrs, return InferCorrectLayoutOutput(input_layouts, output_layouts, attrs); } -bool is_depthwise(const Conv2DAttrs* param) { +bool is_depthwise(const QConv2DAttrs* param) { return param->channels.defined() && tvm::tir::ExprDeepEqual()(param->channels, param->groups) && param->groups != 1; } @@ -124,7 +140,7 @@ using WorkloadType = std::tuple; * \param param The qnn conv2d attributes. * \return A tuple of workload. */ -WorkloadType GetWorkload(const Array& arg_types, const Conv2DAttrs* param) { +WorkloadType GetWorkload(const Array& arg_types, const QConv2DAttrs* param) { // Get conv parameters. const auto in_shape = get_shape(arg_types[0]); int batch_size, in_channels; @@ -191,7 +207,7 @@ WorkloadType GetWorkload(const Array& arg_types, const Conv2DA * int32 tensors instead of int8 tensors. */ Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero_point, - const Expr& kernel_zero_point, const Conv2DAttrs* param) { + const Expr& kernel_zero_point, const QConv2DAttrs* param) { // Upcast the parameters to be at least int32 to avoid overflow auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits(); @@ -224,7 +240,7 @@ Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero * cannot be fused with conv in Relay. In case we see performance * degradation, we can change the conv2D API to accept a pad_const value. */ -Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2DAttrs* param) { +Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const QConv2DAttrs* param) { // 1) Pad the input data auto padded_data = data; auto pad_top_value = get_const_int(param->padding[0]); @@ -270,7 +286,7 @@ Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2D * followed by repeat on the C axis by cm times. */ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, - const Conv2DAttrs* param, int kernel_h, int kernel_w, + const QConv2DAttrs* param, int kernel_h, int kernel_w, int channel_multiplier) { auto casted_t2 = Cast(padded_data, DataType::Int(32)); @@ -343,7 +359,7 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_ * (1, oc, 1, 1) as (oc/m, oc%m) are just contiguous memory locations. */ Expr DepthwiseConv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, - const Conv2DAttrs* param, int out_channels, int channel_multiplier) { + const QConv2DAttrs* param, int out_channels, int channel_multiplier) { // Find which dimensions are R, S. Array axes_t3; if (param->kernel_layout == "OIHW") { @@ -422,7 +438,7 @@ Expr DepthwiseConv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_ * Sigma(c,r,s) QW(k, c, r, s) * QA(n, c, h + r, w + s) * This is just conv2d on int tensors. */ -Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const Conv2DAttrs* param) { +Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QConv2DAttrs* param) { // Lowering for Term 1 Array padding({0, 0, 0, 0}); return Conv2D(padded_data, weight, param->strides, padding, param->dilation, param->groups, @@ -448,7 +464,7 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const Conv2DAt * opportunity to reuse alter_op_layout infrastructure. */ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, - const Conv2DAttrs* param, int kernel_h, int kernel_w, int out_channels) { + const QConv2DAttrs* param, int kernel_h, int kernel_w, int out_channels) { auto casted_t2 = Cast(padded_data, DataType::Int(32)); // We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum. @@ -518,7 +534,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, * a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW * format. */ -Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const Conv2DAttrs* param, +Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const QConv2DAttrs* param, int out_channels) { // Find which dimensions are C, R, S. Array axes_t3; @@ -569,7 +585,7 @@ Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const Con * */ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int in_channels, - int kernel_h, int kernel_w, const Conv2DAttrs* param) { + int kernel_h, int kernel_w, const QConv2DAttrs* param) { auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits(); int scalar_term4 = input_zero_point_int * kernel_zero_point_int * in_channels * kernel_h * kernel_w; @@ -592,7 +608,7 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i * */ Expr Conv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels, - int kernel_h, int kernel_w, const Conv2DAttrs* param) { + int kernel_h, int kernel_w, const QConv2DAttrs* param) { auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits(); Expr scalar_term4 = MakeConstantScalar(DataType::Int(upcast_bits), in_channels * kernel_h * kernel_w); @@ -712,7 +728,7 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, Expr weight = new_args[1]; Expr input_zero_point = new_args[2]; Expr kernel_zero_point = new_args[3]; - const auto* param = attrs.as(); + const auto* param = attrs.as(); ICHECK(param != nullptr); // Assertion checks for existing support. ICHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC") @@ -817,7 +833,7 @@ Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_ze Array padding, Array dilation, int groups, IndexExpr channels, Array kernel_size, String data_layout, String kernel_layout, String out_layout, DataType out_dtype) { - auto attrs = make_object(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -828,6 +844,11 @@ Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_ze attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); + + // Optional extra attributes for requantization. + attrs->axis = -1; + attrs->rq_out_dtype = attrs->out_dtype; + static const Op& op = Op::Get("qnn.conv2d"); return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, Attrs(attrs), {}); @@ -846,7 +867,7 @@ operator to understand how to scale back the int32 output to (u)int8 or (u)int16 - **out**: This depends on the `layout` parameter. Output is 4D array of shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(6) .add_argument("data", "Tensor", "The quantized input data tensor.") .add_argument("weight", "Tensor", "The quantized weight tensor.") diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index adaf509e7daf..e5a21f134cb8 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -35,6 +35,8 @@ namespace tvm { namespace relay { namespace qnn { +TVM_REGISTER_NODE_TYPE(QDenseAttrs); + // relay.op.qnn.dense bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -45,8 +47,8 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* data = types[0].as(); const auto* weight = types[1].as(); if (data == nullptr || weight == nullptr) return false; - const auto* param = attrs.as(); - ICHECK(param != nullptr) << "DenseAttrs cannot be nullptr."; + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "QDenseAttrs cannot be nullptr."; ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8)) << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype; ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8)) @@ -70,22 +72,27 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Dense infer type function. Array tensor_types = {types[0], types[1], types[6]}; - return MatmulRel(tensor_types, 3, attrs, reporter); + return MatmulRel(tensor_types, 3, attrs, reporter); } // Positional relay function to create quantized dense operator used by frontend FFI. Expr MakeQuantizedDense(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point, Expr input_scale, Expr kernel_scale, IndexExpr units, DataType out_dtype) { - auto attrs = make_object(); + auto attrs = make_object(); attrs->units = std::move(units); attrs->out_dtype = out_dtype; + + // Optional extra attributes for requantization. + attrs->axis = -1; + attrs->rq_out_dtype = attrs->out_dtype; + static const Op& op = Op::Get("qnn.dense"); return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, Attrs(attrs), {}); } Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel, - const DenseAttrs* attrs) { + const QDenseAttrs* attrs) { return Dense(quantized_data, quantized_kernel, attrs->units, attrs->out_dtype); } @@ -170,7 +177,7 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, const int reduction_dim_size = get_const_int(in_shape[1]); const int out_dim_size = get_const_int(w_shape[0]); - const auto* qnn_dense_attrs = attrs.as(); + const auto* qnn_dense_attrs = attrs.as(); auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point, out_dim_size); @@ -210,7 +217,7 @@ RELAY_REGISTER_OP("qnn.dense") - **weight**: quantized(int8, unit8) `(units, input_dim)` - **out**: quantized(int32) `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(6) .add_argument("data", "quantized nD Tensor", "Input data.") .add_argument("weight", "quantized 2D Tensor", "Weight matrix.") From 777a9beac669612051e468336d57fe937323957b Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Mon, 4 Jul 2022 12:22:18 +0300 Subject: [PATCH 05/25] Fixed TOPI compute implementation for qnn.add --- python/tvm/relay/qnn/op/legalizations.py | 3 ++- python/tvm/topi/hexagon/qnn.py | 16 +++++++++++----- src/relay/backend/te_compiler_cache.cc | 4 ++++ 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 9bc6efdad00f..d37ed8ecfbec 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -228,7 +228,8 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): -relay.cast(kernel_zero_point, dtype="int16"), output_axis, ) - new_attrs = {k: attrs[k] for k in attrs.keys()} + # Skip optional extra attributes: + new_attrs = {k: attrs[k] for k in attrs.keys() if k not in ("axis", "rq_out_dtype")} return relay_op(shift_data, shift_kernel, **new_attrs) diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 306d8fafdde5..49400847186a 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -23,6 +23,7 @@ from ..utils import get_const_tuple from ..nn.utils import get_pad_tuple from ..nn.pad import pad +from .. import tag def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype): @@ -50,7 +51,7 @@ def _compute(*indices): val = te.add(te.round(te.div(value, scale)), zp) return te.max(te.min(val, const_max), const_min).astype(out_dtype) - return te.compute(data.shape, _compute) + return te.compute(data.shape, _compute, tag=tag.ELEMWISE) def schedule_qnn_quantize(outs): @@ -80,7 +81,7 @@ def _compute(*indices): value = data(*indices) return te.multiply(input_scale, te.subtract(value, input_zero_point)) - return te.compute(data.shape, _compute) + return te.compute(data.shape, _compute, tag=tag.ELEMWISE) def schedule_qnn_dequantize(outs): @@ -164,11 +165,16 @@ def _compute(*indices): rvalue = rhs(*indices) q_lv = te.round( te.multiply(te.div(lhs_scale, output_scale), te.subtract(lvalue, lhs_zero_point)) - ).astype(dtype) + ).astype("int32") q_rv = te.round( te.multiply(te.div(rhs_scale, output_scale), te.subtract(rvalue, rhs_zero_point)) - ).astype(dtype) - return te.add(te.add(q_lv, q_rv), output_zero_point).astype(dtype) + ).astype("int32") + val = te.add(te.add(q_lv, q_rv), output_zero_point) + + # clip + cast: + const_min = tvm.tir.min_value(dtype) + const_max = tvm.tir.max_value(dtype) + return te.max(tvm.te.min(val, const_max), const_min).astype(dtype) return te.compute(lhs.shape, _compute) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 808744588605..1445b543cb55 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -190,6 +190,8 @@ class PatternMatcher { const CallNode* GetAnchorOp() { return anchor_op_; } + void Clear() { registered_ops_.clear(); } + private: const Op& qnn_conv2d_op_; const Op& qnn_dense_op_; @@ -330,6 +332,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslatoroutputs; Op a_op = Downcast(anchor_op->op); op_implementations_[a_op.operator->()] = lowered_out->implementation; + + pattern_matcher_.Clear(); } else { // Forward inputs as "outputs" for successor. readable_name_stream_ << '_' << op->name; From 3aa343ee2cff8b046207178e7bcbc3a0b4632bd4 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Tue, 5 Jul 2022 12:00:08 +0300 Subject: [PATCH 06/25] Fixed issue with non zero padding value for qnn.conv2d --- python/tvm/topi/hexagon/qnn.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 49400847186a..699a12476951 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -234,6 +234,9 @@ def qnn_conv2d( # Conv2d inputs get_const_tuple(padding), (dilated_kernel_h, dilated_kernel_w) ) + # Subtract zero point from input and then do padding with 0 value + data = te.compute(data.shape, lambda *indices: te.subtract(data(*indices), input_zero_point)) + # DOPAD if pad_top != 0 or pad_down != 0 or pad_left != 0 or pad_right != 0: pad_before = (0, 0, pad_top, pad_left) @@ -249,15 +252,12 @@ def qnn_conv2d( # Conv2d inputs out = te.compute( oshape, lambda n, oc, oh, ow: te.sum( - te.subtract( - data_pad[ - n, - ic, - oh * height_stride + kh * dilation_h, - ow * width_stride + kw * dilation_w, - ], - input_zero_point, - ).astype("int32") + data_pad[ + n, + ic, + oh * height_stride + kh * dilation_h, + ow * width_stride + kw * dilation_w, + ].astype("int32") * te.subtract(weight[oc, ic, kh, kw], kernel_zero_point).astype("int32"), axis=[ic, kh, kw], ), From 46d9edfb8ba92a446bc0ce431529ef85d1167c9a Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Tue, 5 Jul 2022 14:55:16 +0300 Subject: [PATCH 07/25] Fixed Bias.add for qnn.conv2d --- python/tvm/topi/hexagon/qnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 699a12476951..5105ff6809b9 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -267,7 +267,7 @@ def qnn_conv2d( # Conv2d inputs if bias is not None: assert len(out.shape) == len(bias.shape) assert bias.shape[2] == 1 and bias.shape[3] == 1 - out = te.compute(out.shape, lambda n, c, h, w: out[n, c, h, w] + bias[n, c, 1, 1]) + out = te.compute(out.shape, lambda n, c, h, w: out[n, c, h, w] + bias[n, c, 0, 0]) # Requantize output of convolution # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) From 723b7abc57e87c90680ef7de3f6c2d7558c664e3 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Wed, 6 Jul 2022 16:53:19 +0300 Subject: [PATCH 08/25] Added support of depthwise qnn.conv2d topi operator --- python/tvm/relay/qnn/strategy/hexagon.py | 16 +++- python/tvm/topi/hexagon/qnn.py | 104 +++++++++++++++++++++++ 2 files changed, 119 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/qnn/strategy/hexagon.py b/python/tvm/relay/qnn/strategy/hexagon.py index fe158f05781d..4a943e27e5e3 100644 --- a/python/tvm/relay/qnn/strategy/hexagon.py +++ b/python/tvm/relay/qnn/strategy/hexagon.py @@ -20,6 +20,7 @@ from tvm import topi from .generic import * from ... import op as _op +from ...op.strategy.generic import is_depthwise_conv2d # TODO: This is POC code. Change it on "hexagon" instead of "cpu" @@ -78,16 +79,29 @@ def qnn_add_strategy_hexagon(attrs, inputs, out_type, target): @qnn_conv2d_strategy.register("cpu") def qnn_conv2d_strategy_hexagon(attrs, inputs, out_type, target): """qnn.conv2d strategy for Hexagon""" + data = inputs[0] + kernel = inputs[1] data_layout = attrs.data_layout + kernel_layout = attrs.kernel_layout groups = attrs.groups strategy = _op.OpStrategy() if groups == 1: - if data_layout == "NCHW": + if data_layout == "NCHW" and kernel_layout == "OIHW": strategy.add_implementation( wrap_topi_qnn_conv2d(topi.hexagon.qnn_conv2d), wrap_topi_schedule(topi.hexagon.schedule_qnn_conv2d), name="qnn_conv2d.hexagon", ) + elif is_depthwise_conv2d(data.shape, data_layout, kernel.shape, kernel_layout, groups): + if data_layout == "NCHW" and kernel_layout == "OIHW": + strategy.add_implementation( + wrap_topi_qnn_conv2d(topi.hexagon.qnn_depthwise_conv2d), + wrap_topi_schedule(topi.hexagon.schedule_qnn_depthwise_conv2d), + name="qnn_depthwise_conv2d.hexagon", + ) + else: + raise RuntimeError("Unsupported strategy for group qnn.conv2d") + return strategy diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 5105ff6809b9..463967d3f4a4 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -302,6 +302,110 @@ def schedule_qnn_conv2d(outs): return _default_schedule(outs, False) +def qnn_depthwise_conv2d( # Conv2d inputs + data, + weight, + # Conv2d quantization params: + input_zero_point, + kernel_zero_point, + _input_scale, + _kernel_scale, + # bias + bias, + # Requantization params: + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + axis, + # Conv2d attributes: + strides, + padding, + dilation, + oshape, + odtype, +): + """Compute for qnn.conv2d with NCHW layout""" + kernel_height = weight.shape[2] # OIHW layout + kernel_width = weight.shape[3] # OIHW layout + + height_stride, width_stride = strides + dilation_h, dilation_w = dilation + + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + get_const_tuple(padding), (dilated_kernel_h, dilated_kernel_w) + ) + + # Subtract zero point from input and then do padding with 0 value + data = te.compute(data.shape, lambda *indices: te.subtract(data(*indices), input_zero_point)) + + # DOPAD + if pad_top != 0 or pad_down != 0 or pad_left != 0 or pad_right != 0: + pad_before = (0, 0, pad_top, pad_left) + pad_after = (0, 0, pad_down, pad_right) + data_pad = pad(data, pad_before, pad_after, name="data_pad") + else: + data_pad = data + + kh = te.reduce_axis((0, kernel_height), name="kh") + kw = te.reduce_axis((0, kernel_width), name="kw") + + out = te.compute( + oshape, + lambda n, oc, oh, ow: te.sum( + data_pad[ + n, + oc, + oh * height_stride + kh * dilation_h, + ow * width_stride + kw * dilation_w, + ].astype("int32") + * te.subtract(weight[oc, 0, kh, kw], kernel_zero_point).astype("int32"), + axis=[kh, kw], + ), + ) + + # Add bias + if bias is not None: + assert len(out.shape) == len(bias.shape) + assert bias.shape[2] == 1 and bias.shape[3] == 1 + out = te.compute(out.shape, lambda n, c, h, w: out[n, c, h, w] + bias[n, c, 0, 0]) + + # Requantize output of convolution + # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) + if rq_input_scale is not None and rq_output_scale is not None: + return qnn_requantize( + out, + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + axis, + odtype, + ) + + return out + + +def schedule_qnn_depthwise_conv2d(outs): + """Schedule for depthwise qnn.conv2d + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.conv2d + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def qnn_dense( data, weight, From ce8a7659c3eaa594ef29ad56d1811bf8568d2665 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Thu, 7 Jul 2022 11:05:21 +0300 Subject: [PATCH 09/25] Added support of 1D quantization params in qnn.dequantize --- python/tvm/relay/qnn/strategy/generic.py | 10 ++++++++++ python/tvm/relay/qnn/strategy/hexagon.py | 2 +- python/tvm/topi/hexagon/qnn.py | 13 ++++++++++--- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py index 0a16377f41ef..d36710675a72 100644 --- a/python/tvm/relay/qnn/strategy/generic.py +++ b/python/tvm/relay/qnn/strategy/generic.py @@ -50,6 +50,16 @@ def wrapper(attrs, inputs, _out_type): return wrapper +def wrap_compute_dequantize(topi_compute): + """Wrap TOPI compute which use axis from attrs""" + + def wrapper(attrs, inputs, _out_type): + args = [*inputs, attrs.axis] + return [topi_compute(*args)] + + return wrapper + + def wrap_topi_qnn_conv2d(topi_compute): """Wrap TOPI compute which use conv2d attrs and output data type""" diff --git a/python/tvm/relay/qnn/strategy/hexagon.py b/python/tvm/relay/qnn/strategy/hexagon.py index 4a943e27e5e3..363f003fe297 100644 --- a/python/tvm/relay/qnn/strategy/hexagon.py +++ b/python/tvm/relay/qnn/strategy/hexagon.py @@ -42,7 +42,7 @@ def qnn_dequantize_strategy_hexagon(attrs, inputs, out_type, target): """qnn.dequantize strategy for Hexagon""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_topi_compute(topi.hexagon.qnn_dequantize), + wrap_compute_dequantize(topi.hexagon.qnn_dequantize), wrap_topi_schedule(topi.hexagon.schedule_qnn_dequantize), name="qnn_dequantize.hexagon", ) diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 463967d3f4a4..7ac401572722 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -71,15 +71,22 @@ def schedule_qnn_quantize(outs): return _default_schedule(outs, False) -def qnn_dequantize(data, input_scale, input_zero_point): +def qnn_dequantize(data, input_scale, input_zero_point, axis): """Compute for qnn.dequantize fp_output = input_scale * (Q_input - input_zero_point) - TODO: Support 'axis' argument. """ def _compute(*indices): value = data(*indices) - return te.multiply(input_scale, te.subtract(value, input_zero_point)) + + # Account scalar and 1D quantization parameters: + scale_idx = tvm.tir.indexmod(indices[axis], topi.shape(input_scale)[0]) + scale = input_scale if len(input_scale.shape) == 0 else input_scale[scale_idx] + + zp_idx = tvm.tir.indexmod(indices[axis], topi.shape(input_zero_point)[0]) + zp = input_zero_point if len(input_zero_point.shape) == 0 else input_zero_point[zp_idx] + + return te.multiply(scale, te.subtract(value, zp)) return te.compute(data.shape, _compute, tag=tag.ELEMWISE) From 823f46af98a7b6966e97878cea93cf743dcd8c6a Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Fri, 8 Jul 2022 12:24:02 +0300 Subject: [PATCH 10/25] Added support of qnn.concatenate --- python/tvm/relay/qnn/op/_qnn.py | 4 ++ python/tvm/relay/qnn/strategy/generic.py | 18 +++++ python/tvm/relay/qnn/strategy/hexagon.py | 13 ++++ python/tvm/topi/hexagon/qnn.py | 86 ++++++++++++++++++++++++ src/relay/backend/te_compiler_cache.cc | 6 -- 5 files changed, 121 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/qnn/op/_qnn.py b/python/tvm/relay/qnn/op/_qnn.py index 72232bdffc20..4bb4feaf04ef 100644 --- a/python/tvm/relay/qnn/op/_qnn.py +++ b/python/tvm/relay/qnn/op/_qnn.py @@ -68,6 +68,10 @@ def simulated_dequantize_compute(attrs, inputs, output_type): register_strategy("qnn.add", strategy.qnn_add_strategy) register_pattern("qnn.add", OpPattern.BROADCAST) +# qnn.concatenate +register_strategy("qnn.concatenate", strategy.qnn_concatenate_strategy) +register_pattern("qnn.concatenate", OpPattern.INJECTIVE) + # qnn.conv2d register_strategy("qnn.conv2d", strategy.qnn_conv2d_strategy) register_pattern("qnn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py index d36710675a72..64ac9ee044a2 100644 --- a/python/tvm/relay/qnn/strategy/generic.py +++ b/python/tvm/relay/qnn/strategy/generic.py @@ -163,6 +163,15 @@ def wrapper(attrs, inputs, _out_type): return wrapper +def wrap_topi_concatenate(topi_compute): + """Wrap TOPI compute which use qnn.concatenate attrs""" + + def wrapper(attrs, inputs, out_type): + return [topi_compute(inputs, attrs.axis, out_type.dtype)] + + return wrapper + + @override_native_generic_func("qnn_quantize_strategy") def qnn_quantize_strategy(attrs, inputs, out_type, target): """qnn.quantize generic strategy""" @@ -199,6 +208,15 @@ def qnn_add_strategy(attrs, inputs, out_type, target): ) +@override_native_generic_func("qnn_concatenate_strategy") +def qnn_concatenate_strategy(attrs, inputs, out_type, target): + """qnn.concatenate generic strategy""" + raise RuntimeError( + "qnn.concatenate is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) + + @override_native_generic_func("qnn_conv2d_strategy") def qnn_conv2d_strategy(attrs, inputs, out_type, target): """qnn.conv2d generic strategy""" diff --git a/python/tvm/relay/qnn/strategy/hexagon.py b/python/tvm/relay/qnn/strategy/hexagon.py index 363f003fe297..95fc6c3f2ef3 100644 --- a/python/tvm/relay/qnn/strategy/hexagon.py +++ b/python/tvm/relay/qnn/strategy/hexagon.py @@ -75,6 +75,19 @@ def qnn_add_strategy_hexagon(attrs, inputs, out_type, target): return strategy +# TODO: This is POC code. Change it on "hexagon" instead of "cpu" +@qnn_concatenate_strategy.register("cpu") +def qnn_concatenate_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.concatenate strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_topi_concatenate(topi.hexagon.qnn_concatenate), + wrap_topi_schedule(topi.hexagon.schedule_qnn_concatenate), + name="qnn_concatenate.hexagon", + ) + return strategy + + # TODO: This is POC code. Change it on "hexagon" instead of "cpu" @qnn_conv2d_strategy.register("cpu") def qnn_conv2d_strategy_hexagon(attrs, inputs, out_type, target): diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 7ac401572722..895e9f353172 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -24,6 +24,14 @@ from ..nn.utils import get_pad_tuple from ..nn.pad import pad from .. import tag +from ..x86.concat import concatenate + + +def clip_cast(val, dtype): + # clip + cast: + const_min = tvm.tir.min_value(dtype) + const_max = tvm.tir.max_value(dtype) + return te.max(tvm.te.min(val, const_max), const_min).astype(dtype) def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype): @@ -203,6 +211,84 @@ def schedule_qnn_add(outs): return _default_schedule(outs, False) +def requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype): + """Requantize tensor""" + + def _compute(*indices): + value = tensor(*indices) + mul_value = te.round( + te.multiply(te.div(i_scale, o_scale), te.subtract(value, i_zp)) + ).astype("int32") + rq_value = te.add(mul_value, o_zp) + + return clip_cast(rq_value, out_dtype) + + return te.compute(tensor.shape, _compute) + + +def qnn_concatenate(data, axis, out_dtype): + """Compute for qnn.concatenate + + Parameters + ---------- + data: Array of Tensor + The computation graph description of qnn.concatenate + in the format of an array of tensors. + + axis: int + The axis along which the tensors are concatenated. + + out_dtype: string + Data type of output tensor + + Returns + ------- + out: Tensor + The computation for the op. + """ + + # Get output quantization parameters. + o_scale = data[-2] + o_zp = data[-1] + + # Initially qnn.concatenate had 3 tuples: (1) tuple with input tensors, (2) tuple with input + # scales and (3) tuple with input zero points. + # Last 2 elements in data represent output scale and zero point. + num_of_tuples = 3 + assert ((len(data) - 2) % num_of_tuples) == 0 + args_num = (len(data) - 2) // num_of_tuples + + args = [] + for i in range(args_num): + # Get next tensor and its quantization parameters. + tensor = data[i] + i_scale = data[i + args_num] + i_zp = data[i + args_num * 2] + + # Requantize tensors and add them to the list. + args.append(requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype)) + + # Call x86 implementation of concatenate. + return concatenate(args, axis) + + +def schedule_qnn_concatenate(outs): + """Schedule for qnn.concatenate + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.add + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def qnn_conv2d( # Conv2d inputs data, weight, diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 1445b543cb55..d983df3d69cb 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -309,12 +309,6 @@ class LowerToTECompute : public backend::MemoizedExprTranslatorargs.size(), 1U) - << "Only functions with a single tuple input are allowed, but " << count_tuple - << " were provided."; - } - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); From 055ea96d1f16f2663af593f12fb7cc8f95ccdae9 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Tue, 12 Jul 2022 10:07:00 +0300 Subject: [PATCH 11/25] Fixed out of range array access --- python/tvm/topi/hexagon/qnn.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 895e9f353172..081c63813fb7 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -34,6 +34,15 @@ def clip_cast(val, dtype): return te.max(tvm.te.min(val, const_max), const_min).astype(dtype) +def get_qnn_param(param, indices, axis): + # Account scalar and 1D quantization parameters: + if len(param.shape) == 0: + return param + + param_idx = tvm.tir.indexmod(indices[axis], topi.shape(param)[0]) + return param[param_idx] + + def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype): """Compute for qnn.quantize Q_output = clamp((round(input_tensor/output_scale) + output_zero_point), @@ -46,18 +55,11 @@ def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype): def _compute(*indices): value = data(*indices) + scale = get_qnn_param(output_scale, indices, axis) + zp = get_qnn_param(output_zero_point, indices, axis) - # Account scalar and 1D quantization parameters: - scale_idx = tvm.tir.indexmod(indices[axis], topi.shape(output_scale)[0]) - scale = output_scale if len(output_scale.shape) == 0 else output_scale[scale_idx] - - zp_idx = tvm.tir.indexmod(indices[axis], topi.shape(output_zero_point)[0]) - zp = output_zero_point if len(output_zero_point.shape) == 0 else output_zero_point[zp_idx] - - const_min = tvm.tir.min_value(out_dtype) - const_max = tvm.tir.max_value(out_dtype) val = te.add(te.round(te.div(value, scale)), zp) - return te.max(te.min(val, const_max), const_min).astype(out_dtype) + return clip_cast(val, out_dtype) return te.compute(data.shape, _compute, tag=tag.ELEMWISE) @@ -86,13 +88,8 @@ def qnn_dequantize(data, input_scale, input_zero_point, axis): def _compute(*indices): value = data(*indices) - - # Account scalar and 1D quantization parameters: - scale_idx = tvm.tir.indexmod(indices[axis], topi.shape(input_scale)[0]) - scale = input_scale if len(input_scale.shape) == 0 else input_scale[scale_idx] - - zp_idx = tvm.tir.indexmod(indices[axis], topi.shape(input_zero_point)[0]) - zp = input_zero_point if len(input_zero_point.shape) == 0 else input_zero_point[zp_idx] + scale = get_qnn_param(input_scale, indices, axis) + zp = get_qnn_param(input_zero_point, indices, axis) return te.multiply(scale, te.subtract(value, zp)) From 8efb58f5848b7aedf1f531eafb2a359c328cb261 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Tue, 26 Jul 2022 14:12:52 +0300 Subject: [PATCH 12/25] Added meta_schedule_original_shape attribute in QDenseAttr and QConv2DAttr --- include/tvm/relay/qnn/attrs.h | 4 +++- src/relay/qnn/op/convolution.cc | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 192ac28905e1..f61479810914 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -137,7 +137,8 @@ struct QConv2DAttrs : public tvm::AttrsNode { tvm::String data_layout; tvm::String kernel_layout; tvm::String out_layout; - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + Array meta_schedule_original_shape; // The original shape of the weights DataType out_dtype; // Optional extra attributes for Hexagon target. Describes requantization parameters. @@ -214,6 +215,7 @@ struct QConv2DAttrs : public tvm::AttrsNode { struct QDenseAttrs : public tvm::AttrsNode { IndexExpr units; tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + Array meta_schedule_original_shape; // The original shape of the weights DataType out_dtype; // Optional extra attributes for Hexagon target. Describes requantization parameters. diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 2dff57d7f658..6ece9220401a 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -98,6 +98,7 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, conv2d_attrs->out_layout = param->out_layout; conv2d_attrs->out_dtype = param->out_dtype; conv2d_attrs->auto_scheduler_rewritten_layout = param->auto_scheduler_rewritten_layout; + conv2d_attrs->meta_schedule_original_shape = param->meta_schedule_original_shape; // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Conv2D infer type function. From 82e21cb689a73e23d5e135c203dd9d9bbb35f7c0 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Tue, 2 Aug 2022 20:10:59 +0300 Subject: [PATCH 13/25] Added support of qnn.batch_matmul as a standalone op. --- include/tvm/relay/qnn/attrs.h | 2 +- python/tvm/relay/qnn/op/_qnn.py | 4 +++ python/tvm/relay/qnn/strategy/generic.py | 20 +++++++++++ python/tvm/relay/qnn/strategy/hexagon.py | 13 +++++++ python/tvm/topi/hexagon/qnn.py | 46 +++++++++++++++++++++++- 5 files changed, 83 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index f61479810914..5cf8b79d2f6d 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -214,7 +214,7 @@ struct QConv2DAttrs : public tvm::AttrsNode { /*! \brief Attributes for QNN dense operator */ struct QDenseAttrs : public tvm::AttrsNode { IndexExpr units; - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite Array meta_schedule_original_shape; // The original shape of the weights DataType out_dtype; diff --git a/python/tvm/relay/qnn/op/_qnn.py b/python/tvm/relay/qnn/op/_qnn.py index 4bb4feaf04ef..4e54583a3be0 100644 --- a/python/tvm/relay/qnn/op/_qnn.py +++ b/python/tvm/relay/qnn/op/_qnn.py @@ -79,3 +79,7 @@ def simulated_dequantize_compute(attrs, inputs, output_type): # qnn.dense register_strategy("qnn.dense", strategy.qnn_dense_strategy) register_pattern("qnn.dense", OpPattern.OUT_ELEMWISE_FUSABLE) + +# qnn.batch_matmul +register_strategy("qnn.batch_matmul", strategy.qnn_batch_matmul_strategy) +register_pattern("qnn.batch_matmul", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py index 64ac9ee044a2..8707d57ffc41 100644 --- a/python/tvm/relay/qnn/strategy/generic.py +++ b/python/tvm/relay/qnn/strategy/generic.py @@ -172,6 +172,17 @@ def wrapper(attrs, inputs, out_type): return wrapper +def wrap_topi_qnn_batch_matmul(topi_compute): + """Wrap TOPI compute which use qnn.batch_matmul attrs""" + + def wrapper(attrs, inputs, _out_type): + assert len([*inputs]) == 6 + args = [*inputs, attrs.transpose_a, attrs.transpose_b, attrs.out_dtype] + return [topi_compute(*args)] + + return wrapper + + @override_native_generic_func("qnn_quantize_strategy") def qnn_quantize_strategy(attrs, inputs, out_type, target): """qnn.quantize generic strategy""" @@ -233,3 +244,12 @@ def qnn_dense_strategy(attrs, inputs, out_type, target): "qnn.dense is currently only supported with Hexagon. " "Please run QNN Canonicalize pass to decompose this op into supported ops." ) + + +@override_native_generic_func("qnn_batch_matmul_strategy") +def qnn_batch_matmul_strategy(attrs, inputs, out_type, target): + """qnn.batch_matmul generic strategy""" + raise RuntimeError( + "qnn.batch_matmul is currently only supported with Hexagon. " + "Please run QNN Canonicalize pass to decompose this op into supported ops." + ) diff --git a/python/tvm/relay/qnn/strategy/hexagon.py b/python/tvm/relay/qnn/strategy/hexagon.py index 95fc6c3f2ef3..a6737f827af7 100644 --- a/python/tvm/relay/qnn/strategy/hexagon.py +++ b/python/tvm/relay/qnn/strategy/hexagon.py @@ -129,3 +129,16 @@ def qnn_dense_strategy_hexagon(attrs, inputs, out_type, target): name="qnn_dense.hexagon", ) return strategy + + +# TODO: This is POC code. Change it on "hexagon" instead of "cpu" +@qnn_batch_matmul_strategy.register("cpu") +def qnn_batch_matmul_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.batch_matmul strategy for Hexagon""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_topi_qnn_batch_matmul(topi.hexagon.qnn_batch_matmul), + wrap_topi_schedule(topi.hexagon.schedule_qnn_batch_matmul), + name="qnn_batch_matmul.hexagon", + ) + return strategy diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 081c63813fb7..3ca10fb3d6e5 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -23,7 +23,7 @@ from ..utils import get_const_tuple from ..nn.utils import get_pad_tuple from ..nn.pad import pad -from .. import tag +from .. import tag, nn from ..x86.concat import concatenate @@ -563,3 +563,47 @@ def schedule_qnn_dense(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + + +def qnn_batch_matmul( + tensor_a, + tensor_b, + # batch_matmul quantization params: + a_zero_point, + b_zero_point, + _a_scale, + _b_scale, + # Attributes + transpose_a, + transpose_b, + out_dtype, +): + """Compute for qnn.dense""" + + # Preprocess tensor_a: subtract zp + a_sub_zp = te.compute( + tensor_a.shape, lambda *indices: te.subtract(tensor_a(*indices), a_zero_point) + ) + # Preprocess tensor_b: subtract zp + b_sub_zp = te.compute( + tensor_b.shape, lambda *indices: te.subtract(tensor_b(*indices), b_zero_point) + ) + + return nn.batch_matmul(a_sub_zp, b_sub_zp, None, out_dtype, transpose_a, transpose_b) + + +def schedule_qnn_batch_matmul(outs): + """Schedule for qnn.batch_matmul + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.batch_matmul + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) From a63a4e4e66d488859290c63b51f3e4b2bdcac652 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Wed, 3 Aug 2022 12:53:46 +0300 Subject: [PATCH 14/25] Added per channel zp in qnn.dense and qnn.conv2d. --- python/tvm/topi/hexagon/qnn.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 3ca10fb3d6e5..6dc13823f399 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -339,6 +339,7 @@ def qnn_conv2d( # Conv2d inputs kh = te.reduce_axis((0, kernel_height), name="kh") kw = te.reduce_axis((0, kernel_width), name="kw") + # axis=0 in get_qnn_param means 'O' dimension in "OIHW" weights layout. out = te.compute( oshape, lambda n, oc, oh, ow: te.sum( @@ -348,7 +349,9 @@ def qnn_conv2d( # Conv2d inputs oh * height_stride + kh * dilation_h, ow * width_stride + kw * dilation_w, ].astype("int32") - * te.subtract(weight[oc, ic, kh, kw], kernel_zero_point).astype("int32"), + * te.subtract( + weight[oc, ic, kh, kw], get_qnn_param(kernel_zero_point, (oc, ic, kh, kw), axis=0) + ).astype("int32"), axis=[ic, kh, kw], ), ) @@ -519,11 +522,14 @@ def qnn_dense( N, _ = get_const_tuple(weight.shape) k = te.reduce_axis((0, K), "k") # This implementation uses "int32" dense output data type. + # axis=0 in get_qnn_param mean 'N' dimension in "NK" weights layout. out = te.compute( (M, N), lambda m, n: te.sum( te.subtract(data[m, k], input_zero_point).astype("int32") - * te.subtract(weight[n, k], kernel_zero_point).astype("int32"), + * te.subtract(weight[n, k], get_qnn_param(kernel_zero_point, (n, k), axis=0)).astype( + "int32" + ), axis=k, ), ) From 235084fefda183bd901ef6e289d484a75ad3b7a9 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Fri, 5 Aug 2022 12:08:41 +0300 Subject: [PATCH 15/25] Fixed corner cases like dense+bias+bias+rq. --- src/relay/backend/te_compiler_cache.cc | 40 +++++++++++++++----------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index d983df3d69cb..7e433dfb9ee8 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -125,7 +125,8 @@ Array GetShape(const Array& shape) { } // Helper class that is used during lowering to TE. -// It matches sequence of Ops and lower them into single TOPI operation. +// It matches sequence of Ops and lower them into single TOPI operation. Has sense for Hexagon only. +// All supported patterns are enumerated in "supported_patterns_" class PatternMatcher { public: PatternMatcher() @@ -139,33 +140,32 @@ class PatternMatcher { ICHECK(call_node->op.as()); Op op = Downcast(call_node->op); if (op == qnn_conv2d_op_) { - registered_ops_[QConv2d]++; + registered_ops_.push_front(P_QConv2d); ICHECK(anchor_op_ == nullptr); anchor_op_ = call_node; } else if (op == qnn_requantize_op_) { - registered_ops_[QRequantize]++; + registered_ops_.push_front(P_QRequantize); } else if (op == bias_add_op_) { - registered_ops_[BiasAdd]++; + registered_ops_.push_front(P_BiasAdd); } else if (op == qnn_dense_op_) { - registered_ops_[QDense]++; + registered_ops_.push_front(P_QDense); ICHECK(anchor_op_ == nullptr); anchor_op_ = call_node; + } else { + registered_ops_.push_front(P_Opaque); } } - // Check whether given Op is part of matched pattern. + // Check whether given Op is a part of matched pattern. bool find(const Op& op) { if (registered_ops_.empty()) return false; if (op == qnn_conv2d_op_ || op == qnn_requantize_op_ || op == bias_add_op_ || op == qnn_dense_op_) { - // Patterns: qnn.conv2d -> qnn.requantize or qnn.conv2d -> bias_add -> qnn.requantize - if (registered_ops_[QConv2d] && registered_ops_[QRequantize]) { - return true; - } - // Patterns: qnn.dense -> qnn.requantize or qnn.dense -> bias_add -> qnn.requantize - if (registered_ops_[QDense] && registered_ops_[QRequantize]) { - return true; + for (const auto& pat : supported_patterns_) { + auto it = + std::search(registered_ops_.begin(), registered_ops_.end(), pat.begin(), pat.end()); + if (it != registered_ops_.end()) return true; } } return false; @@ -198,11 +198,19 @@ class PatternMatcher { const Op& qnn_requantize_op_; const Op& bias_add_op_; - // Main (complicated) operation in the primitive. + // Main (complicated) operation in the primitive (for example qnn.conv2d, qnn.dense etc.). const CallNode* anchor_op_ = nullptr; - enum POper { QConv2d, QDense, BiasAdd, QRequantize }; - std::map registered_ops_; + enum POper { P_QConv2d, P_QDense, P_BiasAdd, P_QRequantize, P_Opaque }; + + std::deque registered_ops_; + + const std::vector> supported_patterns_ = { + {P_QDense, P_BiasAdd, P_QRequantize}, // Pattern qnn.dense -> bias_add -> qnn.requantize + {P_QDense, P_QRequantize}, // Patter qnn.dense -> qnn.requantize + {P_QConv2d, P_BiasAdd, P_QRequantize}, // Pattern qnn.conv2d -> bias_add -> qnn.requantize + {P_QConv2d, P_QRequantize} // Patter qnn.conv2d -> qnn.requantize + }; }; // Lowers Relay primitive Function to TE Compute From 862ea2e1434d117e0e15242a4f2b5a6085ae4f12 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Fri, 5 Aug 2022 18:36:50 +0300 Subject: [PATCH 16/25] Added unit test. --- .../test_wo_qnn_canonicalization.py | 153 ++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py diff --git a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py new file mode 100644 index 000000000000..a58f58486d2c --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py @@ -0,0 +1,153 @@ +import pytest +import numpy as np + +import tvm.testing +from tvm import relay +from tvm.contrib.hexagon.session import Session +from tvm.contrib import graph_executor +from tvm.relay.backend import Executor + + +@tvm.testing.requires_hexagon +def test_no_qnn_pass(): + x = relay.var("x", shape=(4, 8), dtype="float32") + op0 = relay.qnn.op.quantize(x, relay.const(2.0), relay.const(10), out_dtype="uint8") + op1 = relay.qnn.op.dequantize(op0, relay.const(0.5), relay.const(5)) + mod = tvm.IRModule.from_expr(op1) + + target_hexagon = tvm.target.hexagon("v68") + # Default compilation flow + with tvm.transform.PassContext(opt_level=3): + opt_mod_1, _ = relay.optimize(mod, tvm.target.Target(target_hexagon, host=target_hexagon)) + + # Disable QNN legalization and canonicalization passes + with tvm.transform.PassContext(opt_level=3, disabled_pass=["Legalize"]): + opt_mod_2, _ = relay.optimize(mod, tvm.target.Target(target_hexagon, host=target_hexagon)) + + # Check that during Default compilation flow we do not call qnn::canonicalization pass. + tvm.ir.assert_structural_equal(opt_mod_1, opt_mod_2) + + +def execute(executor, data_np, weight_np, bias_np = None): + executor.set_input("data", data_np) + executor.set_input("weight", weight_np) + if bias_np is not None: + executor.set_input("bias", bias_np) + executor.run() + return executor.get_output(0) + + +@tvm.testing.requires_hexagon +def test_qnn_conv2d_rq(hexagon_session: Session): + data_shape = [1, 64, 64, 64] + weight_shape = [64, 64, 3, 3] + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + op0 = relay.qnn.op.quantize(data, relay.const(0.078), relay.const(0), out_dtype="int8") + op1 = relay.qnn.op.quantize(weight, relay.const(0.07), relay.const(0), out_dtype="int8") + op2 = relay.qnn.op.conv2d(op0, + op1, + input_zero_point=relay.const(0), + kernel_zero_point=relay.const(0), + input_scale=relay.const(0.078), + kernel_scale=relay.const(0.07), + padding=[0, 0, 0, 0], + channels=64, + kernel_size=[3, 3]) + op5 = relay.qnn.op.requantize(op2, + input_scale=relay.const(0.05), + input_zero_point=relay.const(0), + output_scale=relay.const(0.21), + output_zero_point=relay.const(61), + out_dtype="int8") + relay_mod = tvm.IRModule.from_expr(op5) + + target_hexagon = tvm.target.hexagon("v68") + target_llvm = tvm.target.Target("llvm") + executor = Executor("graph", {"link-params": True}) + with tvm.transform.PassContext(opt_level=3): + hexagon_lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target_hexagon, host=target_hexagon), + executor=executor, + ) + + llvm_lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target_llvm, host=target_llvm), + executor=executor, + ) + + data_np = np.random.rand(*data_shape) - 0.5 + weight_np = np.random.rand(*weight_shape) - 0.5 + + hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered) + hexagon_output = execute(hx_m, data_np, weight_np) + + dev = tvm.cpu(0) + llvm_m = graph_executor.GraphModule(llvm_lowered["default"](dev)) + llvm_out = execute(llvm_m, data_np, weight_np) + + np.testing.assert_equal(hexagon_output.numpy(), llvm_out.numpy()) + + +@tvm.testing.requires_hexagon +def test_qnn_dense_bias_rq(hexagon_session: Session): + data_shape = [8, 8] + weight_shape = [16, 8] + bias_shape = [16] + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + bias = relay.var("bias", shape=bias_shape, dtype="float32") + + op0 = relay.qnn.op.quantize(data, relay.const(0.08), relay.const(0), out_dtype="int8") + op1 = relay.qnn.op.quantize(weight, relay.const(0.07), relay.const(0), out_dtype="int8") + op2 = relay.qnn.op.dense(op0, + op1, + input_zero_point=relay.const(0), + kernel_zero_point=relay.const(0), + input_scale=relay.const(0.08), + kernel_scale=relay.const(0.07), + units=None) + op3 = relay.qnn.op.quantize(bias, relay.const(0.5), relay.const(0), out_dtype="int32") + op4 = relay.nn.bias_add(op2, op3) + op5 = relay.qnn.op.requantize(op4, + input_scale=relay.const(0.05), + input_zero_point=relay.const(0), + output_scale=relay.const(0.212), + output_zero_point=relay.const(10), + out_dtype="int8") + relay_mod = tvm.IRModule.from_expr(op5) + + target_hexagon = tvm.target.hexagon("v68") + target_llvm = tvm.target.Target("llvm") + executor = Executor("graph", {"link-params": True}) + with tvm.transform.PassContext(opt_level=3): + hexagon_lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target_hexagon, host=target_hexagon), + executor=executor, + ) + + llvm_lowered = tvm.relay.build( + relay_mod, + tvm.target.Target(target_llvm, host=target_llvm), + executor=executor, + ) + + data_np = np.random.rand(*data_shape) - 0.5 + weight_np = np.random.rand(*weight_shape) - 0.5 + bias_np = np.random.rand(*bias_shape) + + hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered) + hexagon_output = execute(hx_m, data_np, weight_np, bias_np) + + dev = tvm.cpu(0) + llvm_m = graph_executor.GraphModule(llvm_lowered["default"](dev)) + llvm_out = execute(llvm_m, data_np, weight_np, bias_np) + + np.testing.assert_equal(hexagon_output.numpy(), llvm_out.numpy()) + + +if __name__ == "__main__": + tvm.testing.main() From c985961eb25e6f9c9892efd70d887c0f918657ee Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Tue, 9 Aug 2022 14:48:03 +0300 Subject: [PATCH 17/25] Removed rq_out_dtype and axis attributes declaration in QConv2DAttra and QDenseAttrs. --- include/tvm/relay/qnn/attrs.h | 18 ---- python/tvm/relay/qnn/op/legalizations.py | 3 +- .../test_wo_qnn_canonicalization.py | 87 ++++++++++++------- 3 files changed, 57 insertions(+), 51 deletions(-) diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 5cf8b79d2f6d..994e0dde6c08 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -199,15 +199,6 @@ struct QConv2DAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); - - TVM_ATTR_FIELD(axis) - .describe( - "The channel axis for channel wise requantization. Default value is -1," - "which corresponds to the last axis.") - .set_default(-1); - TVM_ATTR_FIELD(rq_out_dtype) - .set_default(NullValue()) - .describe("Requantized output data type"); } }; @@ -230,15 +221,6 @@ struct QDenseAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); - - TVM_ATTR_FIELD(axis) - .describe( - "The channel axis for channel wise requantization. Default value is -1," - "which corresponds to the last axis.") - .set_default(-1); - TVM_ATTR_FIELD(rq_out_dtype) - .set_default(NullValue()) - .describe("Requantized output data type"); } }; diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index d37ed8ecfbec..9bc6efdad00f 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -228,8 +228,7 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): -relay.cast(kernel_zero_point, dtype="int16"), output_axis, ) - # Skip optional extra attributes: - new_attrs = {k: attrs[k] for k in attrs.keys() if k not in ("axis", "rq_out_dtype")} + new_attrs = {k: attrs[k] for k in attrs.keys()} return relay_op(shift_data, shift_kernel, **new_attrs) diff --git a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py index a58f58486d2c..b32a260a5f02 100644 --- a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py +++ b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import pytest import numpy as np @@ -28,7 +45,7 @@ def test_no_qnn_pass(): tvm.ir.assert_structural_equal(opt_mod_1, opt_mod_2) -def execute(executor, data_np, weight_np, bias_np = None): +def execute(executor, data_np, weight_np, bias_np=None): executor.set_input("data", data_np) executor.set_input("weight", weight_np) if bias_np is not None: @@ -45,21 +62,25 @@ def test_qnn_conv2d_rq(hexagon_session: Session): weight = relay.var("weight", shape=weight_shape, dtype="float32") op0 = relay.qnn.op.quantize(data, relay.const(0.078), relay.const(0), out_dtype="int8") op1 = relay.qnn.op.quantize(weight, relay.const(0.07), relay.const(0), out_dtype="int8") - op2 = relay.qnn.op.conv2d(op0, - op1, - input_zero_point=relay.const(0), - kernel_zero_point=relay.const(0), - input_scale=relay.const(0.078), - kernel_scale=relay.const(0.07), - padding=[0, 0, 0, 0], - channels=64, - kernel_size=[3, 3]) - op5 = relay.qnn.op.requantize(op2, - input_scale=relay.const(0.05), - input_zero_point=relay.const(0), - output_scale=relay.const(0.21), - output_zero_point=relay.const(61), - out_dtype="int8") + op2 = relay.qnn.op.conv2d( + op0, + op1, + input_zero_point=relay.const(0), + kernel_zero_point=relay.const(0), + input_scale=relay.const(0.078), + kernel_scale=relay.const(0.07), + padding=[0, 0, 0, 0], + channels=64, + kernel_size=[3, 3], + ) + op5 = relay.qnn.op.requantize( + op2, + input_scale=relay.const(0.05), + input_zero_point=relay.const(0), + output_scale=relay.const(0.21), + output_zero_point=relay.const(61), + out_dtype="int8", + ) relay_mod = tvm.IRModule.from_expr(op5) target_hexagon = tvm.target.hexagon("v68") @@ -77,7 +98,7 @@ def test_qnn_conv2d_rq(hexagon_session: Session): tvm.target.Target(target_llvm, host=target_llvm), executor=executor, ) - + data_np = np.random.rand(*data_shape) - 0.5 weight_np = np.random.rand(*weight_shape) - 0.5 @@ -102,21 +123,25 @@ def test_qnn_dense_bias_rq(hexagon_session: Session): op0 = relay.qnn.op.quantize(data, relay.const(0.08), relay.const(0), out_dtype="int8") op1 = relay.qnn.op.quantize(weight, relay.const(0.07), relay.const(0), out_dtype="int8") - op2 = relay.qnn.op.dense(op0, - op1, - input_zero_point=relay.const(0), - kernel_zero_point=relay.const(0), - input_scale=relay.const(0.08), - kernel_scale=relay.const(0.07), - units=None) + op2 = relay.qnn.op.dense( + op0, + op1, + input_zero_point=relay.const(0), + kernel_zero_point=relay.const(0), + input_scale=relay.const(0.08), + kernel_scale=relay.const(0.07), + units=None, + ) op3 = relay.qnn.op.quantize(bias, relay.const(0.5), relay.const(0), out_dtype="int32") op4 = relay.nn.bias_add(op2, op3) - op5 = relay.qnn.op.requantize(op4, - input_scale=relay.const(0.05), - input_zero_point=relay.const(0), - output_scale=relay.const(0.212), - output_zero_point=relay.const(10), - out_dtype="int8") + op5 = relay.qnn.op.requantize( + op4, + input_scale=relay.const(0.05), + input_zero_point=relay.const(0), + output_scale=relay.const(0.212), + output_zero_point=relay.const(10), + out_dtype="int8", + ) relay_mod = tvm.IRModule.from_expr(op5) target_hexagon = tvm.target.hexagon("v68") @@ -134,7 +159,7 @@ def test_qnn_dense_bias_rq(hexagon_session: Session): tvm.target.Target(target_llvm, host=target_llvm), executor=executor, ) - + data_np = np.random.rand(*data_shape) - 0.5 weight_np = np.random.rand(*weight_shape) - 0.5 bias_np = np.random.rand(*bias_shape) From 92862ff81283465d5ff8af6475dbba403d1abf49 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Tue, 9 Aug 2022 15:28:41 +0300 Subject: [PATCH 18/25] Changed target x86->Hexagon to disable QNN passes. --- python/tvm/relay/qnn/strategy/hexagon.py | 24 +++----- python/tvm/topi/hexagon/qnn.py | 78 +++++++++++++++++++----- src/relay/backend/te_compiler_cache.cc | 3 +- src/relay/backend/utils.cc | 4 +- 4 files changed, 75 insertions(+), 34 deletions(-) diff --git a/python/tvm/relay/qnn/strategy/hexagon.py b/python/tvm/relay/qnn/strategy/hexagon.py index a6737f827af7..c7f59cc096fc 100644 --- a/python/tvm/relay/qnn/strategy/hexagon.py +++ b/python/tvm/relay/qnn/strategy/hexagon.py @@ -23,8 +23,7 @@ from ...op.strategy.generic import is_depthwise_conv2d -# TODO: This is POC code. Change it on "hexagon" instead of "cpu" -@qnn_quantize_strategy.register("cpu") +@qnn_quantize_strategy.register("hexagon") def qnn_quantize_strategy_hexagon(attrs, inputs, out_type, target): """qnn.quantize strategy for Hexagon""" strategy = _op.OpStrategy() @@ -36,8 +35,7 @@ def qnn_quantize_strategy_hexagon(attrs, inputs, out_type, target): return strategy -# TODO: This is POC code. Change it on "hexagon" instead of "cpu" -@qnn_dequantize_strategy.register("cpu") +@qnn_dequantize_strategy.register("hexagon") def qnn_dequantize_strategy_hexagon(attrs, inputs, out_type, target): """qnn.dequantize strategy for Hexagon""" strategy = _op.OpStrategy() @@ -49,8 +47,7 @@ def qnn_dequantize_strategy_hexagon(attrs, inputs, out_type, target): return strategy -# TODO: This is POC code. Change it on "hexagon" instead of "cpu" -@qnn_requantize_strategy.register("cpu") +@qnn_requantize_strategy.register("hexagon") def qnn_requantize_strategy_hexagon(attrs, inputs, out_type, target): """qnn.requantize strategy for Hexagon""" strategy = _op.OpStrategy() @@ -62,8 +59,7 @@ def qnn_requantize_strategy_hexagon(attrs, inputs, out_type, target): return strategy -# TODO: This is POC code. Change it on "hexagon" instead of "cpu" -@qnn_add_strategy.register("cpu") +@qnn_add_strategy.register("hexagon") def qnn_add_strategy_hexagon(attrs, inputs, out_type, target): """qnn.add strategy for Hexagon""" strategy = _op.OpStrategy() @@ -75,8 +71,7 @@ def qnn_add_strategy_hexagon(attrs, inputs, out_type, target): return strategy -# TODO: This is POC code. Change it on "hexagon" instead of "cpu" -@qnn_concatenate_strategy.register("cpu") +@qnn_concatenate_strategy.register("hexagon") def qnn_concatenate_strategy_hexagon(attrs, inputs, out_type, target): """qnn.concatenate strategy for Hexagon""" strategy = _op.OpStrategy() @@ -88,8 +83,7 @@ def qnn_concatenate_strategy_hexagon(attrs, inputs, out_type, target): return strategy -# TODO: This is POC code. Change it on "hexagon" instead of "cpu" -@qnn_conv2d_strategy.register("cpu") +@qnn_conv2d_strategy.register("hexagon") def qnn_conv2d_strategy_hexagon(attrs, inputs, out_type, target): """qnn.conv2d strategy for Hexagon""" data = inputs[0] @@ -118,8 +112,7 @@ def qnn_conv2d_strategy_hexagon(attrs, inputs, out_type, target): return strategy -# TODO: This is POC code. Change it on "hexagon" instead of "cpu" -@qnn_dense_strategy.register("cpu") +@qnn_dense_strategy.register("hexagon") def qnn_dense_strategy_hexagon(attrs, inputs, out_type, target): """qnn.dense strategy for Hexagon""" strategy = _op.OpStrategy() @@ -131,8 +124,7 @@ def qnn_dense_strategy_hexagon(attrs, inputs, out_type, target): return strategy -# TODO: This is POC code. Change it on "hexagon" instead of "cpu" -@qnn_batch_matmul_strategy.register("cpu") +@qnn_batch_matmul_strategy.register("hexagon") def qnn_batch_matmul_strategy_hexagon(attrs, inputs, out_type, target): """qnn.batch_matmul strategy for Hexagon""" strategy = _op.OpStrategy() diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 6dc13823f399..4d997183c584 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -19,7 +19,6 @@ import tvm from tvm import te, topi -from ..generic.default import default_schedule as _default_schedule from ..utils import get_const_tuple from ..nn.utils import get_pad_tuple from ..nn.pad import pad @@ -43,8 +42,31 @@ def get_qnn_param(param, indices, axis): return param[param_idx] +def default_schedule(outs): + """Simple default schedule for QNN ops. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of dense in the format + of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.te.tensor.Tensor) else outs + s = tvm.te.create_schedule([x.op for x in outs]) + tvm.te.schedule.AutoInlineInjective(s) + return s + + def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype): """Compute for qnn.quantize + + Note! This is POC code. There was no goal to implement high performance compute function. + Q_output = clamp((round(input_tensor/output_scale) + output_zero_point), out_dtype::min, out_dtype::max) @@ -78,11 +100,14 @@ def schedule_qnn_quantize(outs): sch: Schedule The computation schedule for the op. """ - return _default_schedule(outs, False) + return default_schedule(outs) def qnn_dequantize(data, input_scale, input_zero_point, axis): """Compute for qnn.dequantize + + Note! This is POC code. There was no goal to implement high performance compute function. + fp_output = input_scale * (Q_input - input_zero_point) """ @@ -110,11 +135,14 @@ def schedule_qnn_dequantize(outs): sch: Schedule The computation schedule for the op. """ - return _default_schedule(outs, False) + return default_schedule(outs) def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis, out_dtype): """Compute for qnn.requantize + + Note! This is POC code. There was no goal to implement high performance compute function. + Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) TODO: support 'rounding' and 'compute_dtype' arguments. @@ -156,13 +184,16 @@ def schedule_qnn_requantize(outs): sch: Schedule The computation schedule for the op. """ - return _default_schedule(outs, False) + return default_schedule(outs) def qnn_add( lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point ): """Compute for qnn.add + + Note! This is POC code. There was no goal to implement high performance compute function. + Q_output = zp_output + round((lhs_scale)/(scale_output) * (lhs_input - lhs_zp_input)) + round((rhs_scale)/(scale_output) * (rhs_input - rhs_zp_input)) @@ -205,7 +236,7 @@ def schedule_qnn_add(outs): sch: Schedule The computation schedule for the op. """ - return _default_schedule(outs, False) + return default_schedule(outs) def requantize_tensor(tensor, i_scale, i_zp, o_scale, o_zp, out_dtype): @@ -226,6 +257,8 @@ def _compute(*indices): def qnn_concatenate(data, axis, out_dtype): """Compute for qnn.concatenate + Note! This is POC code. There was no goal to implement high performance compute function. + Parameters ---------- data: Array of Tensor @@ -283,7 +316,7 @@ def schedule_qnn_concatenate(outs): sch: Schedule The computation schedule for the op. """ - return _default_schedule(outs, False) + return default_schedule(outs) def qnn_conv2d( # Conv2d inputs @@ -309,7 +342,11 @@ def qnn_conv2d( # Conv2d inputs oshape, odtype, ): - """Compute for qnn.conv2d with NCHW layout""" + """Compute for qnn.conv2d with NCHW layout + + Note! This is POC code. There was no goal to implement high performance compute function. + + """ in_channel = data.shape[1] # NCHW layout kernel_height = weight.shape[2] # OIHW layout kernel_width = weight.shape[3] # OIHW layout @@ -392,7 +429,7 @@ def schedule_qnn_conv2d(outs): sch: Schedule The computation schedule for the op. """ - return _default_schedule(outs, False) + return default_schedule(outs) def qnn_depthwise_conv2d( # Conv2d inputs @@ -418,7 +455,11 @@ def qnn_depthwise_conv2d( # Conv2d inputs oshape, odtype, ): - """Compute for qnn.conv2d with NCHW layout""" + """Compute for qnn.conv2d with NCHW layout + + Note! This is POC code. There was no goal to implement high performance compute function. + + """ kernel_height = weight.shape[2] # OIHW layout kernel_width = weight.shape[3] # OIHW layout @@ -496,7 +537,7 @@ def schedule_qnn_depthwise_conv2d(outs): sch: Schedule The computation schedule for the op. """ - return _default_schedule(outs, False) + return default_schedule(outs) def qnn_dense( @@ -517,7 +558,12 @@ def qnn_dense( axis, out_dtype, ): - """Compute for qnn.dense""" + """Compute for qnn.dense + + Note! This is POC code. There was no goal to implement high performance compute function. + + """ + M, K = get_const_tuple(data.shape) N, _ = get_const_tuple(weight.shape) k = te.reduce_axis((0, K), "k") @@ -568,7 +614,7 @@ def schedule_qnn_dense(outs): sch: Schedule The computation schedule for the op. """ - return _default_schedule(outs, False) + return default_schedule(outs) def qnn_batch_matmul( @@ -584,7 +630,11 @@ def qnn_batch_matmul( transpose_b, out_dtype, ): - """Compute for qnn.dense""" + """Compute for qnn.dense + + Note! This is POC code. There was no goal to implement high performance compute function. + + """ # Preprocess tensor_a: subtract zp a_sub_zp = te.compute( @@ -612,4 +662,4 @@ def schedule_qnn_batch_matmul(outs): sch: Schedule The computation schedule for the op. """ - return _default_schedule(outs, False) + return default_schedule(outs) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 7e433dfb9ee8..8f7cacb27f8d 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -303,8 +303,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslatorkind->device_type == kDLCPU) pattern_matcher_.Register(call_node); + if (target_->kind->device_type == kDLHexagon) pattern_matcher_.Register(call_node); Array inputs; int count_tuple = 0; diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 13e526195f92..f47a4eebe33f 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -231,8 +231,8 @@ Array GetPassPrefix(Target homogeneous_target, bool is_vm) { pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); pass_seqs.push_back(transform::ToBasicBlockNormalForm()); // Run all dialect legalization passes. - // Should be changed on kDLHexagon - if ((is_homogeneous && homogeneous_target->kind->device_type != kDLCPU) || !is_homogeneous) + // Skip these passes for Hexagon target. + if ((is_homogeneous && homogeneous_target->kind->device_type != kDLHexagon) || !is_homogeneous) pass_seqs.push_back(relay::qnn::transform::Legalize()); // Legalize pass is restricted to homogeneous execution for now. From a28720542bd60b89059697bab542ac7e38c349ce Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Wed, 10 Aug 2022 11:08:12 +0300 Subject: [PATCH 19/25] Fixed issue with QDenseAttrs and QConv2dAttrs. --- python/tvm/relay/qnn/strategy/generic.py | 13 +++++++---- .../backend/contrib/cmsisnn/relay_to_tir.cc | 4 ++-- src/relay/qnn/utils.cc | 22 +++++++++++++++++++ 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py index 8707d57ffc41..0ae4cbda58bb 100644 --- a/python/tvm/relay/qnn/strategy/generic.py +++ b/python/tvm/relay/qnn/strategy/generic.py @@ -16,9 +16,14 @@ # under the License. """Definition of generic operator strategy.""" +from tvm import _ffi from tvm.target import override_native_generic_func +GET_RQ_OUT_DTYPE = _ffi.get_global_func("relay.attrs.get_rq_out_dtype") +GET_RQ_AXIS = _ffi.get_global_func("relay.attrs.get_rq_axis") + + def wrap_topi_schedule(topi_schedule): """Wrap TOPI schedule which doesn't use attrs""" @@ -64,12 +69,12 @@ def wrap_topi_qnn_conv2d(topi_compute): """Wrap TOPI compute which use conv2d attrs and output data type""" def wrapper(attrs, inputs, out_type): + out_dtype = GET_RQ_OUT_DTYPE(attrs) + axis = GET_RQ_AXIS(attrs) oshape = out_type.shape - out_dtype = attrs.rq_out_dtype strides = attrs.strides padding = attrs.padding dilation = attrs.dilation - axis = attrs.axis if len([*inputs]) == 11: args = [*inputs, axis, strides, padding, dilation, oshape, out_dtype] elif len([*inputs]) == 10: @@ -122,8 +127,8 @@ def wrap_topi_qnn_dense(topi_compute): """Wrap TOPI compute which use qnn.dense attrs""" def wrapper(attrs, inputs, _out_type): - out_dtype = attrs.rq_out_dtype - axis = attrs.axis + out_dtype = GET_RQ_OUT_DTYPE(attrs) + axis = GET_RQ_AXIS(attrs) if len([*inputs]) == 11: args = [*inputs, axis, out_dtype] elif len([*inputs]) == 10: diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index da51e6b762dd..d0605cdef5ec 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -179,7 +179,7 @@ class RelayToTIRVisitor : public MixedModeMutator { // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 // prepare cmsis_nn_conv_params - const Conv2DAttrs* conv2d_attrs = conv2d_call->attrs.as(); + const qnn::QConv2DAttrs* conv2d_attrs = conv2d_call->attrs.as(); int32_t input_offset = -GetScalarFromConstant(conv2d_call->args[2]); int32_t output_offset = GetScalarFromConstant(requantize_call->args[4]); int32_t stride_w = qnn::get_const_int(conv2d_attrs->strides[1]); @@ -328,7 +328,7 @@ class RelayToTIRVisitor : public MixedModeMutator { // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 // prepare cmsis_nn_fc_params - const DenseAttrs* dense_attrs = fc_call->attrs.as(); + const qnn::QDenseAttrs* dense_attrs = fc_call->attrs.as(); int32_t input_offset = -GetScalarFromConstant(fc_call->args[2]); int32_t filter_offset = -GetScalarFromConstant(fc_call->args[3]); int32_t output_offset = GetScalarFromConstant(requantize_call->args[4]); diff --git a/src/relay/qnn/utils.cc b/src/relay/qnn/utils.cc index ed7a415cf6af..eedf68724535 100644 --- a/src/relay/qnn/utils.cc +++ b/src/relay/qnn/utils.cc @@ -213,6 +213,28 @@ std::string SelectRequntizeParameter(const std::string& arg_value, const std::st } } +TVM_REGISTER_GLOBAL("relay.attrs.get_rq_out_dtype").set_body_typed([](const Attrs& attrs) { + if (attrs->IsInstance()) { + return attrs.as()->rq_out_dtype; + } else if (attrs->IsInstance()) { + return attrs.as()->rq_out_dtype; + } else { + LOG(FATAL) << "Unhandled attribute: " << attrs; + } + return DataType(); +}); + +TVM_REGISTER_GLOBAL("relay.attrs.get_rq_axis").set_body_typed([](const Attrs& attrs) { + if (attrs->IsInstance()) { + return attrs.as()->axis; + } else if (attrs->IsInstance()) { + return attrs.as()->axis; + } else { + LOG(FATAL) << "Unhandled attribute: " << attrs; + } + return -1; +}); + } // namespace qnn } // namespace relay } // namespace tvm From 72af59c95ca29adcdce36796375f7e0894572435 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Wed, 10 Aug 2022 14:20:35 +0300 Subject: [PATCH 20/25] Fixed build for Cortex-M. --- src/relay/backend/contrib/cmsisnn/convolutions.cc | 2 +- src/relay/backend/contrib/cmsisnn/convolutions.h | 2 +- src/relay/backend/contrib/cmsisnn/generate_constants.cc | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/relay/backend/contrib/cmsisnn/convolutions.cc b/src/relay/backend/contrib/cmsisnn/convolutions.cc index ebac83b81250..67b24da8393a 100644 --- a/src/relay/backend/contrib/cmsisnn/convolutions.cc +++ b/src/relay/backend/contrib/cmsisnn/convolutions.cc @@ -29,7 +29,7 @@ namespace relay { namespace contrib { namespace cmsisnn { -bool IsCMSISNNDepthwise(const Conv2DAttrs* conv2d_attrs, const Array& input_shape, +bool IsCMSISNNDepthwise(const qnn::QConv2DAttrs* conv2d_attrs, const Array& input_shape, const Array& kernel_shape) { std::string kernel_layout = conv2d_attrs->kernel_layout.c_str(); int kernel_pos_o = kernel_layout.find("O"); diff --git a/src/relay/backend/contrib/cmsisnn/convolutions.h b/src/relay/backend/contrib/cmsisnn/convolutions.h index e635702bf353..7369b7492c1a 100644 --- a/src/relay/backend/contrib/cmsisnn/convolutions.h +++ b/src/relay/backend/contrib/cmsisnn/convolutions.h @@ -49,7 +49,7 @@ namespace cmsisnn { * attributes */ -bool IsCMSISNNDepthwise(const Conv2DAttrs* conv2d_attrs, const Array& input_shape, +bool IsCMSISNNDepthwise(const qnn::QConv2DAttrs* conv2d_attrs, const Array& input_shape, const Array& kernel_shape); } // namespace cmsisnn diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc index e08b61c457f9..c71019ff3ccf 100644 --- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -51,7 +51,8 @@ class GenerateConstantsMutator : public MixedModeMutator { private: /*! * \brief Converts Kernel layout from HWIO to OHWI to align to CMSIS-NN requirements */ - Expr ConvertKernelLayout(Expr kernel_expr, const Conv2DAttrs* conv2d_attrs, Attrs* new_attrs) { + Expr ConvertKernelLayout(Expr kernel_expr, const qnn::QConv2DAttrs* conv2d_attrs, + Attrs* new_attrs) { auto attrs = make_object(); attrs->strides = std::move(conv2d_attrs->strides); attrs->padding = std::move(conv2d_attrs->padding); @@ -94,7 +95,7 @@ class GenerateConstantsMutator : public MixedModeMutator { conv2d_call = requantize_input; } - auto* conv2d_attrs = conv2d_call->attrs.as(); + auto* conv2d_attrs = conv2d_call->attrs.as(); tvm::Attrs new_conv2d_attrs = conv2d_call->attrs; Expr conv2d_kernel = conv2d_call->args[1]; From 9fe740119a26523d30499b84ea13d4029a1feb8a Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Thu, 11 Aug 2022 14:58:00 +0300 Subject: [PATCH 21/25] Removed QDenseAttrs and QConv2dAttrs --- include/tvm/relay/qnn/attrs.h | 99 ------------------- python/tvm/relay/backend/te_compiler.py | 33 ++++--- python/tvm/relay/op/op_attrs.py | 10 -- python/tvm/relay/qnn/strategy/generic.py | 21 +--- python/tvm/topi/hexagon/__init__.py | 1 - python/tvm/topi/hexagon/qnn.py | 33 +++++-- .../backend/contrib/cmsisnn/convolutions.cc | 2 +- .../backend/contrib/cmsisnn/convolutions.h | 2 +- .../contrib/cmsisnn/generate_constants.cc | 5 +- .../backend/contrib/cmsisnn/relay_to_tir.cc | 4 +- src/relay/backend/te_compiler_cache.cc | 21 +--- src/relay/backend/utils.cc | 3 +- src/relay/qnn/op/convolution.cc | 58 ++++------- src/relay/qnn/op/dense.cc | 21 ++-- src/relay/qnn/utils.cc | 22 ----- 15 files changed, 82 insertions(+), 253 deletions(-) diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 994e0dde6c08..64b2dc20981d 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -25,7 +25,6 @@ #define TVM_RELAY_QNN_ATTRS_H_ #include -#include #include @@ -126,104 +125,6 @@ struct BroadcastAttrs : public tvm::AttrsNode { } }; -/*! \brief Attributes used in QNN convolution operator */ -struct QConv2DAttrs : public tvm::AttrsNode { - Array strides; - Array padding; - Array dilation; - int groups; - IndexExpr channels; - Array kernel_size; - tvm::String data_layout; - tvm::String kernel_layout; - tvm::String out_layout; - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite - Array meta_schedule_original_shape; // The original shape of the weights - DataType out_dtype; - - // Optional extra attributes for Hexagon target. Describes requantization parameters. - // Note, It is not set up explicitly through qnn._make.conv2d. - int axis; - DataType rq_out_dtype; - - TVM_DECLARE_ATTRS(QConv2DAttrs, "relay.attrs.QConv2DAttrs") { - TVM_ATTR_FIELD(strides) - .set_default(Array({1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding) - .set_default(Array({0, 0})) - .describe( - "If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation) - .set_default(Array({1, 1})) - .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1).describe( - "Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); - TVM_ATTR_FIELD(channels) - .describe( - "The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") - .set_default(NullValue()); - TVM_ATTR_FIELD(kernel_size) - .describe("Specifies the dimensions of the convolution window.") - .set_default(NullValue>()); - TVM_ATTR_FIELD(data_layout) - .set_default("NCHW") - .describe( - "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout) - .set_default("OIHW") - .describe( - "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout) - .set_default("") - .describe( - "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); - - // use 0 bits to indicate none. - TVM_ATTR_FIELD(out_dtype) - .set_default(NullValue()) - .describe("Output data type, set to explicit type under mixed precision setting"); - } -}; - -/*! \brief Attributes for QNN dense operator */ -struct QDenseAttrs : public tvm::AttrsNode { - IndexExpr units; - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite - Array meta_schedule_original_shape; // The original shape of the weights - DataType out_dtype; - - // Optional extra attributes for Hexagon target. Describes requantization parameters. - // Note, It is not set up explicitly through qnn._make.dense. - int axis; - DataType rq_out_dtype; - - TVM_DECLARE_ATTRS(QDenseAttrs, "relay.attrs.QDenseAttrs") { - TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); - - // use 0 bits to indicate none. - TVM_ATTR_FIELD(out_dtype) - .set_default(NullValue()) - .describe("Output data type, set to explicit type under mixed precision setting"); - } -}; - } // namespace qnn } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/backend/te_compiler.py b/python/tvm/relay/backend/te_compiler.py index a2fbf555e12b..173f31ef08f9 100644 --- a/python/tvm/relay/backend/te_compiler.py +++ b/python/tvm/relay/backend/te_compiler.py @@ -281,25 +281,28 @@ def get_shape(shape): @tvm._ffi.register_func("relay.backend.lower_call") -def lower_call(call, inputs, target): +def lower_call(call, inputs, target, otype=None): """Lower the call expression to op implementation and tensor outputs.""" assert isinstance(call.op, tvm.ir.Op) op = call.op - # Prepare the call_node->checked_type(). For the call node inputs, we ensure that - # the shape is Int32. Following code ensures the same for the output as well. - # TODO(@icemelon9): Support recursive tuple - ret_type = call.checked_type - if isinstance(ret_type, _ty.TensorType): - ret_type = _ty.TensorType(get_shape(ret_type.shape), ret_type.dtype) - elif isinstance(ret_type, _ty.TupleType): - new_fields = [] - for field in ret_type.fields: - if isinstance(field, _ty.TensorType): - new_fields.append(_ty.TensorType(get_shape(field.shape), field.dtype)) - else: - new_fields.append(field) - ret_type = _ty.TupleType(new_fields) + if otype is not None: + ret_type = otype + else: + # Prepare the call_node->checked_type(). For the call node inputs, we ensure that + # the shape is Int32. Following code ensures the same for the output as well. + # TODO(@icemelon9): Support recursive tuple + ret_type = call.checked_type + if isinstance(ret_type, _ty.TensorType): + ret_type = _ty.TensorType(get_shape(ret_type.shape), ret_type.dtype) + elif isinstance(ret_type, _ty.TupleType): + new_fields = [] + for field in ret_type.fields: + if isinstance(field, _ty.TensorType): + new_fields.append(_ty.TensorType(get_shape(field.shape), field.dtype)) + else: + new_fields.append(field) + ret_type = _ty.TupleType(new_fields) is_dyn = _ty.is_dynamic(call.checked_type) for arg in call.args: diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index d6d9ec3d2365..b76097722c07 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -529,16 +529,6 @@ class RequantizeAttrs(Attrs): """Attributes used in requantize operators""" -@tvm._ffi.register_object("relay.attrs.QConv2DAttrs") -class QConv2DAttrs(Attrs): - """Attributes used in QNN conv2d operators""" - - -@tvm._ffi.register_object("relay.attrs.QDenseAttrs") -class QDenseAttrs(Attrs): - """Attributes used in QNN dense operators""" - - @tvm._ffi.register_object("relay.attrs.ScatterAttrs") class ScatterAttrs(Attrs): """Attributes used in scatter operators""" diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py index 0ae4cbda58bb..57a364f7e057 100644 --- a/python/tvm/relay/qnn/strategy/generic.py +++ b/python/tvm/relay/qnn/strategy/generic.py @@ -16,14 +16,9 @@ # under the License. """Definition of generic operator strategy.""" -from tvm import _ffi from tvm.target import override_native_generic_func -GET_RQ_OUT_DTYPE = _ffi.get_global_func("relay.attrs.get_rq_out_dtype") -GET_RQ_AXIS = _ffi.get_global_func("relay.attrs.get_rq_axis") - - def wrap_topi_schedule(topi_schedule): """Wrap TOPI schedule which doesn't use attrs""" @@ -69,14 +64,13 @@ def wrap_topi_qnn_conv2d(topi_compute): """Wrap TOPI compute which use conv2d attrs and output data type""" def wrapper(attrs, inputs, out_type): - out_dtype = GET_RQ_OUT_DTYPE(attrs) - axis = GET_RQ_AXIS(attrs) + out_dtype = out_type.dtype oshape = out_type.shape strides = attrs.strides padding = attrs.padding dilation = attrs.dilation if len([*inputs]) == 11: - args = [*inputs, axis, strides, padding, dilation, oshape, out_dtype] + args = [*inputs, strides, padding, dilation, oshape, out_dtype] elif len([*inputs]) == 10: args = [ # QNN Conv2d params: inputs[0], @@ -92,7 +86,6 @@ def wrapper(attrs, inputs, out_type): inputs[7], inputs[8], inputs[9], - axis, # Conv2d attrs: strides, padding, @@ -111,7 +104,6 @@ def wrapper(attrs, inputs, out_type): None, None, None, - axis, strides, padding, dilation, @@ -126,11 +118,10 @@ def wrapper(attrs, inputs, out_type): def wrap_topi_qnn_dense(topi_compute): """Wrap TOPI compute which use qnn.dense attrs""" - def wrapper(attrs, inputs, _out_type): - out_dtype = GET_RQ_OUT_DTYPE(attrs) - axis = GET_RQ_AXIS(attrs) + def wrapper(_attrs, inputs, out_type): + out_dtype = out_type.dtype if len([*inputs]) == 11: - args = [*inputs, axis, out_dtype] + args = [*inputs, out_dtype] elif len([*inputs]) == 10: args = [ # QNN Dense params: inputs[0], @@ -146,7 +137,6 @@ def wrapper(attrs, inputs, _out_type): inputs[7], inputs[8], inputs[9], - axis, out_dtype, ] else: @@ -160,7 +150,6 @@ def wrapper(attrs, inputs, _out_type): None, None, None, - axis, out_dtype, ] return [topi_compute(*args)] diff --git a/python/tvm/topi/hexagon/__init__.py b/python/tvm/topi/hexagon/__init__.py index 7172ddfd7af7..b94526e5b919 100644 --- a/python/tvm/topi/hexagon/__init__.py +++ b/python/tvm/topi/hexagon/__init__.py @@ -25,7 +25,6 @@ from .injective import * from .pad import * from .pooling import * -from .qnn import * from .reduce import * from .resize2d import * from .tensor_intrin import * diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn.py index 4d997183c584..6e1abcdc5668 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn.py @@ -151,12 +151,8 @@ def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis, o def _compute(*indices): value = data(*indices) - # Account scalar and 1D quantization parameters: - iscale_idx = tvm.tir.indexmod(indices[axis], topi.shape(input_scale)[0]) - iscale = input_scale if len(input_scale.shape) == 0 else input_scale[iscale_idx] - - oscale_idx = tvm.tir.indexmod(indices[axis], topi.shape(output_scale)[0]) - oscale = output_scale if len(output_scale.shape) == 0 else output_scale[oscale_idx] + iscale = get_qnn_param(input_scale, indices, axis) + oscale = get_qnn_param(output_scale, indices, axis) sub = te.subtract(value, input_zp) mul = te.div(iscale, oscale) @@ -334,7 +330,6 @@ def qnn_conv2d( # Conv2d inputs rq_input_zero_point, rq_output_scale, rq_output_zero_point, - axis, # Conv2d attributes: strides, padding, @@ -402,6 +397,13 @@ def qnn_conv2d( # Conv2d inputs # Requantize output of convolution # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) if rq_input_scale is not None and rq_output_scale is not None: + # Now supported only scalar and 1D quantization parameters + assert len(rq_input_scale.shape) == 0 or len(rq_input_scale.shape) == 1 + assert len(rq_output_scale.shape) == 0 or len(rq_output_scale.shape) == 1 + axis = -1 + if len(rq_input_scale.shape) == 1 or len(rq_output_scale.shape) == 1: + axis = 1 # Axis param should correspond to 'C' dimension. + return qnn_requantize( out, rq_input_scale, @@ -447,7 +449,6 @@ def qnn_depthwise_conv2d( # Conv2d inputs rq_input_zero_point, rq_output_scale, rq_output_zero_point, - axis, # Conv2d attributes: strides, padding, @@ -510,6 +511,13 @@ def qnn_depthwise_conv2d( # Conv2d inputs # Requantize output of convolution # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) if rq_input_scale is not None and rq_output_scale is not None: + # Now supported only scalar and 1D quantization parameters + assert len(rq_input_scale.shape) == 0 or len(rq_input_scale.shape) == 1 + assert len(rq_output_scale.shape) == 0 or len(rq_output_scale.shape) == 1 + axis = -1 + if len(rq_input_scale.shape) == 1 or len(rq_output_scale.shape) == 1: + axis = 1 # Axis param should correspond to 'C' dimension. + return qnn_requantize( out, rq_input_scale, @@ -555,7 +563,6 @@ def qnn_dense( rq_input_zero_point, rq_output_scale, rq_output_zero_point, - axis, out_dtype, ): """Compute for qnn.dense @@ -563,7 +570,6 @@ def qnn_dense( Note! This is POC code. There was no goal to implement high performance compute function. """ - M, K = get_const_tuple(data.shape) N, _ = get_const_tuple(weight.shape) k = te.reduce_axis((0, K), "k") @@ -587,6 +593,13 @@ def qnn_dense( # Requantize output of dense # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) if rq_input_scale is not None and rq_output_scale is not None: + # Now supported only scalar and 1D quantization parameters + assert len(rq_input_scale.shape) == 0 or len(rq_input_scale.shape) == 1 + assert len(rq_output_scale.shape) == 0 or len(rq_output_scale.shape) == 1 + axis = -1 + if len(rq_input_scale.shape) == 1 or len(rq_output_scale.shape) == 1: + axis = 1 # Axis param should correspond to 'N' dimension. + return qnn_requantize( out, rq_input_scale, diff --git a/src/relay/backend/contrib/cmsisnn/convolutions.cc b/src/relay/backend/contrib/cmsisnn/convolutions.cc index 67b24da8393a..ebac83b81250 100644 --- a/src/relay/backend/contrib/cmsisnn/convolutions.cc +++ b/src/relay/backend/contrib/cmsisnn/convolutions.cc @@ -29,7 +29,7 @@ namespace relay { namespace contrib { namespace cmsisnn { -bool IsCMSISNNDepthwise(const qnn::QConv2DAttrs* conv2d_attrs, const Array& input_shape, +bool IsCMSISNNDepthwise(const Conv2DAttrs* conv2d_attrs, const Array& input_shape, const Array& kernel_shape) { std::string kernel_layout = conv2d_attrs->kernel_layout.c_str(); int kernel_pos_o = kernel_layout.find("O"); diff --git a/src/relay/backend/contrib/cmsisnn/convolutions.h b/src/relay/backend/contrib/cmsisnn/convolutions.h index 7369b7492c1a..e635702bf353 100644 --- a/src/relay/backend/contrib/cmsisnn/convolutions.h +++ b/src/relay/backend/contrib/cmsisnn/convolutions.h @@ -49,7 +49,7 @@ namespace cmsisnn { * attributes */ -bool IsCMSISNNDepthwise(const qnn::QConv2DAttrs* conv2d_attrs, const Array& input_shape, +bool IsCMSISNNDepthwise(const Conv2DAttrs* conv2d_attrs, const Array& input_shape, const Array& kernel_shape); } // namespace cmsisnn diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc index c71019ff3ccf..e08b61c457f9 100644 --- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -51,8 +51,7 @@ class GenerateConstantsMutator : public MixedModeMutator { private: /*! * \brief Converts Kernel layout from HWIO to OHWI to align to CMSIS-NN requirements */ - Expr ConvertKernelLayout(Expr kernel_expr, const qnn::QConv2DAttrs* conv2d_attrs, - Attrs* new_attrs) { + Expr ConvertKernelLayout(Expr kernel_expr, const Conv2DAttrs* conv2d_attrs, Attrs* new_attrs) { auto attrs = make_object(); attrs->strides = std::move(conv2d_attrs->strides); attrs->padding = std::move(conv2d_attrs->padding); @@ -95,7 +94,7 @@ class GenerateConstantsMutator : public MixedModeMutator { conv2d_call = requantize_input; } - auto* conv2d_attrs = conv2d_call->attrs.as(); + auto* conv2d_attrs = conv2d_call->attrs.as(); tvm::Attrs new_conv2d_attrs = conv2d_call->attrs; Expr conv2d_kernel = conv2d_call->args[1]; diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index d0605cdef5ec..da51e6b762dd 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -179,7 +179,7 @@ class RelayToTIRVisitor : public MixedModeMutator { // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 // prepare cmsis_nn_conv_params - const qnn::QConv2DAttrs* conv2d_attrs = conv2d_call->attrs.as(); + const Conv2DAttrs* conv2d_attrs = conv2d_call->attrs.as(); int32_t input_offset = -GetScalarFromConstant(conv2d_call->args[2]); int32_t output_offset = GetScalarFromConstant(requantize_call->args[4]); int32_t stride_w = qnn::get_const_int(conv2d_attrs->strides[1]); @@ -328,7 +328,7 @@ class RelayToTIRVisitor : public MixedModeMutator { // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 // prepare cmsis_nn_fc_params - const qnn::QDenseAttrs* dense_attrs = fc_call->attrs.as(); + const DenseAttrs* dense_attrs = fc_call->attrs.as(); int32_t input_offset = -GetScalarFromConstant(fc_call->args[2]); int32_t filter_offset = -GetScalarFromConstant(fc_call->args[3]); int32_t output_offset = GetScalarFromConstant(requantize_call->args[4]); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 8f7cacb27f8d..633952ed2e53 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -29,7 +29,6 @@ #include #include #include -#include #include #include #include @@ -174,20 +173,6 @@ class PatternMatcher { // returns whether given Op is last in the pattern sequence. bool IsLeafOp(const Op& op) { return op == qnn_requantize_op_; } - // Copy requantization attributes from one node to another. - void CopyAttrs(const CallNode* from, const CallNode* to) { - const auto* requantize_attrs = from->attrs.as(); - if (auto* pattr = const_cast(to->attrs.as())) { - pattr->axis = requantize_attrs->axis; - pattr->rq_out_dtype = requantize_attrs->out_dtype; - } else if (auto* pattr = const_cast(to->attrs.as())) { - pattr->axis = requantize_attrs->axis; - pattr->rq_out_dtype = requantize_attrs->out_dtype; - } else { - LOG(FATAL) << "Unsupported op: " << PrettyPrint(to->op); - } - } - const CallNode* GetAnchorOp() { return anchor_op_; } void Clear() { registered_ops_.clear(); } @@ -303,7 +288,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslatorkind->device_type == kDLHexagon) pattern_matcher_.Register(call_node); + if (target_->GetTargetDeviceType() == kDLHexagon) pattern_matcher_.Register(call_node); Array inputs; int count_tuple = 0; @@ -328,8 +313,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator(anchor_op), inputs, target_); + LoweredOutput lowered_out = + (*flower_call)(GetRef(anchor_op), inputs, target_, call_node->checked_type()); outputs = lowered_out->outputs; Op a_op = Downcast(anchor_op->op); op_implementations_[a_op.operator->()] = lowered_out->implementation; diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index f47a4eebe33f..6d486e3db475 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -232,7 +232,8 @@ Array GetPassPrefix(Target homogeneous_target, bool is_vm) { pass_seqs.push_back(transform::ToBasicBlockNormalForm()); // Run all dialect legalization passes. // Skip these passes for Hexagon target. - if ((is_homogeneous && homogeneous_target->kind->device_type != kDLHexagon) || !is_homogeneous) + if ((is_homogeneous && homogeneous_target->GetTargetDeviceType() != kDLHexagon) || + !is_homogeneous) pass_seqs.push_back(relay::qnn::transform::Legalize()); // Legalize pass is restricted to homogeneous execution for now. diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 6ece9220401a..64a5a02e6e25 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -38,8 +38,6 @@ namespace tvm { namespace relay { namespace qnn { -TVM_REGISTER_NODE_TYPE(QConv2DAttrs); - // relay.op.qnn.conv2d bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -50,8 +48,8 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* data = types[0].as(); const auto* weight = types[1].as(); if (data == nullptr || weight == nullptr) return false; - const auto* param = attrs.as(); - ICHECK(param != nullptr) << "QConv2DAttrs cannot be nullptr."; + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "Conv2DAttrs cannot be nullptr."; ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) || data->dtype == DataType::Int(16)) << "Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype; @@ -85,25 +83,10 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, reporter); // weight_scale } - // Create Conv2DAttrs from QConv2DAttrs - auto conv2d_attrs = make_object(); - conv2d_attrs->strides = param->strides; - conv2d_attrs->padding = param->padding; - conv2d_attrs->dilation = param->dilation; - conv2d_attrs->groups = param->groups; - conv2d_attrs->channels = param->channels; - conv2d_attrs->kernel_size = param->kernel_size; - conv2d_attrs->data_layout = param->data_layout; - conv2d_attrs->kernel_layout = param->kernel_layout; - conv2d_attrs->out_layout = param->out_layout; - conv2d_attrs->out_dtype = param->out_dtype; - conv2d_attrs->auto_scheduler_rewritten_layout = param->auto_scheduler_rewritten_layout; - conv2d_attrs->meta_schedule_original_shape = param->meta_schedule_original_shape; - // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Conv2D infer type function. Array tensor_types = {types[0], types[1], types[6]}; - return Conv2DRel(tensor_types, 3, Attrs(conv2d_attrs), reporter); + return Conv2DRel(tensor_types, 3, attrs, reporter); } InferCorrectLayoutOutput QnnConvInferCorrectLayout(const Attrs& attrs, @@ -112,7 +95,7 @@ InferCorrectLayoutOutput QnnConvInferCorrectLayout(const Attrs& attrs, const Array& old_in_types) { // Use Relay Conv2D Infer correct layout. auto conv_new_layouts = - ConvInferCorrectLayout(attrs, new_in_layouts, old_in_layouts, old_in_types); + ConvInferCorrectLayout(attrs, new_in_layouts, old_in_layouts, old_in_types); // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these // tensors can be treated as channel layout. @@ -127,7 +110,7 @@ InferCorrectLayoutOutput QnnConvInferCorrectLayout(const Attrs& attrs, return InferCorrectLayoutOutput(input_layouts, output_layouts, attrs); } -bool is_depthwise(const QConv2DAttrs* param) { +bool is_depthwise(const Conv2DAttrs* param) { return param->channels.defined() && tvm::tir::ExprDeepEqual()(param->channels, param->groups) && param->groups != 1; } @@ -141,7 +124,7 @@ using WorkloadType = std::tuple; * \param param The qnn conv2d attributes. * \return A tuple of workload. */ -WorkloadType GetWorkload(const Array& arg_types, const QConv2DAttrs* param) { +WorkloadType GetWorkload(const Array& arg_types, const Conv2DAttrs* param) { // Get conv parameters. const auto in_shape = get_shape(arg_types[0]); int batch_size, in_channels; @@ -208,7 +191,7 @@ WorkloadType GetWorkload(const Array& arg_types, const QConv2D * int32 tensors instead of int8 tensors. */ Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero_point, - const Expr& kernel_zero_point, const QConv2DAttrs* param) { + const Expr& kernel_zero_point, const Conv2DAttrs* param) { // Upcast the parameters to be at least int32 to avoid overflow auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits(); @@ -241,7 +224,7 @@ Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero * cannot be fused with conv in Relay. In case we see performance * degradation, we can change the conv2D API to accept a pad_const value. */ -Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const QConv2DAttrs* param) { +Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2DAttrs* param) { // 1) Pad the input data auto padded_data = data; auto pad_top_value = get_const_int(param->padding[0]); @@ -287,7 +270,7 @@ Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const QConv2 * followed by repeat on the C axis by cm times. */ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, - const QConv2DAttrs* param, int kernel_h, int kernel_w, + const Conv2DAttrs* param, int kernel_h, int kernel_w, int channel_multiplier) { auto casted_t2 = Cast(padded_data, DataType::Int(32)); @@ -360,7 +343,7 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_ * (1, oc, 1, 1) as (oc/m, oc%m) are just contiguous memory locations. */ Expr DepthwiseConv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, - const QConv2DAttrs* param, int out_channels, int channel_multiplier) { + const Conv2DAttrs* param, int out_channels, int channel_multiplier) { // Find which dimensions are R, S. Array axes_t3; if (param->kernel_layout == "OIHW") { @@ -439,7 +422,7 @@ Expr DepthwiseConv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_ * Sigma(c,r,s) QW(k, c, r, s) * QA(n, c, h + r, w + s) * This is just conv2d on int tensors. */ -Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QConv2DAttrs* param) { +Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const Conv2DAttrs* param) { // Lowering for Term 1 Array padding({0, 0, 0, 0}); return Conv2D(padded_data, weight, param->strides, padding, param->dilation, param->groups, @@ -465,7 +448,7 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QConv2DA * opportunity to reuse alter_op_layout infrastructure. */ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, - const QConv2DAttrs* param, int kernel_h, int kernel_w, int out_channels) { + const Conv2DAttrs* param, int kernel_h, int kernel_w, int out_channels) { auto casted_t2 = Cast(padded_data, DataType::Int(32)); // We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum. @@ -535,7 +518,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, * a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW * format. */ -Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const QConv2DAttrs* param, +Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const Conv2DAttrs* param, int out_channels) { // Find which dimensions are C, R, S. Array axes_t3; @@ -586,7 +569,7 @@ Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const QCo * */ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int in_channels, - int kernel_h, int kernel_w, const QConv2DAttrs* param) { + int kernel_h, int kernel_w, const Conv2DAttrs* param) { auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits(); int scalar_term4 = input_zero_point_int * kernel_zero_point_int * in_channels * kernel_h * kernel_w; @@ -609,7 +592,7 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i * */ Expr Conv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels, - int kernel_h, int kernel_w, const QConv2DAttrs* param) { + int kernel_h, int kernel_w, const Conv2DAttrs* param) { auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits(); Expr scalar_term4 = MakeConstantScalar(DataType::Int(upcast_bits), in_channels * kernel_h * kernel_w); @@ -729,7 +712,7 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, Expr weight = new_args[1]; Expr input_zero_point = new_args[2]; Expr kernel_zero_point = new_args[3]; - const auto* param = attrs.as(); + const auto* param = attrs.as(); ICHECK(param != nullptr); // Assertion checks for existing support. ICHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC") @@ -834,7 +817,7 @@ Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_ze Array padding, Array dilation, int groups, IndexExpr channels, Array kernel_size, String data_layout, String kernel_layout, String out_layout, DataType out_dtype) { - auto attrs = make_object(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -845,11 +828,6 @@ Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_ze attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); - - // Optional extra attributes for requantization. - attrs->axis = -1; - attrs->rq_out_dtype = attrs->out_dtype; - static const Op& op = Op::Get("qnn.conv2d"); return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, Attrs(attrs), {}); @@ -868,7 +846,7 @@ operator to understand how to scale back the int32 output to (u)int8 or (u)int16 - **out**: This depends on the `layout` parameter. Output is 4D array of shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(6) .add_argument("data", "Tensor", "The quantized input data tensor.") .add_argument("weight", "Tensor", "The quantized weight tensor.") diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index e5a21f134cb8..adaf509e7daf 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -35,8 +35,6 @@ namespace tvm { namespace relay { namespace qnn { -TVM_REGISTER_NODE_TYPE(QDenseAttrs); - // relay.op.qnn.dense bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -47,8 +45,8 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* data = types[0].as(); const auto* weight = types[1].as(); if (data == nullptr || weight == nullptr) return false; - const auto* param = attrs.as(); - ICHECK(param != nullptr) << "QDenseAttrs cannot be nullptr."; + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "DenseAttrs cannot be nullptr."; ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8)) << "Expected quantized dense type(int8, uint8) for input but was " << data->dtype; ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8)) @@ -72,27 +70,22 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Dense infer type function. Array tensor_types = {types[0], types[1], types[6]}; - return MatmulRel(tensor_types, 3, attrs, reporter); + return MatmulRel(tensor_types, 3, attrs, reporter); } // Positional relay function to create quantized dense operator used by frontend FFI. Expr MakeQuantizedDense(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point, Expr input_scale, Expr kernel_scale, IndexExpr units, DataType out_dtype) { - auto attrs = make_object(); + auto attrs = make_object(); attrs->units = std::move(units); attrs->out_dtype = out_dtype; - - // Optional extra attributes for requantization. - attrs->axis = -1; - attrs->rq_out_dtype = attrs->out_dtype; - static const Op& op = Op::Get("qnn.dense"); return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, Attrs(attrs), {}); } Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel, - const QDenseAttrs* attrs) { + const DenseAttrs* attrs) { return Dense(quantized_data, quantized_kernel, attrs->units, attrs->out_dtype); } @@ -177,7 +170,7 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, const int reduction_dim_size = get_const_int(in_shape[1]); const int out_dim_size = get_const_int(w_shape[0]); - const auto* qnn_dense_attrs = attrs.as(); + const auto* qnn_dense_attrs = attrs.as(); auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point, out_dim_size); @@ -217,7 +210,7 @@ RELAY_REGISTER_OP("qnn.dense") - **weight**: quantized(int8, unit8) `(units, input_dim)` - **out**: quantized(int32) `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(6) .add_argument("data", "quantized nD Tensor", "Input data.") .add_argument("weight", "quantized 2D Tensor", "Weight matrix.") diff --git a/src/relay/qnn/utils.cc b/src/relay/qnn/utils.cc index eedf68724535..ed7a415cf6af 100644 --- a/src/relay/qnn/utils.cc +++ b/src/relay/qnn/utils.cc @@ -213,28 +213,6 @@ std::string SelectRequntizeParameter(const std::string& arg_value, const std::st } } -TVM_REGISTER_GLOBAL("relay.attrs.get_rq_out_dtype").set_body_typed([](const Attrs& attrs) { - if (attrs->IsInstance()) { - return attrs.as()->rq_out_dtype; - } else if (attrs->IsInstance()) { - return attrs.as()->rq_out_dtype; - } else { - LOG(FATAL) << "Unhandled attribute: " << attrs; - } - return DataType(); -}); - -TVM_REGISTER_GLOBAL("relay.attrs.get_rq_axis").set_body_typed([](const Attrs& attrs) { - if (attrs->IsInstance()) { - return attrs.as()->axis; - } else if (attrs->IsInstance()) { - return attrs.as()->axis; - } else { - LOG(FATAL) << "Unhandled attribute: " << attrs; - } - return -1; -}); - } // namespace qnn } // namespace relay } // namespace tvm From 968be2bc92f7fdbcb33f05416f741d4d0a26f31f Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Mon, 17 Oct 2022 12:55:58 +0300 Subject: [PATCH 22/25] Fix tests after rebase --- python/tvm/topi/hexagon/qnn/__init__.py | 1 + python/tvm/topi/hexagon/{ => qnn}/qnn.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) rename python/tvm/topi/hexagon/{ => qnn}/qnn.py (99%) diff --git a/python/tvm/topi/hexagon/qnn/__init__.py b/python/tvm/topi/hexagon/qnn/__init__.py index 2616b9315a9b..0f8630c5484d 100644 --- a/python/tvm/topi/hexagon/qnn/__init__.py +++ b/python/tvm/topi/hexagon/qnn/__init__.py @@ -25,3 +25,4 @@ ) from .quantize import quantize_compute, tir_quantize_schedule +from .qnn import * diff --git a/python/tvm/topi/hexagon/qnn.py b/python/tvm/topi/hexagon/qnn/qnn.py similarity index 99% rename from python/tvm/topi/hexagon/qnn.py rename to python/tvm/topi/hexagon/qnn/qnn.py index 6e1abcdc5668..90e6dc651abd 100644 --- a/python/tvm/topi/hexagon/qnn.py +++ b/python/tvm/topi/hexagon/qnn/qnn.py @@ -19,11 +19,11 @@ import tvm from tvm import te, topi -from ..utils import get_const_tuple -from ..nn.utils import get_pad_tuple -from ..nn.pad import pad -from .. import tag, nn -from ..x86.concat import concatenate +from ...utils import get_const_tuple +from ...nn.utils import get_pad_tuple +from ...nn.pad import pad +from ... import tag, nn +from ...x86.concat import concatenate def clip_cast(val, dtype): From b7e6e263b247ca69ceee1a144927bae8f413e9ab Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Mon, 17 Oct 2022 18:34:30 +0300 Subject: [PATCH 23/25] Address code review comments. --- include/tvm/relay/qnn/transform.h | 7 ++++ include/tvm/relay/transform.h | 6 ++++ python/tvm/topi/hexagon/qnn/__init__.py | 2 +- python/tvm/topi/hexagon/qnn/{qnn.py => nn.py} | 33 +++++++------------ src/relay/backend/te_compiler_cache.cc | 12 +++---- src/relay/backend/utils.cc | 9 +++-- src/relay/qnn/pass/legalize.cc | 12 ++++++- 7 files changed, 48 insertions(+), 33 deletions(-) rename python/tvm/topi/hexagon/qnn/{qnn.py => nn.py} (95%) diff --git a/include/tvm/relay/qnn/transform.h b/include/tvm/relay/qnn/transform.h index d1f07c924d6b..6977a5b4dd1d 100644 --- a/include/tvm/relay/qnn/transform.h +++ b/include/tvm/relay/qnn/transform.h @@ -51,6 +51,13 @@ namespace transform { */ TVM_DLL Pass Legalize(); +/*! + * \brief Legalizes a QNN expr (without QNN Canonicalization). + * + * \return The pass. + */ +TVM_DLL Pass QnnLegalize(); + } // namespace transform } // namespace qnn diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index cdea8e8e3c23..dc4f5076c41c 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -598,6 +598,12 @@ TVM_DLL Pass RemoveStandaloneReshapes(); } // namespace transform +namespace legalize { + +TVM_DLL Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name); + +} // namespace legalize + /*! * \brief Bind the free variables to a Relay expression. This is a helper * function usually called by other pass functions to help optimizations. diff --git a/python/tvm/topi/hexagon/qnn/__init__.py b/python/tvm/topi/hexagon/qnn/__init__.py index 0f8630c5484d..bafc6846b6fb 100644 --- a/python/tvm/topi/hexagon/qnn/__init__.py +++ b/python/tvm/topi/hexagon/qnn/__init__.py @@ -25,4 +25,4 @@ ) from .quantize import quantize_compute, tir_quantize_schedule -from .qnn import * +from .nn import * diff --git a/python/tvm/topi/hexagon/qnn/qnn.py b/python/tvm/topi/hexagon/qnn/nn.py similarity index 95% rename from python/tvm/topi/hexagon/qnn/qnn.py rename to python/tvm/topi/hexagon/qnn/nn.py index 90e6dc651abd..40cfd0ee96b1 100644 --- a/python/tvm/topi/hexagon/qnn/qnn.py +++ b/python/tvm/topi/hexagon/qnn/nn.py @@ -65,8 +65,6 @@ def default_schedule(outs): def qnn_quantize(data, output_scale, output_zero_point, axis, out_dtype): """Compute for qnn.quantize - Note! This is POC code. There was no goal to implement high performance compute function. - Q_output = clamp((round(input_tensor/output_scale) + output_zero_point), out_dtype::min, out_dtype::max) @@ -106,8 +104,6 @@ def schedule_qnn_quantize(outs): def qnn_dequantize(data, input_scale, input_zero_point, axis): """Compute for qnn.dequantize - Note! This is POC code. There was no goal to implement high performance compute function. - fp_output = input_scale * (Q_input - input_zero_point) """ @@ -141,8 +137,6 @@ def schedule_qnn_dequantize(outs): def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp, axis, out_dtype): """Compute for qnn.requantize - Note! This is POC code. There was no goal to implement high performance compute function. - Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input - zp_input)) TODO: support 'rounding' and 'compute_dtype' arguments. @@ -188,8 +182,6 @@ def qnn_add( ): """Compute for qnn.add - Note! This is POC code. There was no goal to implement high performance compute function. - Q_output = zp_output + round((lhs_scale)/(scale_output) * (lhs_input - lhs_zp_input)) + round((rhs_scale)/(scale_output) * (rhs_input - rhs_zp_input)) @@ -253,8 +245,6 @@ def _compute(*indices): def qnn_concatenate(data, axis, out_dtype): """Compute for qnn.concatenate - Note! This is POC code. There was no goal to implement high performance compute function. - Parameters ---------- data: Array of Tensor @@ -337,10 +327,11 @@ def qnn_conv2d( # Conv2d inputs oshape, odtype, ): - """Compute for qnn.conv2d with NCHW layout - - Note! This is POC code. There was no goal to implement high performance compute function. + """Compute for qnn.conv2d with NCHW layout. + Output data type should be specified through the 'odtype' parameter. qnn.conv2d leverages int32 + type to store intermediate results. If 'odtype' differs from int32, you need to specify + requantization parameters. """ in_channel = data.shape[1] # NCHW layout kernel_height = weight.shape[2] # OIHW layout @@ -458,8 +449,9 @@ def qnn_depthwise_conv2d( # Conv2d inputs ): """Compute for qnn.conv2d with NCHW layout - Note! This is POC code. There was no goal to implement high performance compute function. - + Output data type should be specified through the 'odtype' parameter. qdepthwise nn.conv2d + leverages int32 type to store intermediate results. If 'odtype' differs from int32, you need to + specify requantization parameters. """ kernel_height = weight.shape[2] # OIHW layout kernel_width = weight.shape[3] # OIHW layout @@ -567,8 +559,9 @@ def qnn_dense( ): """Compute for qnn.dense - Note! This is POC code. There was no goal to implement high performance compute function. - + Output data type should be specified through the 'odtype' parameter. qnn.dense leverages int32 + type to store intermediate results. If 'odtype' differs from int32, you need to specify + requantization parameters. """ M, K = get_const_tuple(data.shape) N, _ = get_const_tuple(weight.shape) @@ -643,11 +636,7 @@ def qnn_batch_matmul( transpose_b, out_dtype, ): - """Compute for qnn.dense - - Note! This is POC code. There was no goal to implement high performance compute function. - - """ + """Compute for qnn.batch_matmul""" # Preprocess tensor_a: subtract zp a_sub_zp = te.compute( diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 633952ed2e53..e7326ed5dd4d 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -124,11 +124,11 @@ Array GetShape(const Array& shape) { } // Helper class that is used during lowering to TE. -// It matches sequence of Ops and lower them into single TOPI operation. Has sense for Hexagon only. -// All supported patterns are enumerated in "supported_patterns_" -class PatternMatcher { +// It matches sequence of Ops and lower them into single TOPI operation. All supported patterns are +// enumerated in "supported_patterns_". +class QnnPatternMatcher { public: - PatternMatcher() + QnnPatternMatcher() : qnn_conv2d_op_(Op::Get("qnn.conv2d")), qnn_dense_op_(Op::Get("qnn.dense")), qnn_requantize_op_(Op::Get("qnn.requantize")), @@ -288,7 +288,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslatorGetTargetDeviceType() == kDLHexagon) pattern_matcher_.Register(call_node); + pattern_matcher_.Register(call_node); Array inputs; int count_tuple = 0; @@ -385,7 +385,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslator GetPassPrefix(Target homogeneous_target, bool is_vm) { pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); pass_seqs.push_back(transform::ToBasicBlockNormalForm()); // Run all dialect legalization passes. - // Skip these passes for Hexagon target. - if ((is_homogeneous && homogeneous_target->GetTargetDeviceType() != kDLHexagon) || - !is_homogeneous) + if (is_homogeneous && homogeneous_target->GetTargetDeviceType() == kDLHexagon) { + // Run QNN Legalize. + pass_seqs.push_back(relay::qnn::transform::QnnLegalize()); + } else { + // Run QNN Legalize + QNN Canonicalize. pass_seqs.push_back(relay::qnn::transform::Legalize()); + } // Legalize pass is restricted to homogeneous execution for now. if (is_homogeneous) { diff --git a/src/relay/qnn/pass/legalize.cc b/src/relay/qnn/pass/legalize.cc index 33b9e59ab241..0b0137e4b197 100644 --- a/src/relay/qnn/pass/legalize.cc +++ b/src/relay/qnn/pass/legalize.cc @@ -30,9 +30,19 @@ namespace qnn { namespace transform { +Pass QnnLegalize() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, tvm::transform::PassContext pc) { + return Downcast(relay::legalize::Legalize(f, "FTVMQnnLegalize")); + }; + return tvm::relay::transform::CreateFunctionPass(pass_func, 1, "qnn.Legalize", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay.qnn._transform.QnnLegalize").set_body_typed(QnnLegalize); + Pass Legalize() { Array pass_seqs; - pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize")); + pass_seqs.push_back(QnnLegalize()); pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize")); relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); return seq; From c1b63999843291063025bd5f87d3a8e7a8756eb3 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Tue, 18 Oct 2022 12:32:43 +0300 Subject: [PATCH 24/25] [QNN] Add option to disabe QNN passes. QNN passes are enabled by default. To disable use disabled_pass=["qnn.Legalize"] in pass config. --- include/tvm/relay/qnn/transform.h | 7 ------ include/tvm/relay/transform.h | 6 ----- src/relay/backend/utils.cc | 8 +------ src/relay/qnn/pass/legalize.cc | 14 ++--------- .../test_wo_qnn_canonicalization.py | 23 ++++++++++++------- 5 files changed, 18 insertions(+), 40 deletions(-) diff --git a/include/tvm/relay/qnn/transform.h b/include/tvm/relay/qnn/transform.h index 6977a5b4dd1d..d1f07c924d6b 100644 --- a/include/tvm/relay/qnn/transform.h +++ b/include/tvm/relay/qnn/transform.h @@ -51,13 +51,6 @@ namespace transform { */ TVM_DLL Pass Legalize(); -/*! - * \brief Legalizes a QNN expr (without QNN Canonicalization). - * - * \return The pass. - */ -TVM_DLL Pass QnnLegalize(); - } // namespace transform } // namespace qnn diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index dc4f5076c41c..cdea8e8e3c23 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -598,12 +598,6 @@ TVM_DLL Pass RemoveStandaloneReshapes(); } // namespace transform -namespace legalize { - -TVM_DLL Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name); - -} // namespace legalize - /*! * \brief Bind the free variables to a Relay expression. This is a helper * function usually called by other pass functions to help optimizations. diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 07d90a3be1b2..0108211151cd 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -231,13 +231,7 @@ Array GetPassPrefix(Target homogeneous_target, bool is_vm) { pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); pass_seqs.push_back(transform::ToBasicBlockNormalForm()); // Run all dialect legalization passes. - if (is_homogeneous && homogeneous_target->GetTargetDeviceType() == kDLHexagon) { - // Run QNN Legalize. - pass_seqs.push_back(relay::qnn::transform::QnnLegalize()); - } else { - // Run QNN Legalize + QNN Canonicalize. - pass_seqs.push_back(relay::qnn::transform::Legalize()); - } + pass_seqs.push_back(relay::qnn::transform::Legalize()); // Legalize pass is restricted to homogeneous execution for now. if (is_homogeneous) { diff --git a/src/relay/qnn/pass/legalize.cc b/src/relay/qnn/pass/legalize.cc index 0b0137e4b197..a5906cf5e694 100644 --- a/src/relay/qnn/pass/legalize.cc +++ b/src/relay/qnn/pass/legalize.cc @@ -30,21 +30,11 @@ namespace qnn { namespace transform { -Pass QnnLegalize() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, tvm::transform::PassContext pc) { - return Downcast(relay::legalize::Legalize(f, "FTVMQnnLegalize")); - }; - return tvm::relay::transform::CreateFunctionPass(pass_func, 1, "qnn.Legalize", {"InferType"}); -} - -TVM_REGISTER_GLOBAL("relay.qnn._transform.QnnLegalize").set_body_typed(QnnLegalize); - Pass Legalize() { Array pass_seqs; - pass_seqs.push_back(QnnLegalize()); + pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize")); pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize")); - relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); + relay::transform::Pass seq = relay::transform::Sequential(pass_seqs, "qnn.Legalize"); return seq; } diff --git a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py index b32a260a5f02..24da1faac697 100644 --- a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py +++ b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py @@ -38,11 +38,16 @@ def test_no_qnn_pass(): opt_mod_1, _ = relay.optimize(mod, tvm.target.Target(target_hexagon, host=target_hexagon)) # Disable QNN legalization and canonicalization passes - with tvm.transform.PassContext(opt_level=3, disabled_pass=["Legalize"]): + with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]): opt_mod_2, _ = relay.optimize(mod, tvm.target.Target(target_hexagon, host=target_hexagon)) - # Check that during Default compilation flow we do not call qnn::canonicalization pass. - tvm.ir.assert_structural_equal(opt_mod_1, opt_mod_2) + # Check that QNN ops are absent with default compilation flow. + assert "qnn.quantize" not in opt_mod_1.astext(show_meta_data=False) + assert "qnn.dequantize" not in opt_mod_1.astext(show_meta_data=False) + + # Check that QNN ops are present without "qnn.Legalize" passes. + assert "qnn.quantize" in opt_mod_2.astext(show_meta_data=False) + assert "qnn.dequantize" in opt_mod_2.astext(show_meta_data=False) def execute(executor, data_np, weight_np, bias_np=None): @@ -56,8 +61,8 @@ def execute(executor, data_np, weight_np, bias_np=None): @tvm.testing.requires_hexagon def test_qnn_conv2d_rq(hexagon_session: Session): - data_shape = [1, 64, 64, 64] - weight_shape = [64, 64, 3, 3] + data_shape = [1, 8, 32, 32] + weight_shape = [16, 8, 3, 3] data = relay.var("data", shape=data_shape, dtype="float32") weight = relay.var("weight", shape=weight_shape, dtype="float32") op0 = relay.qnn.op.quantize(data, relay.const(0.078), relay.const(0), out_dtype="int8") @@ -70,7 +75,7 @@ def test_qnn_conv2d_rq(hexagon_session: Session): input_scale=relay.const(0.078), kernel_scale=relay.const(0.07), padding=[0, 0, 0, 0], - channels=64, + channels=16, kernel_size=[3, 3], ) op5 = relay.qnn.op.requantize( @@ -86,13 +91,14 @@ def test_qnn_conv2d_rq(hexagon_session: Session): target_hexagon = tvm.target.hexagon("v68") target_llvm = tvm.target.Target("llvm") executor = Executor("graph", {"link-params": True}) - with tvm.transform.PassContext(opt_level=3): + with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]): hexagon_lowered = tvm.relay.build( relay_mod, tvm.target.Target(target_hexagon, host=target_hexagon), executor=executor, ) + with tvm.transform.PassContext(opt_level=3): llvm_lowered = tvm.relay.build( relay_mod, tvm.target.Target(target_llvm, host=target_llvm), @@ -147,13 +153,14 @@ def test_qnn_dense_bias_rq(hexagon_session: Session): target_hexagon = tvm.target.hexagon("v68") target_llvm = tvm.target.Target("llvm") executor = Executor("graph", {"link-params": True}) - with tvm.transform.PassContext(opt_level=3): + with tvm.transform.PassContext(opt_level=3, disabled_pass=["qnn.Legalize"]): hexagon_lowered = tvm.relay.build( relay_mod, tvm.target.Target(target_hexagon, host=target_hexagon), executor=executor, ) + with tvm.transform.PassContext(opt_level=3): llvm_lowered = tvm.relay.build( relay_mod, tvm.target.Target(target_llvm, host=target_llvm), From d5dc496b911dd38d6d919dc2b8b92abb5f04e541 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Tue, 18 Oct 2022 13:17:23 +0300 Subject: [PATCH 25/25] Revert changes of GetPassPrefix interface. --- src/relay/backend/build_module.cc | 2 +- src/relay/backend/task_extraction.cc | 2 +- src/relay/backend/utils.cc | 4 +--- src/relay/backend/utils.h | 4 ++-- src/relay/backend/vm/compiler.cc | 2 +- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 1d1bd69b54c9..bca524794a20 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -328,7 +328,7 @@ class RelayBuildModule : public runtime::ModuleNode { backend::BindParamsInModule(relay_module, params_); Array pass_seqs = - GetPassPrefix(/*homogeneous target=*/config_->optional_homogeneous_target, /*is_vm=*/false); + GetPassPrefix(/*is_homogenous=*/config_->primitive_targets.size() == 1, /*is_vm=*/false); transform::PassContext pass_ctx = PassContext::Current(); if (config_->optional_homogeneous_target.defined()) { diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 703f18587c6c..213841c621de 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -36,7 +36,7 @@ Array ExtractTask(IRModule mod, Target target, backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter(); backend::BindParamsInModule(mod, params); // is_vm=true for backward compatibility - Array pass_seqs = relay::backend::GetPassPrefix(target, /*is_vm=*/true); + Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); pass_seqs.push_back(transform::FuseOps()); mod = transform::Sequential(pass_seqs)(std::move(mod)); diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 0108211151cd..51bcab527d1b 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -219,15 +219,13 @@ ExecutorCodegenMetadata::ExecutorCodegenMetadata( TVM_REGISTER_NODE_TYPE(ExecutorCodegenMetadataNode); -Array GetPassPrefix(Target homogeneous_target, bool is_vm) { +Array GetPassPrefix(bool is_homogeneous, bool is_vm) { Array pass_seqs; // TODO(mbs): Would be nice to get spans on all diagnostics, but since they arg forgotton // by most passes there's little utility in including this now. Plus we'd need to only do // this if there's no existing spans to work from. // pass_seqs.push_back(parser::AnnotateSpans()); Array entry_functions{"main"}; - // Can be undefined in case of heterogeneous execution - bool is_homogeneous = homogeneous_target.defined(); pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); pass_seqs.push_back(transform::ToBasicBlockNormalForm()); // Run all dialect legalization passes. diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 91b569ad0cfc..00c75921f2f2 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -676,11 +676,11 @@ inline IRModule PrimFuncToIRModule(tir::PrimFunc f) { * difference. This function unifies the shared optimization pass prefix between vm and graph * runtime, and returns the pass prefix given the backend type. * - * \param homogeneous_target Execution target (can be undefined in case of heterogeneous execution). + * \param is_homogeneous True if all primitives are to be executed on the same device and target. * \param is_vm True if passes are to be used for the vm executor. * \return An array of passes. */ -Array GetPassPrefix(Target homogeneous_target, bool is_vm); +Array GetPassPrefix(bool is_homogeneous, bool is_vm); /*! \brief Target hash function */ struct TargetStrHash { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 570d4b69e4b2..b807f4195947 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1054,7 +1054,7 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig& IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { backend::BindParamsInModule(mod, params_); Array pass_seqs = relay::backend::GetPassPrefix( - /*homogeneous target=*/config_->optional_homogeneous_target, /*is_vm=*/true); + /*is_homogeneous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true); // Always plan devices so the remaining passes don't need to distinguish homogeneous vs // heterogeneous execution.