diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index b0c8108b1624..36579749c48a 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -298,6 +298,19 @@ struct ClipAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for FixedPointMultiply operator */ +struct FixedPointMultiplyAttrs : public tvm::AttrsNode { + int32_t multiplier; + int32_t shift; + + TVM_DECLARE_ATTRS(FixedPointMultiplyAttrs, "relay.attrs.FixedPointMultiplyAttrs") { + TVM_ATTR_FIELD(multiplier) + .describe("Multiplier of a fixed floating point number described as multiplier*2^(shift)"); + TVM_ATTR_FIELD(shift).describe( + "Shift of a fixed floating point number described as multiplier*2^(shift)"); + } +}; + /*! \brief Attributes for LayoutTransform operator */ struct LayoutTransformAttrs : public tvm::AttrsNode { std::string src_layout; diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 464ce6c143c5..bea53136fd54 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -92,6 +92,14 @@ TVM_DLL const Op& shift_right(); */ TVM_DLL const Op& large_uint_imm(); +/*! + * \brief Execute a multiplication between two Q-numbers x and y + * followed by a right shift s + * The default rounding rule is to the nearest value, rounding half up + * (i.e., round(x.1) = x and round (x.5) = x+1) + */ +TVM_DLL const Op& q_multiply_shift(); + /*! * \brief See pesudo code * diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 31ce13c7e66a..68ca2663ede9 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -552,6 +552,27 @@ TVM_DLL PrimExpr trunc(PrimExpr x); */ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); +/*! + * \brief Execute a multiplication between two Q-numbers x and y + * followed by a right shift s. The mathematical expression is: + * + * out = round(x*y*2^-s) + * + * Please note that the two Q-numbers x and y are supposed to have + * the same number of fractional bits q. + * + * More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) + * + * The rounding rule is to the nearest value, rounding half up + * (i.e., round(x.1) = x and round (x.5) = x+1) + * \param x first Q-number + * \param y second Q-number + * \param q number of fractional bits in x and y. Needs to be > 0 + * \param s integer right shift + * \return The constructed expression. + */ +TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s); + // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline PrimExpr OpName(PrimExpr x) { \ diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index d4911d95e90d..feeec1fa89ec 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -131,6 +131,14 @@ def clip_compute(attrs, inputs, output_type): register_injective_schedule("clip") +# fixed point multiply +@register_compute("fixed_point_multiply") +def fixed_point_multiply_compute(attrs, inputs, output_type): + assert len(inputs) == 1 + return [topi.fixed_point_multiply(inputs[0], attrs.multiplier, attrs.shift)] + +register_injective_schedule("fixed_point_multiply") + # full @script def _full_shape_func(shape): diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index a02e08d2deb7..c002c8b2ff7e 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1034,6 +1034,27 @@ def clip(a, a_min, a_max): """ return _make.clip(a, a_min, a_max) +def fixed_point_multiply(data, multiplier, shift): + """Fixed point multiplication between data and a fixed point + constant expressed as multiplier * 2^(-shift), where multiplier + is a Q-number with 31 fractional bits + + Parameters + ---------- + data : relay.Expr + The input tensor. + multiplier : int + The integer multiplier of the fixed point constant. + a_max : float + The integer shift of the fixed point constant. + + Returns + ------- + result : relay.Expr + The output of the fixed point multiplication + """ + return _make.fixed_point_multiply(data, multiplier, shift) + def concatenate(data, axis): """Concatenate the input tensors along the given axis. diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 9dbdc07b4a46..1aac55fa9920 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -45,6 +45,7 @@ from .op import isnan, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from .op import comm_reducer, min, max, sum +from .op import q_multiply_shift from . import ir_builder from . import transform diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index cbbd59fe4eaf..10783768e593 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -965,6 +965,34 @@ def popcount(x): """ return call_intrin(x.dtype, "tir.popcount", x) +def q_multiply_shift(x, y, q, s): + """Execute a multiplication between two Q-numbers x and y + followed by a right shift s. The mathematical expression is: + + out = round(x*y*2^-s) + + More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) + The rounding rule is to the nearest value, rounding half up + (i.e., round(x.1) = x and round (x.5) = x+1) + + Parameters + ---------- + x : PrimExpr + First Q-number + y : PrimExpr + Second Q-number + q : PrimExpr + Number of fractional bits in x and y. Needs to be > 0 + s : PrimExpr + Integer shift + + Returns + ------- + y : PrimExpr + The result. + """ + return call_intrin('int32', "tir.q_multiply_shift", x, y, q, s) + def fmod(x, y): """Return the remainder of x divided by y with the same sign as x. diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 958b8b535873..fc61661566c3 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -290,6 +290,29 @@ This function takes a tensor, a minimum value `a_min`, and a maximum value `a_ma .set_attrs_type() .set_support_level(3); +// relay.fixed_point_multiply +TVM_REGISTER_NODE_TYPE(FixedPointMultiplyAttrs); + +TVM_REGISTER_GLOBAL("relay.op._make.fixed_point_multiply") + .set_body_typed([](Expr a, int32_t multiplier, int32_t shift) { + auto attrs = make_object(); + attrs->multiplier = multiplier; + attrs->shift = shift; + static const Op& op = Op::Get("fixed_point_multiply"); + return Call(op, {a}, Attrs(attrs), {}); + }); + +RELAY_REGISTER_OP("fixed_point_multiply") + .describe(R"code(fixed point multiplication)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kElemWise) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attrs_type() + .set_support_level(10); + RELAY_REGISTER_UNARY_OP("floor") .describe(R"code(Returns the floor of input array, computed element-wise. )code" TVM_ADD_FILELINE) diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index bdeaf05c86bd..222d91021b19 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -153,9 +153,19 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, static_cast(input_scale_float) / static_cast(output_scale_float); // Skip if input and output scales are same. if (!IsEqualScalar(input_scale, output_scale)) { + int32_t fixed_point_multiplier, shift; + std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier); + + const bool is_upward_rounding = (param->rounding == "UPWARD"); + + // When using upward rounding (i.e., x.5 rounded to x+1), leverage + // the FixedPointMultiply operator scaled_int32_t = - FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding); + (is_upward_rounding + ? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift) + : FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape)); } + } else { // This is per-channel (per=axis) quantization. std::vector double_multipliers; diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index 4daa5c9334de..113038e327d7 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -30,25 +30,6 @@ namespace tvm { namespace relay { namespace qnn { -/* - * \brief Convert FP32 representation into fixed point representation. - * \param double_multplier The input FP32 number. - * \return The pair of multiplier and shift for fixed point representation. - * \note Converts a floating point number so that it can be represented by - * integers. The representation is - * float_number = (significand) * 2^(exponent) - * - * The significand is a number between 0.5 and 1. This is represented by - * an integer number. For example, if it is int32, then the decimal point - * exists between bit 31 and 30 from LSB (or between first and second bit - * from the left). - * - * Some examples are - * 0.25 = (0.5) * 2^(-1) - * 0.125 = (0.5) * 2^(-2) - * - * Credit to TFLite reference implementation. - */ std::pair GetFixedPointMultiplierShift(double double_multiplier) { int32_t significand, exponent; if (double_multiplier == 0.) { @@ -75,8 +56,8 @@ std::pair GetFixedPointMultiplierShift(double double_multiplie return std::make_pair(significand, exponent); } -Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& input_shape, - const std::string& rounding) { +Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, + const Array& input_shape) { // Choose high precision datatype to be int64. This is for avoiding overflow // in multiplication of two int32 values. DataType hp_dtype = DataType::Int(64); @@ -109,19 +90,15 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& int64_t pos_rounding_value = (1ll << (total_right_shift - 1)); Expr round_scalar; - if (rounding == "UPWARD") { - round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value); - } else if (rounding == "TONEAREST") { - auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); - auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); - auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); - auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); - auto zero_t = Zeros(input_shape, hp_dtype); - round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); - } else { - LOG(FATAL) << "Rounding mode " << rounding << " not supported."; - } + auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); + auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); + auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); + auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); + + auto zero_t = Zeros(input_shape, hp_dtype); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); + // Add the rounding scalar. tensor = Add(tensor, round_scalar); diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 736b7361a300..72eb2a46b2ae 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -70,6 +70,27 @@ static inline int32_t GetQmax(const DataType& dtype) { } } +/* + * \brief Convert FP32 representation into fixed point representation. + * \param double_multplier The input FP32 number. + * \return The pair of multiplier and shift for fixed point representation. + * \note Converts a floating point number so that it can be represented by + * integers. The representation is + * float_number = (significand) * 2^(exponent) + * + * The significand is a number between 0.5 and 1. This is represented by + * an integer number. For example, if it is int32, then the decimal point + * exists between bit 31 and 30 from LSB (or between first and second bit + * from the left). + * + * Some examples are + * 0.25 = (0.5) * 2^(-1) + * 0.125 = (0.5) * 2^(-2) + * + * Credit to TFLite reference implementation. + */ +std::pair GetFixedPointMultiplierShift(double double_multiplier); + Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& input_zero_point, const Expr& output_scale, const Expr& output_zero_point, const RequantizeAttrs* param, @@ -94,13 +115,12 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) { /* * \brief Fixed point multiplication between integer tensor with floating point - scalar. + * scalar. This implementation rounds to the nearest value when it is midway + * between two representable values. * \param tensor The quantized input tensor of dtype int64. * \param multiplier The scalar multiplier. * \param input_shape Shape of the input tensor. - * \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value - is midway between" "two representable values. - * \return The sequence of Relay ops for fixed point multiplication. + * \return The sequence of Relay ops for fixed point multiplication with TONEARES rounding. * \note Original compuation is scale_fp32 * quantized_tensor. To convert into * integer computation, the multiplication with fp32 scalar can be @@ -114,8 +134,8 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) { * 2) Round the result. * 3) Right shift the result */ -Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& input_shape, - const std::string& rounding); +Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, + const Array& input_shape); /* * \brief Fixed point multiplication between integer tensor with floating point diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index ddf945a0b19f..ace2c2473173 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -113,7 +113,14 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, } else if (static_cast(factor) == factor) { return Multiply(data, MakeConstantScalar(dtype, factor)); } else { - data = qnn::FixedPointMultiply(data, factor, data_shape, cfg->rounding); + if (cfg->rounding == "UPWARD") { + int32_t fixed_point_multiplier, shift; + std::tie(fixed_point_multiplier, shift) = qnn::GetFixedPointMultiplierShift(factor); + data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift); + } else { + data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape); + } + return Cast(data, dtype); } } @@ -164,8 +171,15 @@ Expr QuantizeRealize(const Call& ref_call, const Array& new_args, const Ob return QRealizeIntExpr(data, dom_scale, n->dtype); } else { data = Cast(data, DataType::Int(64)); - data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm, - ref_call->type_as()->shape, cfg->rounding); + if (cfg->rounding == "UPWARD") { + int32_t fixed_point_multiplier, shift; + std::tie(fixed_point_multiplier, shift) = + qnn::GetFixedPointMultiplierShift(idom_scale_imm / odom_scale_imm); + data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift); + } else { + data = qnn::FixedPointMultiplyToNearest(data, idom_scale_imm / odom_scale_imm, + ref_call->type_as()->shape); + } data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype); return QRealizeIntExpr(data, dom_scale, n->dtype); } diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index adbd1bd44431..b3e36818870a 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -495,6 +495,14 @@ inline Expr Round(Expr x) { inline Expr Clip(Expr x, double a_min, double a_max) { return MakeClip(x, a_min, a_max); } +inline Expr FixedPointMultiply(Expr x, int32_t multiplier, int32_t shift) { + static const Op& op = Op::Get("fixed_point_multiply"); + auto attrs = make_object(); + attrs->multiplier = multiplier; + attrs->shift = shift; + return Call(op, {x}, Attrs(attrs), {}); +} + inline Expr Add(Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); return Call(op, {lhs, rhs}, Attrs(), {}); diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 31fadf1ce5ac..fa0ee38d8130 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -115,6 +115,52 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf") *rv = isinf(call->args[0]); }); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift") + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + using tir::make_const; + + PrimExpr e = args[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + + PrimExpr x = call->args[0]; + PrimExpr y = call->args[1]; + PrimExpr q = call->args[2]; + PrimExpr s = call->args[3]; + + // Only int32 types are supported (any number of lanes is allowed) + CHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); + CHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); + + DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); + DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); + + // 1) Calculating the integer multiplier and integer shift + PrimExpr zero = make_const(s.dtype(), 0); + PrimExpr left_shift = tir::Select(s > zero, s, zero); + PrimExpr right_shift = tir::Select(s > zero, zero, -s); + + // 2) Cast and Multiply the integer multiplier + PrimExpr one = make_const(hp_dtype, 1); + x = cast(hp_dtype, x); + y = cast(hp_dtype, y); + x = tir::Select(left_shift != zero, x << left_shift, x); + + // 3) Perform the multiplication in higher precision. + x = x * y; + + // 4) Find the rounding scalar + PrimExpr total_right_shift = right_shift + q; + PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); + x = x + pos_rounding_value; + + // 5) Simply right shift the result to get the final output. + x = x >> total_right_shift; + + // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. + *rv = cast(lp_dtype, x); + }); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index d23662c78d37..3afb8810e774 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -89,6 +89,11 @@ TIR_DEFINE_BUILTIN_FUNC(if_then_else) .set_num_inputs(3) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_BUILTIN_FUNC(q_multiply_shift) + .set_num_inputs(3) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TVectorizable", true); + TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index a0ba8d655232..75a483c4d165 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -90,6 +90,12 @@ PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}); } +// Q-multiplication +PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s) { + return tir::Call(DataType::Int(32, x.dtype().lanes()), tir::builtin::q_multiply_shift(), + {x, y, q, s}); +} + // The public function with a quick checking path. void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) if (lhs.dtype() == rhs.dtype()) return; diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 5372ef8fb1fa..1c529d86523e 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -40,8 +40,15 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; - IntrinInjecter(arith::Analyzer* analyzer, std::string target) : IRMutatorWithAnalyzer(analyzer) { + IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") + : IRMutatorWithAnalyzer(analyzer) { patterns_.push_back("tvm.intrin.rule." + target + "."); + + bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos); + if (is_llvm_aarch64) { + patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64."); + } + patterns_.push_back("tvm.intrin.rule.default."); fma_ = runtime::Registry::Get(patterns_[0] + "fma"); if (target == "stackvm") { @@ -287,7 +294,9 @@ Pass LowerIntrin() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; - n->body = IntrinInjecter(&analyzer, target.value()->id->name)(std::move(n->body)); + auto mtriple = target.value()->GetAttr("mtriple", ""); + n->body = + IntrinInjecter(&analyzer, target.value()->id->name, mtriple.value())(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 115900fea0f3..93b44bfd3504 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -84,6 +84,22 @@ def test_clip(): ref_res = np.clip(data, 1., 4.) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) +def test_fixed_point_multiply(): + # Test 23 * 1/16 + # [m,s] = [0.5, -3] = frexp(1/16) + # M = 0.5*2^31 = 1073741824 + # so M = 1073741824 and s = -3 + + a = relay.var("a", relay.TensorType((10, 4), "int32")) + y = relay.fixed_point_multiply(a, 1073741824, -3) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((10, 4), "int32") + + data = 23*np.ones((10, 4)).astype('int32') + intrp = create_executor() + op_res = intrp.evaluate(y, { a: relay.const(data) }) + ref_res = np.ones((10, 4)).astype('int32') + np.testing.assert_allclose(op_res.asnumpy(), ref_res, atol=1) def test_reinterpret(): a = relay.var("a", relay.TensorType((1000, 4), "float32")) @@ -1034,3 +1050,4 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ test_isinf() test_unravel_index() test_sparse_to_dense() + test_fixed_point_multiply() diff --git a/topi/python/topi/arm_cpu/conv2d_gemm.py b/topi/python/topi/arm_cpu/conv2d_gemm.py index 63d96bb44d92..e97de56a0b65 100644 --- a/topi/python/topi/arm_cpu/conv2d_gemm.py +++ b/topi/python/topi/arm_cpu/conv2d_gemm.py @@ -119,7 +119,7 @@ def compute_conv2d_gemm_without_weight_transform(cfg, C = te.compute((batches, M, N), lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)], - name="C", tag='injective') + name="C") # --- Produce the conv output out_shape = (batches, OH, OW, OC) @@ -129,7 +129,7 @@ def compute_conv2d_gemm_without_weight_transform(cfg, return out # Schedules -def schedule_conv2d_gemm(cfg, s, out): +def schedule_conv2d_gemm(cfg, s, out, final_out): """Create schedule for tensors""" C = out.op.input_tensors[0] C_interleaved = C.op.input_tensors[0] @@ -172,8 +172,11 @@ def schedule_conv2d_gemm(cfg, s, out): s[C_interleaved].tensorize(yi, gem_v_dotprod) # Output transform - N, OH, OW, OC = out.shape - s[C].split(C.op.axis[1], OW) - s[C].compute_at(s[out], out.op.axis[3]) + if out != final_out: + n, h, w, c = out.op.axis + _, inner = s[out].split(c, 4) + s[C].compute_at(s[out], inner) + s[out].vectorize(inner) + return s diff --git a/topi/python/topi/arm_cpu/conv2d_int8.py b/topi/python/topi/arm_cpu/conv2d_int8.py index 5a895c084c06..89a37fa41294 100644 --- a/topi/python/topi/arm_cpu/conv2d_int8.py +++ b/topi/python/topi/arm_cpu/conv2d_int8.py @@ -137,11 +137,23 @@ def compute_conv2d_NHWC_quantized_without_transform(cfg, data, B, strides, paddi def schedule_conv2d_NHWC_quantized(cfg, outs): """Create schedule for tensors""" s = te.create_schedule([x.op for x in outs]) + # Vectorize the output and then inline all the rest + out = outs[0] + n, h, w, c = out.op.axis + outer, inner = s[out].split(c, 4) + s[out].vectorize(inner) def _callback(op): """Traverse operators from computation graph""" if op.name == "conv2d_gemm_output": - schedule_conv2d_gemm(cfg, s, op.output(0)) + conv_out = op.output(0) + schedule_conv2d_gemm(cfg, s, conv_out, out) + if out != conv_out: + s[conv_out].compute_at(s[out], inner) + else: + C = conv_out.op.input_tensors[0] + s[C].compute_at(s[out], inner) + traverse_inline(s, outs[0].op, _callback) return s diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py index 966520088bc7..3e3c73d26553 100644 --- a/topi/python/topi/arm_cpu/injective.py +++ b/topi/python/topi/arm_cpu/injective.py @@ -62,9 +62,10 @@ def schedule_injective(outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) x = outs[0] + if list(s[x].op.axis): # do not vectorize for broadcast - (io, ii) = s[x].split(list(s[x].op.axis)[-1], 8) + (io, ii) = s[x].split(list(s[x].op.axis)[-1], 4) s[x].vectorize(ii) tvm.te.schedule.AutoInlineInjective(s) diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index dfa2f05e7960..270bfbe87766 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -451,3 +451,55 @@ def _instr(index): return te.decl_tensor_intrin( C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, default_buffer_params=buffer_params) + +def _q_multiply_shift_arm(op): + """ + Implementation of q_multiply_shift_arm through arm intrinsics + sqrdmulh and srshl when q == 31. + + Please note that this is introducing a small round-up error for + some corner cases. This is because we are rounding twice instead + than only once. I.e.: + + * original q_multiply_shift: round(x*y*2^-s) + * arm q_multiply_shift: round(round(x*y)*2^-s) + """ + x = op.args[0] + y = op.args[1] + q = op.args[2] + s = op.args[3] + + # Don't use this intrinsic if we don't have a int32x4 vector + # or if we are not multiplying q31 numbers + if x.dtype != "int32x4" or q.value != 31: + return op + + # Case 1, shift is negative + sqrdmulh = tvm.tir.call_llvm_intrin(op.dtype, + 'llvm.aarch64.neon.sqrdmulh', + tvm.tir.const(2, 'uint32'), + x, + y) + + fixup = (sqrdmulh & (-s)) >> 31 + fixed_up_x = (sqrdmulh + fixup) + out_1 = tvm.tir.call_llvm_intrin(op.dtype, + 'llvm.aarch64.neon.srshl', + tvm.tir.const(2, 'uint32'), + sqrdmulh, + s) + + # Case 2, shift is positive + x = x * (1 << (s)) + out_2 = tvm.tir.call_llvm_intrin(op.dtype, + 'llvm.aarch64.neon.sqrdmulh', + tvm.tir.const(2, 'uint32'), + x, + y) + + # Select depending on the shift + return tvm.tir.Select(s < 0, out_1, out_2) + +tvm.target.intrin.register_intrin_rule("llvm.aarch64", + "q_multiply_shift", + _q_multiply_shift_arm, override=True) diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index b4228a4a9178..046b10342c0b 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -612,6 +612,33 @@ def _compute(*indices): return tvm.te.max(tvm.te.min(value, const_max), const_min) return te.compute(x.shape, _compute) +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def fixed_point_multiply(x, multiplier, shift): + """Fixed point multiplication between data and a fixed point + constant expressed as multiplier * 2^(-shift), where multiplier + is a Q-number with 31 fractional bits + + Parameters + ---------- + x : tvm.te.Tensor or Expr + Input argument. + multiplier : int + Multiplier of a fixed floating point number described as multiplier*2^(-shift). + shift : int + Shift of a fixed floating point number described as multiplier*2^(-shift). + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + def _compute(*indices): + value = x(*indices) + return tvm.tir.q_multiply_shift(value, + tvm.tir.const(multiplier, 'int32'), + tvm.tir.const(31, 'int32'), + tvm.tir.const(shift, 'int32')) + return te.compute(x.shape, _compute) def cast(x, dtype): """Cast input to specified data type.