diff --git a/python/tvm/dlight/gpu/general_reduction.py b/python/tvm/dlight/gpu/general_reduction.py index a068e732b986..d3979ce0e4c3 100644 --- a/python/tvm/dlight/gpu/general_reduction.py +++ b/python/tvm/dlight/gpu/general_reduction.py @@ -99,6 +99,46 @@ def f_layout_mapping(*iters): except AssertionError: return None + if "R" not in block_infos[-1].dom_kind(): + # The final block is a spatial block. + # It is possible that the loop order of the last block is not the same as + # previous blocks. + # Thus we reorder spatial loops to align with reduction loops for followup schedule. + # We first collect all the buffers written by reduction blocks, + # then in the final block, any index of those buffers are spatial. + reduced_buffers = [] + for block_info in block_infos[:-1]: + for buffer_write in sch.get(block_info.block_rv).writes: + reduced_buffers.append(buffer_write.buffer) + + spatial_block = sch.get(block_infos[-1].block_rv) + spatial_loops = set() + block_var_to_loop_var = {} + loops = sch.get_loops(block_infos[-1].block_rv) + for block_iter, loop_rv in zip(spatial_block.iter_vars, loops): + block_var_to_loop_var[block_iter.var] = sch.get(loop_rv).loop_var + + def _visit_expr(e: tir.PrimExpr): + if isinstance(e, tir.Var) and e in block_var_to_loop_var: + spatial_loops.add(block_var_to_loop_var[e]) + + for buffer_read in spatial_block.reads: + buffer = buffer_read.buffer + if buffer in reduced_buffers: + for read_range in buffer_read.region: + tir.stmt_functor.post_order_visit(read_range.min, _visit_expr) + tir.stmt_functor.post_order_visit(read_range.extent, _visit_expr) + + s_loops = [] + other_loops = [] + for loop_rv in loops: + loop = sch.get(loop_rv) + if loop.loop_var in spatial_loops or loop.extent == 1: + s_loops.append(loop_rv) + else: + other_loops.append(loop_rv) + sch.reorder(*s_loops, *other_loops) + loops = sch.get_loops(block_infos[-1].block_rv) bx = sch.fuse(*loops[:num_leading_s]) r_loop, tx = sch.split(loops[-1], [None, len_tx]) diff --git a/tests/python/dlight/test_gpu_general_reduction.py b/tests/python/dlight/test_gpu_general_reduction.py index e1a9a8e018ce..9549441a1154 100644 --- a/tests/python/dlight/test_gpu_general_reduction.py +++ b/tests/python/dlight/test_gpu_general_reduction.py @@ -222,6 +222,97 @@ def main(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_sof _check(Before, After) +def test_softmax_3(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func + def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32")): + # with T.block("root"): + T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(8192))) + T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192))) + T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(8192))) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(4), T.int64(8192), T.int64(32)): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(input[v_i0, v_i1, v_k, v_i2]) + T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-340282346638528859811704183484516925440.0) + T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], input[v_i0, v_i1, v_k, v_i2]) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(4), T.int64(32), T.int64(8192)): + with T.block("T_softmax_exp"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(input[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i3]) + T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) + T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(input[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i3]) + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(4), T.int64(8192), T.int64(32)): + with T.block("T_softmax_expsum"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(T_softmax_exp[v_i0, v_i1, v_k, v_i2]) + T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) + with T.init(): + T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0.0) + T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_k, v_i2] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(4), T.int64(32), T.int64(8192)): + with T.block("T_softmax_norm"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i3]) + T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"axis": 2}) + T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i3] + + + @I.ir_module + class After: + @T.prim_func + def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32")): + T.func_attr({"tir.is_scheduled": 1}) + # with T.block("root"): + T_softmax_maxelem_shared = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(8192)), scope="shared") + T_softmax_expsum_shared = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(8192)), scope="shared") + for ax0_ax2_fused in T.thread_binding(T.int64(32768), thread="blockIdx.x"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_fused_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("T_softmax_maxelem"): + v0 = T.axis.spatial(T.int64(4), ax0_ax2_fused // T.int64(8192) + ax0) + v1 = T.axis.spatial(T.int64(8192), ax0_ax2_fused % T.int64(8192) + ax1) + v2 = T.axis.reduce(T.int64(32), ax2_fused_0 * T.int64(256) + ax2_fused_1) + T.where(ax2_fused_0 * T.int64(256) + ax2_fused_1 < T.int64(32)) + T.reads(input[T.int64(0), v0, v2, v1]) + T.writes(T_softmax_maxelem_shared[T.int64(0), v0, v1]) + with T.init(): + T_softmax_maxelem_shared[T.int64(0), v0, v1] = T.float32(-340282346638528859811704183484516925440.0) + T_softmax_maxelem_shared[T.int64(0), v0, v1] = T.max(T_softmax_maxelem_shared[T.int64(0), v0, v1], input[T.int64(0), v0, v2, v1]) + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_fused_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("T_softmax_expsum"): + v0 = T.axis.spatial(T.int64(4), ax0_ax2_fused // T.int64(8192) + ax0) + v1 = T.axis.spatial(T.int64(8192), ax0_ax2_fused % T.int64(8192) + ax1) + v2 = T.axis.reduce(T.int64(32), ax2_fused_0 * T.int64(256) + ax2_fused_1) + T.where(ax2_fused_0 * T.int64(256) + ax2_fused_1 < T.int64(32)) + T.reads(input[T.int64(0), v0, v2, v1], T_softmax_maxelem_shared[T.int64(0), v0, v1]) + T.writes(T_softmax_expsum_shared[T.int64(0), v0, v1]) + with T.init(): + T_softmax_expsum_shared[T.int64(0), v0, v1] = T.float32(0.0) + T_softmax_expsum_shared[T.int64(0), v0, v1] = T_softmax_expsum_shared[T.int64(0), v0, v1] + T.exp(input[T.int64(0), v0, v2, v1] - T_softmax_maxelem_shared[T.int64(0), v0, v1]) + for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("T_softmax_norm"): + v0 = T.axis.spatial(T.int64(4), ax0_ax2_fused // T.int64(8192)) + v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(256) + ax1_1) + v2 = T.axis.spatial(T.int64(8192), ax0_ax2_fused % T.int64(8192)) + T.where(ax1_0 * T.int64(256) + ax1_1 < T.int64(32)) + T.reads(input[T.int64(0), v0, v1, v2], T_softmax_maxelem_shared[T.int64(0), v0, v2], T_softmax_expsum_shared[T.int64(0), v0, v2]) + T.writes(T_softmax_norm[T.int64(0), v0, v1, v2]) + T.block_attr({"axis": 2}) + T_softmax_norm[T.int64(0), v0, v1, v2] = T.exp(input[T.int64(0), v0, v1, v2] - T_softmax_maxelem_shared[T.int64(0), v0, v2]) / T_softmax_expsum_shared[T.int64(0), v0, v2] + # fmt: on + _check(Before, After) + + def test_layer_norm(): # fmt: off @I.ir_module