From 1580ab56ac41be82710c7adadd27822cee27a3d8 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 24 Apr 2024 23:01:35 -0400 Subject: [PATCH] [Fix][Dlight] Fix GeneralReduction for log-sum-exp This PR fixes the GeneralReduction dlight rule so that it can support scheduling log-sum-exp function. Prior to this issue, the rule makes a strong assumption on the pattern of the given function, which allows scheduling softmax, but fails to schedule log-sum-exp due to pattern mismatch. This PR enhances the rule and makes it able to match the pattern of log-sum-exp and apply subsequent scheduling. A regression test is added. --- python/tvm/dlight/gpu/general_reduction.py | 35 +++- .../dlight/test_gpu_general_reduction.py | 149 ++++++++++++++++++ 2 files changed, 176 insertions(+), 8 deletions(-) diff --git a/python/tvm/dlight/gpu/general_reduction.py b/python/tvm/dlight/gpu/general_reduction.py index 28b68a8b62a7..ef6bb1db91e1 100644 --- a/python/tvm/dlight/gpu/general_reduction.py +++ b/python/tvm/dlight/gpu/general_reduction.py @@ -18,7 +18,7 @@ """Reduction rule for operators including softmax, layer norm, RMS norm, etc""" from typing import List, Union -from tvm import tir +from tvm import arith, tir from tvm.target import Target from ..base import normalize_prim_func, try_inline_contiguous_spatial @@ -57,13 +57,32 @@ def apply( # pylint: disable=too-many-locals # Align the number of block iters of the last block. num_last_block_iter = len(block_infos[-1].dom_kind()) if num_last_block_iter < len(dom_kind): - index_map = tir.IndexMap.from_func( - lambda *iters: ( - [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) - + list(iters) - ), - ndim=num_last_block_iter, - ) + + def f_layout_mapping(*iters): + analyzer = arith.Analyzer() + # Try to match the iters of last block to the iters of the first block. + # For matched positions, use the iter from the input `iters`. + # For unmatched positions, use a new iter which is constant 0. + num_matched = 0 + target_layout_iters = [] + for block_iter in block_infos[0].iters: + if num_matched < len(iters) and analyzer.can_prove_equal( + block_iter.dom, block_infos[-1].iters[num_matched].dom + ): + target_layout_iters.append(iters[num_matched]) + num_matched += 1 + else: + target_layout_iters.append(tir.const(0, iters[0].dtype)) + + # If all the iters of the last block can match, return the new layout. + if num_matched == len(iters): + return target_layout_iters + # Otherwise, fallback to appending zeros in the beginning. + return [tir.const(0, iters[0].dtype)] * ( + len(dom_kind) - num_last_block_iter + ) + list(iters) + + index_map = tir.IndexMap.from_func(f_layout_mapping, ndim=num_last_block_iter) sch.transform_block_layout(block_infos[-1].block_rv, index_map) try: diff --git a/tests/python/dlight/test_gpu_general_reduction.py b/tests/python/dlight/test_gpu_general_reduction.py index 44c9a4a126ab..e1a9a8e018ce 100644 --- a/tests/python/dlight/test_gpu_general_reduction.py +++ b/tests/python/dlight/test_gpu_general_reduction.py @@ -453,5 +453,154 @@ def main(A: T.Buffer((1, 2048), "float32"), B: T.Buffer((2048,), "float32"), C: _check(Before, After) +def test_logsumexp(): + @I.ir_module + class Before: + @T.prim_func + def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + blocked_lse = T.match_buffer(var_blocked_lse, (batch_size, num_chunks), dtype="float32") + A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(4096)), dtype="float32") + temp_max = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + temp_sum = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)): + with T.block("pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + A_pad[v0, v1, v2] = T.if_then_else( + v1 * T.int64(4096) + v2 < vocab_size, + A[v0, v1 * T.int64(4096) + v2], + T.min_value("float32"), + ) + + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)): + with T.block("max"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_max[v0, v1] = T.min_value("float32") + temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2]) + + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)): + with T.block("sum_exp"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_sum[v0, v1] = T.float32(0) + temp_sum[v0, v1] += T.if_then_else( + v1 * T.int64(4096) + v2 < vocab_size, + T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), + T.float32(0), + ) + + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)): + with T.block("log"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + blocked_lse[v0, v1] = T.log(temp_sum[v0, v1]) + temp_max[v0, v1] + + @I.ir_module + class After: + @T.prim_func + def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size)) + num_chunks = T.int64(is_size_var=True) + blocked_lse = T.match_buffer(var_blocked_lse, (batch_size, num_chunks)) + temp_max_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared") + temp_sum_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared") + for ax0_ax1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_fused_0 in T.serial( + T.int64(16), + annotations={ + "pragma_auto_unroll_max_step": 256, + "pragma_unroll_explicit": 1, + }, + ): + with T.block("max"): + v0 = T.axis.spatial( + batch_size, + ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0, + ) + v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1) + v2 = T.axis.reduce( + T.int64(4096), ax2_fused_0 * T.int64(256) + ax2_fused_1 + ) + T.reads(A[v0, v1 * T.int64(4096) + v2]) + T.writes(temp_max_shared[v0, v1]) + with T.init(): + temp_max_shared[v0, v1] = T.min_value("float32") + temp_max_shared[v0, v1] = T.max( + temp_max_shared[v0, v1], + T.if_then_else( + v1 * T.int64(4096) + v2 < vocab_size, + A[v0, v1 * T.int64(4096) + v2], + T.min_value("float32"), + ), + ) + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_fused_0 in T.serial( + T.int64(16), + annotations={ + "pragma_auto_unroll_max_step": 256, + "pragma_unroll_explicit": 1, + }, + ): + with T.block("sum_exp"): + v0 = T.axis.spatial( + batch_size, + ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0, + ) + v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1) + v2 = T.axis.reduce( + T.int64(4096), ax2_fused_0 * T.int64(256) + ax2_fused_1 + ) + T.reads(A[v0, v1 * T.int64(4096) + v2], temp_max_shared[v0, v1]) + T.writes(temp_sum_shared[v0, v1]) + with T.init(): + temp_sum_shared[v0, v1] = T.float32(0) + temp_sum_shared[v0, v1] = temp_sum_shared[v0, v1] + T.if_then_else( + v1 * T.int64(4096) + v2 < vocab_size, + T.exp( + ( + T.if_then_else( + v1 * T.int64(4096) + v2 < vocab_size, + A[v0, v1 * T.int64(4096) + v2], + T.min_value("float32"), + ) + - temp_max_shared[v0, v1] + ) + ), + T.float32(0), + ) + for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_0 in T.serial( + T.int64(1), + annotations={ + "pragma_auto_unroll_max_step": 256, + "pragma_unroll_explicit": 1, + }, + ): + with T.block("log"): + v0 = T.axis.spatial( + batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ) + v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks) + v2 = T.axis.spatial(T.int64(1), ax2_0 * T.int64(256) + ax2_1) + T.where(ax2_0 * T.int64(256) + ax2_1 < T.int64(1)) + T.reads(temp_sum_shared[v0, v1], temp_max_shared[v0, v1]) + T.writes(blocked_lse[v0, v1]) + blocked_lse[v0, v1] = ( + T.log(temp_sum_shared[v0, v1]) + temp_max_shared[v0, v1] + ) + + _check(Before, After) + + if __name__ == "__main__": tvm.testing.main()