From 2d2a5f59e9379221d88b6d7c74cc24b437fd5da5 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Tue, 30 Jun 2020 19:11:48 +0100 Subject: [PATCH 01/11] Fixed point multiplication improvements for AArch64 Change-Id: Ib3c10348d4c0eac11fa92b39cc6e792560e9eba4 --- include/tvm/relay/attrs/transform.h | 11 ++++++ include/tvm/tir/builtin.h | 11 ++++++ include/tvm/tir/op.h | 11 ++++++ python/tvm/relay/op/_tensor.py | 8 ++++ python/tvm/tir/__init__.py | 1 + python/tvm/tir/op.py | 21 +++++++++++ src/relay/op/tensor/unary.cc | 14 +++++++ src/relay/qnn/op/requantize.cc | 14 ++++++- src/relay/qnn/util.cc | 19 ---------- src/relay/qnn/util.h | 21 +++++++++++ src/relay/transforms/pattern_util.h | 8 ++++ src/target/intrin_rule.cc | 46 +++++++++++++++++++++++ src/tir/op/builtin.cc | 5 +++ src/tir/op/op.cc | 5 +++ src/tir/transforms/lower_intrin.cc | 13 ++++++- topi/python/topi/arm_cpu/conv2d_gemm.py | 13 ++++--- topi/python/topi/arm_cpu/conv2d_int8.py | 14 ++++++- topi/python/topi/arm_cpu/injective.py | 8 +++- topi/python/topi/arm_cpu/tensor_intrin.py | 43 +++++++++++++++++++++ topi/python/topi/math.py | 24 ++++++++++++ 20 files changed, 280 insertions(+), 30 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index b0c8108b1624..70f844252268 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -298,6 +298,17 @@ struct ClipAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for FixedPointMultiply operator */ +struct FixedPointMultiplyAttrs : public tvm::AttrsNode { + int32_t multiplier; + int32_t shift; + + TVM_DECLARE_ATTRS(FixedPointMultiplyAttrs, "relay.attrs.FixedPointMultiplyAttrs") { + TVM_ATTR_FIELD(multiplier).describe("Integer multiplier."); + TVM_ATTR_FIELD(shift).describe("Shift."); + } +}; + /*! \brief Attributes for LayoutTransform operator */ struct LayoutTransformAttrs : public tvm::AttrsNode { std::string src_layout; diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 464ce6c143c5..7f5bd2d8c368 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -92,6 +92,17 @@ TVM_DLL const Op& shift_right(); */ TVM_DLL const Op& large_uint_imm(); +/*! + * \brief Execute a fixed point multiplication y = round(x * m * 2^s). + * The default rounding rule is to the nearest value, rounding half up + * (i.e., round(x.1) = x and round (x.5) = x+1) + * \param x input value + * \param m integer multiplier + * \param s integer shift + * \return The constructed expression. + */ +TVM_DLL const Op& fixed_point_multiply(); + /*! * \brief See pesudo code * diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 31ce13c7e66a..515060595f0e 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -552,6 +552,17 @@ TVM_DLL PrimExpr trunc(PrimExpr x); */ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); +/*! + * \brief Execute a fixed point multiplication y = round(x * m * 2^s). + * The default rounding rule is to the nearest value, rounding half up + * (i.e., round(x.1) = x and round (x.5) = x+1) + * \param x input value + * \param m integer multiplier + * \param s integer shift + * \return The constructed expression. + */ +TVM_DLL PrimExpr fixed_point_multiply(PrimExpr x, PrimExpr m, PrimExpr s); + // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline PrimExpr OpName(PrimExpr x) { \ diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index d4911d95e90d..feeec1fa89ec 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -131,6 +131,14 @@ def clip_compute(attrs, inputs, output_type): register_injective_schedule("clip") +# fixed point multiply +@register_compute("fixed_point_multiply") +def fixed_point_multiply_compute(attrs, inputs, output_type): + assert len(inputs) == 1 + return [topi.fixed_point_multiply(inputs[0], attrs.multiplier, attrs.shift)] + +register_injective_schedule("fixed_point_multiply") + # full @script def _full_shape_func(shape): diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 9dbdc07b4a46..134c23093a68 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -45,6 +45,7 @@ from .op import isnan, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from .op import comm_reducer, min, max, sum +from .op import fixed_point_multiply from . import ir_builder from . import transform diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index cbbd59fe4eaf..4b2fc70b7f44 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -965,6 +965,27 @@ def popcount(x): """ return call_intrin(x.dtype, "tir.popcount", x) +def fixed_point_multiply(x, m, s): + """Execute a fixed point multiplication y = round(x * m * 2^s). + The default rounding rule is to the nearest value, rounding half up + (i.e., round(x.1) = x and round (x.5) = x+1) + + Parameters + ---------- + x : PrimExpr + Input argument. + m : PrimExpr + Integer multiplier + s : PrimExpr + Integer shift + + Returns + ------- + y : PrimExpr + The result. + """ + return call_intrin(x.dtype, "tir.fixed_point_multiply", x, m, s) + def fmod(x, y): """Return the remainder of x divided by y with the same sign as x. diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 958b8b535873..2f4c150bf4be 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -277,6 +277,20 @@ Expr MakeClip(Expr a, double a_min, double a_max) { TVM_REGISTER_GLOBAL("relay.op._make.clip").set_body_typed(MakeClip); +// relay.fixed_point_multiply +TVM_REGISTER_NODE_TYPE(FixedPointMultiplyAttrs); + +RELAY_REGISTER_OP("fixed_point_multiply") + .describe(R"code( fixed point multiplication )code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kElemWise) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attrs_type() + .set_support_level(3); + RELAY_REGISTER_OP("clip") .describe(R"code(Clip tensor values. This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index bdeaf05c86bd..28013055f31b 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -153,9 +153,19 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, static_cast(input_scale_float) / static_cast(output_scale_float); // Skip if input and output scales are same. if (!IsEqualScalar(input_scale, output_scale)) { - scaled_int32_t = - FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding); + int32_t fixed_point_multiplier, shift; + std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier); + + const bool is_upward_rounding = (param->rounding == "UPWARD"); + + // When using upward rounding (i.e., x.5 rounded to x+1), leverage + // the fixed_point_muliply intrinsic + scaled_int32_t = (is_upward_rounding ? relay::FixedPointMultiply( + scaled_int32_t, fixed_point_multiplier, shift) + : FixedPointMultiply(scaled_int32_t, double_multiplier, + input_shape, param->rounding)); } + } else { // This is per-channel (per=axis) quantization. std::vector double_multipliers; diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index 4daa5c9334de..2bf2c3fa2445 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -30,25 +30,6 @@ namespace tvm { namespace relay { namespace qnn { -/* - * \brief Convert FP32 representation into fixed point representation. - * \param double_multplier The input FP32 number. - * \return The pair of multiplier and shift for fixed point representation. - * \note Converts a floating point number so that it can be represented by - * integers. The representation is - * float_number = (significand) * 2^(exponent) - * - * The significand is a number between 0.5 and 1. This is represented by - * an integer number. For example, if it is int32, then the decimal point - * exists between bit 31 and 30 from LSB (or between first and second bit - * from the left). - * - * Some examples are - * 0.25 = (0.5) * 2^(-1) - * 0.125 = (0.5) * 2^(-2) - * - * Credit to TFLite reference implementation. - */ std::pair GetFixedPointMultiplierShift(double double_multiplier) { int32_t significand, exponent; if (double_multiplier == 0.) { diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 736b7361a300..cacecc82f05b 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -70,6 +70,27 @@ static inline int32_t GetQmax(const DataType& dtype) { } } +/* + * \brief Convert FP32 representation into fixed point representation. + * \param double_multplier The input FP32 number. + * \return The pair of multiplier and shift for fixed point representation. + * \note Converts a floating point number so that it can be represented by + * integers. The representation is + * float_number = (significand) * 2^(exponent) + * + * The significand is a number between 0.5 and 1. This is represented by + * an integer number. For example, if it is int32, then the decimal point + * exists between bit 31 and 30 from LSB (or between first and second bit + * from the left). + * + * Some examples are + * 0.25 = (0.5) * 2^(-1) + * 0.125 = (0.5) * 2^(-2) + * + * Credit to TFLite reference implementation. + */ +std::pair GetFixedPointMultiplierShift(double double_multiplier); + Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& input_zero_point, const Expr& output_scale, const Expr& output_zero_point, const RequantizeAttrs* param, diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index adbd1bd44431..b3e36818870a 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -495,6 +495,14 @@ inline Expr Round(Expr x) { inline Expr Clip(Expr x, double a_min, double a_max) { return MakeClip(x, a_min, a_max); } +inline Expr FixedPointMultiply(Expr x, int32_t multiplier, int32_t shift) { + static const Op& op = Op::Get("fixed_point_multiply"); + auto attrs = make_object(); + attrs->multiplier = multiplier; + attrs->shift = shift; + return Call(op, {x}, Attrs(attrs), {}); +} + inline Expr Add(Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); return Call(op, {lhs, rhs}, Attrs(), {}); diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 31fadf1ce5ac..fa2eef3f8f2a 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -115,6 +115,52 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf") *rv = isinf(call->args[0]); }); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.fixed_point_multiply") + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + using tir::make_const; + + PrimExpr e = args[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + + PrimExpr tensor = call->args[0]; + PrimExpr fixed_point_multiplier = call->args[1]; + PrimExpr shift = call->args[2]; + + // Only int32 types are supported (any number of lanes is allowed) + CHECK(tensor.dtype().code() == DLDataTypeCode::kDLInt && tensor.dtype().bits() == 32); + CHECK(fixed_point_multiplier.dtype().code() == DLDataTypeCode::kDLInt && + fixed_point_multiplier.dtype().bits() == 32); + CHECK(shift.dtype().code() == DLDataTypeCode::kDLInt && shift.dtype().bits() == 32); + + DataType hp_dtype = DataType::Int(64, tensor.dtype().lanes()); + DataType lp_dtype = DataType::Int(32, tensor.dtype().lanes()); + + // 1) Calculating the integer multiplier and integer shift + PrimExpr zero = make_const(shift.dtype(), 0); + PrimExpr left_shift = tir::Select((shift > zero), shift, zero); + PrimExpr right_shift = tir::Select(shift > zero, zero, -shift); + + // 2) Multiply the integer multiplier + tensor = tir::Select(left_shift != zero, tensor << cast(hp_dtype, left_shift), + cast(hp_dtype, tensor)); + + // 3) Perform the multiplication in higher precision. + tensor = tensor * fixed_point_multiplier; + + // 4) Find the rounding scalar + PrimExpr total_right_shift = right_shift + 31; + PrimExpr pos_rounding_value = (make_const(hp_dtype, 1) << (total_right_shift - 1)); + + tensor = tensor + pos_rounding_value; + + // 5) Simply right shift the result to get the final output. + tensor = tensor >> total_right_shift; + + // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. + *rv = cast(lp_dtype, tensor); + }); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index d23662c78d37..68a07852b5a6 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -89,6 +89,11 @@ TIR_DEFINE_BUILTIN_FUNC(if_then_else) .set_num_inputs(3) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_BUILTIN_FUNC(fixed_point_multiply) + .set_num_inputs(3) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TVectorizable", true); + TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index a0ba8d655232..aec41001bbf2 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -90,6 +90,11 @@ PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}); } +// fixed_point_multiply +PrimExpr fixed_point_multiply(PrimExpr x, PrimExpr m, PrimExpr s) { + return tir::Call(x.dtype(), tir::builtin::fixed_point_multiply(), {x, m, s}); +} + // The public function with a quick checking path. void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) if (lhs.dtype() == rhs.dtype()) return; diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 5372ef8fb1fa..1c529d86523e 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -40,8 +40,15 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; - IntrinInjecter(arith::Analyzer* analyzer, std::string target) : IRMutatorWithAnalyzer(analyzer) { + IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") + : IRMutatorWithAnalyzer(analyzer) { patterns_.push_back("tvm.intrin.rule." + target + "."); + + bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos); + if (is_llvm_aarch64) { + patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64."); + } + patterns_.push_back("tvm.intrin.rule.default."); fma_ = runtime::Registry::Get(patterns_[0] + "fma"); if (target == "stackvm") { @@ -287,7 +294,9 @@ Pass LowerIntrin() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; - n->body = IntrinInjecter(&analyzer, target.value()->id->name)(std::move(n->body)); + auto mtriple = target.value()->GetAttr("mtriple", ""); + n->body = + IntrinInjecter(&analyzer, target.value()->id->name, mtriple.value())(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); diff --git a/topi/python/topi/arm_cpu/conv2d_gemm.py b/topi/python/topi/arm_cpu/conv2d_gemm.py index 63d96bb44d92..fa5aff37fa31 100644 --- a/topi/python/topi/arm_cpu/conv2d_gemm.py +++ b/topi/python/topi/arm_cpu/conv2d_gemm.py @@ -119,7 +119,7 @@ def compute_conv2d_gemm_without_weight_transform(cfg, C = te.compute((batches, M, N), lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)], - name="C", tag='injective') + name="C") # --- Produce the conv output out_shape = (batches, OH, OW, OC) @@ -129,7 +129,7 @@ def compute_conv2d_gemm_without_weight_transform(cfg, return out # Schedules -def schedule_conv2d_gemm(cfg, s, out): +def schedule_conv2d_gemm(cfg, s, out, final_out): """Create schedule for tensors""" C = out.op.input_tensors[0] C_interleaved = C.op.input_tensors[0] @@ -172,8 +172,11 @@ def schedule_conv2d_gemm(cfg, s, out): s[C_interleaved].tensorize(yi, gem_v_dotprod) # Output transform - N, OH, OW, OC = out.shape - s[C].split(C.op.axis[1], OW) - s[C].compute_at(s[out], out.op.axis[3]) + if out != final_out: + n, h, w, c = out.op.axis + _, inner = s[out].split(c, 4) + s[C].compute_at(s[out],inner) + s[out].vectorize(inner) + return s diff --git a/topi/python/topi/arm_cpu/conv2d_int8.py b/topi/python/topi/arm_cpu/conv2d_int8.py index 5a895c084c06..34924c7da452 100644 --- a/topi/python/topi/arm_cpu/conv2d_int8.py +++ b/topi/python/topi/arm_cpu/conv2d_int8.py @@ -137,11 +137,23 @@ def compute_conv2d_NHWC_quantized_without_transform(cfg, data, B, strides, paddi def schedule_conv2d_NHWC_quantized(cfg, outs): """Create schedule for tensors""" s = te.create_schedule([x.op for x in outs]) + # Vectorize the output and then inline all the rest + out = outs[0] + n,h,w,c = out.op.axis + outer, inner = s[out].split(c, 4) + s[out].vectorize(inner) def _callback(op): """Traverse operators from computation graph""" if op.name == "conv2d_gemm_output": - schedule_conv2d_gemm(cfg, s, op.output(0)) + conv_out = op.output(0) + schedule_conv2d_gemm(cfg, s, conv_out, out) + if out != conv_out: + s[conv_out].compute_at(s[out], inner) + else: + C = conv_out.op.input_tensors[0] + s[C].compute_at(s[out], inner) + traverse_inline(s, outs[0].op, _callback) return s diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py index 966520088bc7..41ed92502f5e 100644 --- a/topi/python/topi/arm_cpu/injective.py +++ b/topi/python/topi/arm_cpu/injective.py @@ -62,9 +62,15 @@ def schedule_injective(outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) x = outs[0] + dtype = x.op.input_tensors[0].dtype + print(dtype) + if dtype == 'int32': + max_vlen = 4 + else: + max_vlen = 8 if list(s[x].op.axis): # do not vectorize for broadcast - (io, ii) = s[x].split(list(s[x].op.axis)[-1], 8) + (io, ii) = s[x].split(list(s[x].op.axis)[-1], max_vlen) s[x].vectorize(ii) tvm.te.schedule.AutoInlineInjective(s) diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index dfa2f05e7960..fa004820d6c3 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -451,3 +451,46 @@ def _instr(index): return te.decl_tensor_intrin( C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, default_buffer_params=buffer_params) + +def _fixed_point_multiply_arm(op): + """ + Implementation of fixed point multiplication through arm + intrinsics sqrdmulh and srshl + """ + x = op.args[0] + multiplier = op.args[1] + shift = op.args[2] + + # Don't use this intrinsic if we don't have a int32x4 vector + if x.dtype != "int32x4": + return op + + # Case 1, shift is negative + sqrdmulh = tvm.tir.call_llvm_intrin(op.dtype, + 'llvm.aarch64.neon.sqrdmulh', + tvm.tir.const(2, 'uint32'), + x, + multiplier) + + fixup = (sqrdmulh & (-shift)) >> 31 + fixed_up_x = (sqrdmulh + fixup) + out_1 = tvm.tir.call_llvm_intrin(op.dtype, + 'llvm.aarch64.neon.srshl', + tvm.tir.const(2, 'uint32'), + sqrdmulh, + shift) + + # Case 2, shift is positive + x = x * (1 << (shift)) + out_2 = tvm.tir.call_llvm_intrin(op.dtype, + 'llvm.aarch64.neon.sqrdmulh', + tvm.tir.const(2, 'uint32'), + x, + multiplier) + + # Select depending on the shift + return tvm.tir.Select(shift < 0, out_1, out_2) + +tvm.target.intrin.register_intrin_rule("llvm.aarch64", + "fixed_point_multiply", + _fixed_point_multiply_arm, override=True) diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index b4228a4a9178..475f8074fb35 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -612,6 +612,30 @@ def _compute(*indices): return tvm.te.max(tvm.te.min(value, const_max), const_min) return te.compute(x.shape, _compute) +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def fixed_point_multiply(x, multiplier, shift): + """ + + Parameters + ---------- + x : tvm.te.Tensor or Expr + Input argument. + multiplier: Integer multiplier + shift: Integer shift + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + def _compute(*indices): + value = x(*indices) + m = tvm.tir.const(multiplier, x.dtype) + s = tvm.tir.const(shift, x.dtype) + return tvm.tir.fixed_point_multiply(value, m, s) + + assert x.dtype == "int32", "input tensor type needs to be int32" + return te.compute(x.shape, _compute) def cast(x, dtype): """Cast input to specified data type. From c4b25c355cb76fc4a942a824804a34af66cfee8c Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 2 Jul 2020 10:56:49 +0100 Subject: [PATCH 02/11] Fix python linting errors Change-Id: I4cf5ac18aa24b39374b83805dcc8e1663e173909 --- topi/python/topi/arm_cpu/conv2d_gemm.py | 2 +- topi/python/topi/arm_cpu/conv2d_int8.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/topi/python/topi/arm_cpu/conv2d_gemm.py b/topi/python/topi/arm_cpu/conv2d_gemm.py index fa5aff37fa31..e97de56a0b65 100644 --- a/topi/python/topi/arm_cpu/conv2d_gemm.py +++ b/topi/python/topi/arm_cpu/conv2d_gemm.py @@ -175,7 +175,7 @@ def schedule_conv2d_gemm(cfg, s, out, final_out): if out != final_out: n, h, w, c = out.op.axis _, inner = s[out].split(c, 4) - s[C].compute_at(s[out],inner) + s[C].compute_at(s[out], inner) s[out].vectorize(inner) diff --git a/topi/python/topi/arm_cpu/conv2d_int8.py b/topi/python/topi/arm_cpu/conv2d_int8.py index 34924c7da452..89a37fa41294 100644 --- a/topi/python/topi/arm_cpu/conv2d_int8.py +++ b/topi/python/topi/arm_cpu/conv2d_int8.py @@ -139,7 +139,7 @@ def schedule_conv2d_NHWC_quantized(cfg, outs): s = te.create_schedule([x.op for x in outs]) # Vectorize the output and then inline all the rest out = outs[0] - n,h,w,c = out.op.axis + n, h, w, c = out.op.axis outer, inner = s[out].split(c, 4) s[out].vectorize(inner) From 210841ee78aed7782912a2aff3bed5545e7e45e2 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 2 Jul 2020 11:05:53 +0100 Subject: [PATCH 03/11] Fix doxygen errors Change-Id: Ie3c861f8ead3f1ea5b30d5e9d7d94e222299d407 --- include/tvm/tir/builtin.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 7f5bd2d8c368..e2d37e1d529e 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -96,10 +96,6 @@ TVM_DLL const Op& large_uint_imm(); * \brief Execute a fixed point multiplication y = round(x * m * 2^s). * The default rounding rule is to the nearest value, rounding half up * (i.e., round(x.1) = x and round (x.5) = x+1) - * \param x input value - * \param m integer multiplier - * \param s integer shift - * \return The constructed expression. */ TVM_DLL const Op& fixed_point_multiply(); From 1e958333e3468acadcd422f93355b965d6594eb6 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 2 Jul 2020 15:58:46 +0100 Subject: [PATCH 04/11] Fix arm_cpu injective tests Change-Id: I6ad9da61b61e6bd737627f26fba59767418c07cd --- topi/python/topi/arm_cpu/injective.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py index 41ed92502f5e..50a21f3b1e94 100644 --- a/topi/python/topi/arm_cpu/injective.py +++ b/topi/python/topi/arm_cpu/injective.py @@ -61,13 +61,11 @@ def schedule_injective(outs): """ outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) - x = outs[0] - dtype = x.op.input_tensors[0].dtype - print(dtype) - if dtype == 'int32': - max_vlen = 4 - else: - max_vlen = 8 + out = outs[0] + ins = out.op.input_tensors + dtype = ins[0].dtype if len(ins) else out.dtype + max_vlen = 4 if dtype == 'int32' else 8 + if list(s[x].op.axis): # do not vectorize for broadcast (io, ii) = s[x].split(list(s[x].op.axis)[-1], max_vlen) From a063a0abc94aa71232bcfca0f609eef8c92a31ae Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 2 Jul 2020 16:02:42 +0100 Subject: [PATCH 05/11] Fix python linting errors - 2 Change-Id: Ic864a235aa5da5786393cbf6146dd815c121df5e --- topi/python/topi/arm_cpu/injective.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py index 50a21f3b1e94..45b9336a9ed7 100644 --- a/topi/python/topi/arm_cpu/injective.py +++ b/topi/python/topi/arm_cpu/injective.py @@ -63,7 +63,7 @@ def schedule_injective(outs): s = te.create_schedule([x.op for x in outs]) out = outs[0] ins = out.op.input_tensors - dtype = ins[0].dtype if len(ins) else out.dtype + dtype = ins[0].dtype if len(ins) > 0 else out.dtype max_vlen = 4 if dtype == 'int32' else 8 if list(s[x].op.axis): From a26b432af192b3e268ac0a72d235a476de8a6bae Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 2 Jul 2020 16:32:16 +0100 Subject: [PATCH 06/11] Fix arm_cpu injective tests - 2 Change-Id: If9ca1cc3d947b1656c836c7f88de90470d92f979 --- topi/python/topi/arm_cpu/injective.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py index 45b9336a9ed7..54306deadd85 100644 --- a/topi/python/topi/arm_cpu/injective.py +++ b/topi/python/topi/arm_cpu/injective.py @@ -61,9 +61,9 @@ def schedule_injective(outs): """ outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) - out = outs[0] - ins = out.op.input_tensors - dtype = ins[0].dtype if len(ins) > 0 else out.dtype + x = outs[0] + ins = x.op.input_tensors + dtype = ins[0].dtype if len(ins) > 0 else x.dtype max_vlen = 4 if dtype == 'int32' else 8 if list(s[x].op.axis): From 070234d5d258a9016319658952580b1e0b29a735 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Wed, 8 Jul 2020 17:53:30 +0100 Subject: [PATCH 07/11] Redesign: introduce a qmuls (q-multiply and shift) general intrinsic Change-Id: I1966fef9aee32eab50e4b984bbe81018488c8c02 --- include/tvm/tir/builtin.h | 5 +-- include/tvm/tir/op.h | 19 +++++++---- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 21 ++++++++---- src/target/intrin_rule.cc | 39 +++++++++++------------ src/tir/op/builtin.cc | 2 +- src/tir/op/op.cc | 6 ++-- topi/python/topi/arm_cpu/tensor_intrin.py | 37 +++++++++++++-------- topi/python/topi/math.py | 3 +- 9 files changed, 79 insertions(+), 55 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index e2d37e1d529e..4feabdb5bbde 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -93,11 +93,12 @@ TVM_DLL const Op& shift_right(); TVM_DLL const Op& large_uint_imm(); /*! - * \brief Execute a fixed point multiplication y = round(x * m * 2^s). + * \brief Execute a multiplication between two Q-numbers x and y + * followed by a right shift s * The default rounding rule is to the nearest value, rounding half up * (i.e., round(x.1) = x and round (x.5) = x+1) */ -TVM_DLL const Op& fixed_point_multiply(); +TVM_DLL const Op& qmuls(); /*! * \brief See pesudo code diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 515060595f0e..f40a209b5461 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -553,15 +553,22 @@ TVM_DLL PrimExpr trunc(PrimExpr x); TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); /*! - * \brief Execute a fixed point multiplication y = round(x * m * 2^s). - * The default rounding rule is to the nearest value, rounding half up + * \brief Execute a multiplication between two Q-numbers x and y + * followed by a right shift s. The mathematical expression is: + * + * out = round(x*y*2^-s) + * + * More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) + * + * The rounding rule is to the nearest value, rounding half up * (i.e., round(x.1) = x and round (x.5) = x+1) - * \param x input value - * \param m integer multiplier - * \param s integer shift + * \param x first Q-number + * \param y second Q-number + * \param q Q-ness of x and y + * \param s integer right shift * \return The constructed expression. */ -TVM_DLL PrimExpr fixed_point_multiply(PrimExpr x, PrimExpr m, PrimExpr s); +TVM_DLL PrimExpr qmuls(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s); // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 134c23093a68..1e51a3ab4544 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -45,7 +45,7 @@ from .op import isnan, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from .op import comm_reducer, min, max, sum -from .op import fixed_point_multiply +from .op import qmuls from . import ir_builder from . import transform diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 4b2fc70b7f44..457ccbff8b6f 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -965,17 +965,24 @@ def popcount(x): """ return call_intrin(x.dtype, "tir.popcount", x) -def fixed_point_multiply(x, m, s): - """Execute a fixed point multiplication y = round(x * m * 2^s). - The default rounding rule is to the nearest value, rounding half up +def qmuls(x, y, q, s): + """Execute a multiplication between two Q-numbers x and y + followed by a right shift s. The mathematical expression is: + + out = round(x*y*2^-s) + + More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) + The rounding rule is to the nearest value, rounding half up (i.e., round(x.1) = x and round (x.5) = x+1) Parameters ---------- x : PrimExpr - Input argument. - m : PrimExpr - Integer multiplier + First Q-number + y : PrimExpr + Second Q-number + q : PrimExpr + Q-ness of x and y s : PrimExpr Integer shift @@ -984,7 +991,7 @@ def fixed_point_multiply(x, m, s): y : PrimExpr The result. """ - return call_intrin(x.dtype, "tir.fixed_point_multiply", x, m, s) + return call_intrin(x.dtype, "tir.qmuls", x, y, q, s) def fmod(x, y): """Return the remainder of x divided by y with the same sign as x. diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index fa2eef3f8f2a..29fb692e6c62 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -115,7 +115,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf") *rv = isinf(call->args[0]); }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.fixed_point_multiply") +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.qmuls") .set_body([](const TVMArgs& args, TVMRetValue* rv) { using tir::make_const; @@ -123,42 +123,41 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.fixed_point_multiply") const tir::CallNode* call = e.as(); CHECK(call != nullptr); - PrimExpr tensor = call->args[0]; - PrimExpr fixed_point_multiplier = call->args[1]; - PrimExpr shift = call->args[2]; + PrimExpr x = call->args[0]; + PrimExpr y = call->args[1]; + PrimExpr q = call->args[2]; + PrimExpr s = call->args[3]; // Only int32 types are supported (any number of lanes is allowed) - CHECK(tensor.dtype().code() == DLDataTypeCode::kDLInt && tensor.dtype().bits() == 32); - CHECK(fixed_point_multiplier.dtype().code() == DLDataTypeCode::kDLInt && - fixed_point_multiplier.dtype().bits() == 32); - CHECK(shift.dtype().code() == DLDataTypeCode::kDLInt && shift.dtype().bits() == 32); + CHECK(x.dtype().code() == DLDataTypeCode::kDLInt && x.dtype().bits() == 32); + CHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); + CHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); - DataType hp_dtype = DataType::Int(64, tensor.dtype().lanes()); - DataType lp_dtype = DataType::Int(32, tensor.dtype().lanes()); + DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); + DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); // 1) Calculating the integer multiplier and integer shift - PrimExpr zero = make_const(shift.dtype(), 0); - PrimExpr left_shift = tir::Select((shift > zero), shift, zero); - PrimExpr right_shift = tir::Select(shift > zero, zero, -shift); + PrimExpr zero = make_const(s.dtype(), 0); + PrimExpr left_shift = tir::Select((s > zero), s, zero); + PrimExpr right_shift = tir::Select(s > zero, zero, -s); // 2) Multiply the integer multiplier - tensor = tir::Select(left_shift != zero, tensor << cast(hp_dtype, left_shift), - cast(hp_dtype, tensor)); + x = tir::Select(left_shift != zero, x << cast(hp_dtype, left_shift), cast(hp_dtype, x)); // 3) Perform the multiplication in higher precision. - tensor = tensor * fixed_point_multiplier; + x = x * y; // 4) Find the rounding scalar - PrimExpr total_right_shift = right_shift + 31; + PrimExpr total_right_shift = right_shift + q; PrimExpr pos_rounding_value = (make_const(hp_dtype, 1) << (total_right_shift - 1)); - tensor = tensor + pos_rounding_value; + x = x + pos_rounding_value; // 5) Simply right shift the result to get the final output. - tensor = tensor >> total_right_shift; + x = x >> total_right_shift; // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. - *rv = cast(lp_dtype, tensor); + *rv = cast(lp_dtype, x); }); } // namespace intrin diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 68a07852b5a6..24f1ec66e74a 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -89,7 +89,7 @@ TIR_DEFINE_BUILTIN_FUNC(if_then_else) .set_num_inputs(3) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_BUILTIN_FUNC(fixed_point_multiply) +TIR_DEFINE_BUILTIN_FUNC(qmuls) .set_num_inputs(3) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TVectorizable", true); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index aec41001bbf2..541c66117edd 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -90,9 +90,9 @@ PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}); } -// fixed_point_multiply -PrimExpr fixed_point_multiply(PrimExpr x, PrimExpr m, PrimExpr s) { - return tir::Call(x.dtype(), tir::builtin::fixed_point_multiply(), {x, m, s}); +// Q-multiplication +PrimExpr qmuls(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s) { + return tir::Call(x.dtype(), tir::builtin::qmuls(), {x, y, q, s}); } // The public function with a quick checking path. diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index fa004820d6c3..5a7393e170c1 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -452,17 +452,26 @@ def _instr(index): C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, default_buffer_params=buffer_params) -def _fixed_point_multiply_arm(op): +def _qmuls_arm(op): """ - Implementation of fixed point multiplication through arm - intrinsics sqrdmulh and srshl + Implementation of qmuls through arm intrinsics sqrdmulh and srshl + when q == 31. + + Please note that this is introducing a small round-up error for + some corner cases. This is because we are rounding twice instead + than only once. I.e.: + + * original qmuls: round(x*y*2^-s) + * arm qmuls: round(round(x*y)*2^-s) """ x = op.args[0] - multiplier = op.args[1] - shift = op.args[2] + y = op.args[1] + q = op.args[2] + s = op.args[3] # Don't use this intrinsic if we don't have a int32x4 vector - if x.dtype != "int32x4": + # and if we are not multiplying q31 numbers + if x.dtype != "int32x4" and q == 31: return op # Case 1, shift is negative @@ -470,27 +479,27 @@ def _fixed_point_multiply_arm(op): 'llvm.aarch64.neon.sqrdmulh', tvm.tir.const(2, 'uint32'), x, - multiplier) + y) - fixup = (sqrdmulh & (-shift)) >> 31 + fixup = (sqrdmulh & (-s)) >> 31 fixed_up_x = (sqrdmulh + fixup) out_1 = tvm.tir.call_llvm_intrin(op.dtype, 'llvm.aarch64.neon.srshl', tvm.tir.const(2, 'uint32'), sqrdmulh, - shift) + s) # Case 2, shift is positive - x = x * (1 << (shift)) + x = x * (1 << (s)) out_2 = tvm.tir.call_llvm_intrin(op.dtype, 'llvm.aarch64.neon.sqrdmulh', tvm.tir.const(2, 'uint32'), x, - multiplier) + y) # Select depending on the shift - return tvm.tir.Select(shift < 0, out_1, out_2) + return tvm.tir.Select(s < 0, out_1, out_2) tvm.target.intrin.register_intrin_rule("llvm.aarch64", - "fixed_point_multiply", - _fixed_point_multiply_arm, override=True) + "qmuls", + _qmuls_arm, override=True) diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 475f8074fb35..cb0d15d52943 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -632,7 +632,8 @@ def _compute(*indices): value = x(*indices) m = tvm.tir.const(multiplier, x.dtype) s = tvm.tir.const(shift, x.dtype) - return tvm.tir.fixed_point_multiply(value, m, s) + q = tvm.tir.const(31, x.dtype) + return tvm.tir.qmuls(value, m, q, s) assert x.dtype == "int32", "input tensor type needs to be int32" return te.compute(x.shape, _compute) From 0c7a010918be1bbdca3fa6293b2d20e21e407111 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 9 Jul 2020 09:52:28 +0100 Subject: [PATCH 08/11] Fix python linting errors - 3 Change-Id: Ib87a19a8ee2d532954a7db1eb5793666e7aef366 --- topi/python/topi/arm_cpu/tensor_intrin.py | 6 +++--- topi/python/topi/math.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index 5a7393e170c1..71009339e9dc 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -455,10 +455,10 @@ def _instr(index): def _qmuls_arm(op): """ Implementation of qmuls through arm intrinsics sqrdmulh and srshl - when q == 31. + when q == 31. - Please note that this is introducing a small round-up error for - some corner cases. This is because we are rounding twice instead + Please note that this is introducing a small round-up error for + some corner cases. This is because we are rounding twice instead than only once. I.e.: * original qmuls: round(x*y*2^-s) diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index cb0d15d52943..92783664f57e 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -630,10 +630,10 @@ def fixed_point_multiply(x, multiplier, shift): """ def _compute(*indices): value = x(*indices) - m = tvm.tir.const(multiplier, x.dtype) - s = tvm.tir.const(shift, x.dtype) - q = tvm.tir.const(31, x.dtype) - return tvm.tir.qmuls(value, m, q, s) + return tvm.tir.qmuls(value, + tvm.tir.const(multiplier, x.dtype), + tvm.tir.const(31, x.dtype), + tvm.tir.const(shift, x.dtype)) assert x.dtype == "int32", "input tensor type needs to be int32" return te.compute(x.shape, _compute) From e8730eb389277338c34a5ede6ba588881daf1e90 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Tue, 14 Jul 2020 17:00:13 +0100 Subject: [PATCH 09/11] Addressing review comments Change-Id: Ie82e75204e5a421d17660f381f3e31fc325cd26c --- include/tvm/relay/attrs/transform.h | 6 +++-- include/tvm/tir/op.h | 5 +++- python/tvm/relay/op/tensor.py | 21 +++++++++++++++ python/tvm/tir/op.py | 2 +- src/relay/op/tensor/unary.cc | 33 ++++++++++++++--------- src/relay/qnn/op/requantize.cc | 10 +++---- src/relay/qnn/util.cc | 24 +++++++---------- src/relay/qnn/util.h | 11 ++++---- src/relay/quantize/realize.cc | 20 +++++++++++--- src/target/intrin_rule.cc | 4 +-- tests/python/relay/test_op_level3.py | 17 ++++++++++++ topi/python/topi/arm_cpu/injective.py | 5 +--- topi/python/topi/arm_cpu/tensor_intrin.py | 2 +- topi/python/topi/math.py | 6 +++-- 14 files changed, 113 insertions(+), 53 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 70f844252268..36579749c48a 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -304,8 +304,10 @@ struct FixedPointMultiplyAttrs : public tvm::AttrsNode int32_t shift; TVM_DECLARE_ATTRS(FixedPointMultiplyAttrs, "relay.attrs.FixedPointMultiplyAttrs") { - TVM_ATTR_FIELD(multiplier).describe("Integer multiplier."); - TVM_ATTR_FIELD(shift).describe("Shift."); + TVM_ATTR_FIELD(multiplier) + .describe("Multiplier of a fixed floating point number described as multiplier*2^(shift)"); + TVM_ATTR_FIELD(shift).describe( + "Shift of a fixed floating point number described as multiplier*2^(shift)"); } }; diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index f40a209b5461..8a80088fc1dd 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -558,13 +558,16 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); * * out = round(x*y*2^-s) * + * Please note that the two Q-numbers x and y are supposed to have + * the same number of fractional bits q. + * * More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) * * The rounding rule is to the nearest value, rounding half up * (i.e., round(x.1) = x and round (x.5) = x+1) * \param x first Q-number * \param y second Q-number - * \param q Q-ness of x and y + * \param q number of fractional bits in x and y. Needs to be > 0 * \param s integer right shift * \return The constructed expression. */ diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index a02e08d2deb7..c002c8b2ff7e 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1034,6 +1034,27 @@ def clip(a, a_min, a_max): """ return _make.clip(a, a_min, a_max) +def fixed_point_multiply(data, multiplier, shift): + """Fixed point multiplication between data and a fixed point + constant expressed as multiplier * 2^(-shift), where multiplier + is a Q-number with 31 fractional bits + + Parameters + ---------- + data : relay.Expr + The input tensor. + multiplier : int + The integer multiplier of the fixed point constant. + a_max : float + The integer shift of the fixed point constant. + + Returns + ------- + result : relay.Expr + The output of the fixed point multiplication + """ + return _make.fixed_point_multiply(data, multiplier, shift) + def concatenate(data, axis): """Concatenate the input tensors along the given axis. diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 457ccbff8b6f..feef4413f043 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -982,7 +982,7 @@ def qmuls(x, y, q, s): y : PrimExpr Second Q-number q : PrimExpr - Q-ness of x and y + Number of fractional bits in x and y. Needs to be > 0 s : PrimExpr Integer shift diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 2f4c150bf4be..fc61661566c3 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -277,32 +277,41 @@ Expr MakeClip(Expr a, double a_min, double a_max) { TVM_REGISTER_GLOBAL("relay.op._make.clip").set_body_typed(MakeClip); -// relay.fixed_point_multiply -TVM_REGISTER_NODE_TYPE(FixedPointMultiplyAttrs); - -RELAY_REGISTER_OP("fixed_point_multiply") - .describe(R"code( fixed point multiplication )code" TVM_ADD_FILELINE) +RELAY_REGISTER_OP("clip") + .describe(R"code(Clip tensor values. +This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. +)code" TVM_ADD_FILELINE) .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .add_type_rel("Identity", IdentityRel) .set_attr("TOpPattern", kElemWise) .set_attr("TOpIsStateful", false) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attrs_type() + .set_attrs_type() .set_support_level(3); -RELAY_REGISTER_OP("clip") - .describe(R"code(Clip tensor values. -This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. -)code" TVM_ADD_FILELINE) +// relay.fixed_point_multiply +TVM_REGISTER_NODE_TYPE(FixedPointMultiplyAttrs); + +TVM_REGISTER_GLOBAL("relay.op._make.fixed_point_multiply") + .set_body_typed([](Expr a, int32_t multiplier, int32_t shift) { + auto attrs = make_object(); + attrs->multiplier = multiplier; + attrs->shift = shift; + static const Op& op = Op::Get("fixed_point_multiply"); + return Call(op, {a}, Attrs(attrs), {}); + }); + +RELAY_REGISTER_OP("fixed_point_multiply") + .describe(R"code(fixed point multiplication)code" TVM_ADD_FILELINE) .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .add_type_rel("Identity", IdentityRel) .set_attr("TOpPattern", kElemWise) .set_attr("TOpIsStateful", false) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attrs_type() - .set_support_level(3); + .set_attrs_type() + .set_support_level(10); RELAY_REGISTER_UNARY_OP("floor") .describe(R"code(Returns the floor of input array, computed element-wise. diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 28013055f31b..222d91021b19 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -159,11 +159,11 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const bool is_upward_rounding = (param->rounding == "UPWARD"); // When using upward rounding (i.e., x.5 rounded to x+1), leverage - // the fixed_point_muliply intrinsic - scaled_int32_t = (is_upward_rounding ? relay::FixedPointMultiply( - scaled_int32_t, fixed_point_multiplier, shift) - : FixedPointMultiply(scaled_int32_t, double_multiplier, - input_shape, param->rounding)); + // the FixedPointMultiply operator + scaled_int32_t = + (is_upward_rounding + ? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift) + : FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape)); } } else { diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index 2bf2c3fa2445..113038e327d7 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -56,8 +56,8 @@ std::pair GetFixedPointMultiplierShift(double double_multiplie return std::make_pair(significand, exponent); } -Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& input_shape, - const std::string& rounding) { +Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, + const Array& input_shape) { // Choose high precision datatype to be int64. This is for avoiding overflow // in multiplication of two int32 values. DataType hp_dtype = DataType::Int(64); @@ -90,19 +90,15 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& int64_t pos_rounding_value = (1ll << (total_right_shift - 1)); Expr round_scalar; - if (rounding == "UPWARD") { - round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value); - } else if (rounding == "TONEAREST") { - auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); - auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); - auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); - auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); - auto zero_t = Zeros(input_shape, hp_dtype); - round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); - } else { - LOG(FATAL) << "Rounding mode " << rounding << " not supported."; - } + auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); + auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); + auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); + auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); + + auto zero_t = Zeros(input_shape, hp_dtype); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); + // Add the rounding scalar. tensor = Add(tensor, round_scalar); diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index cacecc82f05b..72eb2a46b2ae 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -115,13 +115,12 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) { /* * \brief Fixed point multiplication between integer tensor with floating point - scalar. + * scalar. This implementation rounds to the nearest value when it is midway + * between two representable values. * \param tensor The quantized input tensor of dtype int64. * \param multiplier The scalar multiplier. * \param input_shape Shape of the input tensor. - * \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value - is midway between" "two representable values. - * \return The sequence of Relay ops for fixed point multiplication. + * \return The sequence of Relay ops for fixed point multiplication with TONEARES rounding. * \note Original compuation is scale_fp32 * quantized_tensor. To convert into * integer computation, the multiplication with fp32 scalar can be @@ -135,8 +134,8 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) { * 2) Round the result. * 3) Right shift the result */ -Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& input_shape, - const std::string& rounding); +Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, + const Array& input_shape); /* * \brief Fixed point multiplication between integer tensor with floating point diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index ddf945a0b19f..ace2c2473173 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -113,7 +113,14 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, } else if (static_cast(factor) == factor) { return Multiply(data, MakeConstantScalar(dtype, factor)); } else { - data = qnn::FixedPointMultiply(data, factor, data_shape, cfg->rounding); + if (cfg->rounding == "UPWARD") { + int32_t fixed_point_multiplier, shift; + std::tie(fixed_point_multiplier, shift) = qnn::GetFixedPointMultiplierShift(factor); + data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift); + } else { + data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape); + } + return Cast(data, dtype); } } @@ -164,8 +171,15 @@ Expr QuantizeRealize(const Call& ref_call, const Array& new_args, const Ob return QRealizeIntExpr(data, dom_scale, n->dtype); } else { data = Cast(data, DataType::Int(64)); - data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm, - ref_call->type_as()->shape, cfg->rounding); + if (cfg->rounding == "UPWARD") { + int32_t fixed_point_multiplier, shift; + std::tie(fixed_point_multiplier, shift) = + qnn::GetFixedPointMultiplierShift(idom_scale_imm / odom_scale_imm); + data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift); + } else { + data = qnn::FixedPointMultiplyToNearest(data, idom_scale_imm / odom_scale_imm, + ref_call->type_as()->shape); + } data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype); return QRealizeIntExpr(data, dom_scale, n->dtype); } diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 29fb692e6c62..6801d5549d99 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -138,14 +138,14 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.qmuls") // 1) Calculating the integer multiplier and integer shift PrimExpr zero = make_const(s.dtype(), 0); - PrimExpr left_shift = tir::Select((s > zero), s, zero); + PrimExpr left_shift = tir::Select(s > zero, s, zero); PrimExpr right_shift = tir::Select(s > zero, zero, -s); // 2) Multiply the integer multiplier x = tir::Select(left_shift != zero, x << cast(hp_dtype, left_shift), cast(hp_dtype, x)); // 3) Perform the multiplication in higher precision. - x = x * y; + x = x * cast(hp_dtype, y); // 4) Find the rounding scalar PrimExpr total_right_shift = right_shift + q; diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 115900fea0f3..93b44bfd3504 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -84,6 +84,22 @@ def test_clip(): ref_res = np.clip(data, 1., 4.) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) +def test_fixed_point_multiply(): + # Test 23 * 1/16 + # [m,s] = [0.5, -3] = frexp(1/16) + # M = 0.5*2^31 = 1073741824 + # so M = 1073741824 and s = -3 + + a = relay.var("a", relay.TensorType((10, 4), "int32")) + y = relay.fixed_point_multiply(a, 1073741824, -3) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((10, 4), "int32") + + data = 23*np.ones((10, 4)).astype('int32') + intrp = create_executor() + op_res = intrp.evaluate(y, { a: relay.const(data) }) + ref_res = np.ones((10, 4)).astype('int32') + np.testing.assert_allclose(op_res.asnumpy(), ref_res, atol=1) def test_reinterpret(): a = relay.var("a", relay.TensorType((1000, 4), "float32")) @@ -1034,3 +1050,4 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ test_isinf() test_unravel_index() test_sparse_to_dense() + test_fixed_point_multiply() diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py index 54306deadd85..3e3c73d26553 100644 --- a/topi/python/topi/arm_cpu/injective.py +++ b/topi/python/topi/arm_cpu/injective.py @@ -62,13 +62,10 @@ def schedule_injective(outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) x = outs[0] - ins = x.op.input_tensors - dtype = ins[0].dtype if len(ins) > 0 else x.dtype - max_vlen = 4 if dtype == 'int32' else 8 if list(s[x].op.axis): # do not vectorize for broadcast - (io, ii) = s[x].split(list(s[x].op.axis)[-1], max_vlen) + (io, ii) = s[x].split(list(s[x].op.axis)[-1], 4) s[x].vectorize(ii) tvm.te.schedule.AutoInlineInjective(s) diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index 71009339e9dc..28701bde8c57 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -471,7 +471,7 @@ def _qmuls_arm(op): # Don't use this intrinsic if we don't have a int32x4 vector # and if we are not multiplying q31 numbers - if x.dtype != "int32x4" and q == 31: + if x.dtype != "int32x4" or q.val != 31: return op # Case 1, shift is negative diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 92783664f57e..6189acc651f2 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -620,8 +620,10 @@ def fixed_point_multiply(x, multiplier, shift): ---------- x : tvm.te.Tensor or Expr Input argument. - multiplier: Integer multiplier - shift: Integer shift + multiplier: int + Multiplier of a fixed floating point number described as multiplier*2^(shift) + shift: int + Shift of a fixed floating point number described as multiplier*2^(shift) Returns ------- From d18e2fb521c0d8649e834bce2c178f28474cedf0 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Wed, 15 Jul 2020 12:18:33 +0100 Subject: [PATCH 10/11] Fixing test failures Change-Id: I74cc675764cf8d260fe68a41e770b1ec7e84729a --- python/tvm/tir/op.py | 2 +- src/target/intrin_rule.cc | 13 +++++++------ src/tir/op/op.cc | 2 +- topi/python/topi/math.py | 8 +++----- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index feef4413f043..d012617667d4 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -991,7 +991,7 @@ def qmuls(x, y, q, s): y : PrimExpr The result. """ - return call_intrin(x.dtype, "tir.qmuls", x, y, q, s) + return call_intrin('int32', "tir.qmuls", x, y, q, s) def fmod(x, y): """Return the remainder of x divided by y with the same sign as x. diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 6801d5549d99..fd29f00cbae8 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -129,7 +129,6 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.qmuls") PrimExpr s = call->args[3]; // Only int32 types are supported (any number of lanes is allowed) - CHECK(x.dtype().code() == DLDataTypeCode::kDLInt && x.dtype().bits() == 32); CHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); CHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); @@ -141,16 +140,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.qmuls") PrimExpr left_shift = tir::Select(s > zero, s, zero); PrimExpr right_shift = tir::Select(s > zero, zero, -s); - // 2) Multiply the integer multiplier - x = tir::Select(left_shift != zero, x << cast(hp_dtype, left_shift), cast(hp_dtype, x)); + // 2) Cast and Multiply the integer multiplier + PrimExpr one = make_const(hp_dtype, 1); + x = cast(hp_dtype, x); + y = cast(hp_dtype, y); + x = tir::Select(left_shift != zero, x << left_shift, x); // 3) Perform the multiplication in higher precision. - x = x * cast(hp_dtype, y); + x = x * y; // 4) Find the rounding scalar PrimExpr total_right_shift = right_shift + q; - PrimExpr pos_rounding_value = (make_const(hp_dtype, 1) << (total_right_shift - 1)); - + PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); x = x + pos_rounding_value; // 5) Simply right shift the result to get the final output. diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 541c66117edd..4ff98725cf58 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -92,7 +92,7 @@ PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { // Q-multiplication PrimExpr qmuls(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s) { - return tir::Call(x.dtype(), tir::builtin::qmuls(), {x, y, q, s}); + return tir::Call(DataType::Int(32, x.dtype().lanes()), tir::builtin::qmuls(), {x, y, q, s}); } // The public function with a quick checking path. diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 6189acc651f2..398ec0561289 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -633,11 +633,9 @@ def fixed_point_multiply(x, multiplier, shift): def _compute(*indices): value = x(*indices) return tvm.tir.qmuls(value, - tvm.tir.const(multiplier, x.dtype), - tvm.tir.const(31, x.dtype), - tvm.tir.const(shift, x.dtype)) - - assert x.dtype == "int32", "input tensor type needs to be int32" + tvm.tir.const(multiplier, 'int32'), + tvm.tir.const(31, 'int32'), + tvm.tir.const(shift, 'int32')) return te.compute(x.shape, _compute) def cast(x, dtype): From a257327a2f14dbba9000348b36ba89811faf66ff Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Fri, 17 Jul 2020 12:39:05 +0100 Subject: [PATCH 11/11] Renaming qmuls to q_multiply_shift Change-Id: I5a8ed60ba855208040304fcdf6e1ea28061f06ad --- include/tvm/tir/builtin.h | 2 +- include/tvm/tir/op.h | 2 +- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 4 ++-- src/target/intrin_rule.cc | 2 +- src/tir/op/builtin.cc | 2 +- src/tir/op/op.cc | 5 +++-- topi/python/topi/arm_cpu/tensor_intrin.py | 18 ++++++++--------- topi/python/topi/math.py | 24 ++++++++++++----------- 9 files changed, 32 insertions(+), 29 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 4feabdb5bbde..bea53136fd54 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -98,7 +98,7 @@ TVM_DLL const Op& large_uint_imm(); * The default rounding rule is to the nearest value, rounding half up * (i.e., round(x.1) = x and round (x.5) = x+1) */ -TVM_DLL const Op& qmuls(); +TVM_DLL const Op& q_multiply_shift(); /*! * \brief See pesudo code diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 8a80088fc1dd..68ca2663ede9 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -571,7 +571,7 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); * \param s integer right shift * \return The constructed expression. */ -TVM_DLL PrimExpr qmuls(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s); +TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s); // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 1e51a3ab4544..1aac55fa9920 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -45,7 +45,7 @@ from .op import isnan, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from .op import comm_reducer, min, max, sum -from .op import qmuls +from .op import q_multiply_shift from . import ir_builder from . import transform diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index d012617667d4..10783768e593 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -965,7 +965,7 @@ def popcount(x): """ return call_intrin(x.dtype, "tir.popcount", x) -def qmuls(x, y, q, s): +def q_multiply_shift(x, y, q, s): """Execute a multiplication between two Q-numbers x and y followed by a right shift s. The mathematical expression is: @@ -991,7 +991,7 @@ def qmuls(x, y, q, s): y : PrimExpr The result. """ - return call_intrin('int32', "tir.qmuls", x, y, q, s) + return call_intrin('int32', "tir.q_multiply_shift", x, y, q, s) def fmod(x, y): """Return the remainder of x divided by y with the same sign as x. diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index fd29f00cbae8..fa0ee38d8130 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -115,7 +115,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf") *rv = isinf(call->args[0]); }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.qmuls") +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift") .set_body([](const TVMArgs& args, TVMRetValue* rv) { using tir::make_const; diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 24f1ec66e74a..3afb8810e774 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -89,7 +89,7 @@ TIR_DEFINE_BUILTIN_FUNC(if_then_else) .set_num_inputs(3) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_BUILTIN_FUNC(qmuls) +TIR_DEFINE_BUILTIN_FUNC(q_multiply_shift) .set_num_inputs(3) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TVectorizable", true); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 4ff98725cf58..75a483c4d165 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -91,8 +91,9 @@ PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { } // Q-multiplication -PrimExpr qmuls(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s) { - return tir::Call(DataType::Int(32, x.dtype().lanes()), tir::builtin::qmuls(), {x, y, q, s}); +PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s) { + return tir::Call(DataType::Int(32, x.dtype().lanes()), tir::builtin::q_multiply_shift(), + {x, y, q, s}); } // The public function with a quick checking path. diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index 28701bde8c57..270bfbe87766 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -452,17 +452,17 @@ def _instr(index): C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, default_buffer_params=buffer_params) -def _qmuls_arm(op): +def _q_multiply_shift_arm(op): """ - Implementation of qmuls through arm intrinsics sqrdmulh and srshl - when q == 31. + Implementation of q_multiply_shift_arm through arm intrinsics + sqrdmulh and srshl when q == 31. Please note that this is introducing a small round-up error for some corner cases. This is because we are rounding twice instead than only once. I.e.: - * original qmuls: round(x*y*2^-s) - * arm qmuls: round(round(x*y)*2^-s) + * original q_multiply_shift: round(x*y*2^-s) + * arm q_multiply_shift: round(round(x*y)*2^-s) """ x = op.args[0] y = op.args[1] @@ -470,8 +470,8 @@ def _qmuls_arm(op): s = op.args[3] # Don't use this intrinsic if we don't have a int32x4 vector - # and if we are not multiplying q31 numbers - if x.dtype != "int32x4" or q.val != 31: + # or if we are not multiplying q31 numbers + if x.dtype != "int32x4" or q.value != 31: return op # Case 1, shift is negative @@ -501,5 +501,5 @@ def _qmuls_arm(op): return tvm.tir.Select(s < 0, out_1, out_2) tvm.target.intrin.register_intrin_rule("llvm.aarch64", - "qmuls", - _qmuls_arm, override=True) + "q_multiply_shift", + _q_multiply_shift_arm, override=True) diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 398ec0561289..046b10342c0b 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -614,16 +614,18 @@ def _compute(*indices): @tvm.te.tag_scope(tag=tag.ELEMWISE) def fixed_point_multiply(x, multiplier, shift): - """ + """Fixed point multiplication between data and a fixed point + constant expressed as multiplier * 2^(-shift), where multiplier + is a Q-number with 31 fractional bits Parameters ---------- - x : tvm.te.Tensor or Expr - Input argument. - multiplier: int - Multiplier of a fixed floating point number described as multiplier*2^(shift) - shift: int - Shift of a fixed floating point number described as multiplier*2^(shift) + x : tvm.te.Tensor or Expr + Input argument. + multiplier : int + Multiplier of a fixed floating point number described as multiplier*2^(-shift). + shift : int + Shift of a fixed floating point number described as multiplier*2^(-shift). Returns ------- @@ -632,10 +634,10 @@ def fixed_point_multiply(x, multiplier, shift): """ def _compute(*indices): value = x(*indices) - return tvm.tir.qmuls(value, - tvm.tir.const(multiplier, 'int32'), - tvm.tir.const(31, 'int32'), - tvm.tir.const(shift, 'int32')) + return tvm.tir.q_multiply_shift(value, + tvm.tir.const(multiplier, 'int32'), + tvm.tir.const(31, 'int32'), + tvm.tir.const(shift, 'int32')) return te.compute(x.shape, _compute) def cast(x, dtype):