From 4ae9608c388cbc2c6379b2c7cb3840c78545f46d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 15 Dec 2022 07:49:22 +0900 Subject: [PATCH 1/3] fixed tensor core batch_matmul legalize for transpose_b = False case --- python/tvm/topi/cuda/tensorcore_alter_op.py | 28 +++++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index 0ba428014548..117d861a9f7c 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -48,14 +48,22 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): x_tensor, y_tensor = arg_types[0], arg_types[1] dtype = x_tensor.dtype + if attrs.transpose_a: + B, K, M = x_tensor.shape + else: + B, M, K = x_tensor.shape + + if attrs.transpose_b: + B, N, K = y_tensor.shape + else: + B, K, N = y_tensor.shape + # Collect the output tensor. output_tensor = arg_types[2] # Collect the input exprs. x, y = inputs - B, M, K = x_tensor.shape - B, N, K = y_tensor.shape if ( isinstance(B, tir.expr.Any) or isinstance(M, tir.expr.Any) @@ -96,9 +104,19 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): return None logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops) - x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) if dm or dk else x - y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) if dn or dk else y - out_ = relay.nn.batch_matmul(x_, y_, attrs.out_dtype) + + if attrs.transpose_a: + x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dk), (0, dm))) if dm or dk else x + else: + x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) if dm or dk else x + + if attrs.transpose_b: + y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) if dn or dk else y + else: + y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dk), (0, dn))) if dn or dk else y + + out_ = relay.nn.batch_matmul(x_, y_, **attrs) + out = ( relay.strided_slice(out_, begin=[0, 0, 0], end=[x.value for x in output_tensor.shape]) if dm or dn From 658df0259b6375f01ec69c0e0b6bb850f4523535 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 15 Dec 2022 08:43:21 +0900 Subject: [PATCH 2/3] add test --- .../relay/test_pass_legalize_tensorcore.py | 43 ++++++++++++++++--- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py index 0e3c171d87da..c9782aec1b2c 100644 --- a/tests/python/relay/test_pass_legalize_tensorcore.py +++ b/tests/python/relay/test_pass_legalize_tensorcore.py @@ -277,17 +277,27 @@ def expected(): @tvm.testing.uses_gpu def test_legalize_batch_matmul(): - def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): + def _test_legalize_batch_matmul( + data_shape, kernel_shape, pad_shape, dtype, do_pad=True, transpose_a=False, transpose_b=True + ): """test legalize dense to enable tensorcore""" - B, M, _ = data_shape - _, N, _ = kernel_shape + if transpose_a: + B, _, M = data_shape + else: + B, M, _ = data_shape + + if transpose_b: + _, N, _ = kernel_shape + else: + _, _, N = kernel_shape + out_shape = (B, M, N) dm, dk, dn = pad_shape def before(): x = relay.var("x", shape=data_shape, dtype=dtype) weight = relay.var("weight", shape=kernel_shape, dtype=dtype) - y = relay.nn.batch_matmul(x, weight) + y = relay.nn.batch_matmul(x, weight, transpose_a=transpose_a, transpose_b=transpose_b) y = relay.Function([x, weight], y) return y @@ -298,19 +308,31 @@ def legalize_batch_matmul(attrs, inputs, types): def expected(): if not do_pad: return before() + x = relay.var("x", shape=data_shape, dtype=dtype) + weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) + if dm or dk: - x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) + if transpose_a: + x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dk), (0, dm))) + else: + x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) else: x_pad = x - weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) + if dn or dk: - weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk))) + if transpose_b: + weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk))) + else: + weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dk), (0, dn))) else: weight_pad = weight + y_pad = relay.nn.batch_matmul( x_pad, weight_pad, + transpose_a=transpose_a, + transpose_b=transpose_b, ) if dm or dn: y = relay.strided_slice(y_pad, begin=[0, 0, 0], end=out_shape) @@ -343,6 +365,13 @@ def expected(): _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 16, 0), "int4") _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), "int4", False) + _test_legalize_batch_matmul( + (16, 8, 16), (16, 16, 32), (0, 0, 0), "float16", False, transpose_b=False + ) + _test_legalize_batch_matmul( + (16, 16, 8), (16, 32, 16), (0, 0, 0), "float16", False, transpose_a=True + ) + if __name__ == "__main__": test_legalize_conv2d_NHWC() From 2e74f44b159794484f471373bcbe509ecf67c671 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 15 Dec 2022 08:45:55 +0900 Subject: [PATCH 3/3] clean up --- python/tvm/topi/cuda/tensorcore_alter_op.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index 117d861a9f7c..dbbf9e74903c 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -106,14 +106,18 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops) if attrs.transpose_a: - x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dk), (0, dm))) if dm or dk else x + pad_width = ((0, 0), (0, dk), (0, dm)) else: - x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) if dm or dk else x + pad_width = ((0, 0), (0, dm), (0, dk)) + + x_ = relay.nn.pad(x, pad_width=pad_width) if dm or dk else x if attrs.transpose_b: - y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) if dn or dk else y + pad_width = ((0, 0), (0, dn), (0, dk)) else: - y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dk), (0, dn))) if dn or dk else y + pad_width = ((0, 0), (0, dk), (0, dn)) + + y_ = relay.nn.pad(y, pad_width=pad_width) if dn or dk else y out_ = relay.nn.batch_matmul(x_, y_, **attrs)