Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions python/tvm/relax/backend/dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,48 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
if call.op.name in ("relax.cumprod", "relax.cumsum"):
tgt = self._get_target(call.struct_info)
axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis
shape = call.struct_info.shape
kwargs = {}
if (
(axis == -1 or axis == len(shape) - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For tensors of unknown shape, the shape field is none. Instead of len(call.struct_info.shape), can we use call.struct_info.ndim? (Alternatively, since it looks like the implementation requires an explicit shape in order to apply a reshape, we could add shape is not None to this condition.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the caching. Unfortunately, the original implementation does not support unknown shape. I added a check in the pass

and is_gpu_target(tgt)
and not can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan")
and call.op.name == "relax.cumsum"
and call.attrs.exclusive == 0
):
from tvm.relax.backend_tir import ( # pylint: disable=import-outside-toplevel
gpu_2d_continuous_cumsum,
)

dim = 1
for i in range(len(shape) - 1):
dim *= shape[i]
in_dtype = call.args[0].struct_info.dtype
out_dtype = call.attrs.dtype
out_dtype = out_dtype or in_dtype
cumsum_2d_shape = relax.ShapeExpr([dim, shape[-1]])
reshape = relax.call_pure_packed(
"vm.builtin.reshape",
call.args[0],
cumsum_2d_shape,
sinfo_args=relax.TensorStructInfo(cumsum_2d_shape, out_dtype),
)
gv = self.builder_.add_func(
gpu_2d_continuous_cumsum(in_dtype=in_dtype, out_dtype=out_dtype),
"gpu_2d_continuous_cumsum",
)
cumsum = relax.call_tir(
gv,
reshape,
out_sinfo=relax.TensorStructInfo(cumsum_2d_shape, out_dtype),
)
return relax.call_pure_packed(
"vm.builtin.reshape",
cumsum,
shape,
sinfo_args=call.struct_info,
)

with tgt:
if call.op.name == "relax.cumsum":
te_func = topi.cuda.cumsum if is_gpu_target(tgt) else topi.cumsum
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/backend_tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@

from . import contrib
from .pattern import get_tir_pattern
from .cumsum import gpu_2d_continuous_cumsum
193 changes: 193 additions & 0 deletions python/tvm/relax/backend_tir/cumsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-nested-blocks
"""Backend kernels for cumsum operator."""

import math
from typing import Optional

from tvm.script import tir as T
from tvm.tir import PrimFunc


def _is_power_of_two(n: int):
"""Check if n is a power of 2."""
return n > 0 and (n & (n - 1)) == 0


def gpu_2d_continuous_cumsum(
ty_len: int = 4,
tx_len: int = 32,
thread_elem: int = 4,
in_dtype: str = "int32",
out_dtype: Optional[str] = None,
) -> PrimFunc:
"""Generate GPU kernel for 2D continuous cumsum, i.e. The cumsum axis is -1

Parameters
----------
ty_len : int
The length of thread.y

tx_len : int
The length of thread.x

thread_elem : int
The number of elements processed by single thread

in_dtype : str
The input data type

out_dtype : Optional[str]
The output data type, if None, it will be the same as in_dtype

Returns
-------
cumsum : PrimFunc
The generated cumsum kernel
"""

out_dtype = out_dtype or in_dtype

# Configuration for GPU kernel
TX = T.int64(tx_len) # thread.x
TY = T.int64(ty_len) # thread.y
N = T.int64(thread_elem) # number of elements in single thread

if not _is_power_of_two(TX) or not _is_power_of_two(TY) or not _is_power_of_two(N):
raise ValueError("Configuration of TX, TY, N must be power of 2")

# number of elements to be processed by single warp
warp_elem = T.int64(tx_len * thread_elem)
# number of elements to be processed by single block(SM)
block_elem = T.int64(tx_len * ty_len * thread_elem)

LOG_TX = T.int64(int(math.log2(tx_len)))
LOG_BLOCK_N = T.int64(int(math.log2(tx_len * ty_len * thread_elem)))

@T.macro
def block_inclusive_inside_block(
batch: T.int64,
cur_len: T.int64,
source: T.Buffer,
output: T.Buffer,
tmp_buf: T.Buffer,
src_offset: T.int64,
tmp_offset: T.int64,
):
for by in T.thread_binding(batch, thread="blockIdx.y"):
for bx in T.thread_binding(T.ceildiv(cur_len, block_elem), thread="blockIdx.x"):
with T.block():
local_buf = T.alloc_buffer((thread_elem,), out_dtype, scope="local")
shared_buf = T.alloc_buffer((block_elem,), out_dtype, scope="shared")
for ty in T.thread_binding(TY, thread="threadIdx.y"):
for tx in T.thread_binding(TX, thread="threadIdx.x"):
tx_idx = bx * block_elem + ty * warp_elem + tx * thread_elem
# Load data from global memory
for i in T.vectorized(N):
local_buf[i] = T.if_then_else(
tx_idx + i < cur_len,
T.Cast(out_dtype, source[by, src_offset + tx_idx + i]),
T.Cast(out_dtype, 0),
)
# Inclusive scan inside thread
for i in T.unroll(1, N):
local_buf[i] += local_buf[i - 1]
# Store data to shared memory
for i in T.vectorized(N):
shared_buf[ty * warp_elem + tx * thread_elem + i] = local_buf[i]
# Inclusive scan inside warp
for i in T.unroll(LOG_TX):
for j in T.vectorized(N):
idx: T.int64 = ty * warp_elem + tx * thread_elem
if tx >= (1 << i):
shared_buf[idx + j] += shared_buf[
idx - (1 << i) * thread_elem + N - 1
]
# Inclusive scan inside block
for i in T.unroll(1, TY):
for j in T.vectorized(N):
if ty == 0:
idx: T.int64 = i * warp_elem + tx * thread_elem
shared_buf[idx + j] += shared_buf[i * warp_elem - 1]
# Write sum of block to global memory
for i in T.vectorized(N):
idx: T.int64 = ty * warp_elem + tx * thread_elem + i
if bx * block_elem + idx < cur_len:
output[by, src_offset + bx * block_elem + idx] = shared_buf[idx]
if tx == 0 and ty == 0:
for i in T.vectorized(N):
tmp_buf[by, tmp_offset + bx] = shared_buf[block_elem - 1]

@T.macro
def update_cross_block(
batch: T.int64,
cur_len: T.int64,
source: T.Buffer,
output: T.Buffer,
src_offset: T.int64,
out_offset: T.int64,
):
for by in T.thread_binding(batch, thread="blockIdx.y"):
for bx in T.thread_binding(T.ceildiv(cur_len, block_elem), thread="blockIdx.x"):
for ty in T.thread_binding(TY, thread="threadIdx.y"):
for tx in T.thread_binding(TX, thread="threadIdx.x"):
for i in T.serial(N):
idx: T.int64 = bx * block_elem + ty * warp_elem + i * TX + tx
if idx < cur_len:
output[by, out_offset + idx] += T.if_then_else(
bx > 0, source[by, src_offset + bx - 1], 0
)

@T.prim_func(private=True)
def cumsum(var_a: T.handle, var_out: T.handle):
T.func_attr({"tir.is_scheduled": 1}) # prevent further scheduling
m, n = T.int64(), T.int64()
A = T.match_buffer(var_a, [m, n], dtype=in_dtype)
Out = T.match_buffer(var_out, [m, n], dtype=out_dtype)
Tmp = T.alloc_buffer([m, n], dtype=out_dtype)
ceil_log2 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n))))
total_rounds = ceil_log2 // LOG_BLOCK_N

