diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 474b688e2ad8..ab0eeb091043 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3913,10 +3913,17 @@ class QLinearMatMul(OnnxOpConverter): - Only supports 2D input tensors. - Not guaranteed to meet the integer-overflow behavior stipulated in the ONNX documentation for this operator. + + The QLinearMatMul converter is re-used for MatMulInteger and is adapted for + the latter with the optional `expected_out_dtypes` argument. """ @classmethod - def _impl_v10(cls, inputs, attr, params): + def _impl_v10(cls, inputs, attr, params, expected_out_dtypes=None): + if expected_out_dtypes is None: + # The default QLinearMatMul converter is expected to have one of + # these output dtypes. + expected_out_dtypes = ["int8", "uint8"] # Some of the ops used below take scalar-like inputs, and may require either # of the following: @@ -3966,7 +3973,7 @@ def try_resolve_to_const(x, dtype_override=None): assert b_zp_type.dtype == b_type.dtype assert y_scale_type.dtype == "float32" - assert y_zp_type.dtype in ["int8", "uint8"] + assert y_zp_type.dtype in expected_out_dtypes # TODO: relax this limitation in a future version of this importer. a_rank = len(a_shape) @@ -4028,6 +4035,11 @@ def try_resolve_to_const(x, dtype_override=None): matmul_result_scale_scalar = fold_constant(_op.multiply(a_scale_scalar, b_scale_scalar)) matmul_result_zp_scalar = _op.const(0, dtype="int32") + if "int32" in expected_out_dtypes: + # This is the adaptation of the QLinearMatMul converter for MatMulInteger, + # in the MatMulInteger case we skip the unnecessary requantization step. + return matmul_result + # requantize requires y_scale to be constant, # if y_scale is not constant, doing dequantize -> quantize if isinstance(y_scale_scalar, _expr.Constant): @@ -4053,6 +4065,58 @@ def try_resolve_to_const(x, dtype_override=None): return y +class MatMulInteger(OnnxOpConverter): + """Operator converter for MatMulInteger.""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + a = inputs[0] + b = inputs[1] + + a_dtype = infer_type(a).checked_type.dtype + b_dtype = infer_type(b).checked_type.dtype + + assert a_dtype in ("int8", "uint8"), "MatMulInteger: invalid dtype for first input" + assert b_dtype in ("int8", "uint8"), "MatMulInteger: invalid dtype for second input" + + assert a_dtype == b_dtype, "MatMulInteger: input dtypes must match" + + a_scale = _op.const(1.0, dtype="float32") + b_scale = _op.const(1.0, dtype="float32") + out_scale = _op.const(1.0, dtype="float32") + + a_zero_point = _op.const(0.0, dtype=a_dtype) + b_zero_point = _op.const(0.0, dtype=b_dtype) + out_zero_point = _op.const(0.0, dtype="int32") + + if len(inputs) == 4: + a_zero_point = inputs[2] + b_zero_point = inputs[3] + + a_zp_dtype = infer_type(a_zero_point).checked_type.dtype + b_zp_dtype = infer_type(b_zero_point).checked_type.dtype + assert ( + a_zp_dtype == a_dtype and b_zp_dtype == b_dtype + ), "MatMulInteger: input dtype doesn't match zero point dtype" + elif len(inputs) != 2: + raise AssertionError( + "MatMulInteger op takes 2 or 4 inputs, {} given".format(len(inputs)) + ) + + inputs = [ + a, + a_scale, + a_zero_point, + b, + b_scale, + b_zero_point, + out_scale, + out_zero_point, + ] + + return QLinearMatMul.get_converter(10)(inputs, attr, params, expected_out_dtypes=["int32"]) + + class QLinearMul(OnnxOpConverter): """Operator converter for QLinearMul from Microsoft onnxruntime contrib opset.""" @@ -4781,6 +4845,7 @@ def _get_convert_map(opset): "Softsign": Softsign.get_converter(opset), "Gemm": Gemm.get_converter(opset), "MatMul": MatMul.get_converter(opset), + "MatMulInteger": MatMulInteger.get_converter(opset), "MatMulInteger16": MatMulInteger16.get_converter(opset), "Mod": Mod.get_converter(opset), "Xor": Renamer("logical_xor"), diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index 080ddf28b7c2..0ba428014548 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -167,8 +167,22 @@ def _dense_legalize(attrs, inputs, arg_types): return None (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N, candidates) + skip_pad = extra_flops_ratio > 2 + + if skip_pad and dtype in ["int8", "uint8"]: + skip_pad = False + # If tensorcore schedule padding fails, pad to nearest upward 4x4x4 as long as + # the additional flops ratio isn't double or more. + # Note that 4x4x4 is invalid for tensorcore scheduling, but padding upwards to 4x4x4 + # doesn't hurt if tensorcore padding has already failed. + if M % 4 == 0 and K % 4 == 0 and N % 4 == 0: + # No need to pad + return None + (dm, dk, dn) = _pad_to(M, K, N, (4, 4, 4)) + extra_flops_ratio = _extra_flops(M, K, N, dm, dk, dn) / (M * K * N) + skip_pad = extra_flops_ratio > 2 - if extra_flops_ratio > 2: + if skip_pad: logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio) return None @@ -198,7 +212,7 @@ def pad_to_tensorcore(M, K, N, candidates): best_pad = (0, 0, 0) for padding in candidates: dm, dk, dn = _pad_to(M, K, N, padding) - e = (M + dm) * (N + dn) * (K + dk) - M * N * K + e = _extra_flops(M, K, N, dm, dk, dn) # print(dm, dk, dn, e, flops) if e < extra_flops: extra_flops = e @@ -206,6 +220,10 @@ def pad_to_tensorcore(M, K, N, candidates): return best_pad, extra_flops / flops +def _extra_flops(M, K, N, dm, dk, dn): + return (M + dm) * (N + dn) * (K + dk) - M * N * K + + def _pad_to(M, K, N, PADDING): dm, dk, dn = 0, 0, 0 diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0751f4a2e293..94fd0a5de40b 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5053,7 +5053,6 @@ def verify_eyelike(indata): "test_loop11", "test_loop13_seq", "test_lstm_batchwise", - "test_matmulinteger", "test_maxpool_with_argmax_2d_precomputed_pads", "test_maxpool_with_argmax_2d_precomputed_strides", "test_maxunpool_export_with_output_shape", diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py index 97860630dea5..0e3c171d87da 100644 --- a/tests/python/relay/test_pass_legalize_tensorcore.py +++ b/tests/python/relay/test_pass_legalize_tensorcore.py @@ -249,6 +249,7 @@ def expected(): a = before() a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) # dense @@ -259,7 +260,7 @@ def expected(): _test_legalize_dense((8, 16), (31, 16), (0, 0, 1), dtype) _test_legalize_dense((7, 15), (31, 15), (1, 1, 1), dtype) _test_legalize_dense((3, 16), (32, 16), (5, 0, 0), dtype) - _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), dtype, False) + _test_legalize_dense((1, 16), (32, 16), (0, 0, 0), dtype, False) # Test if units parameter is correctly updated _test_legalize_dense((8, 16), (30, 16), (0, 0, 2), "float16", units=30)