From 4c04c40d26e649831b1df35a352af492fb63f17d Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 16 Jul 2023 01:42:03 -0400 Subject: [PATCH] [Unity][Dlight] Fix decode-GeMV rule when spatial-inner without broadcasting This PR fixes a bug of the previous decode-GeMV dlight scheduling. Previously, when the inner dimension of the largest tensor is spatial, in the end the fused epilogue block was not bound to any thread axis, which is wrong and will generate wrong GPU code with wrong numerical results. That is because after doing reverse-compute-at of the epilogue block, there are at lease one remaining spatial axis, and such axis is supposed to be bound to threadIdx. This PR fixes this issue, and add three test cases which can cover both the reduction-inner and spatial-inner cases with or without broadcasting. --- python/tvm/dlight/gpu/decode_gemv.py | 6 +- tests/python/dlight/test_gpu_decode_gemv.py | 230 +++++++++++++++++++- 2 files changed, 228 insertions(+), 8 deletions(-) diff --git a/python/tvm/dlight/gpu/decode_gemv.py b/python/tvm/dlight/gpu/decode_gemv.py index afcfdb30206b..1aa5d68fc53e 100644 --- a/python/tvm/dlight/gpu/decode_gemv.py +++ b/python/tvm/dlight/gpu/decode_gemv.py @@ -233,7 +233,11 @@ def _sch_inner_spatial( _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name _, tx, ty = sch.split(sch.fuse(*s), factors=[None, len_tx, len_ty]) sch.bind(tx, "threadIdx.x") - sch.bind(ty, "threadIdx.x") + sch.bind(ty, "threadIdx.y") else: + # The epilogue is element-wise without broadcasting. + # Thus the remaining spatial part should be bind to tx. sch.set_scope(block, 0, "local") + _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + sch.bind(sch.fuse(*s), "threadIdx.x") # pylint: enable=invalid-name diff --git a/tests/python/dlight/test_gpu_decode_gemv.py b/tests/python/dlight/test_gpu_decode_gemv.py index 7b19e6b7f811..971f5f4d09ba 100644 --- a/tests/python/dlight/test_gpu_decode_gemv.py +++ b/tests/python/dlight/test_gpu_decode_gemv.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring,line-too-long,invalid-name,too-few-public-methods,too-many-locals + +import tvm.testing from tvm import dlight as dl from tvm.ir import assert_structural_equal from tvm.script import ir as I @@ -489,11 +491,225 @@ def main(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), "float16"), assert_structural_equal(mod, After) +def test_spatial_inner_no_broadcasting(): + # fmt: off + @I.ir_module + class Module: + @T.prim_func + def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + p_output0_intermediate_1 = T.alloc_buffer((11008, 4096), "float16") + var_matmul_intermediate = T.alloc_buffer((1, 1, 4096), "float16") + for i, j in T.grid(11008, 4096): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv575[v_i // 8, v_j], lv576[v_i // 32, v_j]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v_i // 32, v_j] + for i0, i1, i2, k in T.grid(1, 1, 4096, 11008): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(1, 1, 4096): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv570[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + @I.ir_module + class Expected: + @T.prim_func + def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + var_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") + var_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local") + for ax0_fused_0 in T.thread_binding(256, thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(16, thread="threadIdx.x"): + for ax1_0_fused_1 in T.thread_binding(16, thread="threadIdx.y"): + with T.block("matmul_rf_init"): + vax1_0_fused_1 = T.axis.spatial(16, ax1_0_fused_1) + v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused_1) + T.reads() + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) + for ax1_0_fused_0, ax1_1 in T.grid(86, 8): + with T.block("matmul_rf_update"): + vax1_0_fused_1 = T.axis.spatial(16, ax1_0_fused_1) + v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused_1) + vax1_0_fused_0, vax1_1 = T.axis.remap("RR", [ax1_0_fused_0, ax1_1]) + T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0], lv574[0, 0, vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1], lv575[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 8, v0], lv576[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 32, v0]) + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + lv574[0, 0, vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv575[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 8, v0], T.Cast("uint32", (vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 32, v0]) + for ax1_fused in T.thread_binding(16, thread="threadIdx.x"): + for ax0 in T.thread_binding(16, thread="threadIdx.y"): + with T.block("matmul"): + vax1_0_fused_1 = T.axis.reduce(16, ax0) + v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax1_fused) + T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) + T.writes(var_matmul_intermediate_local[0, 0, v0]) + with T.init(): + var_matmul_intermediate_local[0, 0, v0] = T.float16(0) + var_matmul_intermediate_local[0, 0, v0] = var_matmul_intermediate_local[0, 0, v0] + var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + for ax0_fused in T.thread_binding(16, thread="threadIdx.x"): + with T.block("T_add"): + v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused) + T.reads(lv570[0, 0, v0], var_matmul_intermediate_local[0, 0, v0]) + T.writes(p_output0_intermediate[0, 0, v0]) + p_output0_intermediate[0, 0, v0] = lv570[0, 0, v0] + var_matmul_intermediate_local[0, 0, v0] + # fmt: on + + target = Target("nvidia/geforce-rtx-3090-ti") + with target: + mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Module) # pylint: disable=not-callable + assert_structural_equal(mod, Expected) + + +def test_spatial_inner_broadcasting(): + # fmt: off + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + temp_local = T.alloc_buffer((256,)) + for j in T.serial(256): + for k in T.serial(256): + with T.block("sum"): + vj, vk = T.axis.remap("SR", [j, k]) + T.reads(A[vk, vj]) + T.writes(temp_local[vj]) + with T.init(): + temp_local[vj] = T.float32(0) + temp_local[vj] = temp_local[vj] + A[vk, vj] + for i, j in T.grid(256, 256): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(temp_local[vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] + temp_local[vj] + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + temp_local_shared = T.alloc_buffer((256,), scope="shared") + temp_local_rf_local = T.alloc_buffer((16, 256), scope="local") + for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(16, thread="threadIdx.x"): + for ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): + with T.block("sum_rf_init"): + vax1_fused_1 = T.axis.spatial(16, ax1_fused_1) + v0 = T.axis.spatial(256, ax0_fused_0 * 16 + ax0_fused_1) + T.reads() + T.writes(temp_local_rf_local[vax1_fused_1, v0]) + temp_local_rf_local[vax1_fused_1, v0] = T.float32(0) + for ax1_fused_0, u in T.grid(16, 1): + with T.block("sum_rf_update"): + vax1_fused_1 = T.axis.spatial(16, ax1_fused_1) + v0 = T.axis.spatial(256, ax0_fused_0 * 16 + ax0_fused_1) + vax1_fused_0 = T.axis.reduce(16, ax1_fused_0) + T.reads(temp_local_rf_local[vax1_fused_1, v0], A[vax1_fused_0 * 16 + vax1_fused_1, v0]) + T.writes(temp_local_rf_local[vax1_fused_1, v0]) + temp_local_rf_local[vax1_fused_1, v0] = temp_local_rf_local[vax1_fused_1, v0] + A[vax1_fused_0 * 16 + vax1_fused_1, v0] + for ax1_fused in T.thread_binding(16, thread="threadIdx.x"): + for ax0 in T.thread_binding(16, thread="threadIdx.y"): + with T.block("sum"): + vax1_fused_1 = T.axis.reduce(16, ax0) + v0 = T.axis.spatial(256, ax0_fused_0 * 16 + ax1_fused) + T.reads(temp_local_rf_local[vax1_fused_1, v0]) + T.writes(temp_local_shared[v0]) + with T.init(): + temp_local_shared[v0] = T.float32(0) + temp_local_shared[v0] = temp_local_shared[v0] + temp_local_rf_local[vax1_fused_1, v0] + for ax0_ax1_fused_0 in range(16): + for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.thread_binding(16, thread="threadIdx.y"): + with T.block("add"): + v0 = T.axis.spatial(256, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) // 16) + v1 = T.axis.spatial(256, ax0_fused_0 * 16 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 16) + T.reads(temp_local_shared[v1]) + T.writes(B[v0, v1]) + B[v0, v1] = A[v0, v1] + temp_local_shared[v1] + # fmt: on + + target = Target("nvidia/geforce-rtx-3090-ti") + with target: + mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Module) # pylint: disable=not-callable + assert_structural_equal(mod, Expected) + + +def test_reduction_inner_no_broadcasting(): + # fmt: off + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + temp_local = T.alloc_buffer((256,)) + for i in T.serial(256): + for k in T.serial(256): + with T.block("sum"): + vi, vk = T.axis.remap("SR", [i, k]) + T.reads(A[vi, vk]) + T.writes(temp_local[vi]) + with T.init(): + temp_local[vi] = T.float32(0) + temp_local[vi] = temp_local[vi] + A[vi, vk] + for i in T.grid(256): + with T.block("add"): + vi = T.axis.remap("S", [i]) + T.reads(temp_local[vi]) + T.writes(B[vi,]) + B[vi] = temp_local[vi] + T.float32(1) + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + temp_local_local = T.alloc_buffer((256,), scope="local") + temp_local_rf_local = T.alloc_buffer((256, 256), scope="local") + for ax0_fused in T.thread_binding(256, thread="blockIdx.x"): + for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("sum_rf_init"): + vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused]) + T.reads() + T.writes(temp_local_rf_local[vax1_fused_1, v0]) + temp_local_rf_local[vax1_fused_1, v0] = T.float32(0) + for ax1_fused_0, u in T.grid(1, 1): + with T.block("sum_rf_update"): + vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0]) + T.reads(temp_local_rf_local[vax1_fused_1, v0], A[v0, vax1_fused_0 * 256 + vax1_fused_1]) + T.writes(temp_local_rf_local[vax1_fused_1, v0]) + temp_local_rf_local[vax1_fused_1, v0] = temp_local_rf_local[vax1_fused_1, v0] + A[v0, vax1_fused_0 * 256 + vax1_fused_1] + for ax1_fused in range(1): + for ax0 in T.thread_binding(256, thread="threadIdx.x"): + with T.block("sum"): + vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused]) + T.reads(temp_local_rf_local[vax1_fused_1, v0]) + T.writes(temp_local_local[v0]) + with T.init(): + temp_local_local[v0] = T.float32(0) + temp_local_local[v0] = temp_local_local[v0] + temp_local_rf_local[vax1_fused_1, v0] + with T.block("add"): + v0 = T.axis.spatial(256, ax0_fused) + T.reads(temp_local_local[v0]) + T.writes(B[v0]) + B[v0] = temp_local_local[v0] + T.float32(1) + # fmt: on + + target = Target("nvidia/geforce-rtx-3090-ti") + with target: + mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Module) # pylint: disable=not-callable + assert_structural_equal(mod, Expected) + + if __name__ == "__main__": - test_decode_gemv_1() - test_decode_gemv_2() - test_decode_gemv_3() - test_decode_gemv_4() - test_decode_gemv_sigmoid() - test_decode_gemv_1_fp32() - test_reduction_no_spatial() + tvm.testing.main()