block_inclusive_inside_block(
m, n, A, Out, Tmp, src_offset=T.int64(0), tmp_offset=T.int64(0)
)
for i in range(total_rounds):
cur_len = T.ceildiv(n, 1 << (LOG_BLOCK_N * (i + 1)))
block_inclusive_inside_block(
m,
cur_len,
Tmp,
Tmp,
Tmp,
src_offset=i * T.ceildiv(n, block_elem),
tmp_offset=(i + 1) * T.ceildiv(n, block_elem),
)
for i in range(total_rounds - 1):
real_idx = total_rounds - 1 - i - 1
cur_len = T.ceildiv(n, 1 << (LOG_BLOCK_N * (real_idx + 1)))
update_cross_block(
m,
cur_len,
Tmp,
Tmp,
src_offset=(real_idx + 1) * T.ceildiv(n, block_elem),
out_offset=real_idx * T.ceildiv(n, block_elem),
)
update_cross_block(m, n, Tmp, Out, src_offset=0, out_offset=0)

return cumsum
38 changes: 33 additions & 5 deletions tests/python/relax/test_backend_dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
# specific language governing permissions and limitations
# under the License.

import numpy as np
import pytest

import tvm
from tvm import topi, relax, tir, dlight
import tvm.script
import tvm.testing
from tvm.script import relax as R, tir as T, ir as I
from tvm import dlight, relax, tir, topi
from tvm.contrib.thrust import can_use_thrust


from tvm.relax.backend import DispatchSortScan
from tvm.ir.base import assert_structural_equal
from tvm.relax.backend import DispatchSortScan
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T


def test_dispatch_scanop():
Expand Down Expand Up @@ -399,5 +400,32 @@ def foo(x: R.Tensor((2, 3), "float32", "vulkan")):
assert_structural_equal(mod, expected_mod)


@tvm.testing.requires_cuda
def test_dispatch_cumsum_gpu():
"""Test cumsum kernel dispatch and numerical correctness"""

@I.ir_module
class Module:
@R.function
def main(x: R.Tensor(("m", "n"), "int32")):
with R.dataflow():
gv = R.cumsum(x, axis=-1, exclusive=False)
R.output(gv)
return gv

size = (8, 2000)
np_data = np.random.randint(0, 10, size).astype("int32")
np_cumsum = np.cumsum(np_data, axis=-1)
for target in ["cuda", "vulkan -supports_int64=1"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Use @tvm.testing.parametrize_targets("cuda", "vulkan -supports_int64=1") instead of looping over each target. This performs each test case in a separate pytest environment,

  • Exercises each test in a separate pytest case. Can distinguish between failure on one specific backend as opposed to failure on every backend.
  • Applies the appropriate @tvm.testing.requires_* marks for each target. Currently, this test would fail if a developer runs it with set(USE_CUDA ON) and set(USE_VULKAN OFF).
@tvm.testing.parametrize_targets("cuda", "vulkan -supports_int64=1")
def test_dispatch_cumsum_gpu(target, dev):
   ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in #16947

with tvm.target.Target(target):
mod = DispatchSortScan()(Module)
ex = tvm.relax.build(mod, target)
device = tvm.device(target, 0)
vm = tvm.relax.VirtualMachine(ex, device)
tvm_data = tvm.nd.array(np_data, device)
cumsum = vm["main"](tvm_data)
tvm.testing.assert_allclose(cumsum.numpy(), np_cumsum)


if __name__ == "__main__":
tvm.testing.main()