Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,19 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
}
};

/*! \brief Attributes for FixedPointMultiply operator */
struct FixedPointMultiplyAttrs : public tvm::AttrsNode<FixedPointMultiplyAttrs> {
int32_t multiplier;
int32_t shift;

TVM_DECLARE_ATTRS(FixedPointMultiplyAttrs, "relay.attrs.FixedPointMultiplyAttrs") {
TVM_ATTR_FIELD(multiplier)
.describe("Multiplier of a fixed floating point number described as multiplier*2^(shift)");
TVM_ATTR_FIELD(shift).describe(
"Shift of a fixed floating point number described as multiplier*2^(shift)");
}
};

/*! \brief Attributes for LayoutTransform operator */
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout;
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ TVM_DLL const Op& shift_right();
*/
TVM_DLL const Op& large_uint_imm();

/*!
* \brief Execute a multiplication between two Q-numbers x and y
* followed by a right shift s
* The default rounding rule is to the nearest value, rounding half up
* (i.e., round(x.1) = x and round (x.5) = x+1)
*/
TVM_DLL const Op& q_multiply_shift();

/*!
* \brief See pesudo code
*
Expand Down
21 changes: 21 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,27 @@ TVM_DLL PrimExpr trunc(PrimExpr x);
*/
TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);

/*!
* \brief Execute a multiplication between two Q-numbers x and y
* followed by a right shift s. The mathematical expression is:
*
* out = round(x*y*2^-s)
*
* Please note that the two Q-numbers x and y are supposed to have
* the same number of fractional bits q.
*
* More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
*
* The rounding rule is to the nearest value, rounding half up
* (i.e., round(x.1) = x and round (x.5) = x+1)
* \param x first Q-number
* \param y second Q-number
* \param q number of fractional bits in x and y. Needs to be > 0
* \param s integer right shift
* \return The constructed expression.
*/
TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s);

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x) { \
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ def clip_compute(attrs, inputs, output_type):

register_injective_schedule("clip")

# fixed point multiply
@register_compute("fixed_point_multiply")
def fixed_point_multiply_compute(attrs, inputs, output_type):
assert len(inputs) == 1
return [topi.fixed_point_multiply(inputs[0], attrs.multiplier, attrs.shift)]

register_injective_schedule("fixed_point_multiply")

# full
@script
def _full_shape_func(shape):
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,27 @@ def clip(a, a_min, a_max):
"""
return _make.clip(a, a_min, a_max)

def fixed_point_multiply(data, multiplier, shift):
"""Fixed point multiplication between data and a fixed point
constant expressed as multiplier * 2^(-shift), where multiplier
is a Q-number with 31 fractional bits

Parameters
----------
data : relay.Expr
The input tensor.
multiplier : int
The integer multiplier of the fixed point constant.
a_max : float
The integer shift of the fixed point constant.

Returns
-------
result : relay.Expr
The output of the fixed point multiplication
"""
return _make.fixed_point_multiply(data, multiplier, shift)


def concatenate(data, axis):
"""Concatenate the input tensors along the given axis.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .op import isnan, isfinite, isinf, copysign
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift

from . import ir_builder
from . import transform
Expand Down
28 changes: 28 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,34 @@ def popcount(x):
"""
return call_intrin(x.dtype, "tir.popcount", x)

def q_multiply_shift(x, y, q, s):
"""Execute a multiplication between two Q-numbers x and y
followed by a right shift s. The mathematical expression is:

out = round(x*y*2^-s)

More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
The rounding rule is to the nearest value, rounding half up
(i.e., round(x.1) = x and round (x.5) = x+1)

Parameters
----------
x : PrimExpr
First Q-number
y : PrimExpr
Second Q-number
q : PrimExpr
Number of fractional bits in x and y. Needs to be > 0
s : PrimExpr
Integer shift

Returns
-------
y : PrimExpr
The result.
"""
return call_intrin('int32', "tir.q_multiply_shift", x, y, q, s)

def fmod(x, y):
"""Return the remainder of x divided by y with the same sign as x.

Expand Down
23 changes: 23 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,29 @@ This function takes a tensor, a minimum value `a_min`, and a maximum value `a_ma
.set_attrs_type<ClipAttrs>()
.set_support_level(3);

// relay.fixed_point_multiply
TVM_REGISTER_NODE_TYPE(FixedPointMultiplyAttrs);

TVM_REGISTER_GLOBAL("relay.op._make.fixed_point_multiply")
.set_body_typed([](Expr a, int32_t multiplier, int32_t shift) {
auto attrs = make_object<FixedPointMultiplyAttrs>();
attrs->multiplier = multiplier;
attrs->shift = shift;
static const Op& op = Op::Get("fixed_point_multiply");
return Call(op, {a}, Attrs(attrs), {});
});

RELAY_REGISTER_OP("fixed_point_multiply")
.describe(R"code(fixed point multiplication)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attrs_type<FixedPointMultiplyAttrs>()
.set_support_level(10);

RELAY_REGISTER_UNARY_OP("floor")
.describe(R"code(Returns the floor of input array, computed element-wise.
)code" TVM_ADD_FILELINE)
Expand Down
12 changes: 11 additions & 1 deletion src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,19 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
// Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);

const bool is_upward_rounding = (param->rounding == "UPWARD");

// When using upward rounding (i.e., x.5 rounded to x+1), leverage
// the FixedPointMultiply operator
scaled_int32_t =
FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding);
(is_upward_rounding
? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
: FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
}

} else {
// This is per-channel (per=axis) quantization.
std::vector<double> double_multipliers;
Expand Down
43 changes: 10 additions & 33 deletions src/relay/qnn/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,6 @@ namespace tvm {
namespace relay {
namespace qnn {

/*
* \brief Convert FP32 representation into fixed point representation.
* \param double_multplier The input FP32 number.
* \return The pair of multiplier and shift for fixed point representation.
* \note Converts a floating point number so that it can be represented by
* integers. The representation is
* float_number = (significand) * 2^(exponent)
*
* The significand is a number between 0.5 and 1. This is represented by
* an integer number. For example, if it is int32, then the decimal point
* exists between bit 31 and 30 from LSB (or between first and second bit
* from the left).
*
* Some examples are
* 0.25 = (0.5) * 2^(-1)
* 0.125 = (0.5) * 2^(-2)
*
* Credit to TFLite reference implementation.
*/
std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier) {
int32_t significand, exponent;
if (double_multiplier == 0.) {
Expand All @@ -75,8 +56,8 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
return std::make_pair(significand, exponent);
}

Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& input_shape,
const std::string& rounding) {
Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape) {
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
DataType hp_dtype = DataType::Int(64);
Expand Down Expand Up @@ -109,19 +90,15 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&
int64_t pos_rounding_value = (1ll << (total_right_shift - 1));

Expr round_scalar;
if (rounding == "UPWARD") {
round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
} else if (rounding == "TONEAREST") {
auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);

auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
} else {
LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
}
auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);

auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);

// Add the rounding scalar.
tensor = Add(tensor, round_scalar);

Expand Down
32 changes: 26 additions & 6 deletions src/relay/qnn/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,27 @@ static inline int32_t GetQmax(const DataType& dtype) {
}
}

/*
* \brief Convert FP32 representation into fixed point representation.
* \param double_multplier The input FP32 number.
* \return The pair of multiplier and shift for fixed point representation.
* \note Converts a floating point number so that it can be represented by
* integers. The representation is
* float_number = (significand) * 2^(exponent)
*
* The significand is a number between 0.5 and 1. This is represented by
* an integer number. For example, if it is int32, then the decimal point
* exists between bit 31 and 30 from LSB (or between first and second bit
* from the left).
*
* Some examples are
* 0.25 = (0.5) * 2^(-1)
* 0.125 = (0.5) * 2^(-2)
*
* Credit to TFLite reference implementation.
*/
std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier);

Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
Expand All @@ -94,13 +115,12 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) {

/*
* \brief Fixed point multiplication between integer tensor with floating point
scalar.
* scalar. This implementation rounds to the nearest value when it is midway
* between two representable values.
* \param tensor The quantized input tensor of dtype int64.
* \param multiplier The scalar multiplier.
* \param input_shape Shape of the input tensor.
* \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value
is midway between" "two representable values.
* \return The sequence of Relay ops for fixed point multiplication.
* \return The sequence of Relay ops for fixed point multiplication with TONEARES rounding.

* \note Original compuation is scale_fp32 * quantized_tensor. To convert into
* integer computation, the multiplication with fp32 scalar can be
Expand All @@ -114,8 +134,8 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) {
* 2) Round the result.
* 3) Right shift the result
*/
Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>& input_shape,
const std::string& rounding);
Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape);

/*
* \brief Fixed point multiplication between integer tensor with floating point
Expand Down
20 changes: 17 additions & 3 deletions src/relay/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,14 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
data = qnn::FixedPointMultiply(data, factor, data_shape, cfg->rounding);
if (cfg->rounding == "UPWARD") {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = qnn::GetFixedPointMultiplierShift(factor);
data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
} else {
data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape);
}

return Cast(data, dtype);
}
}
Expand Down Expand Up @@ -164,8 +171,15 @@ Expr QuantizeRealize(const Call& ref_call, const Array<Expr>& new_args, const Ob
return QRealizeIntExpr(data, dom_scale, n->dtype);
} else {
data = Cast(data, DataType::Int(64));
data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm,
ref_call->type_as<TensorTypeNode>()->shape, cfg->rounding);
if (cfg->rounding == "UPWARD") {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) =
qnn::GetFixedPointMultiplierShift(idom_scale_imm / odom_scale_imm);
data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
} else {
data = qnn::FixedPointMultiplyToNearest(data, idom_scale_imm / odom_scale_imm,
ref_call->type_as<TensorTypeNode>()->shape);
}
data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
return QRealizeIntExpr(data, dom_scale, n->dtype);
}
Expand Down
8 changes: 8 additions & 0 deletions src/relay/transforms/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,14 @@ inline Expr Round(Expr x) {

inline Expr Clip(Expr x, double a_min, double a_max) { return MakeClip(x, a_min, a_max); }

inline Expr FixedPointMultiply(Expr x, int32_t multiplier, int32_t shift) {
static const Op& op = Op::Get("fixed_point_multiply");
auto attrs = make_object<FixedPointMultiplyAttrs>();
attrs->multiplier = multiplier;
attrs->shift = shift;
return Call(op, {x}, Attrs(attrs), {});
}

inline Expr Add(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add");
return Call(op, {lhs, rhs}, Attrs(), {});
Expand Down
Loading