Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 90 additions & 74 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,77 @@ def get_scalar(x, params, dtype="float32"):
return _op.cast(x, dtype)


def matmul_out_dtype(inputs, out_dtype):
"""Common function to handle MatMul and MatMulInteger16"""
a_shape = shape_of(inputs[0])
a_rank = infer_shape(a_shape)[0]
b_shape = shape_of(inputs[1])
b_rank = infer_shape(b_shape)[0]
if a_rank > 2 or b_rank > 2:

def flatten_to_nd(x, x_shape, nd=3):
ndims = infer_shape(x_shape)[0]
if ndims == nd:
return x
newshape = _op.concatenate(
[
_expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),
_op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
],
0,
)
out = _op.reshape(x, fold_constant(newshape))
return out

b_type = infer_type(inputs[1])
# Convert to dense if the second matrix is 2d and non-dynamic
if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
a = flatten_to_nd(inputs[0], a_shape, 2)
b = _op.transpose(inputs[1])
output = _op.nn.dense(a, b, out_dtype=out_dtype)
else:
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(inputs[0], a_shape, 3)
b = flatten_to_nd(inputs[1], b_shape, 3)
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
elif a_rank < b_rank:
out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
# If its unclear how broadcasting should be applied, the output
# shape is determined by choosing the maximum value from each input.
else:
out_batch = _op.concatenate(
[
_op.maximum(
_op.strided_slice(a_shape, [i], [i + 1]),
_op.strided_slice(b_shape, [i], [i + 1]),
)
for i in range(a_rank - 2)
],
0,
)
# Reshape output to original dimensions.
final_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(
a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1]
),
_op.strided_slice(
b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]]
),
],
0,
)
return _op.reshape(output, fold_constant(final_shape))
# Otherwise a simple dense op will get the job done.
input_1_t = _op.transpose(inputs[1], axes=(1, 0))
return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype)


class OnnxOpConverter(object):
"""A helper class for holding onnx op converters."""

Expand Down Expand Up @@ -735,80 +806,24 @@ class MatMul(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
# Need to check input shape as batch matmul must be supported.
a_shape = shape_of(inputs[0])
a_rank = infer_shape(a_shape)[0]
b_shape = shape_of(inputs[1])
b_rank = infer_shape(b_shape)[0]
# When performing a batch matmul, we need to properly handle N-dim shapes.
if a_rank > 2 or b_rank > 2:

def flatten_to_nd(x, x_shape, nd=3):
ndims = infer_shape(x_shape)[0]
if ndims == nd:
return x
newshape = _op.concatenate(
[
_expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),
_op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
],
0,
)
out = _op.reshape(x, fold_constant(newshape))
return out

b_type = infer_type(inputs[1])
# Convert to dense if the second matrix is 2d and non-dynamic
if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
a = flatten_to_nd(inputs[0], a_shape, 2)
b = _op.transpose(inputs[1])
output = _op.nn.dense(a, b)
else:
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(inputs[0], a_shape, 3)
b = flatten_to_nd(inputs[1], b_shape, 3)
if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a NT batch matmul.
output = _op.nn.batch_matmul(a, b)
else:
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, transpose_b=False)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
elif a_rank < b_rank:
out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
# If its unclear how broadcasting should be applied, the output
# shape is determined by choosing the maximum value from each input.
else:
out_batch = _op.concatenate(
[
_op.maximum(
_op.strided_slice(a_shape, [i], [i + 1]),
_op.strided_slice(b_shape, [i], [i + 1]),
)
for i in range(a_rank - 2)
],
0,
)
# Reshape output to original dimensions.
final_shape = _op.concatenate(
[
out_batch,
_op.strided_slice(
a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1]
),
_op.strided_slice(
b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]]
),
],
0,
)
return _op.reshape(output, fold_constant(final_shape))
# Otherwise a simple dense op will get the job done.
input_1_t = _op.transpose(inputs[1], axes=(1, 0))
return _op.nn.dense(inputs[0], input_1_t)
return matmul_out_dtype(inputs, out_dtype=infer_type(inputs[0]).checked_type.dtype)


