Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions python/tvm/dlight/gpu/general_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Reduction rule for operators including softmax, layer norm, RMS norm, etc"""
from typing import List, Union

from tvm import tir
from tvm import arith, tir
from tvm.target import Target

from ..base import normalize_prim_func, try_inline_contiguous_spatial
Expand Down Expand Up @@ -57,13 +57,32 @@ def apply( # pylint: disable=too-many-locals
# Align the number of block iters of the last block.
num_last_block_iter = len(block_infos[-1].dom_kind())
if num_last_block_iter < len(dom_kind):
index_map = tir.IndexMap.from_func(
lambda *iters: (
[tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter)
+ list(iters)
),
ndim=num_last_block_iter,
)

def f_layout_mapping(*iters):
analyzer = arith.Analyzer()
# Try to match the iters of last block to the iters of the first block.
# For matched positions, use the iter from the input `iters`.
# For unmatched positions, use a new iter which is constant 0.
num_matched = 0
target_layout_iters = []
for block_iter in block_infos[0].iters:
if num_matched < len(iters) and analyzer.can_prove_equal(
block_iter.dom, block_infos[-1].iters[num_matched].dom
):
target_layout_iters.append(iters[num_matched])
num_matched += 1
else:
target_layout_iters.append(tir.const(0, iters[0].dtype))

# If all the iters of the last block can match, return the new layout.
if num_matched == len(iters):
return target_layout_iters
# Otherwise, fallback to appending zeros in the beginning.
return [tir.const(0, iters[0].dtype)] * (
len(dom_kind) - num_last_block_iter
) + list(iters)

index_map = tir.IndexMap.from_func(f_layout_mapping, ndim=num_last_block_iter)
sch.transform_block_layout(block_infos[-1].block_rv, index_map)

try:
Expand Down
149 changes: 149 additions & 0 deletions tests/python/dlight/test_gpu_general_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,5 +453,154 @@ def main(A: T.Buffer((1, 2048), "float32"), B: T.Buffer((2048,), "float32"), C:
_check(Before, After)


def test_logsumexp():
@I.ir_module
class Before:
@T.prim_func
def compute_lse(var_A: T.handle, var_blocked_lse: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
batch_size = T.int64(is_size_var=True)
vocab_size = T.int64(is_size_var=True)
num_chunks = T.int64(is_size_var=True)
A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32")
blocked_lse = T.match_buffer(var_blocked_lse, (batch_size, num_chunks), dtype="float32")
A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(4096)), dtype="float32")
temp_max = T.alloc_buffer((batch_size, num_chunks), dtype="float32")
temp_sum = T.alloc_buffer((batch_size, num_chunks), dtype="float32")

for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
with T.block("pad"):
v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
A_pad[v0, v1, v2] = T.if_then_else(
v1 * T.int64(4096) + v2 < vocab_size,
A[v0, v1 * T.int64(4096) + v2],
T.min_value("float32"),
)

for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
with T.block("max"):
v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
with T.init():
temp_max[v0, v1] = T.min_value("float32")
temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2])

for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
with T.block("sum_exp"):
v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
with T.init():
temp_sum[v0, v1] = T.float32(0)
temp_sum[v0, v1] += T.if_then_else(
v1 * T.int64(4096) + v2 < vocab_size,
T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]),
T.float32(0),
)

for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)):
with T.block("log"):
v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
blocked_lse[v0, v1] = T.log(temp_sum[v0, v1]) + temp_max[v0, v1]

@I.ir_module
class After:
@T.prim_func
def compute_lse(var_A: T.handle, var_blocked_lse: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
A = T.match_buffer(var_A, (batch_size, vocab_size))
num_chunks = T.int64(is_size_var=True)
blocked_lse = T.match_buffer(var_blocked_lse, (batch_size, num_chunks))
temp_max_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared")
temp_sum_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared")
for ax0_ax1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
for ax2_fused_0 in T.serial(
T.int64(16),
annotations={
"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1,
},
):
with T.block("max"):
v0 = T.axis.spatial(
batch_size,
ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0,
)
v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1)
v2 = T.axis.reduce(
T.int64(4096), ax2_fused_0 * T.int64(256) + ax2_fused_1
)
T.reads(A[v0, v1 * T.int64(4096) + v2])
T.writes(temp_max_shared[v0, v1])
with T.init():
temp_max_shared[v0, v1] = T.min_value("float32")
temp_max_shared[v0, v1] = T.max(
temp_max_shared[v0, v1],
T.if_then_else(
v1 * T.int64(4096) + v2 < vocab_size,
A[v0, v1 * T.int64(4096) + v2],
T.min_value("float32"),
),
)
for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
for ax2_fused_0 in T.serial(
T.int64(16),
annotations={
"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1,
},
):
with T.block("sum_exp"):
v0 = T.axis.spatial(
batch_size,
ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0,
)
v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1)
v2 = T.axis.reduce(
T.int64(4096), ax2_fused_0 * T.int64(256) + ax2_fused_1
)
T.reads(A[v0, v1 * T.int64(4096) + v2], temp_max_shared[v0, v1])
T.writes(temp_sum_shared[v0, v1])
with T.init():
temp_sum_shared[v0, v1] = T.float32(0)
temp_sum_shared[v0, v1] = temp_sum_shared[v0, v1] + T.if_then_else(
v1 * T.int64(4096) + v2 < vocab_size,
T.exp(
(
T.if_then_else(
v1 * T.int64(4096) + v2 < vocab_size,
A[v0, v1 * T.int64(4096) + v2],
T.min_value("float32"),
)
- temp_max_shared[v0, v1]
)
),
T.float32(0),
)
for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
for ax2_0 in T.serial(
T.int64(1),
annotations={
"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1,
},
):
with T.block("log"):
v0 = T.axis.spatial(
batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks
)
v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks)
v2 = T.axis.spatial(T.int64(1), ax2_0 * T.int64(256) + ax2_1)
T.where(ax2_0 * T.int64(256) + ax2_1 < T.int64(1))
T.reads(temp_sum_shared[v0, v1], temp_max_shared[v0, v1])
T.writes(blocked_lse[v0, v1])
blocked_lse[v0, v1] = (
T.log(temp_sum_shared[v0, v1]) + temp_max_shared[v0, v1]
)

_check(Before, After)


if __name__ == "__main__":
tvm.testing.main()