-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TOPI] Custom schedule for standalone transpose in cuda #8030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b5478be
b57aade
09a65f2
95bf430
c0d051f
cdd9f5e
5804cc7
f4982a2
410f312
804a8cd
30fcc8e
676a342
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """CUDA implementations of transforms""" | ||
|
|
||
| from ... import te | ||
| from ...target import Target | ||
| from ..utils import traverse_inline | ||
|
|
||
|
|
||
| def schedule_transpose(outs): | ||
| """Schedule a unfused transpose""" | ||
| outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs | ||
| s = te.create_schedule([x.op for x in outs]) | ||
| schedule_transpose_from_existing(s, outs[0]) | ||
| return s | ||
|
|
||
|
|
||
| def schedule_transpose_from_existing(s, out): | ||
| """Schedule for transpose on the gpu. | ||
|
|
||
| Roughly follows this: | ||
| https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/, but | ||
| without the padding for shared memory. For better performance, we could | ||
| rewrite it in tir to add the padding. Also, rewriting in tir would allow | ||
| use to use warp shuffles instead of shared memory (see | ||
| https://github.com/bryancatanzaro/trove). | ||
| """ | ||
|
|
||
| def _callback(op): | ||
| # pylint: disable=invalid-name | ||
| m, n = s[op].op.axis | ||
| warp_size = int(Target.current(allow_none=False).thread_warp_size) | ||
| no, ni = s[op].split(n, factor=warp_size) | ||
| mo, mi = s[op].split(m, factor=warp_size) | ||
| s[op].reorder(mo, no, mi, ni) | ||
| s[op].bind(mo, te.thread_axis("blockIdx.x")) | ||
| s[op].bind(no, te.thread_axis("blockIdx.y")) | ||
| c = s.cache_read(op.input_tensors[0], "shared", op) | ||
| s[c].compute_at(s[op], no) | ||
| thread_x = te.thread_axis("threadIdx.x") | ||
| thread_y = te.thread_axis("threadIdx.y") | ||
| s[op].bind(ni, thread_x) | ||
| # This is a hack to make the scheduling language realize that this axis | ||
| # can be scheduled. | ||
| a, _ = s[c].split(s[c].op.axis[1], factor=1) | ||
| s[c].bind(a, thread_x) | ||
| # Use 4 warps per block. Slightly faster than 1 warp per block | ||
| ao, _ = s[op].split(mi, nparts=4) | ||
| s[op].bind(ao, thread_y) | ||
| ao, _ = s[c].split(s[c].op.axis[0], nparts=4) | ||
| s[c].bind(ao, thread_y) | ||
|
|
||
| traverse_inline(s, out.op, _callback) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| import tvm | ||
| from tvm import te | ||
| from tvm import topi | ||
| from tvm import relay | ||
| import tvm.topi.testing | ||
| from tvm.contrib.nvcc import have_fp16 | ||
|
|
||
|
|
@@ -870,6 +871,31 @@ def test_transpose(): | |
| verify_transpose((3, 10), None) | ||
|
|
||
|
|
||
| @tvm.testing.parametrize_targets("cuda", "rocm") | ||
| def test_transpose_unfused_schedule(target, dev): | ||
| shape = (100, tvm.target.Target(target).thread_warp_size + 3) | ||
| x = relay.var("x", relay.TensorType(shape, "float32")) | ||
| f = relay.transpose(x) | ||
| ex = relay.create_executor( | ||
| kind="graph", mod=tvm.IRModule.from_expr(relay.Function([x], f)), device=dev, target=target | ||
| ) | ||
| r = np.random.rand(*shape) | ||
| tvm.testing.assert_allclose(ex.evaluate()(r).asnumpy(), np.transpose(r)) | ||
|
|
||
| # We want to make sure schedule does not fire here, but there is no way of | ||
| # inspecting which schedules were used. | ||
|
Comment on lines
+885
to
+886
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like this comment mentions, there is no way of inspecting which schedules were used, so it seems to me that the difference between this test and
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could. I like to keep it separate so the intention is known.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough. Then it might be better to name it
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. switched to |
||
| x = relay.var("x", relay.TensorType(shape, "float32")) | ||
| y = relay.var("y", relay.TensorType(shape, "float32")) | ||
| f = relay.transpose(x + y) | ||
| ex = relay.create_executor( | ||
| kind="graph", | ||
| mod=tvm.IRModule.from_expr(relay.Function([x, y], f)), | ||
| device=dev, | ||
| target=target, | ||
| ) | ||
| tvm.testing.assert_allclose(ex.evaluate()(r, r).asnumpy(), np.transpose(r + r)) | ||
|
|
||
|
|
||
| @tvm.testing.uses_gpu | ||
| def test_reshape(): | ||
| verify_reshape((1, 2, 3, 4), (2, 3, 4)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -357,7 +357,7 @@ def tune_and_evaluate(tuning_opt): | |
| ) | ||
|
|
||
| # filter out non-packed conv2d task | ||
| tasks = list(filter(lambda t: len(t.args[0][1]) > 4, tasks)) | ||
| tasks = list(filter(lambda t: len(t.args[0][1]) > 4 and "conv" in t.name, tasks)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happened here, did this transpose change introduce a new task or something?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually, no, but this check makes sure anyways.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't the new added schedule not tunable? Or is there any concern of adding knobs?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may want to tune it in the future. |
||
|
|
||
| # We should have extracted 10 convolution tasks | ||
| assert len(tasks) == 10 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a more principled way to do this? like maybe with an OpStrategy or something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell, there is not a better way to do this. There is a way to add implementations based on input sizes, but these are not on a per-target basis. If you know a better way, let me know.