class MatMulInteger16(OnnxOpConverter):
"""Operator converter for MatMulInteger16 from Microsoft onnxruntime contrib opset."""

@classmethod
def _impl_v10(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMulInteger16 op take 2 inputs, {} given".format(len(inputs))
a_dtype = infer_type(inputs[0]).checked_type.dtype
b_dtype = infer_type(inputs[1]).checked_type.dtype
# Check input data types
assert a_dtype in ("int16", "uint16"), "MatMulInteger16: invalid dtype for first input"
assert b_dtype in ("int16", "uint16"), "MatMulInteger16: invalid dtype for second input"
out_dtype = "int32"
if a_dtype == "uint16" and b_dtype == "uint16":
out_dtype = "uint32"
return matmul_out_dtype(inputs, out_dtype)


class Mod(OnnxOpConverter):
Expand Down Expand Up @@ -4298,6 +4313,7 @@ def _get_convert_map(opset):
"Softsign": Softsign.get_converter(opset),
"Gemm": Gemm.get_converter(opset),
"MatMul": MatMul.get_converter(opset),
"MatMulInteger16": MatMulInteger16.get_converter(opset),
"Mod": Mod.get_converter(opset),
"Xor": Renamer("logical_xor"),
# defs/nn
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
)
else:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_compute_batch_matmul(topi.cuda.batch_matmul, need_out_dtype=True),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10,
Expand Down
42 changes: 42 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,47 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None):
)


@tvm.testing.parametrize_targets
def test_matmulinteger16(target, dev):
def verify_matmulinteger16(a_shape, b_shape, out_shape):
a_dtype = "int16"
b_dtype = "int16"
low = np.iinfo(np.int16).min
high = np.iinfo(np.int16).max

a_proto = TensorProto.INT16
b_proto = TensorProto.INT16
out_proto = TensorProto.INT32
a_array = np.random.randint(low, high, size=a_shape).astype(a_dtype)
b_array = np.random.randint(low, high, size=b_shape).astype(b_dtype)

mul_node = helper.make_node("MatMulInteger16", ["a", "b"], ["out"], domain="com.microsoft")

graph = helper.make_graph(
[mul_node],
"matmuli16_test",
inputs=[
helper.make_tensor_value_info("a", a_proto, list(a_shape)),
helper.make_tensor_value_info("b", b_proto, list(b_shape)),
],
outputs=[helper.make_tensor_value_info("out", out_proto, list(out_shape))],
)

model = helper.make_model(graph, producer_name="matmuli16_test")
verify_with_ort_with_inputs(model, [a_array, b_array], target=target, dev=dev)

# 2D computation to verify matmul op
verify_matmulinteger16((4, 3), (3, 4), (4, 4))
verify_matmulinteger16((5, 7), (7, 8), (5, 8))
# Verify 3D matmul using batch_matmul op
verify_matmulinteger16((2, 4, 3), (1, 3, 4), (2, 4, 4))
verify_matmulinteger16((1, 4, 3), (2, 3, 4), (2, 4, 4))
# Test implicit broadcasting
verify_matmulinteger16((2, 3, 5, 3), (2, 3, 3, 5), (2, 3, 5, 5))
verify_matmulinteger16((2, 7, 3), (3, 7), (2, 7, 7))
verify_matmulinteger16((2, 3, 4, 3), (3, 4), (2, 3, 4, 4))


def verify_simple_dynamic_model(a_shape, b_shape, target, dev):
def verify_model(model, a_shape, b_shape):
a_array = np.random.uniform(size=a_shape).astype("float32")
Expand Down Expand Up @@ -5856,6 +5897,7 @@ def repeat(N, D):
test_onehot()
test_gemm()
test_matmul()
test_matmulinteger16()
test_gather()
test_gatherelements()
test_gather_nd()
Expand Down