Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,3 +835,34 @@ def lstm_cell(
outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]

return outputs_list, hidden_state, cell_state


def ensure_scalar_shape(x):
"""
Assume that `x` is a tensor with one element (regardless of tensor rank).
Return a version of that tensor with rank 0.
"""
x_shape = infer_shape(x)
x_rank = len(x_shape)

if x_rank == 0:
return x

num_elem = np.prod(x_shape)
assert num_elem == 1, "Cannot squeeze tensor shape {} to scalar form.".format(x_shape)

return _op.squeeze(x)


def try_resolve_var_to_const(x, graph_params):
"""
Try to resolve the value of tensor `x` to a specific value.
If successful, return a Const op with that value.
If unsuccessful, simply return `x`.
"""
if isinstance(x, _expr.Var) and x.name_hint in graph_params:
value = graph_params[x.name_hint].numpy()
dtype = infer_type(x).checked_type.dtype
return _op.const(value, dtype)

return x
153 changes: 153 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .common import (
AttrCvt,
Renamer,
ensure_scalar_shape,
fold_constant,
get_name,
get_relay_op,
Expand All @@ -50,6 +51,7 @@
infer_value,
lstm_cell,
new_var,
try_resolve_var_to_const,
unbind,
)

Expand Down Expand Up @@ -3506,6 +3508,156 @@ def _impl_v10(cls, inputs, attr, params):
return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype)


class QLinearMatMul(OnnxOpConverter):
"""
Operator converter for QLinearMatMul from Microsoft onnxruntime contrib opset.

Limitations:
- Only supports 2D input tensors.
- Not guaranteed to meet the integer-overflow behavior stipulated in the
ONNX documentation for this operator.
"""

@classmethod
def _impl_v10(cls, inputs, attr, params):

# Some of the ops used below take scalar-like inputs, and may require either
# of the following:
#
# - the input is Const node (not merely an expression that *could* be reduced
# to a single Const at graph-compilation time)
#
# - the input has a specific dtype
#
# This function attempts to present 'x' in a form that meets both of those
# requirements.
def try_resolve_to_const_scalar(x, dtype_override=None):
x2 = try_resolve_var_to_const(x, params)
x3 = ensure_scalar_shape(x2)

x_dtype = infer_type(x).checked_type.dtype
if (dtype_override is not None) and (dtype_override != x_dtype):
x4 = _op.cast(x3, dtype_override)
else:
x4 = x3

x5 = fold_constant(x4)
return x5

# Unpack the inputs and obtain some type info...
a, a_scale, a_zp, b, b_scale, b_zp, y_scale, y_zp = inputs

a_type = infer_type(a).checked_type # 'T1' in ONNX doc for this op
a_scale_type = infer_type(a_scale).checked_type
a_zp_type = infer_type(a_zp).checked_type

b_type = infer_type(b).checked_type # 'T2' in ONNX doc for this op
b_scale_type = infer_type(b_scale).checked_type
b_zp_type = infer_type(b_zp).checked_type

y_scale_type = infer_type(y_scale).checked_type
y_zp_type = infer_type(y_zp).checked_type # 'T3' in ONNX doc for this op

a_shape = infer_shape(a)
b_shape = infer_shape(b)

# Verify type assumptions, based on the ONNX doc for this op...
assert a_type.dtype in ["int8", "uint8"]
assert a_scale_type.dtype == "float32"
assert a_zp_type.dtype == a_type.dtype

assert b_type.dtype in ["int8", "uint8"]
assert b_scale_type.dtype == "float32"
assert b_zp_type.dtype == b_type.dtype

assert y_scale_type.dtype == "float32"
assert y_zp_type.dtype in ["int8", "uint8"]

# TODO: relax this limitation in a future version of this importer.
a_rank = len(a_shape)
b_rank = len(b_shape)
assert (a_rank == 2) and (b_rank == 2), (
"QLinearMatMul importer currently requires both 'a' and 'b' tensors to be 2D, but"
" rank(a)={}, rank(b)={}".format(a_rank, b_rank)
)

# _qnn.op.dense requires the zero-point values to have dtype int32.
a_scale_scalar = try_resolve_to_const_scalar(a_scale)
a_zp_scalar = try_resolve_to_const_scalar(a_zp, "int32")

