diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 2741d68eec14..274a421e5719 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -405,6 +405,23 @@ struct FixedPointMultiplyAttrs : public tvm::AttrsNode } }; +/*! \brief Attributes for per channel/per axes FixedPointMultiply operator */ +struct FixedPointMultiplyPerAxisAttrs : public tvm::AttrsNode { + bool is_lshift_required; + bool is_rshift_required; + Array axes; + + TVM_DECLARE_ATTRS(FixedPointMultiplyPerAxisAttrs, "relay.attrs.FixedPointMultiplyPerAxisAttrs") { + TVM_ATTR_FIELD(is_lshift_required) + .describe("Whether left shift is required in fixed point multiplication.") + .set_default(false); + TVM_ATTR_FIELD(is_rshift_required) + .describe("Whether right shift is required in fixed point multiplication.") + .set_default(false); + TVM_ATTR_FIELD(axes).describe("List of axes on which input data was quantized."); + } +}; + /*! \brief Attributes for LayoutTransform operator */ struct LayoutTransformAttrs : public tvm::AttrsNode { std::string src_layout; diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index a04199f6a5b1..cf318a025c36 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -170,6 +170,19 @@ def fixed_point_multiply_compute(attrs, inputs, output_type): register_injective_schedule("fixed_point_multiply") +# per-channel/per-axis fixed point multiply +@register_compute("fixed_point_multiply_per_axis") +def fixed_point_multiply_per_axis_compute(attrs, inputs, output_type): + assert len(inputs) == 4 + return [ + topi.fixed_point_multiply_per_axis( + *inputs, attrs.is_lshift_required, attrs.is_rshift_required, attrs.axes + ) + ] + + +register_broadcast_schedule("fixed_point_multiply_per_axis") + # full @script def _full_shape_func(shape): diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 2767f2d5f779..d02f7fab7a5c 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -72,7 +72,7 @@ from .op import likely, isnan, isnullptr, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv from .op import comm_reducer, min, max, sum -from .op import q_multiply_shift, shift_left, shift_right +from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .generic import add, subtract, multiply diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 1fd3050c0a7f..588b40ae4033 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -21,10 +21,10 @@ import tvm._ffi from tvm.ir.base import Span from tvm.runtime import convert, const -from tvm.ir import Array, Op +from tvm.ir import Array, Op, PrimExpr from .buffer import Buffer -from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer +from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer, IntImm from . import _ffi_api @@ -263,8 +263,6 @@ def call_llvm_intrin(dtype, name, *args, span=None): # pylint: disable=import-outside-toplevel from tvm.target import codegen - from .expr import IntImm - if isinstance(name, str): llvm_id = codegen.llvm_lookup_intrinsic_id(name) elif isinstance(name, IntImm): @@ -307,8 +305,6 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): # pylint: disable=import-outside-toplevel from tvm.target import codegen - from .expr import IntImm - if isinstance(name, str): llvm_id = codegen.llvm_lookup_intrinsic_id(name) elif isinstance(name, IntImm): @@ -2238,6 +2234,52 @@ def q_multiply_shift(x, y, q, s): return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s) +def q_multiply_shift_per_axis( + x: PrimExpr, + y: PrimExpr, + ls: PrimExpr, + rs: PrimExpr, + q: IntImm, + is_lshift_required: IntImm, + is_rshift_required: IntImm, +): + """Execute a multiplication between two Q-numbers x and y + + Parameters + ---------- + x : PrimExpr + First Q-number. + y : PrimExpr + Second Q-number. + ls : PrimExpr + Integer left shift. + rs : PrimExpr + Integer right shift. + q : IntImm + Number of fractional bits in x and y. Needs to be > 0. + is_lshift_required : IntImm + Whether we need to do left shift or not. + is_rshift_required : IntImm + Whether we need to do right shift or not. + + Returns + ------- + z : PrimExpr + The result. + """ + return call_intrin( + "int32", + "tir.q_multiply_shift_per_axis", + x, + y, + ls, + rs, + q, + is_lshift_required, + is_rshift_required, + ) + + def shift_left(x, y, span=None): """Return the result of x left shifted by y bits. diff --git a/python/tvm/topi/hexagon/tensor_intrin.py b/python/tvm/topi/hexagon/tensor_intrin.py index adea4690d4a7..3e9fd47b0fc6 100644 --- a/python/tvm/topi/hexagon/tensor_intrin.py +++ b/python/tvm/topi/hexagon/tensor_intrin.py @@ -25,12 +25,6 @@ def _q_multiply_shift_hexagon(op): """ Implementation of q_multiply_shift through hexagon intrinsics vmpyewuh and vmpyowh when q == 31. - - Please note that this is introducing a small round-up error for some corner cases with negative - shift argument. This is because we are rounding twice instead than only once. I.e.: - - * original q_multiply_shift: round(x*y*2^-s) - * hexagon q_multiply_shift: round(round(x*y)*2^-s) """ x = op.args[0] y = op.args[1] @@ -47,9 +41,9 @@ def _q_multiply_shift_hexagon(op): op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y ) mul_o_1 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_1, x, y + op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_1, x, y ) - fixup = mul_o_1 & (-shift) + fixup = 1 << (-shift - 1) round_mul = mul_o_1 + fixup out_negative_shift = tvm.tir.call_llvm_intrin( op.dtype, "llvm.hexagon.V6.vaslwv.128B", tvm.tir.const(2, "uint32"), round_mul, shift @@ -73,6 +67,80 @@ def _q_multiply_shift_hexagon(op): ) +def _q_multiply_shift_per_axis_hexagon(op): + """ + Implementation of q_multiply_shift_per_axis through hexagon intrinsics vmpyewuh and vmpyowh when + q == 31. + """ + x = op.args[0] + y = op.args[1] + left_shift = op.args[2] + right_shift = op.args[3] + fractional_bits = op.args[4] + is_lshift_required = op.args[5] + is_rshift_required = op.args[6] + + # Don't use this intrinsic if we don't have a int32x32 vector + # or if we are not multiplying q31 numbers + if x.dtype != "int32x32" or fractional_bits.value != 31: + return op + + # Don't use this intrinsic when we need do both: left and right shifts. + # For now it is not clear how to implement this case through vector HVX instructions without + # accuracy drop. + if is_rshift_required.value and is_lshift_required.value: + return op + + # Case 1: do the left shift + shifted_x = x << left_shift + mul_e_1 = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), shifted_x, y + ) + left_shift_out = tvm.tir.call_llvm_intrin( + op.dtype, + "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", + tvm.tir.const(3, "uint32"), + mul_e_1, + shifted_x, + y, + ) + + # Case 2: do the right shift + mul_e_2 = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + mul_o_2 = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_2, x, y + ) + fixup = 1 << (right_shift - 1) + round_mul = mul_o_2 + fixup + right_shift_out = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vasrwv.128B", tvm.tir.const(2, "uint32"), round_mul, right_shift + ) + + # Case 3: do neither right nor left shift + mul_e_3 = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + no_shift_out = tvm.tir.call_llvm_intrin( + op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_3, x, y + ) + + return tvm.tir.Select( + tvm.tir.Not(tvm.tir.Or(is_lshift_required, is_rshift_required)), + no_shift_out, + tvm.tir.Select(is_lshift_required, left_shift_out, right_shift_out), + ) + + +register_intrin_lowering( + "tir.q_multiply_shift_per_axis", + target="hexagon", + f=_q_multiply_shift_per_axis_hexagon, + level=99, +) + + def dot_vrmpy(x_ty, y_ty): """Generates vrmpy instruciton for tensorization.""" int32_lanes = 32 diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index 9823024ea0bf..dd191c49be28 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -20,6 +20,7 @@ from tvm import te from . import tag from . import cpp +from .utils import get_const_tuple @tvm.te.tag_scope(tag=tag.ELEMWISE) @@ -672,6 +673,63 @@ def _compute(*indices): return te.compute(x.shape, _compute) +@tvm.te.tag_scope(tag=tag.BROADCAST) +def fixed_point_multiply_per_axis( + x: te.Tensor, + y: te.Tensor, + lshift: te.Tensor, + rshift: te.Tensor, + is_lshift_required: int, + is_rshift_required: int, + axes, +): + """Fixed point multiplication between data and a fixed point constant expressed as + multiplier * 2^(-shift), where multiplier is a Q-number with 31 fractional bits + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + y : tvm.te.Tensor + Multiplier of a fixed floating point number described as multiplier*2^(-shift). + lshift : tvm.te.Tensor + Left shifts of a fixed floating point number described as multiplier*2^(-shift). + rshift : tvm.te.Tensor + Right shifts of a fixed floating point number described as multiplier*2^(-shift). + is_lshift_required : int + Whether we need to do left shift or not. + is_rshift_required : int + Whether we need to do right shift or not. + + Returns + ------- + z : tvm.te.Tensor + The result. + """ + + def _compute(*indices): + elements = [] + for element in get_const_tuple(axes): + elements += [indices[element]] + param_indices = tuple(elements) + + value = x(*indices) + m = y(*param_indices) + l_shift = lshift(*param_indices) + r_shift = rshift(*param_indices) + return tvm.tir.q_multiply_shift_per_axis( + value, + m, + l_shift, + r_shift, + tvm.tir.const(31, "int32"), + tvm.tir.const(is_lshift_required, "bool"), + tvm.tir.const(is_rshift_required, "bool"), + ) + + return te.compute(x.shape, _compute) + + def cast(x, dtype, span=None): """Cast input to specified data type. diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 85938a739182..50d8531c7dd0 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -54,6 +54,10 @@ Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype, bool transpose_a, b Expr MakeExpandDims(Expr data, int axis, int num_newaxis); +Expr MakeFixedPointMultiplyPerAxis(Expr x, Expr m, Expr lshift, Expr rshift, + bool is_lshift_required, bool is_rshift_required, + Array axis); + Expr MakeFull(Expr fill_value, Array shape, DataType dtype); Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 985222307ad9..5f063a290740 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -4302,5 +4302,134 @@ RELAY_REGISTER_OP("trilu") .set_support_level(3) .set_attr("TOpPattern", kElemWise); +// FixedPointMultiplyPerAxis + +TVM_REGISTER_NODE_TYPE(FixedPointMultiplyPerAxisAttrs); + +bool FixedPointMultiplyPerAxisRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 5) << "FixedPointMultiplyPerAxis: expect 5 types but " << types.size() + << " provided"; + ICHECK_EQ(num_inputs, 4) << "FixedPointMultiplyPerAxis: expect 4 inputs but " << num_inputs + << " provided"; + + for (int i = 0; i < num_inputs; i++) { + auto data = types[i].as(); + if (data == nullptr) { + ICHECK(types[i].as()) + << "FixedPointMultiplyPerAxis: expect input type to be TensorType but get " << types[i]; + return false; + } + } + + return IdentityRel({types[0], types[4]}, 1, attrs, reporter); +} + +InferCorrectLayoutOutput FixedPointMultiplyPerAxisInferCorrectLayout( + const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, + const Array& old_in_types) { + const auto* attrs_ptr = attrs.as(); + ICHECK(attrs_ptr); + ObjectPtr param = + make_object(*attrs_ptr); + + Array> old_in_shapes; + for (auto old_in_t : old_in_types) { + ICHECK(old_in_t.as()); + old_in_shapes.push_back(old_in_t.as()->shape); + } + + Array input_layouts, output_layouts; + + if (new_in_layouts.defined()) { + const Layout& new_layout = new_in_layouts[0]; + const Layout& old_layout = old_in_layouts[0]; + + std::unordered_set old_dims; + for (auto axis : param->axes) { + ICHECK_GE(axis->value, 0) << "Axis out of bounds in FixedPointMultiplyPerAxis operator."; + ICHECK_LT(axis->value, old_in_shapes[0].size()) + << "Axis out of bounds in FixedPointMultiplyPerAxis operator."; + old_dims.emplace(old_layout[axis->value].name()); + } + + Array new_axes; + std::string new_layout_string = ""; + for (size_t axis_index = 0; axis_index < new_layout->axes.size(); ++axis_index) { + const auto& layout_axis = LayoutAxis::Get(new_layout->axes[axis_index]); + const std::string& layout_dim = layout_axis.name(); + if (layout_axis.IsPrimal()) { + if (old_dims.count(layout_dim)) { + new_axes.push_back(tvm::Integer(axis_index)); + new_layout_string += layout_dim; + } + } else { + auto primal_dim = layout_axis.ToPrimal().name(); + if (old_dims.count(primal_dim)) { + new_axes.push_back(tvm::Integer(axis_index)); + new_layout_string += std::to_string(new_layout.FactorOf(layout_axis)) + layout_dim; + } + } + } + + Layout channel_layout = Layout(new_layout_string); + + input_layouts = {new_layout, channel_layout, channel_layout, channel_layout}; + output_layouts = {new_layout}; + param->axes = std::move(new_axes); + } else if (old_in_layouts.defined()) { + ICHECK_EQ(old_in_layouts.size(), 4); + ICHECK_EQ(param->axes.size(), 1); // Not tested other cases + const Layout& old_layout = old_in_layouts[0]; + if (old_layout.defined()) { + std::string layout_string = old_layout[param->axes[0]->value].name(); + Layout channel_layout = Layout(layout_string); + + input_layouts = {old_layout, channel_layout, channel_layout, channel_layout}; + output_layouts = {old_layout}; + } else { + // Set the layouts to undef. + Layout undef = Layout::Undef(); + input_layouts = Array(4, undef); + output_layouts = {undef}; + } + } else { + // Set the layouts to undef. + Layout undef = Layout::Undef(); + input_layouts = Array(4, undef); + output_layouts = {undef}; + } + + return InferCorrectLayoutOutput(input_layouts, output_layouts, Attrs(param)); +} + +Expr MakeFixedPointMultiplyPerAxis(Expr x, Expr m, Expr lshift, Expr rshift, + bool is_lshift_required, bool is_rshift_required, + Array axes) { + auto attrs = make_object(); + attrs->is_lshift_required = is_lshift_required; + attrs->is_rshift_required = is_rshift_required; + attrs->axes = std::move(axes); + static const Op& op = Op::Get("fixed_point_multiply_per_axis"); + return Call(op, {x, m, lshift, rshift}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.fixed_point_multiply_per_axis") + .set_body_typed(MakeFixedPointMultiplyPerAxis); + +RELAY_REGISTER_OP("fixed_point_multiply_per_axis") + .describe(R"code(per channel fixed point multiplication)code" TVM_ADD_FILELINE) + .set_num_inputs(4) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("fp_multiplier", "Tensor", "The multipliers tensor.") + .add_argument("left_shift", "Tensor", "The left shifts tensor.") + .add_argument("right_shift", "Tensor", "The right shifts tensor.") + .add_type_rel("FixedPointMultiplyPerAxis", FixedPointMultiplyPerAxisRel) + .set_attr("TOpPattern", kBroadcast) + .set_attr("FInferCorrectLayout", + FixedPointMultiplyPerAxisInferCorrectLayout) + .set_attrs_type() + .set_support_level(10); + } // namespace relay } // namespace tvm diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index ae321b459788..1614652719c6 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -214,6 +214,7 @@ Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale, // if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for // the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input // tensor. Depending on the quantization type, the fixed point multiplication routing is called. + const bool is_upward_rounding = (param->rounding == "UPWARD"); auto scaled_int32_t = tensor; float output_scale_float = GetScalarFromConstant(output_scale); if (IsConstScalar(input_scale)) { @@ -225,8 +226,6 @@ Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale, if (!IsEqualScalar(input_scale, output_scale)) { auto [fixed_point_multiplier, shift] = GetFixedPointMultiplierShift(double_multiplier); - const bool is_upward_rounding = (param->rounding == "UPWARD"); - // When using upward rounding (i.e., x.5 rounded to x+1), leverage // the FixedPointMultiply operator scaled_int32_t = @@ -246,8 +245,13 @@ Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale, } int axis = param->axis; axis = (axis == -1) ? input_shape.size() - 1 : axis; - scaled_int32_t = FixedPointMultiplyPerChannel(scaled_int32_t, double_multipliers, input_shape, - axis, param->rounding); + + // When using "upward" rounding, leverage the FixedPointMultiplyPerAxis operator, + // for "tonearest" rounding - lower to multiply, add, shift operators sequence. + scaled_int32_t = is_upward_rounding + ? FixedPointMultiplyPerChannel(scaled_int32_t, double_multipliers, axis) + : FixedPointMultiplyPerChannelToNearest(scaled_int32_t, double_multipliers, + input_shape, axis); } // 3) Add the output zero point. diff --git a/src/relay/qnn/utils.cc b/src/relay/qnn/utils.cc index ed7a415cf6af..ab72bd957080 100644 --- a/src/relay/qnn/utils.cc +++ b/src/relay/qnn/utils.cc @@ -108,6 +108,32 @@ Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, return Cast(tensor, DataType::Int(32)); } +Expr FixedPointMultiplyPerChannel(Expr tensor, const std::vector& multipliers, int axis) { + DataType dtype = DataType::Int(32); + int64_t n_channels = static_cast(multipliers.size()); + + std::vector fixed_pt_multipliers, lshifts, rshifts; + bool is_lshift_required = false, is_rshift_required = false; + for (auto multiplier : multipliers) { + auto [fixed_pt_multiplier, shift] = GetFixedPointMultiplierShift(multiplier); + int lshift = shift > 0 ? shift : 0; + int rshift = shift > 0 ? 0 : -shift; + fixed_pt_multipliers.push_back(fixed_pt_multiplier); + lshifts.push_back(lshift); + rshifts.push_back(rshift); + is_lshift_required = is_lshift_required | (lshift != 0); + is_rshift_required = is_rshift_required | (rshift != 0); + } + + auto left_shift_expr = MakeConstantTensor(dtype, {n_channels}, lshifts); + auto right_shift_expr = MakeConstantTensor(dtype, {n_channels}, rshifts); + auto fixed_pt_multiplier_expr = MakeConstantTensor(dtype, {n_channels}, fixed_pt_multipliers); + + return FixedPointMultiplyPerAxis(tensor, fixed_pt_multiplier_expr, left_shift_expr, + right_shift_expr, is_lshift_required, is_rshift_required, + {axis}); +} + Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, const Array& input_shape, int channel_axis, const std::string& rounding) { @@ -197,6 +223,11 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, return Cast(tensor, DataType::Int(32)); } +Expr FixedPointMultiplyPerChannelToNearest(Expr tensor, std::vector multipliers, + const Array& input_shape, int channel_axis) { + return FixedPointMultiplyPerChannel(tensor, multipliers, input_shape, channel_axis, "TONEAREST"); +} + std::string SelectRequntizeParameter(const std::string& arg_value, const std::string& cfg_value, const bool is_cfg_default, const std::string& name) { if (arg_value == "None") { diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index d084e4871e95..87195eb34d94 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -212,6 +212,23 @@ Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multiplier, const Array& input_shape, int channel_axis, const std::string& rounding); + +/* + * Wrapper for 'FixedPointMultiplyPerChannel' with rounding parameter == "TONEAREST". + */ +Expr FixedPointMultiplyPerChannelToNearest(Expr tensor, std::vector multiplier, + const Array& input_shape, int channel_axis); + +/* + * \brief Creates FixedPointMultiply operation where the input tensor is + per-axis/per-channel quantized.. + * \param tensor The quantized input tensor. + * \param multipliers List of scalar multipliers. + * \param channel_axis The channel_axis along which the input tensor is quantized. + * \return The Relay op. + */ +Expr FixedPointMultiplyPerChannel(Expr tensor, const std::vector& multipliers, int axis); + /* * \brief Checks whether an expr type is scalar of a given data type. * \param expr_type The type of expr to be checked. diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index ffe1cc2ca2ab..d03939e09ea8 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -661,6 +661,13 @@ inline Expr FixedPointMultiply(Expr x, int32_t multiplier, int32_t shift) { return Call(op, {x}, Attrs(attrs), {}); } +inline Expr FixedPointMultiplyPerAxis(Expr x, Expr m, Expr lshift, Expr rshift, + bool is_lshift_required, bool is_rshift_required, + Array axes) { + return MakeFixedPointMultiplyPerAxis(x, m, lshift, rshift, is_lshift_required, is_rshift_required, + axes); +} + inline Expr Add(Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); return Call(op, {lhs, rhs}, Attrs(), {}); diff --git a/src/runtime/crt/host/Makefile b/src/runtime/crt/host/Makefile index d9e87c7d6a41..ea2966045bb2 100644 --- a/src/runtime/crt/host/Makefile +++ b/src/runtime/crt/host/Makefile @@ -21,7 +21,7 @@ CXXFLAGS ?= -Werror -Wall -std=c++11 -DTVM_HOST_USE_GRAPH_EXECUTOR_MODULE LDFLAGS ?= -Werror -Wall # Codegen produces spurious lines like: int32_t arg2_code = ((int32_t*)arg_type_ids)[(2)]; -MODEL_CFLAGS ?= -Wno-error=unused-variable -Wno-error=missing-braces +MODEL_CFLAGS ?= -Wno-error=unused-variable -Wno-error=missing-braces -Wno-error=unused-const-variable AR ?= ${PREFIX}ar CC ?= ${PREFIX}gcc diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index e697d9b60273..86c50f2609be 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -151,6 +151,46 @@ TVM_REGISTER_OP("tir.isinf") return isinf(call->args[0]); }); +/*! + * \brief Makes fixed point multiplication. + * \param x Input tensor. + * \param y Integer multiplier. + * \param left_shift Integer left shift. + * \param right_shift Integer right shift. + * \param is_left_shift_required Flag whether we need to do left shift or not. + * \return Calculated expression. + */ +static PrimExpr QMultiplyShift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr left_shift, + PrimExpr right_shift, PrimExpr is_left_shift_required) { + // Only int32 types are supported (any number of lanes is allowed) + ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); + ICHECK(left_shift.dtype().code() == DLDataTypeCode::kDLInt && left_shift.dtype().bits() == 32); + ICHECK(right_shift.dtype().code() == DLDataTypeCode::kDLInt && right_shift.dtype().bits() == 32); + + DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); + DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); + + // 1) Cast and Multiply the integer multiplier + PrimExpr one = make_const(hp_dtype, 1); + x = cast(hp_dtype, x); + y = cast(hp_dtype, y); + x = tir::Select(is_left_shift_required, x << left_shift, x); + + // 2) Perform the multiplication in higher precision. + x = x * y; + + // 3) Find the rounding scalar + PrimExpr total_right_shift = right_shift + q; + PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); + x = x + pos_rounding_value; + + // 4) Simply right shift the result to get the final output. + x = x >> total_right_shift; + + // 5) The fixed point multiplication keeps the value in int32 range. Casting back to int32. + return cast(lp_dtype, x); +} + TVM_REGISTER_OP("tir.q_multiply_shift") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; @@ -194,40 +234,34 @@ TVM_REGISTER_OP("tir.q_multiply_shift") } } else { // Only int32 types are supported (any number of lanes is allowed) - ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); - DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); - DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); - - // 1) Calculating the integer multiplier and integer shift + // Calculating integer shifts PrimExpr zero = make_const(s.dtype(), 0); PrimExpr left_shift = tir::Select(s > zero, s, zero); PrimExpr right_shift = tir::Select(s > zero, zero, -s); + PrimExpr is_left_shift_required = (left_shift != zero); - // 2) Cast and Multiply the integer multiplier - PrimExpr one = make_const(hp_dtype, 1); - x = cast(hp_dtype, x); - y = cast(hp_dtype, y); - x = tir::Select(left_shift != zero, x << left_shift, x); - - // 3) Perform the multiplication in higher precision. - x = x * y; - - // 4) Find the rounding scalar - PrimExpr total_right_shift = right_shift + q; - PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); - x = x + pos_rounding_value; - - // 5) Simply right shift the result to get the final output. - x = x >> total_right_shift; - - // 6) The fixed point multiplication keeps the value in int32 range. Casting back to - // int32. - return cast(lp_dtype, x); + return QMultiplyShift(x, y, q, left_shift, right_shift, is_left_shift_required); } }); +TVM_REGISTER_OP("tir.q_multiply_shift_per_axis") + .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + + PrimExpr x = call->args[0]; + PrimExpr y = call->args[1]; + PrimExpr left_shift = call->args[2]; + PrimExpr right_shift = call->args[3]; + PrimExpr q = call->args[4]; + PrimExpr is_lshift_required = call->args[5]; + // Note, 7th argument is "is_rshift_required" flag, but we don't need that here. + // PrimExpr is_rshift_required = call->args[6]; + + return QMultiplyShift(x, y, q, left_shift, right_shift, is_lshift_required); + }); } // namespace legalize } // namespace codegen } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 1e2d790c76e1..929626bd7d53 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -98,6 +98,11 @@ TIR_DEFINE_BUILTIN_FUNC(q_multiply_shift) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TVectorizable", true); +TIR_DEFINE_BUILTIN_FUNC(q_multiply_shift_per_axis) + .set_num_inputs(7) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TVectorizable", true); + TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py index ee03599ff1f4..e7e4aa212e35 100644 --- a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py +++ b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py @@ -78,11 +78,25 @@ def run_module(graph_mod, inputs): return output +in_scale_const, out_scale_const = tvm.testing.parameters( + (1.3, 30.0), + (1.37, 1.0), + (0.6, 1.0), + ((1.7, 0.6), 1.0), + ((0.007, 1.9), 1.0), +) + +multiplier, shift = tvm.testing.parameters( + (1288490240, -2), # 0.15 + (1395864320, 1), # 1.3 + (1288490188, 0), # 0.6 +) + + @tvm.testing.requires_hexagon -def test_fixed_point_multiply_positive_shift(hexagon_session: Session): +def test_fixed_point_multiply(hexagon_session: Session, multiplier: int, shift: int): ishape = (6, 32) a = relay.var("a", relay.TensorType(ishape, "int32")) - multiplier, shift = (1395864320, 1) # 1.3 fpm = relay.fixed_point_multiply(a, multiplier, shift) relay_mod = tvm.IRModule.from_expr(fpm) @@ -108,22 +122,37 @@ def test_fixed_point_multiply_positive_shift(hexagon_session: Session): @tvm.testing.requires_hexagon -def test_fixed_point_multiply_negative_shift(hexagon_session: Session): - ishape = (6, 32) - a = relay.var("a", relay.TensorType(ishape, "int32")) - multiplier, shift = (1288490240, -2) # 0.15 - fpm = relay.fixed_point_multiply(a, multiplier, shift) - relay_mod = tvm.IRModule.from_expr(fpm) +def test_per_channel_fixed_point_multiply( + hexagon_session: Session, in_scale_const, out_scale_const +): + ishape = [1, 128, 56, 56] + axis = 1 + a = relay.var("a", shape=ishape, dtype="int32") + + # Make list of input scales from in_scale_const parameter. + if isinstance(in_scale_const, tuple): + in_scale = list(in_scale_const) * (ishape[axis] // len(in_scale_const)) + else: + in_scale = [in_scale_const] * ishape[axis] + assert len(in_scale) == ishape[axis] + + # qnn.requantize is lowered to fixed_point_multiply if zp == 0 and in_dtype == out_dtype. + iscale = relay.const(in_scale) + izero = relay.const(0) + oscale = relay.const(out_scale_const) + ozero = relay.const(0) + op = relay.qnn.op.requantize(a, iscale, izero, oscale, ozero, axis=axis, out_dtype="int32") + mod = tvm.IRModule.from_expr(op) with tvm.transform.PassContext(opt_level=3): # Compile for Hexagon... - hexagon_lowered = build_module(relay_mod, tvm.target.hexagon("v68")) + hexagon_lowered = build_module(mod, tvm.target.hexagon("v68")) # Compile for LLVM... - llvm_lowered = build_module(relay_mod, tvm.target.Target("llvm")) + llvm_lowered = build_module(mod, tvm.target.Target("llvm")) - data_in = np.arange(-96, 96).reshape(ishape) - inputs = {"a": data_in} + a_np = np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape) + inputs = {"a": a_np} # Run hexagon... graph_mod = hexagon_session.get_executor_from_factory(hexagon_lowered) @@ -133,7 +162,7 @@ def test_fixed_point_multiply_negative_shift(hexagon_session: Session): llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0))) expected_output = run_module(llvm_graph_mod, inputs) - tvm.testing.assert_allclose(hexagon_output, expected_output, atol=1) + tvm.testing.assert_allclose(hexagon_output, expected_output) if __name__ == "__main__":