-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TOPI] Improve CUDA softmax scheduling #5600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,12 +16,12 @@ | |
| # under the License. | ||
| # pylint: disable=invalid-name, unused-variable, trailing-whitespace | ||
| """Schedule for softmax operator""" | ||
| from tvm import target as target_ | ||
| from tvm import te | ||
| from tvm.contrib import cudnn | ||
| from .. import generic | ||
| from .injective import schedule_injective_from_existing | ||
|
|
||
|
|
||
| def schedule_softmax(outs): | ||
| """Schedule for softmax op. | ||
|
|
||
|
|
@@ -39,6 +39,7 @@ def schedule_softmax(outs): | |
| outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs | ||
| s = te.create_schedule([x.op for x in outs]) | ||
| softmax = outs[0] | ||
| tgt = target_.Target.current(allow_none=False) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might want to register the warp level strategies only when the target is cuda, given that the "gpu" schedule is reused by other GPUs that does not support warp
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm. I'd like to benefit from the GPU schedules with warps... |
||
|
|
||
| op_tag = softmax.op.tag | ||
| if op_tag == 'softmax_output': | ||
|
|
@@ -53,13 +54,61 @@ def schedule_softmax(outs): | |
| raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \ | ||
| Got {0}'.format(op_tag)) | ||
|
|
||
| # The nvptx backend only supports 32-bits warp shuffle instructions. | ||
| # | ||
| # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend. | ||
| def sched_warp_softmax(): | ||
| if tgt.target_name == "nvptx": | ||
| return softmax.dtype == "float32" or softmax.dtype == "int32" | ||
| return True | ||
wpan11nv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if len(softmax.shape) > 2: | ||
| ops = [max_elem.op, expsum.op, softmax.op] | ||
| if exp is not None: | ||
| ops.append(exp.op) | ||
|
|
||
| for op in ops: | ||
| s = schedule_injective_from_existing(s, op.output(0)) | ||
|
|
||
| elif sched_warp_softmax(): | ||
| # A warp of 32 threads performs a row reduction. | ||
| num_thread = tgt.thread_warp_size | ||
| block_x = te.thread_axis("blockIdx.x") | ||
| thread_x = te.thread_axis((0, num_thread), "threadIdx.x") | ||
|
|
||
| # (4) softmax | ||
| xo, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread) | ||
| _, xii = s[softmax].split(xi, factor=4) | ||
| s[softmax].vectorize(xii) | ||
| s[softmax].bind(xo, thread_x) | ||
| s[softmax].bind(softmax.op.axis[0], block_x) | ||
|
|
||
| # (3) expsum | ||
| k = expsum.op.reduce_axis[0] | ||
| ko, _ = s[expsum].split(k, nparts=num_thread) | ||
| s[expsum].bind(ko, thread_x) | ||
| s[expsum].compute_at(s[softmax], xo) | ||
|
|
||
| # (2) exp | ||
| if exp is not None: | ||
| xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread) | ||
| _, xii = s[exp].split(xi, factor=4) | ||
| s[exp].vectorize(xii) | ||
wpan11nv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| s[exp].bind(xo, thread_x) | ||
| s[exp].compute_at(s[expsum], expsum.op.axis[0]) | ||
| s[exp].compute_at(s[softmax], softmax.op.axis[0]) | ||
| s[exp].set_scope("warp") | ||
|
|
||
| # (1) max_elem | ||
| k = max_elem.op.reduce_axis[0] | ||
| ko, _ = s[max_elem].split(k, nparts=num_thread) | ||
| s[max_elem].bind(ko, thread_x) | ||
| if exp is not None: | ||
| s[max_elem].compute_at(s[exp], xo) | ||
| else: | ||
| s[max_elem].bind(ko, thread_x) | ||
| s[max_elem].bind(max_elem.op.axis[0], block_x) | ||
|
|
||
| else: | ||
| num_thread = 64 | ||
| block_x = te.thread_axis("blockIdx.x") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.