diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index ffd6b6d09533..55b38fc66b01 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -342,12 +342,16 @@ def apply( sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) tr, vec_c, *ts_tile_s = sch.get_loops(block=rf2)[1:] ts_tile_s = sch.fuse(*ts_tile_s) - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + ts_o, ts_i, tile_s = sch.split( + ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) tile_s, vec_s = sch.split( tile_s, factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])], preserve_unit_iters=True, ) + assert sch.get(ts_o).extent.value == 1 + ts = sch.fuse(ts_o, ts_i) sch.reorder(ts, tr, tile_s, vec_s, vec_c) sch.bind(ts, TAG_S) sch.bind(tr, TAG_R) @@ -357,7 +361,11 @@ def apply( sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True) tr, *ts_tile_s = sch.get_loops(block=gemv)[1:] ts_tile_s = sch.fuse(*ts_tile_s) - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + ts_o, ts_i, tile_s = sch.split( + ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) + assert sch.get(ts_o).extent.value == 1 + ts = sch.fuse(ts_o, ts_i) sch.reorder(tile_s, ts, tr) sch.bind(ts, TAG_S) sch.bind(tr, TAG_R) @@ -411,7 +419,11 @@ def apply( sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:]) ts_tile_s = sch.get_loops(epilogue)[-1] - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + ts_o, ts_i, tile_s = sch.split( + ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) + assert sch.get(ts_o).extent.value == 1 + ts = sch.fuse(ts_o, ts_i) sch.bind(ts, TAG_S) sch.set_scope(block, 0, "local") # pylint: enable=invalid-name diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index 84a9319248c5..9a92c9e0e9dc 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -17,7 +17,7 @@ """A rule for low-batch GEMM / decode-GEMM using GEMV schedule.""" import re from functools import reduce -from typing import List, Optional, Union, Set +from typing import List, Optional, Set, Union from tvm import DataType, arith, ir, tir from tvm.target import Target @@ -428,12 +428,16 @@ def apply( sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) tr, vec_c, batch_loop, *ts_tile_s = sch.get_loops(block=rf2)[2:] ts_tile_s = sch.fuse(*ts_tile_s) - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + ts_o, ts_i, tile_s = sch.split( + ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) tile_s, vec_s = sch.split( tile_s, factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])], preserve_unit_iters=True, ) + assert sch.get(ts_o).extent.value == 1 + ts = sch.fuse(ts_o, ts_i) sch.reorder(ts, tr, tile_s, batch_loop, vec_s, vec_c) sch.bind(ts, TAG_S) sch.bind(tr, TAG_R) @@ -444,7 +448,11 @@ def apply( tr, batch_loop, *ts_tile_s = sch.get_loops(block=gemv)[2:] ts_tile_s = sch.fuse(*ts_tile_s) - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + ts_o, ts_i, tile_s = sch.split( + ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) + assert sch.get(ts_o).extent.value == 1 + ts = sch.fuse(ts_o, ts_i) sch.reorder(tile_s, batch_loop, ts, tr) sch.bind(ts, TAG_S) sch.bind(tr, TAG_R) @@ -499,7 +507,11 @@ def apply( sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[3:]) ts_tile_s = sch.get_loops(epilogue)[-1] - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + ts_o, ts_i, tile_s = sch.split( + ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) + assert sch.get(ts_o).extent.value == 1 + ts = sch.fuse(ts_o, ts_i) sch.bind(ts, TAG_S) sch.set_scope(block, 0, "local") diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py b/tests/python/dlight/test_gpu_low_batch_gemv.py index d3e635ddaa4e..4b63cfddba3c 100644 --- a/tests/python/dlight/test_gpu_low_batch_gemv.py +++ b/tests/python/dlight/test_gpu_low_batch_gemv.py @@ -275,5 +275,111 @@ def before(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int tvm.ir.assert_structural_equal(mod["main"], before) +def test_small_spatial_axis(): + @T.prim_func(private=True) + def func(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16"), var_C: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + batch_size = T.int64() + A = T.match_buffer(var_A, (batch_size, T.int64(4096)), "float16") + C = T.match_buffer(var_C, (batch_size, T.int64(8)), "float16") + for i0, i1, k in T.grid(batch_size, T.int64(8), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(A[v_i0, v_k], B[v_i1, v_k]) + T.writes(C[v_i0, v_i1]) + with T.init(): + C[v_i0, v_i1] = T.float16(0) + C[v_i0, v_i1] = C[v_i0, v_i1] + A[v_i0, v_k] * B[v_i1, v_k] + + # fmt: off + @T.prim_func(private=True) + def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16"), var_C: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + A = T.match_buffer(var_A, (batch_size, T.int64(4096)), "float16") + C = T.match_buffer(var_C, (batch_size, T.int64(8)), "float16") + # with T.block("root"): + C_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(8)), "float16", scope="local") + C_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(8)), "float16", scope="local") + C_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(8)), "float16", scope="local") + for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"): + for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): + for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_1_init, u_fused_ax1_fused_fused_2_init in T.grid(T.int64(4), T.int64(2)): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_init"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init) + v1 = T.axis.spatial(T.int64(8), u_fused_ax1_fused_fused_0 * T.int64(32) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2_init) + T.where((u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1) * T.int64(2) + u_fused_ax1_fused_fused_2_init < T.int64(8)) + T.reads() + T.writes(C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1]) + C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1] = T.float16(0) + for ax2_fused_u_fused_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_1, u_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1) + v1 = T.axis.spatial(T.int64(8), u_fused_ax1_fused_fused_0 * T.int64(32) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2) + vax2_fused_u_fused_0, vax2_fused_u_fused_2 = T.axis.remap("RR", [ax2_fused_u_fused_0, ax2_fused_u_fused_2]) + T.where((u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1) * T.int64(2) + u_fused_ax1_fused_fused_2 < T.int64(8)) + T.reads(C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1], A[v0, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)], B[v1, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)]) + T.writes(C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1]) + C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1] = C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1] + T.if_then_else(v0 < batch_size, A[v0, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)], T.float16(0)) * B[v1, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)] + for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2 in range(T.int64(4)): + for ax3_fused_2_1 in T.vectorized(T.int64(2)): + with T.block("NT_matmul_rf_init"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) + v1 = T.axis.spatial(T.int64(8), ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) + T.where((T.Mul(T.int64(0), T.int64(16)) + ax3_fused_0_ax3_fused_1_fused % T.int64(16)) * T.int64(2) + (ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) < T.int64(8)) + T.reads() + T.writes(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1]) + C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1] = T.float16(0) + for ax1 in range(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) + v1 = T.axis.spatial(T.int64(8), ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) + T.where((T.Mul(T.int64(0), T.int64(16)) + ax3_fused_0_ax3_fused_1_fused % T.int64(16)) * T.int64(2) + (ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) < T.int64(8)) + T.reads(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1], C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, v1]) + T.writes(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1]) + C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1] = C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1] + C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, v1] + for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)): + for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + with T.block("NT_matmul"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1) + v1 = T.axis.spatial(T.int64(8), ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2) + T.where((T.Mul(T.int64(0), T.int64(16)) + ax2_fused_0_ax2_fused_1_fused % T.int64(16)) * T.int64(2) + ax2_fused_2 < T.int64(8)) + T.reads(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1]) + T.writes(C_pad_local[v0, v1]) + with T.init(): + C_pad_local[v0, v1] = T.float16(0) + C_pad_local[v0, v1] = C_pad_local[v0, v1] + C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1] + for ax0 in range(T.int64(4)): + for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax1_fused_2 in range(T.int64(2)): + with T.block("C_pad"): + v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) + v1 = T.axis.spatial(T.int64(8), ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) + T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size and (T.Mul(T.int64(0), T.int64(16)) + ax1_fused_0_ax1_fused_1_fused % T.int64(16)) * T.int64(2) + ax1_fused_2 < T.int64(8)) + T.reads(C_pad_local[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_pad_local[v0, v1] + # fmt: on + + mod = tvm.IRModule({"main": func}) + with Target("cuda"): + mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + if __name__ == "__main__": tvm.testing.main()