From 64b7520c2ee07e8c660e4eb2453577251161df8a Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 16 Mar 2025 17:19:49 -0400 Subject: [PATCH] [Dlight] Fix general reduction rule to support non-last reduction axis This PR fixes a bug in the general reduction dlight rule, which happens when there is a trailing spatial block, and for the previous reduction blocks, the reduction axes are not on the back. In the case above, the loop orders of the reduction blocks and the trailing spatial block are inconsistent, while the dlight rule before this fix always treat the loop orders as consistent. As a result, though the function after applying the rule is numerically correct, it may require much extra shared memory use (in proportion to the size of spatial loops). And when the spatial dimensions are large, the required share memory size may exceed the device limit. This PR fixes this bug and adds a unit test. --- python/tvm/dlight/gpu/general_reduction.py | 40 ++++++++ .../dlight/test_gpu_general_reduction.py | 91 +++++++++++++++++++ 2 files changed, 131 insertions(+) diff --git a/python/tvm/dlight/gpu/general_reduction.py b/python/tvm/dlight/gpu/general_reduction.py index a068e732b986..d3979ce0e4c3 100644 --- a/python/tvm/dlight/gpu/general_reduction.py +++ b/python/tvm/dlight/gpu/general_reduction.py @@ -99,6 +99,46 @@ def f_layout_mapping(*iters): except AssertionError: return None + if "R" not in block_infos[-1].dom_kind(): + # The final block is a spatial block. + # It is possible that the loop order of the last block is not the same as + # previous blocks. + # Thus we reorder spatial loops to align with reduction loops for followup schedule. + # We first collect all the buffers written by reduction blocks, + # then in the final block, any index of those buffers are spatial. + reduced_buffers = [] + for block_info in block_infos[:-1]: + for buffer_write in sch.get(block_info.block_rv).writes: + reduced_buffers.append(buffer_write.buffer) + + spatial_block = sch.get(block_infos[-1].block_rv) + spatial_loops = set() + block_var_to_loop_var = {} + loops = sch.get_loops(block_infos[-1].block_rv) + for block_iter, loop_rv in zip(spatial_block.iter_vars, loops): + block_var_to_loop_var[block_iter.var] = sch.get(loop_rv).loop_var + + def _visit_expr(e: tir.PrimExpr): + if isinstance(e, tir.Var) and e in block_var_to_loop_var: + spatial_loops.add(block_var_to_loop_var[e]) + + for buffer_read in spatial_block.reads: + buffer = buffer_read.buffer + if buffer in reduced_buffers: + for read_range in buffer_read.region: + tir.stmt_functor.post_order_visit(read_range.min, _visit_expr) + tir.stmt_functor.post_order_visit(read_range.extent, _visit_expr) + + s_loops = [] + other_loops = [] + for loop_rv in loops: + loop = sch.get(loop_rv) + if loop.loop_var in spatial_loops or loop.extent == 1: + s_loops.append(loop_rv) + else: + other_loops.append(loop_rv) + sch.reorder(*s_loops, *other_loops) + loops = sch.get_loops(block_infos[-1].block_rv) bx = sch.fuse(*loops[:num_leading_s]) r_loop, tx = sch.split(loops[-1], [None, len_tx]) diff --git a/tests/python/dlight/test_gpu_general_reduction.py b/tests/python/dlight/test_gpu_general_reduction.py index e1a9a8e018ce..9549441a1154 100644 --- a/tests/python/dlight/test_gpu_general_reduction.py +++ b/tests/python/dlight/test_gpu_general_reduction.py @@ -222,6 +222,97 @@ def main(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_sof _check(Before, After) +def test_softmax_3(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func + def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32")): + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(8192))) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192))) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(8192))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(4), T.int64(8192), T.int64(32)): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(input[v_i0, v_i1, v_k, v_i2]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-340282346638528859811704183484516925440.0) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], input[v_i0, v_i1, v_k, v_i2]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(4), T.int64(32), T.int64(8192)): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(input[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i3]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(input[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i3]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(4), T.int64(8192), T.int64(32)): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_k, v_i2]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0.0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_k, v_i2] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(4), T.int64(32), T.int64(8192)): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i3]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 2}) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i3] + + + @I.ir_module + class After: + @T.prim_func + def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32")): + T.func_attr({"tir.is_scheduled": 1}) + # with T.block("root"): + T_softmax_maxelem_shared = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(8192)), scope="shared") + T_softmax_expsum_shared = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(8192)), scope="shared") + for ax0_ax2_fused in T.thread_binding(T.int64(32768), thread="blockIdx.x"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_fused_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("T_softmax_maxelem"): + v0 = T.axis.spatial(T.int64(4), ax0_ax2_fused // T.int64(8192) + ax0) + v1 = T.axis.spatial(T.int64(8192), ax0_ax2_fused % T.int64(8192) + ax1) + v2 = T.axis.reduce(T.int64(32), ax2_fused_0 * T.int64(256) + ax2_fused_1) + T.where(ax2_fused_0 * T.int64(256) + ax2_fused_1 < T.int64(32)) + T.reads(input[T.int64(0), v0, v2, v1]) + T.writes(T_softmax_maxelem_shared[T.int64(0), v0, v1]) + with T.init(): + T_softmax_maxelem_shared[T.int64(0), v0, v1] = T.float32(-340282346638528859811704183484516925440.0) + T_softmax_maxelem_shared[T.int64(0), v0, v1] = T.max(T_softmax_maxelem_shared[T.int64(0), v0, v1], input[T.int64(0), v0, v2, v1]) + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_fused_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("T_softmax_expsum"): + v0 = T.axis.spatial(T.int64(4), ax0_ax2_fused // T.int64(8192) + ax0) + v1 = T.axis.spatial(T.int64(8192), ax0_ax2_fused % T.int64(8192) + ax1) + v2 = T.axis.reduce(T.int64(32), ax2_fused_0 * T.int64(256) + ax2_fused_1) + T.where(ax2_fused_0 * T.int64(256) + ax2_fused_1 < T.int64(32)) + T.reads(input[T.int64(0), v0, v2, v1], T_softmax_maxelem_shared[T.int64(0), v0, v1]) + T.writes(T_softmax_expsum_shared[T.int64(0), v0, v1]) + with T.init(): + T_softmax_expsum_shared[T.int64(0), v0, v1] = T.float32(0.0) + T_softmax_expsum_shared[T.int64(0), v0, v1] = T_softmax_expsum_shared[T.int64(0), v0, v1] + T.exp(input[T.int64(0), v0, v2, v1] - T_softmax_maxelem_shared[T.int64(0), v0, v1]) + for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("T_softmax_norm"): + v0 = T.axis.spatial(T.int64(4), ax0_ax2_fused // T.int64(8192)) + v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(256) + ax1_1) + v2 = T.axis.spatial(T.int64(8192), ax0_ax2_fused % T.int64(8192)) + T.where(ax1_0 * T.int64(256) + ax1_1 < T.int64(32)) + T.reads(input[T.int64(0), v0, v1, v2], T_softmax_maxelem_shared[T.int64(0), v0, v2], T_softmax_expsum_shared[T.int64(0), v0, v2]) + T.writes(T_softmax_norm[T.int64(0), v0, v1, v2]) + T.block_attr({"axis": 2}) + T_softmax_norm[T.int64(0), v0, v1, v2] = T.exp(input[T.int64(0), v0, v1, v2] - T_softmax_maxelem_shared[T.int64(0), v0, v2]) / T_softmax_expsum_shared[T.int64(0), v0, v2] + # fmt: on + _check(Before, After) + + def test_layer_norm(): # fmt: off @I.ir_module