-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relax][TIR] Introduce new cumsum op for gpu
#16934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,3 +18,4 @@ | |
|
|
||
| from . import contrib | ||
| from .pattern import get_tir_pattern | ||
| from .cumsum import gpu_2d_continuous_cumsum | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,193 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| # pylint: disable=invalid-name, too-many-nested-blocks | ||
| """Backend kernels for cumsum operator.""" | ||
|
|
||
| import math | ||
| from typing import Optional | ||
|
|
||
| from tvm.script import tir as T | ||
| from tvm.tir import PrimFunc | ||
|
|
||
|
|
||
| def _is_power_of_two(n: int): | ||
| """Check if n is a power of 2.""" | ||
| return n > 0 and (n & (n - 1)) == 0 | ||
|
|
||
|
|
||
| def gpu_2d_continuous_cumsum( | ||
| ty_len: int = 4, | ||
| tx_len: int = 32, | ||
| thread_elem: int = 4, | ||
| in_dtype: str = "int32", | ||
| out_dtype: Optional[str] = None, | ||
| ) -> PrimFunc: | ||
| """Generate GPU kernel for 2D continuous cumsum, i.e. The cumsum axis is -1 | ||
|
|
||
| Parameters | ||
| ---------- | ||
| ty_len : int | ||
| The length of thread.y | ||
|
|
||
| tx_len : int | ||
| The length of thread.x | ||
|
|
||
| thread_elem : int | ||
| The number of elements processed by single thread | ||
|
|
||
| in_dtype : str | ||
| The input data type | ||
|
|
||
| out_dtype : Optional[str] | ||
| The output data type, if None, it will be the same as in_dtype | ||
|
|
||
| Returns | ||
| ------- | ||
| cumsum : PrimFunc | ||
| The generated cumsum kernel | ||
| """ | ||
|
|
||
| out_dtype = out_dtype or in_dtype | ||
|
|
||
| # Configuration for GPU kernel | ||
| TX = T.int64(tx_len) # thread.x | ||
| TY = T.int64(ty_len) # thread.y | ||
| N = T.int64(thread_elem) # number of elements in single thread | ||
|
|
||
| if not _is_power_of_two(TX) or not _is_power_of_two(TY) or not _is_power_of_two(N): | ||
| raise ValueError("Configuration of TX, TY, N must be power of 2") | ||
|
|
||
| # number of elements to be processed by single warp | ||
| warp_elem = T.int64(tx_len * thread_elem) | ||
| # number of elements to be processed by single block(SM) | ||
| block_elem = T.int64(tx_len * ty_len * thread_elem) | ||
|
|
||
| LOG_TX = T.int64(int(math.log2(tx_len))) | ||
| LOG_BLOCK_N = T.int64(int(math.log2(tx_len * ty_len * thread_elem))) | ||
|
|
||
| @T.macro | ||
| def block_inclusive_inside_block( | ||
| batch: T.int64, | ||
| cur_len: T.int64, | ||
| source: T.Buffer, | ||
| output: T.Buffer, | ||
| tmp_buf: T.Buffer, | ||
| src_offset: T.int64, | ||
| tmp_offset: T.int64, | ||
| ): | ||
| for by in T.thread_binding(batch, thread="blockIdx.y"): | ||
| for bx in T.thread_binding(T.ceildiv(cur_len, block_elem), thread="blockIdx.x"): | ||
| with T.block(): | ||
| local_buf = T.alloc_buffer((thread_elem,), out_dtype, scope="local") | ||
| shared_buf = T.alloc_buffer((block_elem,), out_dtype, scope="shared") | ||
| for ty in T.thread_binding(TY, thread="threadIdx.y"): | ||
| for tx in T.thread_binding(TX, thread="threadIdx.x"): | ||
| tx_idx = bx * block_elem + ty * warp_elem + tx * thread_elem | ||
| # Load data from global memory | ||
| for i in T.vectorized(N): | ||
| local_buf[i] = T.if_then_else( | ||
| tx_idx + i < cur_len, | ||
| T.Cast(out_dtype, source[by, src_offset + tx_idx + i]), | ||
| T.Cast(out_dtype, 0), | ||
| ) | ||
| # Inclusive scan inside thread | ||
| for i in T.unroll(1, N): | ||
| local_buf[i] += local_buf[i - 1] | ||
| # Store data to shared memory | ||
| for i in T.vectorized(N): | ||
| shared_buf[ty * warp_elem + tx * thread_elem + i] = local_buf[i] | ||
| # Inclusive scan inside warp | ||
| for i in T.unroll(LOG_TX): | ||
| for j in T.vectorized(N): | ||
| idx: T.int64 = ty * warp_elem + tx * thread_elem | ||
| if tx >= (1 << i): | ||
| shared_buf[idx + j] += shared_buf[ | ||
| idx - (1 << i) * thread_elem + N - 1 | ||
| ] | ||
| # Inclusive scan inside block | ||
| for i in T.unroll(1, TY): | ||
| for j in T.vectorized(N): | ||
| if ty == 0: | ||
| idx: T.int64 = i * warp_elem + tx * thread_elem | ||
| shared_buf[idx + j] += shared_buf[i * warp_elem - 1] | ||
| # Write sum of block to global memory | ||
| for i in T.vectorized(N): | ||
| idx: T.int64 = ty * warp_elem + tx * thread_elem + i | ||
| if bx * block_elem + idx < cur_len: | ||
| output[by, src_offset + bx * block_elem + idx] = shared_buf[idx] | ||
| if tx == 0 and ty == 0: | ||
| for i in T.vectorized(N): | ||
| tmp_buf[by, tmp_offset + bx] = shared_buf[block_elem - 1] | ||
|
|
||
| @T.macro | ||
| def update_cross_block( | ||
| batch: T.int64, | ||
| cur_len: T.int64, | ||
| source: T.Buffer, | ||
| output: T.Buffer, | ||
| src_offset: T.int64, | ||
| out_offset: T.int64, | ||
| ): | ||
| for by in T.thread_binding(batch, thread="blockIdx.y"): | ||
| for bx in T.thread_binding(T.ceildiv(cur_len, block_elem), thread="blockIdx.x"): | ||
| for ty in T.thread_binding(TY, thread="threadIdx.y"): | ||
| for tx in T.thread_binding(TX, thread="threadIdx.x"): | ||
| for i in T.serial(N): | ||
| idx: T.int64 = bx * block_elem + ty * warp_elem + i * TX + tx | ||
| if idx < cur_len: | ||
| output[by, out_offset + idx] += T.if_then_else( | ||
| bx > 0, source[by, src_offset + bx - 1], 0 | ||
| ) | ||
|
|
||
| @T.prim_func(private=True) | ||
| def cumsum(var_a: T.handle, var_out: T.handle): | ||
| T.func_attr({"tir.is_scheduled": 1}) # prevent further scheduling | ||
| m, n = T.int64(), T.int64() | ||
| A = T.match_buffer(var_a, [m, n], dtype=in_dtype) | ||
| Out = T.match_buffer(var_out, [m, n], dtype=out_dtype) | ||
| Tmp = T.alloc_buffer([m, n], dtype=out_dtype) | ||
| ceil_log2 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n)))) | ||
| total_rounds = ceil_log2 // LOG_BLOCK_N | ||
|
|
||
| block_inclusive_inside_block( | ||
| m, n, A, Out, Tmp, src_offset=T.int64(0), tmp_offset=T.int64(0) | ||
| ) | ||
| for i in range(total_rounds): | ||
| cur_len = T.ceildiv(n, 1 << (LOG_BLOCK_N * (i + 1))) | ||
| block_inclusive_inside_block( | ||
| m, | ||
| cur_len, | ||
| Tmp, | ||
| Tmp, | ||
| Tmp, | ||
| src_offset=i * T.ceildiv(n, block_elem), | ||
| tmp_offset=(i + 1) * T.ceildiv(n, block_elem), | ||
| ) | ||
| for i in range(total_rounds - 1): | ||
| real_idx = total_rounds - 1 - i - 1 | ||
| cur_len = T.ceildiv(n, 1 << (LOG_BLOCK_N * (real_idx + 1))) | ||
| update_cross_block( | ||
| m, | ||
| cur_len, | ||
| Tmp, | ||
| Tmp, | ||
| src_offset=(real_idx + 1) * T.ceildiv(n, block_elem), | ||
| out_offset=real_idx * T.ceildiv(n, block_elem), | ||
| ) | ||
| update_cross_block(m, n, Tmp, Out, src_offset=0, out_offset=0) | ||
|
|
||
| return cumsum |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,18 +15,19 @@ | |
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
|
|
||
| import tvm | ||
| from tvm import topi, relax, tir, dlight | ||
| import tvm.script | ||
| import tvm.testing | ||
| from tvm.script import relax as R, tir as T, ir as I | ||
| from tvm import dlight, relax, tir, topi | ||
| from tvm.contrib.thrust import can_use_thrust | ||
|
|
||
|
|
||
| from tvm.relax.backend import DispatchSortScan | ||
| from tvm.ir.base import assert_structural_equal | ||
| from tvm.relax.backend import DispatchSortScan | ||
| from tvm.script import ir as I | ||
| from tvm.script import relax as R | ||
| from tvm.script import tir as T | ||
|
|
||
|
|
||
| def test_dispatch_scanop(): | ||
|
|
@@ -399,5 +400,32 @@ def foo(x: R.Tensor((2, 3), "float32", "vulkan")): | |
| assert_structural_equal(mod, expected_mod) | ||
|
|
||
|
|
||
| @tvm.testing.requires_cuda | ||
| def test_dispatch_cumsum_gpu(): | ||
| """Test cumsum kernel dispatch and numerical correctness""" | ||
|
|
||
| @I.ir_module | ||
| class Module: | ||
| @R.function | ||
| def main(x: R.Tensor(("m", "n"), "int32")): | ||
| with R.dataflow(): | ||
| gv = R.cumsum(x, axis=-1, exclusive=False) | ||
| R.output(gv) | ||
| return gv | ||
|
|
||
| size = (8, 2000) | ||
| np_data = np.random.randint(0, 10, size).astype("int32") | ||
| np_cumsum = np.cumsum(np_data, axis=-1) | ||
| for target in ["cuda", "vulkan -supports_int64=1"]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: Use
@tvm.testing.parametrize_targets("cuda", "vulkan -supports_int64=1")
def test_dispatch_cumsum_gpu(target, dev):
...
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed in #16947 |
||
| with tvm.target.Target(target): | ||
| mod = DispatchSortScan()(Module) | ||
| ex = tvm.relax.build(mod, target) | ||
| device = tvm.device(target, 0) | ||
| vm = tvm.relax.VirtualMachine(ex, device) | ||
| tvm_data = tvm.nd.array(np_data, device) | ||
| cumsum = vm["main"](tvm_data) | ||
| tvm.testing.assert_allclose(cumsum.numpy(), np_cumsum) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tvm.testing.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For tensors of unknown shape, the
shapefield is none. Instead oflen(call.struct_info.shape), can we usecall.struct_info.ndim? (Alternatively, since it looks like the implementation requires an explicit shape in order to apply a reshape, we could addshape is not Noneto this condition.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the caching. Unfortunately, the original implementation does not support unknown shape. I added a check in the pass