b_scale_scalar = try_resolve_to_const_scalar(b_scale)
b_zp_scalar = try_resolve_to_const_scalar(b_zp, "int32")

y_scale_scalar = try_resolve_to_const_scalar(y_scale)
y_zp_scalar = try_resolve_to_const_scalar(y_zp, "int32")

# TODO: Confirm that we're using 'num_hidden_units' correctly / as intended with
# the '_qnn.op.dense' instance below.
num_hidden_units = infer_shape(b)[-1]

# - Specify the matmul result dtype as int32, so that hopefully the matmul will use
# a 32-bit accumulator as seems to be required by the ONNX op's documentation.
#
# TL;DR:
# The ONNX documentation for this op is clear about acceptable overflow
# behavior during the matmul operation:
# - The scalar multiplication ops MAY NOT overflow.
# - The scalar addition ops, which sum the results of the scalar multiplication,
# MAY overflow, but if they do so, it must behave as one would expect during
# 32-bit integer-addition overflow.
# As of this writing, Relay's qnn.op.dense operator doesn't expose a way for us to
# express these constraints.
#
# TODO: Extend TVM / Relay / TIR / etc. to allow this kind of constraint to be
# expressed in a Relay graph. And then update this importer and various TVM
# backends accordingly.
matmul_result_dtype = "int32"

matmul_result = _qnn.op.dense(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a TODO looking into refactoring this with regular MatMul code which handles broadcasting and the 3D cases already?

Feel free to assign it to you or me.

E.g.
TODO (AndrewZhaoLuo): ... or TODO(cconvey): ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking that this should probably be part of a larger discussion about what collection of matmul-related ops are provided by Relay / Relax. If you don't mind I'll bring the topic up in a different forum.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the Matmul code in onnx, which handles numpy style broadcasting. That one is just a frontend change on how we generate relay from frontend code.

a,
_op.transpose(b),
a_zp_scalar,
b_zp_scalar,
a_scale_scalar,
b_scale_scalar,
num_hidden_units,
matmul_result_dtype,
)

# This information might only be found in the C++ code-comments for the
# dense.matmul op, but the quantized tensor returned by _qnn.op.dense
# has scale==(a_scale_scalar * b_scale_scalar), and zero_point==0.
#
# 'matmul_result_zp_scalar' has type 'int32' to satisfy input requirements
# of the [de/re]quantize ops below.
matmul_result_scale_scalar = fold_constant(_op.multiply(a_scale_scalar, b_scale_scalar))
matmul_result_zp_scalar = _op.const(0, dtype="int32")

# requantize requires y_scale to be constant,
# if y_scale is not constant, doing dequantize -> quantize
if isinstance(y_scale_scalar, _expr.Constant):
y = _qnn.op.requantize(
matmul_result,
matmul_result_scale_scalar,
matmul_result_zp_scalar,
y_scale_scalar,
y_zp_scalar,
axis=-1,
rounding="TONEAREST",
out_dtype=y_zp_type.dtype,
)
else:
matmul_result_deq = _qnn.op.dequantize(
matmul_result, matmul_result_scale_scalar, matmul_result_zp_scalar, axis=0
)

y = _qnn.op.quantize(
matmul_result_deq, y_scale_scalar, y_zp_scalar, axis=0, out_dtype=y_zp_type.dtype
)

return y


class QLinearMul(OnnxOpConverter):
"""Operator converter for QLinearMul from Microsoft onnxruntime contrib opset."""

Expand Down Expand Up @@ -4234,6 +4386,7 @@ def _get_convert_map(opset):
"QLinearConv": QLinearConv.get_converter(opset),
"QLinearConcat": QLinearConcat.get_converter(opset),
"QLinearAdd": QLinearAdd.get_converter(opset),
"QLinearMatMul": QLinearMatMul.get_converter(opset),
"QLinearMul": QLinearMul.get_converter(opset),
"QLinearSigmoid": QLinearSigmoid.get_converter(opset),
"ConvInteger": ConvInteger.get_converter(opset),
Expand Down
1 change: 0 additions & 1 deletion tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4941,7 +4941,6 @@ def verify_eyelike(indata):
"test_mvn",
# This test fails llvm with a lowering error:
"test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded",
"test_qlinearmatmul_2D",
"test_qlinearmatmul_3D",
"test_range_float_type_positive_delta_expanded",
"test_range_int32_type_negative_delta_expanded",
Expand Down