From 483a1be059f2a539c5c3877328ee7148fd6fcde8 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Fri, 14 Oct 2022 11:17:31 +0300 Subject: [PATCH 1/2] [Relay][Hexagon] Add per-channel FixedPointMultiply operation Main goal of this commit is to improve performance for Hexagon target and preserve performance/accuracy for x86, GPU and etc. targets. "qnn.requantize" operation is lowered into the sequence of multiply, add, shift during QNN canonicalization pass if scale quantization parameter is the vector of scalars. This commit adds new Relay per-channel/per-axis FixedPointMultiply operation and is used in "qnn.requantize" operation lowering. per-channel/per-axis FixedPointMultiply is implemented through tir.q_multiply_shift_per_axis intrinsic. For Hexagon target it overrides default implementation and generates HVX vmpye/vmpyo instruction (see _q_multiply_shift_per_axis_hexagon). For all other targets it uses default implementation (64 bits arithmetic). Performance/accuracy measurement: CPU(x86) target: accuracy and performance are the same. For other targets should be the same (otherwise it is bug). Hexagon target: speedup of qnn.requantize 7x-9x times (Snapdragon 888, 3.08 ms -> 0.39 ms) --- include/tvm/relay/attrs/transform.h | 17 +++ python/tvm/relay/op/_tensor.py | 13 ++ python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 54 +++++++- python/tvm/topi/hexagon/tensor_intrin.py | 84 ++++++++++-- python/tvm/topi/math.py | 58 ++++++++ src/relay/op/make_op.h | 4 + src/relay/op/tensor/transform.cc | 129 ++++++++++++++++++ src/relay/qnn/op/requantize.cc | 12 +- src/relay/qnn/utils.cc | 31 +++++ src/relay/qnn/utils.h | 17 +++ src/relay/transforms/pattern_utils.h | 7 + src/runtime/crt/host/Makefile | 2 +- src/target/intrin_rule.cc | 84 ++++++++---- src/tir/op/builtin.cc | 5 + .../test_hexagon/test_fixed_point_multiply.py | 55 ++++++-- 16 files changed, 516 insertions(+), 58 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 2741d68eec14..274a421e5719 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -405,6 +405,23 @@ struct FixedPointMultiplyAttrs : public tvm::AttrsNode } }; +/*! \brief Attributes for per channel/per axes FixedPointMultiply operator */ +struct FixedPointMultiplyPerAxisAttrs : public tvm::AttrsNode { + bool is_lshift_required; + bool is_rshift_required; + Array axes; + + TVM_DECLARE_ATTRS(FixedPointMultiplyPerAxisAttrs, "relay.attrs.FixedPointMultiplyPerAxisAttrs") { + TVM_ATTR_FIELD(is_lshift_required) + .describe("Whether left shift is required in fixed point multiplication.") + .set_default(false); + TVM_ATTR_FIELD(is_rshift_required) + .describe("Whether right shift is required in fixed point multiplication.") + .set_default(false); + TVM_ATTR_FIELD(axes).describe("List of axes on which input data was quantized."); + } +}; + /*! \brief Attributes for LayoutTransform operator */ struct LayoutTransformAttrs : public tvm::AttrsNode { std::string src_layout; diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index a04199f6a5b1..cf318a025c36 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -170,6 +170,19 @@ def fixed_point_multiply_compute(attrs, inputs, output_type): register_injective_schedule("fixed_point_multiply") +# per-channel/per-axis fixed point multiply +@register_compute("fixed_point_multiply_per_axis") +def fixed_point_multiply_per_axis_compute(attrs, inputs, output_type): + assert len(inputs) == 4 + return [ + topi.fixed_point_multiply_per_axis( + *inputs, attrs.is_lshift_required, attrs.is_rshift_required, attrs.axes + ) + ] + + +register_broadcast_schedule("fixed_point_multiply_per_axis") + # full @script def _full_shape_func(shape): diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 2767f2d5f779..d02f7fab7a5c 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -72,7 +72,7 @@ from .op import likely, isnan, isnullptr, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv from .op import comm_reducer, min, max, sum -from .op import q_multiply_shift, shift_left, shift_right +from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .generic import add, subtract, multiply diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 1fd3050c0a7f..588b40ae4033 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -21,10 +21,10 @@ import tvm._ffi from tvm.ir.base import Span from tvm.runtime import convert, const -from tvm.ir import Array, Op +from tvm.ir import Array, Op, PrimExpr from .buffer import Buffer -from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer +from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer, IntImm from . import _ffi_api @@ -263,8 +263,6 @@ def call_llvm_intrin(dtype, name, *args, span=None): # pylint: disable=import-outside-toplevel from tvm.target import codegen - from .expr import IntImm - if isinstance(name, str): llvm_id = codegen.llvm_lookup_intrinsic_id(name) elif isinstance(name, IntImm): @@ -307,8 +305,6 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): # pylint: disable=import-outside-toplevel from tvm.target import codegen - from .expr import IntImm - if isinstance(name, str): llvm_id = codegen.llvm_lookup_intrinsic_id(name) elif isinstance(name, IntImm): @@ -2238,6 +2234,52 @@ def q_multiply_shift(x, y, q, s): return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s) +def q_multiply_shift_per_axis( + x: PrimExpr, + y: PrimExpr, + ls: PrimExpr, + rs: PrimExpr, + q: IntImm, + is_lshift_required: IntImm, + is_rshift_required: IntImm, +): + """Execute a multiplication between two Q-numbers x and y + + Parameters + ---------- + x : PrimExpr + First Q-number. + y : PrimExpr + Second Q-number. + ls : PrimExpr + Integer left shift. + rs : PrimExpr + Integer right shift. + q : IntImm + Number of fractional bits in x and y. Needs to be > 0. + is_lshift_required : IntImm + Whether we need to do left shift or not. + is_rshift_required : IntImm + Whether we need to do right shift or not. + + Returns + ------- + z : PrimExpr + The result. + """ + return call_intrin( + "int32", + "tir.q_multiply_shift_per_axis", + x, + y, + ls, + rs, + q, + is_lshift_required, + is_rshift_required, + ) + + def shift_left(x, y, span=None): """Return the result of x left shifted by y bits. diff --git a/python/tvm/topi/hexagon/tensor_intrin.py b/python/tvm/topi/hexagon/tensor_intrin.py index adea4690d4a7..3e9fd47b0fc6 100644 --- a/python/tvm/topi/hexagon/tensor_intrin.py +++ b/python/tvm/topi/hexagon/tensor_intrin.py @@ -25,12 +25,6 @@ def _q_multiply_shift_hexagon(op): """ Implementation of q_multiply_shift through hexagon intrinsics vmpyewuh and vmpyowh when q == 31. - - Please note that this is introducing a small round-up error for some corner cases with negative - shift argument. This is because we are rounding twice instead than only once. I.e.: - - * original q_multiply_shift: round(x*y*2^-s) - * hexagon q_multiply_shift: round(round(x*y)*2^-s) """ x = op.args[0] y = op.args[1] @@ -47,9 +41,9 @@ def _q_multiply_shift_hexagon(op): op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y ) mul_o_1 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_1, x, y + op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_1, x, y ) - fixup = mul_o_1 & (-shift) + fixup = 1 << (-shift - 1) round_mul = mul_o_1 + fixup out_negative_shift = tvm.tir.call_llvm_intrin( op.dtype, "llvm.hexagon.V6.vaslwv.128B", tvm.tir.const(2, "uint32"), round_mul, shift @@ -73,6 +67,80 @@ def _q_multiply_shift_hexagon(op): ) +def _q_multiply_shift_per_axis_hexagon(op): + """ + Implementation of q_multiply_shift_per_axis through hexagon intrinsics vmpyewuh and vmpyowh when + q == 31. + """ + x = op.args[0] + y = op.args[1] + left_shift = op.args[2] + right_shift = op.args[3] + fractional_bits = op.args[4] + is_lshift_required = op.args[5] + is_rshift_required = op.args[6] + + # Don't use this intrinsic if we don't have a int32x32 vector + # or if we are not multiplying q31 numbers + if x.dtype != "int32x32" or fractional_bits.value != 31: + return op + + # Don't use this intrinsic when we need do both: left and right shifts. + # For now it is not clear how to implement this case through vector HVX instructions without + # accuracy drop. + if is_rshift_required.value and is_lshift_required.value: + return op + + # Case 1: do the left shift + shifted_x = x << left_shift + mul_e_1 = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), shifted_x, y + ) + left_shift_out = tvm.tir.call_llvm_intrin( + op.dtype, + "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", + tvm.tir.const(3, "uint32"), + mul_e_1, + shifted_x, + y, + ) + + # Case 2: do the right shift + mul_e_2 = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + mul_o_2 = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_2, x, y + ) + fixup = 1 << (right_shift - 1) + round_mul = mul_o_2 + fixup + right_shift_out = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vasrwv.128B", tvm.tir.const(2, "uint32"), round_mul, right_shift + ) + + # Case 3: do neither right nor left shift + mul_e_3 = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + no_shift_out = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_3, x, y + ) + + return tvm.tir.Select( + tvm.tir.Not(tvm.tir.Or(is_lshift_required, is_rshift_required)), + no_shift_out, + tvm.tir.Select(is_lshift_required, left_shift_out, right_shift_out), + ) + + +register_intrin_lowering( + "tir.q_multiply_shift_per_axis", + target="hexagon", + f=_q_multiply_shift_per_axis_hexagon, + level=99, +) + + def dot_vrmpy(x_ty, y_ty): """Generates vrmpy instruciton for tensorization.""" int32_lanes = 32 diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index 9823024ea0bf..dd191c49be28 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -20,6 +20,7 @@ from tvm import te from . import tag from . import cpp +from .utils import get_const_tuple @tvm.te.tag_scope(tag=tag.ELEMWISE) @@ -672,6 +673,63 @@ def _compute(*indices): return te.compute(x.shape, _compute) +@tvm.te.tag_scope(tag=tag.BROADCAST) +def fixed_point_multiply_per_axis( + x: te.Tensor, + y: te.Tensor, + lshift: te.Tensor, + rshift: te.Tensor, + is_lshift_required: int, + is_rshift_required: int, + axes, +): + """Fixed point multiplication between data and a fixed point constant expressed as + multiplier * 2^(-shift), where multiplier is a Q-number with 31 fractional bits + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + y : tvm.te.Tensor + Multiplier of a fixed floating point number described as multiplier*2^(-shift). + lshift : tvm.te.Tensor + Left shifts of a fixed floating point number described as multiplier*2^(-shift). + rshift : tvm.te.Tensor + Right shifts of a fixed floating point number described as multiplier*2^(-shift). + is_lshift_required : int + Whether we need to do left shift or not. + is_rshift_required : int + Whether we need to do right shift or not. + + Returns + ------- + z : tvm.te.Tensor + The result. + """ + + def _compute(*indices): + elements = [] + for element in get_const_tuple(axes): + elements += [indices[element]] + param_indices = tuple(elements) + + value = x(*indices) + m = y(*param_indices) + l_shift = lshift(*param_indices) + r_shift = rshift(*param_indices) + return tvm.tir.q_multiply_shift_per_axis( + value, + m, + l_shift, + r_shift, + tvm.tir.const(31, "int32"), + tvm.tir.const(is_lshift_required, "bool"), + tvm.tir.const(is_rshift_required, "bool"), + ) + + return te.compute(x.shape, _compute) + + def cast(x, dtype, span=None): """Cast input to specified data type. diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 85938a739182..50d8531c7dd0 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -54,6 +54,10 @@ Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype, bool transpose_a, b Expr MakeExpandDims(Expr data, int axis, int num_newaxis); +Expr MakeFixedPointMultiplyPerAxis(Expr x, Expr m, Expr lshift, Expr rshift, + bool is_lshift_required, bool is_rshift_required, + Array axis); + Expr MakeFull(Expr fill_value, Array shape, DataType dtype); Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 985222307ad9..5f063a290740 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -4302,5 +4302,134 @@ RELAY_REGISTER_OP("trilu") .set_support_level(3) .set_attr("TOpPattern", kElemWise); +// FixedPointMultiplyPerAxis + +TVM_REGISTER_NODE_TYPE(FixedPointMultiplyPerAxisAttrs); + +bool FixedPointMultiplyPerAxisRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 5) << "FixedPointMultiplyPerAxis: expect 5 types but " << types.size() + << " provided"; + ICHECK_EQ(num_inputs, 4) << "FixedPointMultiplyPerAxis: expect 4 inputs but " << num_inputs + << " provided"; + + for (int i = 0; i < num_inputs; i++) { + auto data = types[i].as(); + if (data == nullptr) { + ICHECK(types[i].as()) + << "FixedPointMultiplyPerAxis: expect input type to be TensorType but get " << types[i]; + return false; + } + } + + return IdentityRel({types[0], types[4]}, 1, attrs, reporter); +} + +InferCorrectLayoutOutput FixedPointMultiplyPerAxisInferCorrectLayout( + const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, + const Array& old_in_types) { + const auto* attrs_ptr = attrs.as(); + ICHECK(attrs_ptr); + ObjectPtr param = + make_object(*attrs_ptr); + + Array> old_in_shapes; + for (auto old_in_t : old_in_types) { + ICHECK(old_in_t.as()); + old_in_shapes.push_back(old_in_t.as()->shape); + } + + Array input_layouts, output_layouts; + + if (new_in_layouts.defined()) { + const Layout& new_layout = new_in_layouts[0]; + const Layout& old_layout = old_in_layouts[0]; + + std::unordered_set old_dims; + for (auto axis : param->axes) { + ICHECK_GE(axis->value, 0) << "Axis out of bounds in FixedPointMultiplyPerAxis operator."; + ICHECK_LT(axis->value, old_in_shapes[0].size()) + << "Axis out of bounds in FixedPointMultiplyPerAxis operator."; + old_dims.emplace(old_layout[axis->value].name()); + } + + Array new_axes; + std::string new_layout_string = ""; + for (size_t axis_index = 0; axis_index < new_layout->axes.size(); ++axis_index) { + const auto& layout_axis = LayoutAxis::Get(new_layout->axes[axis_index]); + const std::string& layout_dim = layout_axis.name(); + if (layout_axis.IsPrimal()) { + if (old_dims.count(layout_dim)) { + new_axes.push_back(tvm::Integer(axis_index)); + new_layout_string += layout_dim; + } + } else { + auto primal_dim = layout_axis.ToPrimal().name(); + if (old_dims.count(primal_dim)) { + new_axes.push_back(tvm::Integer(axis_index)); + new_layout_string += std::to_string(new_layout.FactorOf(layout_axis)) + layout_dim; + } + } + } + + Layout channel_layout = Layout(new_layout_string); + + input_layouts = {new_layout, channel_layout, channel_layout, channel_layout}; + output_layouts = {new_layout}; + param->axes = std::move(new_axes); + } else if (old_in_layouts.defined()) { + ICHECK_EQ(old_in_layouts.size(), 4); + ICHECK_EQ(param->axes.size(), 1); // Not tested other cases + const Layout& old_layout = old_in_layouts[0]; + if (old_layout.defined()) { + std::string layout_string = old_layout[param->axes[0]->value].name(); + Layout channel_layout = Layout(layout_string); + + input_layouts = {old_layout, channel_layout, channel_layout, channel_layout}; + output_layouts = {old_layout}; + } else { + // Set the layouts to undef. + Layout undef = Layout::Undef(); + input_layouts = Array(4, undef); + output_layouts = {undef}; + } + } else { + // Set the layouts to undef. + Layout undef = Layout::Undef(); + input_layouts = Array(4, undef); + output_layouts = {undef}; + } + + return InferCorrectLayoutOutput(input_layouts, output_layouts, Attrs(param)); +} + +Expr MakeFixedPointMultiplyPerAxis(Expr x, Expr m, Expr lshift, Expr rshift, + bool is_lshift_required, bool is_rshift_required, + Array axes) { + auto attrs = make_object(); + attrs->is_lshift_required = is_lshift_required; + attrs->is_rshift_required = is_rshift_required; + attrs->axes = std::move(axes); + static const Op& op = Op::Get("fixed_point_multiply_per_axis"); + return Call(op, {x, m, lshift, rshift}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.fixed_point_multiply_per_axis") + .set_body_typed(MakeFixedPointMultiplyPerAxis); + +RELAY_REGISTER_OP("fixed_point_multiply_per_axis") + .describe(R"code(per channel fixed point multiplication)code" TVM_ADD_FILELINE) + .set_num_inputs(4) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("fp_multiplier", "Tensor", "The multipliers tensor.") + .add_argument("left_shift", "Tensor", "The left shifts tensor.") + .add_argument("right_shift", "Tensor", "The right shifts tensor.") + .add_type_rel("FixedPointMultiplyPerAxis", FixedPointMultiplyPerAxisRel) + .set_attr("TOpPattern", kBroadcast) + .set_attr("FInferCorrectLayout", + FixedPointMultiplyPerAxisInferCorrectLayout) + .set_attrs_type() + .set_support_level(10); + } // namespace relay } // namespace tvm diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index ae321b459788..1614652719c6 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -214,6 +214,7 @@ Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale, // if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for // the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input // tensor. Depending on the quantization type, the fixed point multiplication routing is called. + const bool is_upward_rounding = (param->rounding == "UPWARD"); auto scaled_int32_t = tensor; float output_scale_float = GetScalarFromConstant(output_scale); if (IsConstScalar(input_scale)) { @@ -225,8 +226,6 @@ Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale, if (!IsEqualScalar(input_scale, output_scale)) { auto [fixed_point_multiplier, shift] = GetFixedPointMultiplierShift(double_multiplier); - const bool is_upward_rounding = (param->rounding == "UPWARD"); - // When using upward rounding (i.e., x.5 rounded to x+1), leverage // the FixedPointMultiply operator scaled_int32_t = @@ -246,8 +245,13 @@ Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale, } int axis = param->axis; axis = (axis == -1) ? input_shape.size() - 1 : axis; - scaled_int32_t = FixedPointMultiplyPerChannel(scaled_int32_t, double_multipliers, input_shape, - axis, param->rounding); + + // When using "upward" rounding, leverage the FixedPointMultiplyPerAxis operator, + // for "tonearest" rounding - lower to multiply, add, shift operators sequence. + scaled_int32_t = is_upward_rounding + ? FixedPointMultiplyPerChannel(scaled_int32_t, double_multipliers, axis) + : FixedPointMultiplyPerChannelToNearest(scaled_int32_t, double_multipliers, + input_shape, axis); } // 3) Add the output zero point. diff --git a/src/relay/qnn/utils.cc b/src/relay/qnn/utils.cc index ed7a415cf6af..ab72bd957080 100644 --- a/src/relay/qnn/utils.cc +++ b/src/relay/qnn/utils.cc @@ -108,6 +108,32 @@ Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, return Cast(tensor, DataType::Int(32)); } +Expr FixedPointMultiplyPerChannel(Expr tensor, const std::vector& multipliers, int axis) { + DataType dtype = DataType::Int(32); + int64_t n_channels = static_cast(multipliers.size()); + + std::vector fixed_pt_multipliers, lshifts, rshifts; + bool is_lshift_required = false, is_rshift_required = false; + for (auto multiplier : multipliers) { + auto [fixed_pt_multiplier, shift] = GetFixedPointMultiplierShift(multiplier); + int lshift = shift > 0 ? shift : 0; + int rshift = shift > 0 ? 0 : -shift; + fixed_pt_multipliers.push_back(fixed_pt_multiplier); + lshifts.push_back(lshift); + rshifts.push_back(rshift); + is_lshift_required = is_lshift_required | (lshift != 0); + is_rshift_required = is_rshift_required | (rshift != 0); + } + + auto left_shift_expr = MakeConstantTensor(dtype, {n_channels}, lshifts); + auto right_shift_expr = MakeConstantTensor(dtype, {n_channels}, rshifts); + auto fixed_pt_multiplier_expr = MakeConstantTensor(dtype, {n_channels}, fixed_pt_multipliers); + + return FixedPointMultiplyPerAxis(tensor, fixed_pt_multiplier_expr, left_shift_expr, + right_shift_expr, is_lshift_required, is_rshift_required, + {axis}); +} + Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, const Array& input_shape, int channel_axis, const std::string& rounding) { @@ -197,6 +223,11 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, return Cast(tensor, DataType::Int(32)); } +Expr FixedPointMultiplyPerChannelToNearest(Expr tensor, std::vector multipliers, + const Array& input_shape, int channel_axis) { + return FixedPointMultiplyPerChannel(tensor, multipliers, input_shape, channel_axis, "TONEAREST"); +} + std::string SelectRequntizeParameter(const std::string& arg_value, const std::string& cfg_value, const bool is_cfg_default, const std::string& name) { if (arg_value == "None") { diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index d084e4871e95..87195eb34d94 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -212,6 +212,23 @@ Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multiplier, const Array& input_shape, int channel_axis, const std::string& rounding); + +/* + * Wrapper for 'FixedPointMultiplyPerChannel' with rounding parameter == "TONEAREST". + */ +Expr FixedPointMultiplyPerChannelToNearest(Expr tensor, std::vector multiplier, + const Array& input_shape, int channel_axis); + +/* + * \brief Creates FixedPointMultiply operation where the input tensor is + per-axis/per-channel quantized.. + * \param tensor The quantized input tensor. + * \param multipliers List of scalar multipliers. + * \param channel_axis The channel_axis along which the input tensor is quantized. + * \return The Relay op. + */ +Expr FixedPointMultiplyPerChannel(Expr tensor, const std::vector& multipliers, int axis); + /* * \brief Checks whether an expr type is scalar of a given data type. * \param expr_type The type of expr to be checked. diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index ffe1cc2ca2ab..d03939e09ea8 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -661,6 +661,13 @@ inline Expr FixedPointMultiply(Expr x, int32_t multiplier, int32_t shift) { return Call(op, {x}, Attrs(attrs), {}); } +inline Expr FixedPointMultiplyPerAxis(Expr x, Expr m, Expr lshift, Expr rshift, + bool is_lshift_required, bool is_rshift_required, + Array axes) { + return MakeFixedPointMultiplyPerAxis(x, m, lshift, rshift, is_lshift_required, is_rshift_required, + axes); +} + inline Expr Add(Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); return Call(op, {lhs, rhs}, Attrs(), {}); diff --git a/src/runtime/crt/host/Makefile b/src/runtime/crt/host/Makefile index d9e87c7d6a41..ea2966045bb2 100644 --- a/src/runtime/crt/host/Makefile +++ b/src/runtime/crt/host/Makefile @@ -21,7 +21,7 @@ CXXFLAGS ?= -Werror -Wall -std=c++11 -DTVM_HOST_USE_GRAPH_EXECUTOR_MODULE LDFLAGS ?= -Werror -Wall # Codegen produces spurious lines like: int32_t arg2_code = ((int32_t*)arg_type_ids)[(2)]; -MODEL_CFLAGS ?= -Wno-error=unused-variable -Wno-error=missing-braces +MODEL_CFLAGS ?= -Wno-error=unused-variable -Wno-error=missing-braces -Wno-error=unused-const-variable AR ?= ${PREFIX}ar CC ?= ${PREFIX}gcc diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index e697d9b60273..f18b63714418 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -151,6 +151,46 @@ TVM_REGISTER_OP("tir.isinf") return isinf(call->args[0]); }); +/*! + * \brief Makes fixed point multiplication. + * \param x Input tensor. + * \param y Integer multiplier. + * \param left_shift Integer left shift. + * \param right_shift Integer right shift. + * \param is_left_shift_required Flag whether we need to do left shift or not. + * \return Calculated expression. + */ +static PrimExpr QMultiplyShift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr left_shift, + PrimExpr right_shift, PrimExpr is_left_shift_required) { + // Only int32 types are supported (any number of lanes is allowed) + ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); + ICHECK(left_shift.dtype().code() == DLDataTypeCode::kDLInt && left_shift.dtype().bits() == 32); + ICHECK(right_shift.dtype().code() == DLDataTypeCode::kDLInt && right_shift.dtype().bits() == 32); + + DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); + DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); + + // 1) Cast and Multiply the integer multiplier + PrimExpr one = make_const(hp_dtype, 1); + x = cast(hp_dtype, x); + y = cast(hp_dtype, y); + x = tir::Select(is_left_shift_required, x << left_shift, x); + + // 2) Perform the multiplication in higher precision. + x = x * y; + + // 3) Find the rounding scalar + PrimExpr total_right_shift = right_shift + q; + PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); + x = x + pos_rounding_value; + + // 4) Simply right shift the result to get the final output. + x = x >> total_right_shift; + + // 5) The fixed point multiplication keeps the value in int32 range. Casting back to int32. + return cast(lp_dtype, x); +} + TVM_REGISTER_OP("tir.q_multiply_shift") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; @@ -194,40 +234,34 @@ TVM_REGISTER_OP("tir.q_multiply_shift") } } else { // Only int32 types are supported (any number of lanes is allowed) - ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); - DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); - DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); - - // 1) Calculating the integer multiplier and integer shift + // Calculating integer shifts PrimExpr zero = make_const(s.dtype(), 0); PrimExpr left_shift = tir::Select(s > zero, s, zero); PrimExpr right_shift = tir::Select(s > zero, zero, -s); + PrimExpr is_left_shift_required = (left_shift != zero); - // 2) Cast and Multiply the integer multiplier - PrimExpr one = make_const(hp_dtype, 1); - x = cast(hp_dtype, x); - y = cast(hp_dtype, y); - x = tir::Select(left_shift != zero, x << left_shift, x); - - // 3) Perform the multiplication in higher precision. - x = x * y; - - // 4) Find the rounding scalar - PrimExpr total_right_shift = right_shift + q; - PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); - x = x + pos_rounding_value; - - // 5) Simply right shift the result to get the final output. - x = x >> total_right_shift; - - // 6) The fixed point multiplication keeps the value in int32 range. Casting back to - // int32. - return cast(lp_dtype, x); + return QMultiplyShift(x, y, q, left_shift, right_shift, is_left_shift_required); } }); +TVM_REGISTER_OP("tir.q_multiply_shift_per_axis") + .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + + PrimExpr x = call->args[0]; + PrimExpr y = call->args[1]; + PrimExpr left_shift = call->args[2]; + PrimExpr right_shift = call->args[3]; + PrimExpr q = call->args[4]; + PrimExpr is_lshift_required = call->args[5]; + // Note, 7th argument is "is_rshift_required" flag, but we do need that here. + // PrimExpr is_rshift_required = call->args[6]; + + return QMultiplyShift(x, y, q, left_shift, right_shift, is_lshift_required); + }); } // namespace legalize } // namespace codegen } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 1e2d790c76e1..929626bd7d53 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -98,6 +98,11 @@ TIR_DEFINE_BUILTIN_FUNC(q_multiply_shift) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TVectorizable", true); +TIR_DEFINE_BUILTIN_FUNC(q_multiply_shift_per_axis) + .set_num_inputs(7) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TVectorizable", true); + TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py index ee03599ff1f4..e7e4aa212e35 100644 --- a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py +++ b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py @@ -78,11 +78,25 @@ def run_module(graph_mod, inputs): return output +in_scale_const, out_scale_const = tvm.testing.parameters( + (1.3, 30.0), + (1.37, 1.0), + (0.6, 1.0), + ((1.7, 0.6), 1.0), + ((0.007, 1.9), 1.0), +) + +multiplier, shift = tvm.testing.parameters( + (1288490240, -2), # 0.15 + (1395864320, 1), # 1.3 + (1288490188, 0), # 0.6 +) + + @tvm.testing.requires_hexagon -def test_fixed_point_multiply_positive_shift(hexagon_session: Session): +def test_fixed_point_multiply(hexagon_session: Session, multiplier: int, shift: int): ishape = (6, 32) a = relay.var("a", relay.TensorType(ishape, "int32")) - multiplier, shift = (1395864320, 1) # 1.3 fpm = relay.fixed_point_multiply(a, multiplier, shift) relay_mod = tvm.IRModule.from_expr(fpm) @@ -108,22 +122,37 @@ def test_fixed_point_multiply_positive_shift(hexagon_session: Session): @tvm.testing.requires_hexagon -def test_fixed_point_multiply_negative_shift(hexagon_session: Session): - ishape = (6, 32) - a = relay.var("a", relay.TensorType(ishape, "int32")) - multiplier, shift = (1288490240, -2) # 0.15 - fpm = relay.fixed_point_multiply(a, multiplier, shift) - relay_mod = tvm.IRModule.from_expr(fpm) +def test_per_channel_fixed_point_multiply( + hexagon_session: Session, in_scale_const, out_scale_const +): + ishape = [1, 128, 56, 56] + axis = 1 + a = relay.var("a", shape=ishape, dtype="int32") + + # Make list of input scales from in_scale_const parameter. + if isinstance(in_scale_const, tuple): + in_scale = list(in_scale_const) * (ishape[axis] // len(in_scale_const)) + else: + in_scale = [in_scale_const] * ishape[axis] + assert len(in_scale) == ishape[axis] + + # qnn.requantize is lowered to fixed_point_multiply if zp == 0 and in_dtype == out_dtype. + iscale = relay.const(in_scale) + izero = relay.const(0) + oscale = relay.const(out_scale_const) + ozero = relay.const(0) + op = relay.qnn.op.requantize(a, iscale, izero, oscale, ozero, axis=axis, out_dtype="int32") + mod = tvm.IRModule.from_expr(op) with tvm.transform.PassContext(opt_level=3): # Compile for Hexagon... - hexagon_lowered = build_module(relay_mod, tvm.target.hexagon("v68")) + hexagon_lowered = build_module(mod, tvm.target.hexagon("v68")) # Compile for LLVM... - llvm_lowered = build_module(relay_mod, tvm.target.Target("llvm")) + llvm_lowered = build_module(mod, tvm.target.Target("llvm")) - data_in = np.arange(-96, 96).reshape(ishape) - inputs = {"a": data_in} + a_np = np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape) + inputs = {"a": a_np} # Run hexagon... graph_mod = hexagon_session.get_executor_from_factory(hexagon_lowered) @@ -133,7 +162,7 @@ def test_fixed_point_multiply_negative_shift(hexagon_session: Session): llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) expected_output = run_module(llvm_graph_mod, inputs) - tvm.testing.assert_allclose(hexagon_output, expected_output, atol=1) + tvm.testing.assert_allclose(hexagon_output, expected_output) if __name__ == "__main__": From 62309a6d8aa7aefd6499787036f950cce5587726 Mon Sep 17 00:00:00 2001 From: ibsidorenko Date: Thu, 27 Oct 2022 09:44:38 +0300 Subject: [PATCH 2/2] Address code review comments --- src/target/intrin_rule.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index f18b63714418..86c50f2609be 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -257,7 +257,7 @@ TVM_REGISTER_OP("tir.q_multiply_shift_per_axis") PrimExpr right_shift = call->args[3]; PrimExpr q = call->args[4]; PrimExpr is_lshift_required = call->args[5]; - // Note, 7th argument is "is_rshift_required" flag, but we do need that here. + // Note, 7th argument is "is_rshift_required" flag, but we don't need that here. // PrimExpr is_rshift_required = call->args[6]; return QMultiplyShift(x, y, q, left_shift, right_shift, is_lshift_required);