Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,23 @@ struct FixedPointMultiplyAttrs : public tvm::AttrsNode<FixedPointMultiplyAttrs>
}
};

/*! \brief Attributes for per channel/per axes FixedPointMultiply operator */
struct FixedPointMultiplyPerAxisAttrs : public tvm::AttrsNode<FixedPointMultiplyPerAxisAttrs> {
bool is_lshift_required;
bool is_rshift_required;
Array<Integer> axes;

TVM_DECLARE_ATTRS(FixedPointMultiplyPerAxisAttrs, "relay.attrs.FixedPointMultiplyPerAxisAttrs") {
TVM_ATTR_FIELD(is_lshift_required)
.describe("Whether left shift is required in fixed point multiplication.")
.set_default(false);
TVM_ATTR_FIELD(is_rshift_required)
.describe("Whether right shift is required in fixed point multiplication.")
.set_default(false);
TVM_ATTR_FIELD(axes).describe("List of axes on which input data was quantized.");
}
};

/*! \brief Attributes for LayoutTransform operator */
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout;
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,19 @@ def fixed_point_multiply_compute(attrs, inputs, output_type):

register_injective_schedule("fixed_point_multiply")

# per-channel/per-axis fixed point multiply
@register_compute("fixed_point_multiply_per_axis")
def fixed_point_multiply_per_axis_compute(attrs, inputs, output_type):
assert len(inputs) == 4
return [
topi.fixed_point_multiply_per_axis(
*inputs, attrs.is_lshift_required, attrs.is_rshift_required, attrs.axes
)
]


register_broadcast_schedule("fixed_point_multiply_per_axis")

# full
@script
def _full_shape_func(shape):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from .op import likely, isnan, isnullptr, isfinite, isinf, copysign
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift, shift_left, shift_right
from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
from .generic import add, subtract, multiply

Expand Down
54 changes: 48 additions & 6 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
import tvm._ffi
from tvm.ir.base import Span
from tvm.runtime import convert, const
from tvm.ir import Array, Op
from tvm.ir import Array, Op, PrimExpr

from .buffer import Buffer
from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer
from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer, IntImm
from . import _ffi_api


Expand Down Expand Up @@ -263,8 +263,6 @@ def call_llvm_intrin(dtype, name, *args, span=None):
# pylint: disable=import-outside-toplevel
from tvm.target import codegen

from .expr import IntImm

if isinstance(name, str):
llvm_id = codegen.llvm_lookup_intrinsic_id(name)
elif isinstance(name, IntImm):
Expand Down Expand Up @@ -307,8 +305,6 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
# pylint: disable=import-outside-toplevel
from tvm.target import codegen

from .expr import IntImm

if isinstance(name, str):
llvm_id = codegen.llvm_lookup_intrinsic_id(name)
elif isinstance(name, IntImm):
Expand Down Expand Up @@ -2238,6 +2234,52 @@ def q_multiply_shift(x, y, q, s):
return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s)


def q_multiply_shift_per_axis(
x: PrimExpr,
y: PrimExpr,
ls: PrimExpr,
rs: PrimExpr,
q: IntImm,
is_lshift_required: IntImm,
is_rshift_required: IntImm,
):
"""Execute a multiplication between two Q-numbers x and y

Parameters
----------
x : PrimExpr
First Q-number.
y : PrimExpr
Second Q-number.
ls : PrimExpr
Integer left shift.
rs : PrimExpr
Integer right shift.
q : IntImm
Number of fractional bits in x and y. Needs to be > 0.
is_lshift_required : IntImm
Whether we need to do left shift or not.
is_rshift_required : IntImm
Whether we need to do right shift or not.

Returns
-------
z : PrimExpr
The result.
"""
return call_intrin(
"int32",
"tir.q_multiply_shift_per_axis",
x,
y,
ls,
rs,
q,
is_lshift_required,
is_rshift_required,
)


def shift_left(x, y, span=None):
"""Return the result of x left shifted by y bits.

Expand Down
84 changes: 76 additions & 8 deletions python/tvm/topi/hexagon/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@
def _q_multiply_shift_hexagon(op):
"""
Implementation of q_multiply_shift through hexagon intrinsics vmpyewuh and vmpyowh when q == 31.

Please note that this is introducing a small round-up error for some corner cases with negative
shift argument. This is because we are rounding twice instead than only once. I.e.:

