From 1af0a0368a409bdddb2c413d3e7f52d7cb79f5ff Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 23 Mar 2024 18:47:29 -0400 Subject: [PATCH] [Fix][Dlight] (Low-batched-)GeMV on small spatial loops This PR fixes an issue in the dlight GeMV rule and the low-batch GeMV rule. The issue happens when the inner spatial loop has small length (e.g., in the MoE gate layer, this length is usually 8). The error is because the GeMV scheduling does not make sure that each TIR block reads/writes the same number of local registers, and this inconsistency leads to wrong generated code. For example, in the schedule (prior to this fix), the first TIR block was scheduled to assign each thread 2 local registers, while the second block was scheduled to assign each thread 1 local register, which is incorrect. Unfortunately, this error only shows up when the spatial loop has small length. One regression test is added. --- python/tvm/dlight/gpu/gemv.py | 18 ++- python/tvm/dlight/gpu/low_batch_gemv.py | 20 +++- .../python/dlight/test_gpu_low_batch_gemv.py | 106 ++++++++++++++++++ 3 files changed, 137 insertions(+), 7 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index ffd6b6d09533..55b38fc66b01 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -342,12 +342,16 @@ def apply( sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) tr, vec_c, *ts_tile_s = sch.get_loops(block=rf2)[1:] ts_tile_s = sch.fuse(*ts_tile_s) - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + ts_o, ts_i, tile_s = sch.split( + ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) tile_s, vec_s = sch.split( tile_s, factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])], preserve_unit_iters=True, ) + assert sch.get(ts_o).extent.value == 1 + ts = sch.fuse(ts_o, ts_i) sch.reorder(ts, tr, tile_s, vec_s, vec_c) sch.bind(ts, TAG_S) sch.bind(tr, TAG_R) @@ -357,7 +361,11 @@ def apply( sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True) tr, *ts_tile_s = sch.get_loops(block=gemv)[1:] ts_tile_s = sch.fuse(*ts_tile_s) - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + ts_o, ts_i, tile_s = sch.split( + ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) + assert sch.get(ts_o).extent.value == 1 + ts = sch.fuse(ts_o, ts_i) sch.reorder(tile_s, ts, tr) sch.bind(ts, TAG_S) sch.bind(tr, TAG_R) @@ -411,7 +419,11 @@ def apply( sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:]) ts_tile_s = sch.get_loops(epilogue)[-1] - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + ts_o, ts_i, tile_s = sch.split( + ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) + assert sch.get(ts_o).extent.value == 1 + ts = sch.fuse(ts_o, ts_i) sch.bind(ts, TAG_S) sch.set_scope(block, 0, "local") # pylint: enable=invalid-name diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index 84a9319248c5..9a92c9e0e9dc 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -17,7 +17,7 @@ """A rule for low-batch GEMM / decode-GEMM using GEMV schedule.""" import re from functools import reduce -from typing import List, Optional, Union, Set +from typing import List, Optional, Set, Union from tvm import DataType, arith, ir, tir from tvm.target import Target @@ -428,12 +428,16 @@ def apply( sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) tr, vec_c, batch_loop, *ts_tile_s = sch.get_loops(block=rf2)[2:] ts_tile_s = sch.fuse(*ts_tile_s) - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + ts_o, ts_i, tile_s = sch.split( + ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) tile_s, vec_s = sch.split( tile_s, factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])], preserve_unit_iters=True, ) + assert sch.get(ts_o).extent.value == 1 + ts = sch.fuse(ts_o, ts_i) sch.reorder(ts, tr, tile_s, batch_loop, vec_s, vec_c) sch.bind(ts, TAG_S) sch.bind(tr, TAG_R) @@ -444,7 +448,11 @@ def apply( tr, batch_loop, *ts_tile_s = sch.get_loops(block=gemv)[2:] ts_tile_s = sch.fuse(*ts_tile_s) - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + ts_o, ts_i, tile_s = sch.split( + ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) + assert sch.get(ts_o).extent.value == 1 + ts = sch.fuse(ts_o, ts_i) sch.reorder(tile_s, batch_loop, ts, tr) sch.bind(ts, TAG_S) sch.bind(tr, TAG_R) @@ -499,7 +507,11 @@ def apply( sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[3:]) ts_tile_s = sch.get_loops(epilogue)[-1] - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + ts_o, ts_i, tile_s = sch.split( + ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) + assert sch.get(ts_o).extent.value == 1 + ts = sch.fuse(ts_o, ts_i) sch.bind(ts, TAG_S) sch.set_scope(block, 0, "local") diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py b/tests/python/dlight/test_gpu_low_batch_gemv.py index d3e635ddaa4e..4b63cfddba3c 100644 --- a/tests/python/dlight/test_gpu_low_batch_gemv.py +++ b/tests/python/dlight/test_gpu_low_batch_gemv.py @@ -275,5 +275,111 @@ def before(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int tvm.ir.assert_structural_equal(mod["main"], before) +def test_small_spatial_axis(): + @T.prim_func(private=True) + def func(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16"), var_C: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + batch_size = T.int64() + A = T.match_buffer(var_A, (batch_size, T.int64(4096)), "float16") + C = T.match_buffer(var_C, (batch_size, T.int64(8)), "float16") + for i0, i1, k in T.grid(batch_size, T.int64(8), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(A[v_i0, v_k], B[v_i1, v_k]) + T.writes(C[v_i0, v_i1]) + with T.init(): + C[v_i0, v_i1] = T.float16(0) + C[v_i0, v_i1] = C[v_i0, v_i1] + A[v_i0, v_k] * B[v_i1, v_k] + + # fmt: off + @T.prim_func(private=True) + def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16"), var_C: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + A = T.match_buffer(var_A, (batch_size, T.int64(4096)), "float16") + C = T.match_buffer(var_C, (batch_size, T.int64(8)), "float16") + # with T.block("root"): + C_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(8)), "float16", scope="local") + C_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(8)), "float16", scope="local") + C_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(8)), "float16", scope="local") + for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"): + for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): + for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_1_init, u_fused_ax1_fused_fused_2_init in T.grid(T.int64(4), T.int64(2)): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_init"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init) + v1 = T.axis.spatial(T.int64(8), u_fused_ax1_fused_fused_0 * T.int64(32) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2_init) + T.where((u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1) * T.int64(2) + u_fused_ax1_fused_fused_2_init < T.int64(8)) + T.reads() + T.writes(C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1]) + C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1] = T.float16(0) + for ax2_fused_u_fused_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_1, u_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1) + v1 = T.axis.spatial(T.int64(8), u_fused_ax1_fused_fused_0 * T.int64(32) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2) + vax2_fused_u_fused_0, vax2_fused_u_fused_2 = T.axis.remap("RR", [ax2_fused_u_fused_0, ax2_fused_u_fused_2]) + T.where((u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1) * T.int64(2) + u_fused_ax1_fused_fused_2 < T.int64(8)) + T.reads(C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1], A[v0, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)], B[v1, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)]) + T.writes(C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1]) + C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1] = C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1] + T.if_then_else(v0 < batch_size, A[v0, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)], T.float16(0)) * B[v1, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)] + for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2 in range(T.int64(4)): + for ax3_fused_2_1 in T.vectorized(T.int64(2)): + with T.block("NT_matmul_rf_init"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) + v1 = T.axis.spatial(T.int64(8), ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) + T.where((T.Mul(T.int64(0), T.int64(16)) + ax3_fused_0_ax3_fused_1_fused % T.int64(16)) * T.int64(2) + (ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) < T.int64(8)) + T.reads() + T.writes(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1]) + C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1] = T.float16(0) + for ax1 in range(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) + v1 = T.axis.spatial(T.int64(8), ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) + T.where((T.Mul(T.int64(0), T.int64(16)) + ax3_fused_0_ax3_fused_1_fused % T.int64(16)) * T.int64(2) + (ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) < T.int64(8)) + T.reads(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1], C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, v1]) + T.writes(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1]) + C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1] = C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1] + C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, v1] + for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)): + for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + with T.block("NT_matmul"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1) + v1 = T.axis.spatial(T.int64(8), ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2) + T.where((T.Mul(T.int64(0), T.int64(16)) + ax2_fused_0_ax2_fused_1_fused % T.int64(16)) * T.int64(2) + ax2_fused_2 < T.int64(8)) + T.reads(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1]) + T.writes(C_pad_local[v0, v1]) + with T.init(): + C_pad_local[v0, v1] = T.float16(0) + C_pad_local[v0, v1] = C_pad_local[v0, v1] + C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1] + for ax0 in range(T.int64(4)): + for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax1_fused_2 in range(T.int64(2)): + with T.block("C_pad"): + v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) + v1 = T.axis.spatial(T.int64(8), ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) + T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size and (T.Mul(T.int64(0), T.int64(16)) + ax1_fused_0_ax1_fused_1_fused % T.int64(16)) * T.int64(2) + ax1_fused_2 < T.int64(8)) + T.reads(C_pad_local[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_pad_local[v0, v1] + # fmt: on + + mod = tvm.IRModule({"main": func}) + with Target("cuda"): + mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + if __name__ == "__main__": tvm.testing.main()