From b5478be25d0910441cbcd8d683825ee988ffa018 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Wed, 12 May 2021 11:51:29 -0700 Subject: [PATCH 01/12] [TOPI] Custom schedule for standalone transpose in cuda --- python/tvm/relay/op/_transform.py | 4 ++- python/tvm/relay/op/strategy/cuda.py | 18 +++++++++++++ python/tvm/relay/op/strategy/generic.py | 7 ++++++ python/tvm/topi/cuda/sparse.py | 15 ++++++++--- .../python/topi/python/test_topi_transform.py | 25 +++++++++++++++++++ 5 files changed, 65 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 76adee477a1a..412acb4cea17 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -53,7 +53,6 @@ _reg.register_injective_schedule("slice_like") _reg.register_injective_schedule("split") _reg.register_injective_schedule("take") -_reg.register_injective_schedule("transpose") _reg.register_injective_schedule("stack") _reg.register_injective_schedule("contrib_reverse_reshape") _reg.register_injective_schedule("gather") @@ -746,6 +745,9 @@ def transpose_shape_func(attrs, inputs, _): return [_transpose_shape_func(inputs[0], convert(axes))] +_reg.register_schedule("transpose", strategy.schedule_transpose) + + @script def _squeeze_shape_func(data_shape, keep_axes, remove_axes): out = output_tensor((len(keep_axes),), "int64") diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index a6775ae7bd20..b9bd4ec6288c 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -23,6 +23,7 @@ from tvm.te import SpecializedCondition from .. import op as _op +from ....target import Target from .generic import * @@ -1068,3 +1069,20 @@ def unique_strategy_cuda(attrs, inputs, out_type, target): name="unique.cuda", ) return strategy + + +@schedule_transpose.register(["cuda", "gpu", "rocm"]) +def schedule_transpose_cuda(attrs, outs, target): + """ + Transpose cuda strategy + Dispatches to and optimized schedule if the transpose is standalone (not fused). + """ + warp_size = int(Target.current(allow_none=False).thread_warp_size) + if ( + isinstance(outs[0].op.input_tensors[0].op, te.PlaceholderOp) + and len(outs[0].shape) == 2 + and (attrs.axes is None or (len(attrs.axes) == 2 and attrs.axes == [1, 0])) + and outs[0].shape[1] >= warp_size + ): + return topi.cuda.schedule_transpose(outs) + return schedule_injective(attrs, outs, target) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 7451b397265f..3b54e2a54f0c 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1547,3 +1547,10 @@ def unique_strategy(attrs, inputs, out_type, target): name="unique.generic", ) return strategy + + +@generic_func +def schedule_transpose(attrs, outs, target): + """schedule transpose""" + with target: + return topi.generic.schedule_injective(outs) diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index 1e846ebf5311..8b4ba1335bc7 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -105,13 +105,22 @@ def _callback(op): return s -def schedule_cuda_transpose(s, out): +def schedule_transpose(outs): + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + schedule_transpose_existing(s, outs[0]) + return s + + +def schedule_transpose_existing(s, out): """Schedule for transpose on the gpu. Roughly follows this: https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/, but without the padding for shared memory. For better performance, we could - rewrite it in tir to add the padding. + rewrite it in tir to add the padding. Also, rewriting in tir would allow + use to use warp shuffles instead of shared memory (see + https://github.com/bryancatanzaro/trove). """ def _callback(op): @@ -388,7 +397,7 @@ def schedule_sparse_dense_padded(outs): # necessary data_t = outs[0].op.input_tensors[0] s = te.create_schedule([outs[0].op, data_t.op]) - schedule_cuda_transpose(s, outs[0].op.input_tensors[0]) + schedule_transpose_existing(s, outs[0].op.input_tensors[0]) return s diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 16f9f13f05b0..391fe50b2415 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -20,6 +20,7 @@ import tvm from tvm import te from tvm import topi +from tvm import relay import tvm.topi.testing from tvm.contrib.nvcc import have_fp16 @@ -870,6 +871,30 @@ def test_transpose(): verify_transpose((3, 10), None) +@tvm.testing.parametrize_targets +def test_transpose_schedule(target, dev): + shape = (100, 34) + x = relay.var("x", relay.TensorType(shape, "float32")) + f = relay.transpose(x) + ex = relay.create_executor( + kind="graph", mod=tvm.IRModule.from_expr(relay.Function([x], f)), device=dev, target=target + ) + r = np.random.rand(*shape) + tvm.testing.assert_allclose(ex.evaluate()(r).asnumpy(), np.transpose(r)) + + # make sure schedule does not fire here + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("y", relay.TensorType(shape, "float32")) + f = relay.transpose(x + y) + ex = relay.create_executor( + kind="graph", + mod=tvm.IRModule.from_expr(relay.Function([x, y], f)), + device=dev, + target=target, + ) + tvm.testing.assert_allclose(ex.evaluate()(r, r).asnumpy(), np.transpose(r + r)) + + @tvm.testing.uses_gpu def test_reshape(): verify_reshape((1, 2, 3, 4), (2, 3, 4)) From b57aadee57b71a52740a0e75093c82a5bf180d46 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 13 May 2021 09:53:47 -0700 Subject: [PATCH 02/12] check if input is not Any --- python/tvm/relay/op/strategy/cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index b9bd4ec6288c..8571123dd654 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -24,6 +24,7 @@ from .. import op as _op from ....target import Target +from ....tir import IntImm from .generic import * @@ -1082,6 +1083,7 @@ def schedule_transpose_cuda(attrs, outs, target): isinstance(outs[0].op.input_tensors[0].op, te.PlaceholderOp) and len(outs[0].shape) == 2 and (attrs.axes is None or (len(attrs.axes) == 2 and attrs.axes == [1, 0])) + and isinstance(outs[0].shape[1], (int, IntImm)) and outs[0].shape[1] >= warp_size ): return topi.cuda.schedule_transpose(outs) From 09a65f2d8a328609bedace7009de95e55113499a Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 13 May 2021 14:53:34 -0700 Subject: [PATCH 03/12] fix vta test --- vta/tutorials/autotvm/tune_relay_vta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vta/tutorials/autotvm/tune_relay_vta.py b/vta/tutorials/autotvm/tune_relay_vta.py index 38633b01d976..2f505b2a86a6 100644 --- a/vta/tutorials/autotvm/tune_relay_vta.py +++ b/vta/tutorials/autotvm/tune_relay_vta.py @@ -357,7 +357,7 @@ def tune_and_evaluate(tuning_opt): ) # filter out non-packed conv2d task - tasks = list(filter(lambda t: len(t.args[0][1]) > 4, tasks)) + tasks = list(filter(lambda t: len(t.args[0][1]) > 4 and "conv" in t.name, tasks)) # We should have extracted 10 convolution tasks assert len(tasks) == 10 From 95bf4304d1aaa9854775ab96c771d06101a2f2c1 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 13 May 2021 14:53:55 -0700 Subject: [PATCH 04/12] check input shape --- python/tvm/relay/op/strategy/cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 8571123dd654..6c5b1e0cdead 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1083,6 +1083,8 @@ def schedule_transpose_cuda(attrs, outs, target): isinstance(outs[0].op.input_tensors[0].op, te.PlaceholderOp) and len(outs[0].shape) == 2 and (attrs.axes is None or (len(attrs.axes) == 2 and attrs.axes == [1, 0])) + and isinstance(outs[0].shape[0], (int, IntImm)) + and outs[0].shape[0] >= warp_size and isinstance(outs[0].shape[1], (int, IntImm)) and outs[0].shape[1] >= warp_size ): From c0d051f3f381a857458c7003af61cf73de1c70a0 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 14 May 2021 09:41:22 -0700 Subject: [PATCH 05/12] fix injective --- python/tvm/relay/op/strategy/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 3b54e2a54f0c..fbfe8ce33ec6 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1553,4 +1553,4 @@ def unique_strategy(attrs, inputs, out_type, target): def schedule_transpose(attrs, outs, target): """schedule transpose""" with target: - return topi.generic.schedule_injective(outs) + return schedule_injective(attrs, outs, target) From cdd9f5e955dff682710cd2e9218ac407ea941996 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 17 May 2021 11:23:30 -0700 Subject: [PATCH 06/12] move transpose out of sparse.py --- python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/sparse.py | 48 +---------------------- python/tvm/topi/cuda/transform.py | 63 +++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 46 deletions(-) create mode 100644 python/tvm/topi/cuda/transform.py diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 4d838db8bfba..2788a884724a 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -57,4 +57,5 @@ from .argwhere import * from .scan import * from .sparse_reshape import * +from .transfrom import * from .unique import * diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index 8b4ba1335bc7..b6baa9cd67a5 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -24,6 +24,7 @@ from .. import nn from ..utils import traverse_inline, get_const_tuple, prod, get_const_int, ceil_div +from .transform import schedule_transpose_from_existing def sparse_dense(data, weight_data, weight_indices, weight_indptr, sparse_lhs=False): @@ -105,51 +106,6 @@ def _callback(op): return s -def schedule_transpose(outs): - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - schedule_transpose_existing(s, outs[0]) - return s - - -def schedule_transpose_existing(s, out): - """Schedule for transpose on the gpu. - - Roughly follows this: - https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/, but - without the padding for shared memory. For better performance, we could - rewrite it in tir to add the padding. Also, rewriting in tir would allow - use to use warp shuffles instead of shared memory (see - https://github.com/bryancatanzaro/trove). - """ - - def _callback(op): - # pylint: disable=invalid-name - m, n = s[op].op.axis - warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) - no, ni = s[op].split(n, factor=warp_size) - mo, mi = s[op].split(m, factor=warp_size) - s[op].reorder(mo, no, mi, ni) - s[op].bind(mo, te.thread_axis("blockIdx.x")) - s[op].bind(no, te.thread_axis("blockIdx.y")) - c = s.cache_read(op.input_tensors[0], "shared", op) - s[c].compute_at(s[op], no) - thread_x = te.thread_axis("threadIdx.x") - thread_y = te.thread_axis("threadIdx.y") - s[op].bind(ni, thread_x) - # This is a hack to make the scheduling language realize that this axis - # can be scheduled. - a, _ = s[c].split(s[c].op.axis[1], factor=1) - s[c].bind(a, thread_x) - # Use 4 warps per block. Slightly faster than 1 warp per block - ao, _ = s[op].split(mi, nparts=4) - s[op].bind(ao, thread_y) - ao, _ = s[c].split(s[c].op.axis[0], nparts=4) - s[c].bind(ao, thread_y) - - traverse_inline(s, out.op, _callback) - - def sparse_dense_tir(data, w_data, w_indices, w_indptr): """Compute data * w^T. @@ -397,7 +353,7 @@ def schedule_sparse_dense_padded(outs): # necessary data_t = outs[0].op.input_tensors[0] s = te.create_schedule([outs[0].op, data_t.op]) - schedule_transpose_existing(s, outs[0].op.input_tensors[0]) + schedule_transpose_from_existing(s, outs[0].op.input_tensors[0]) return s diff --git a/python/tvm/topi/cuda/transform.py b/python/tvm/topi/cuda/transform.py new file mode 100644 index 000000000000..eb0e1b2cb423 --- /dev/null +++ b/python/tvm/topi/cuda/transform.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +def schedule_transpose(outs): + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + schedule_transpose_from_existing(s, outs[0]) + return s + + +def schedule_transpose_from_existing(s, out): + """Schedule for transpose on the gpu. + + Roughly follows this: + https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/, but + without the padding for shared memory. For better performance, we could + rewrite it in tir to add the padding. Also, rewriting in tir would allow + use to use warp shuffles instead of shared memory (see + https://github.com/bryancatanzaro/trove). + """ + + def _callback(op): + # pylint: disable=invalid-name + m, n = s[op].op.axis + warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) + no, ni = s[op].split(n, factor=warp_size) + mo, mi = s[op].split(m, factor=warp_size) + s[op].reorder(mo, no, mi, ni) + s[op].bind(mo, te.thread_axis("blockIdx.x")) + s[op].bind(no, te.thread_axis("blockIdx.y")) + c = s.cache_read(op.input_tensors[0], "shared", op) + s[c].compute_at(s[op], no) + thread_x = te.thread_axis("threadIdx.x") + thread_y = te.thread_axis("threadIdx.y") + s[op].bind(ni, thread_x) + # This is a hack to make the scheduling language realize that this axis + # can be scheduled. + a, _ = s[c].split(s[c].op.axis[1], factor=1) + s[c].bind(a, thread_x) + # Use 4 warps per block. Slightly faster than 1 warp per block + ao, _ = s[op].split(mi, nparts=4) + s[op].bind(ao, thread_y) + ao, _ = s[c].split(s[c].op.axis[0], nparts=4) + s[c].bind(ao, thread_y) + + traverse_inline(s, out.op, _callback) + + From 5804cc7a876178e02c5e18b40ae583453913a32a Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 17 May 2021 11:35:56 -0700 Subject: [PATCH 07/12] update comments, use warp size --- tests/python/topi/python/test_topi_transform.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 391fe50b2415..fdc1cf7e33f2 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -873,7 +873,7 @@ def test_transpose(): @tvm.testing.parametrize_targets def test_transpose_schedule(target, dev): - shape = (100, 34) + shape = (100, target.thread_warp_size + 3) x = relay.var("x", relay.TensorType(shape, "float32")) f = relay.transpose(x) ex = relay.create_executor( @@ -882,7 +882,8 @@ def test_transpose_schedule(target, dev): r = np.random.rand(*shape) tvm.testing.assert_allclose(ex.evaluate()(r).asnumpy(), np.transpose(r)) - # make sure schedule does not fire here + # We want to make sure schedule does not fire here, but there is no way of + # inspecting which schedules were used. x = relay.var("x", relay.TensorType(shape, "float32")) y = relay.var("y", relay.TensorType(shape, "float32")) f = relay.transpose(x + y) From f4982a2ae0b952566772a7a73975a67d2d465c80 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 17 May 2021 13:53:22 -0700 Subject: [PATCH 08/12] missspelled transform --- python/tvm/topi/cuda/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 2788a884724a..21ddf57ca1d0 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -57,5 +57,5 @@ from .argwhere import * from .scan import * from .sparse_reshape import * -from .transfrom import * +from .transform import * from .unique import * From 410f3125a1777ceec390481c009716c275f4a588 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 18 May 2021 15:07:22 -0700 Subject: [PATCH 09/12] formatting --- python/tvm/topi/cuda/transform.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tvm/topi/cuda/transform.py b/python/tvm/topi/cuda/transform.py index eb0e1b2cb423..d1640a0a37b7 100644 --- a/python/tvm/topi/cuda/transform.py +++ b/python/tvm/topi/cuda/transform.py @@ -59,5 +59,3 @@ def _callback(op): s[c].bind(ao, thread_y) traverse_inline(s, out.op, _callback) - - From 804a8cd30ff8b34dd5d7b4397820e67c1d554af4 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 18 May 2021 15:31:52 -0700 Subject: [PATCH 10/12] rename test --- tests/python/topi/python/test_topi_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index fdc1cf7e33f2..80e2a7673be5 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -872,7 +872,7 @@ def test_transpose(): @tvm.testing.parametrize_targets -def test_transpose_schedule(target, dev): +def test_transpose_unfused_schedule(target, dev): shape = (100, target.thread_warp_size + 3) x = relay.var("x", relay.TensorType(shape, "float32")) f = relay.transpose(x) From 30fcc8e59e2ab3c551ef2fac5423dcaebab5edfd Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 18 May 2021 16:12:45 -0700 Subject: [PATCH 11/12] comment --- python/tvm/topi/cuda/transform.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/topi/cuda/transform.py b/python/tvm/topi/cuda/transform.py index d1640a0a37b7..58c48df9ddeb 100644 --- a/python/tvm/topi/cuda/transform.py +++ b/python/tvm/topi/cuda/transform.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""CUDA implementations of transforms""" def schedule_transpose(outs): From 676a3423072cb57527ca85e8ef16f917562fb1ca Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Wed, 19 May 2021 09:54:27 -0700 Subject: [PATCH 12/12] fix tests --- python/tvm/topi/cuda/transform.py | 7 ++++++- tests/python/topi/python/test_topi_transform.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/transform.py b/python/tvm/topi/cuda/transform.py index 58c48df9ddeb..89caf94bbbc1 100644 --- a/python/tvm/topi/cuda/transform.py +++ b/python/tvm/topi/cuda/transform.py @@ -16,8 +16,13 @@ # under the License. """CUDA implementations of transforms""" +from ... import te +from ...target import Target +from ..utils import traverse_inline + def schedule_transpose(outs): + """Schedule a unfused transpose""" outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) schedule_transpose_from_existing(s, outs[0]) @@ -38,7 +43,7 @@ def schedule_transpose_from_existing(s, out): def _callback(op): # pylint: disable=invalid-name m, n = s[op].op.axis - warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) + warp_size = int(Target.current(allow_none=False).thread_warp_size) no, ni = s[op].split(n, factor=warp_size) mo, mi = s[op].split(m, factor=warp_size) s[op].reorder(mo, no, mi, ni) diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 80e2a7673be5..94cdc613ce9c 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -871,9 +871,9 @@ def test_transpose(): verify_transpose((3, 10), None) -@tvm.testing.parametrize_targets +@tvm.testing.parametrize_targets("cuda", "rocm") def test_transpose_unfused_schedule(target, dev): - shape = (100, target.thread_warp_size + 3) + shape = (100, tvm.target.Target(target).thread_warp_size + 3) x = relay.var("x", relay.TensorType(shape, "float32")) f = relay.transpose(x) ex = relay.create_executor(