* original q_multiply_shift: round(x*y*2^-s)
* hexagon q_multiply_shift: round(round(x*y)*2^-s)
"""
x = op.args[0]
y = op.args[1]
Expand All @@ -47,9 +41,9 @@ def _q_multiply_shift_hexagon(op):
op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y
)
mul_o_1 = tvm.tir.call_llvm_intrin(
op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_1, x, y
op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_1, x, y
)
fixup = mul_o_1 & (-shift)
fixup = 1 << (-shift - 1)
round_mul = mul_o_1 + fixup
out_negative_shift = tvm.tir.call_llvm_intrin(
op.dtype, "llvm.hexagon.V6.vaslwv.128B", tvm.tir.const(2, "uint32"), round_mul, shift
Expand All @@ -73,6 +67,80 @@ def _q_multiply_shift_hexagon(op):
)


def _q_multiply_shift_per_axis_hexagon(op):
"""
Implementation of q_multiply_shift_per_axis through hexagon intrinsics vmpyewuh and vmpyowh when
q == 31.
"""
x = op.args[0]
y = op.args[1]
left_shift = op.args[2]
right_shift = op.args[3]
fractional_bits = op.args[4]
is_lshift_required = op.args[5]
is_rshift_required = op.args[6]

# Don't use this intrinsic if we don't have a int32x32 vector
# or if we are not multiplying q31 numbers
if x.dtype != "int32x32" or fractional_bits.value != 31:
return op

# Don't use this intrinsic when we need do both: left and right shifts.
# For now it is not clear how to implement this case through vector HVX instructions without
# accuracy drop.
if is_rshift_required.value and is_lshift_required.value:
return op

# Case 1: do the left shift
shifted_x = x << left_shift
mul_e_1 = tvm.tir.call_llvm_intrin(
op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), shifted_x, y
)
left_shift_out = tvm.tir.call_llvm_intrin(
op.dtype,
"llvm.hexagon.V6.vmpyowh.rnd.sacc.128B",
tvm.tir.const(3, "uint32"),
mul_e_1,
shifted_x,
y,
)

# Case 2: do the right shift
mul_e_2 = tvm.tir.call_llvm_intrin(
op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y
)
mul_o_2 = tvm.tir.call_llvm_intrin(
op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_2, x, y
)
fixup = 1 << (right_shift - 1)
round_mul = mul_o_2 + fixup
right_shift_out = tvm.tir.call_llvm_intrin(
op.dtype, "llvm.hexagon.V6.vasrwv.128B", tvm.tir.const(2, "uint32"), round_mul, right_shift
)

# Case 3: do neither right nor left shift
mul_e_3 = tvm.tir.call_llvm_intrin(
op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y
)
no_shift_out = tvm.tir.call_llvm_intrin(
op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_3, x, y
)

return tvm.tir.Select(
tvm.tir.Not(tvm.tir.Or(is_lshift_required, is_rshift_required)),
no_shift_out,
tvm.tir.Select(is_lshift_required, left_shift_out, right_shift_out),
)


register_intrin_lowering(
"tir.q_multiply_shift_per_axis",
target="hexagon",
f=_q_multiply_shift_per_axis_hexagon,
level=99,
)


def dot_vrmpy(x_ty, y_ty):
"""Generates vrmpy instruciton for tensorization."""
int32_lanes = 32
Expand Down
58 changes: 58 additions & 0 deletions python/tvm/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm import te
from . import tag
from . import cpp
from .utils import get_const_tuple


@tvm.te.tag_scope(tag=tag.ELEMWISE)
Expand Down Expand Up @@ -672,6 +673,63 @@ def _compute(*indices):
return te.compute(x.shape, _compute)


@tvm.te.tag_scope(tag=tag.BROADCAST)
def fixed_point_multiply_per_axis(
x: te.Tensor,
y: te.Tensor,
lshift: te.Tensor,
rshift: te.Tensor,
is_lshift_required: int,
is_rshift_required: int,
axes,
):
"""Fixed point multiplication between data and a fixed point constant expressed as
multiplier * 2^(-shift), where multiplier is a Q-number with 31 fractional bits

Parameters
----------
x : tvm.te.Tensor
Input argument.
y : tvm.te.Tensor
Multiplier of a fixed floating point number described as multiplier*2^(-shift).
lshift : tvm.te.Tensor
Left shifts of a fixed floating point number described as multiplier*2^(-shift).
rshift : tvm.te.Tensor
Right shifts of a fixed floating point number described as multiplier*2^(-shift).
is_lshift_required : int
Whether we need to do left shift or not.
is_rshift_required : int
Whether we need to do right shift or not.

Returns
-------
z : tvm.te.Tensor
The result.
"""

def _compute(*indices):
elements = []
for element in get_const_tuple(axes):
elements += [indices[element]]
param_indices = tuple(elements)

value = x(*indices)
m = y(*param_indices)
l_shift = lshift(*param_indices)
r_shift = rshift(*param_indices)
return tvm.tir.q_multiply_shift_per_axis(
value,
m,
l_shift,
r_shift,
tvm.tir.const(31, "int32"),
tvm.tir.const(is_lshift_required, "bool"),
tvm.tir.const(is_rshift_required, "bool"),
)

return te.compute(x.shape, _compute)


def cast(x, dtype, span=None):
"""Cast input to specified data type.

Expand Down
4 changes: 4 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype, bool transpose_a, b

Expr MakeExpandDims(Expr data, int axis, int num_newaxis);

Expr MakeFixedPointMultiplyPerAxis(Expr x, Expr m, Expr lshift, Expr rshift,
bool is_lshift_required, bool is_rshift_required,
Array<Integer> axis);

Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType dtype);

Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout);
Expand Down
Loading