diff --git a/python/tvm/dlight/gpu/decode_gemv.py b/python/tvm/dlight/gpu/decode_gemv.py index afcfdb30206b..1aa5d68fc53e 100644 --- a/python/tvm/dlight/gpu/decode_gemv.py +++ b/python/tvm/dlight/gpu/decode_gemv.py @@ -233,7 +233,11 @@ def _sch_inner_spatial( _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name _, tx, ty = sch.split(sch.fuse(*s), factors=[None, len_tx, len_ty]) sch.bind(tx, "threadIdx.x") - sch.bind(ty, "threadIdx.x") + sch.bind(ty, "threadIdx.y") else: + # The epilogue is element-wise without broadcasting. + # Thus the remaining spatial part should be bind to tx. sch.set_scope(block, 0, "local") + _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + sch.bind(sch.fuse(*s), "threadIdx.x") # pylint: enable=invalid-name diff --git a/tests/python/dlight/test_gpu_decode_gemv.py b/tests/python/dlight/test_gpu_decode_gemv.py index 7b19e6b7f811..971f5f4d09ba 100644 --- a/tests/python/dlight/test_gpu_decode_gemv.py +++ b/tests/python/dlight/test_gpu_decode_gemv.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring,line-too-long,invalid-name,too-few-public-methods,too-many-locals + +import tvm.testing from tvm import dlight as dl from tvm.ir import assert_structural_equal from tvm.script import ir as I @@ -489,11 +491,225 @@ def main(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), "float16"), assert_structural_equal(mod, After) +def test_spatial_inner_no_broadcasting(): + # fmt: off + @I.ir_module + class Module: + @T.prim_func + def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + p_output0_intermediate_1 = T.alloc_buffer((11008, 4096), "float16") + var_matmul_intermediate = T.alloc_buffer((1, 1, 4096), "float16") + for i, j in T.grid(11008, 4096): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv575[v_i // 8, v_j], lv576[v_i // 32, v_j]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v_i // 32, v_j] + for i0, i1, i2, k in T.grid(1, 1, 4096, 11008): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(1, 1, 4096): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv570[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + @I.ir_module + class Expected: + @T.prim_func + def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + var_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") + var_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local") + for ax0_fused_0 in T.thread_binding(256, thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(16, thread="threadIdx.x"): + for ax1_0_fused_1 in T.thread_binding(16, thread="threadIdx.y"): + with T.block("matmul_rf_init"): + vax1_0_fused_1 = T.axis.spatial(16, ax1_0_fused_1) + v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused_1) + T.reads() + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) + for ax1_0_fused_0, ax1_1 in T.grid(86, 8): + with T.block("matmul_rf_update"): + vax1_0_fused_1 = T.axis.spatial(16, ax1_0_fused_1) + v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused_1) + vax1_0_fused_0, vax1_1 = T.axis.remap("RR", [ax1_0_fused_0, ax1_1]) + T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0], lv574[0, 0, vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1], lv575[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 8, v0], lv576[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 32, v0]) + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + lv574[0, 0, vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv575[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 8, v0], T.Cast("uint32", (vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 32, v0]) + for ax1_fused in T.thread_binding(16, thread="threadIdx.x"): + for ax0 in T.thread_binding(16, thread="threadIdx.y"): + with T.block("matmul"): + vax1_0_fused_1 = T.axis.reduce(16, ax0) + v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax1_fused) + T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) + T.writes(var_matmul_intermediate_local[0, 0, v0]) + with T.init(): + var_matmul_intermediate_local[0, 0, v0] = T.float16(0) + var_matmul_intermediate_local[0, 0, v0] = var_matmul_intermediate_local[0, 0, v0] + var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + for ax0_fused in T.thread_binding(16, thread="threadIdx.x"): + with T.block("T_add"): + v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused) + T.reads(lv570[0, 0, v0], var_matmul_intermediate_local[0, 0, v0]) + T.writes(p_output0_intermediate[0, 0, v0]) + p_output0_intermediate[0, 0, v0] = lv570[0, 0, v0] + var_matmul_intermediate_local[0, 0, v0] + # fmt: on + + target = Target("nvidia/geforce-rtx-3090-ti") + with target: + mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Module) # pylint: disable=not-callable + assert_structural_equal(mod, Expected) + + +def test_spatial_inner_broadcasting(): + # fmt: off + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + temp_local = T.alloc_buffer((256,)) + for j in T.serial(256): + for k in T.serial(256): + with T.block("sum"): + vj, vk = T.axis.remap("SR", [j, k]) + T.reads(A[vk, vj]) + T.writes(temp_local[vj]) + with T.init(): + temp_local[vj] = T.float32(0) + temp_local[vj] = temp_local[vj] + A[vk, vj] + for i, j in T.grid(256, 256): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(temp_local[vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] + temp_local[vj] + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + temp_local_shared = T.alloc_buffer((256,), scope="shared") + temp_local_rf_local = T.alloc_buffer((16, 256), scope="local") + for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(16, thread="threadIdx.x"): + for ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): + with T.block("sum_rf_init"): + vax1_fused_1 = T.axis.spatial(16, ax1_fused_1) + v0 = T.axis.spatial(256, ax0_fused_0 * 16 + ax0_fused_1) + T.reads() + T.writes(temp_local_rf_local[vax1_fused_1, v0]) + temp_local_rf_local[vax1_fused_1, v0] = T.float32(0) + for ax1_fused_0, u in T.grid(16, 1): + with T.block("sum_rf_update"): + vax1_fused_1 = T.axis.spatial(16, ax1_fused_1) + v0 = T.axis.spatial(256, ax0_fused_0 * 16 + ax0_fused_1) + vax1_fused_0 = T.axis.reduce(16, ax1_fused_0) + T.reads(temp_local_rf_local[vax1_fused_1, v0], A[vax1_fused_0 * 16 + vax1_fused_1, v0]) + T.writes(temp_local_rf_local[vax1_fused_1, v0]) + temp_local_rf_local[vax1_fused_1, v0] = temp_local_rf_local[vax1_fused_1, v0] + A[vax1_fused_0 * 16 + vax1_fused_1, v0] + for ax1_fused in T.thread_binding(16, thread="threadIdx.x"): + for ax0 in T.thread_binding(16, thread="threadIdx.y"): + with T.block("sum"): + vax1_fused_1 = T.axis.reduce(16, ax0) + v0 = T.axis.spatial(256, ax0_fused_0 * 16 + ax1_fused) + T.reads(temp_local_rf_local[vax1_fused_1, v0]) + T.writes(temp_local_shared[v0]) + with T.init(): + temp_local_shared[v0] = T.float32(0) + temp_local_shared[v0] = temp_local_shared[v0] + temp_local_rf_local[vax1_fused_1, v0] + for ax0_ax1_fused_0 in range(16): + for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.thread_binding(16, thread="threadIdx.y"): + with T.block("add"): + v0 = T.axis.spatial(256, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) // 16) + v1 = T.axis.spatial(256, ax0_fused_0 * 16 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 16) + T.reads(temp_local_shared[v1]) + T.writes(B[v0, v1]) + B[v0, v1] = A[v0, v1] + temp_local_shared[v1] + # fmt: on + + target = Target("nvidia/geforce-rtx-3090-ti") + with target: + mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Module) # pylint: disable=not-callable + assert_structural_equal(mod, Expected) + + +def test_reduction_inner_no_broadcasting(): + # fmt: off + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + temp_local = T.alloc_buffer((256,)) + for i in T.serial(256): + for k in T.serial(256): + with T.block("sum"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads(A[vi, vk]) + T.writes(temp_local[vi]) + with T.init(): + temp_local[vi] = T.float32(0) + temp_local[vi] = temp_local[vi] + A[vi, vk] + for i in T.grid(256): + with T.block("add"): + vi = T.axis.remap("S", [i]) + T.reads(temp_local[vi]) + T.writes(B[vi,]) + B[vi] = temp_local[vi] + T.float32(1) + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + temp_local_local = T.alloc_buffer((256,), scope="local") + temp_local_rf_local = T.alloc_buffer((256, 256), scope="local") + for ax0_fused in T.thread_binding(256, thread="blockIdx.x"): + for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("sum_rf_init"): + vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused]) + T.reads() + T.writes(temp_local_rf_local[vax1_fused_1, v0]) + temp_local_rf_local[vax1_fused_1, v0] = T.float32(0) + for ax1_fused_0, u in T.grid(1, 1): + with T.block("sum_rf_update"): + vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0]) + T.reads(temp_local_rf_local[vax1_fused_1, v0], A[v0, vax1_fused_0 * 256 + vax1_fused_1]) + T.writes(temp_local_rf_local[vax1_fused_1, v0]) + temp_local_rf_local[vax1_fused_1, v0] = temp_local_rf_local[vax1_fused_1, v0] + A[v0, vax1_fused_0 * 256 + vax1_fused_1] + for ax1_fused in range(1): + for ax0 in T.thread_binding(256, thread="threadIdx.x"): + with T.block("sum"): + vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused]) + T.reads(temp_local_rf_local[vax1_fused_1, v0]) + T.writes(temp_local_local[v0]) + with T.init(): + temp_local_local[v0] = T.float32(0) + temp_local_local[v0] = temp_local_local[v0] + temp_local_rf_local[vax1_fused_1, v0] + with T.block("add"): + v0 = T.axis.spatial(256, ax0_fused) + T.reads(temp_local_local[v0]) + T.writes(B[v0]) + B[v0] = temp_local_local[v0] + T.float32(1) + # fmt: on + + target = Target("nvidia/geforce-rtx-3090-ti") + with target: + mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Module) # pylint: disable=not-callable + assert_structural_equal(mod, Expected) + + if __name__ == "__main__": - test_decode_gemv_1() - test_decode_gemv_2() - test_decode_gemv_3() - test_decode_gemv_4() - test_decode_gemv_sigmoid() - test_decode_gemv_1_fp32() - test_reduction_no_spatial() + tvm.testing.main()