diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index ed32ea77858f..cbef6235c098 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -208,8 +208,17 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- elif is_inner_reduction: self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) return sch + elif target.kind.name == "opencl" and "android" in str(target.host): + ret = self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) + if ret is None: + return self.sch_outer_reduction_fallback( + sch, target, block, vector_input_buffers, epilogue + ) + return sch else: - return self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) + return self.sch_outer_reduction_fallback( + sch, target, block, vector_input_buffers, epilogue + ) def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument self, @@ -486,7 +495,7 @@ def apply( LOAD_V_SHARED = False LOAD_V_VEC = -1 UNROLL = 8 - TS, TR = 2, 32 + TS, TR = 2, 64 elif target.kind.name == "vulkan": VEC_C = 4 LOAD_V_SHARED = True @@ -553,6 +562,191 @@ def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, un epilogue_info: Optional[BlockInfo], ): """Schedule the outer reduction block.""" + + def get_max_factor(n, factors): + factors = sorted(factors, reverse=True) + for factor in factors: + if n % factor == 0: + return factor + return 1 + + def apply( + sch: tir.Schedule, + gemv, + TAG_S, + TAG_R, + TS, + TR, + SCALE_PACK, + DEC_PACK, + VEC_LOAD, + VEC_C, + LOAD_V_SHARED, + LOAD_V_VEC, + UNROLL, + LOAD_V_TILE, + ): + # rfactor: reduce to tx * vec_c + batch, s, r, c = sch.get_loops(block=gemv) + s = sch.fuse(batch, s) + r = sch.fuse(r, c) + bx, ts = sch.split(s, factors=[None, TS], preserve_unit_iters=True) + r, v_tile, tr, tile_r, vec_c = sch.split( + r, factors=[None, LOAD_V_TILE, TR, SCALE_PACK, DEC_PACK], preserve_unit_iters=True + ) + sch.reorder(bx, ts, r, v_tile, tile_r, tr, vec_c) + tr_vec_c = sch.fuse(tr, vec_c) + rf = sch.rfactor(tr_vec_c, 0) + + # rfactor: reduce to tx + bx, ts, tr_vec_c = sch.get_loops(block=gemv) + tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True) + rf2 = sch.rfactor(tr, 0) + + # bind, vectorize compute + bx, ts, r, v_tile, tile_r, tr_vec_c = sch.get_loops(block=rf) + tr, vec_c = sch.split(tr_vec_c, factors=[TR, DEC_PACK]) + sch.reorder(bx, ts, tr, r, v_tile, tile_r, vec_c) + # sch.bind(batch, "blockIdx.z") + sch.bind(bx, "blockIdx.x") + sch.bind(ts, "threadIdx.x") + sch.bind(tr, "threadIdx.y") + sch.vectorize(vec_c) + + # decompose independent scale read to outer loop + block_rf_stmt = sch.get(rf) + if len(block_rf_stmt.reads) >= 3: + As_local = sch.cache_read(rf, read_buffer_index=2, storage_scope="local") + sch.compute_at(As_local, v_tile, preserve_unit_loops=True) + # *tile_thr, vec_s = sch.get_loops(block=As_local) + # sch.vectorize(vec_s) + + Aq_local = sch.cache_read(rf, read_buffer_index=1, storage_scope="local") + sch.compute_at(Aq_local, tile_r, preserve_unit_loops=True) + # *tile_thr, vec_s = sch.get_loops(block=Aq_local) + # sch.vectorize(vec_s) + + if LOAD_V_SHARED: + V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared") + sch.compute_at(V_shared, r, preserve_unit_loops=True) + l = sch.get_loops(block=V_shared)[-1] + _, v_tile, tx, ty, vec = sch.split( + l, factors=[None, LOAD_V_TILE, TS, TR, LOAD_V_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + # reduce tile_s * tr * vec to tile_s * tr + sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) + tr, vec_c, ts = sch.get_loops(block=rf2)[1:] + sch.reorder(ts, tr, vec_c) + sch.bind(ts, "threadIdx.x") + sch.bind(tr, "threadIdx.y") + + # reduce tile_s * tr to tile_s + sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True) + tr, ts = sch.get_loops(block=gemv)[1:] + sch.reorder(ts, tr) + sch.bind(ts, "threadIdx.x") + sch.bind(tr, "threadIdx.y") + + sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[2]) + sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1]) + + sch.set_scope(rf, buffer_index=0, storage_scope="local") + sch.set_scope(rf2, buffer_index=0, storage_scope="local") + + sch.annotate( + block_or_loop=sch.get_loops(rf2)[3], + ann_key="pragma_auto_unroll_max_step", + ann_val=DEC_PACK, + ) + sch.annotate( + block_or_loop=sch.get_loops(rf2)[3], ann_key="pragma_unroll_explicit", ann_val=1 + ) + + # Schedule epilogue + if epilogue_info is not None: + epilogue = epilogue_info.block_rv + if is_broadcast_epilogue(sch, block, epilogue): + sch.reverse_compute_at(epilogue, bx) + sch.set_scope(block, 0, "shared") + _, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + _, tx = sch.split(sch.fuse(*s), factors=[None, TX]) + sch.bind(tx, "threadIdx.x") + else: + sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) + ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:]) + ts_tile_s = sch.get_loops(epilogue)[-1] + ts, _ = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + sch.bind(ts, "threadIdx.x") + sch.set_scope(block, 0, "local") + return sch + + # Specify the `len_tx` and `len_ty` according to the loop extent + batch, s, r, c = sch.get_loops(block=block) + _, len_s, len_r, len_c = ( + get_extent(sch, batch), + get_extent(sch, s), + get_extent(sch, r), + get_extent(sch, c), + ) + + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + VEC_C = 1 + UNROLL = 4 + TS, TR = 64, 4 + DEC_PACK = 8 + SCALE_PACK = 4 + LOAD_V_SHARED = False + LOAD_V_VEC = 4 + LOAD_V_TILE = 8 + + if LOAD_V_SHARED is False: + LOAD_V_TILE = 1 + + if not isinstance(len_r, int): + return None + + if isinstance(len_s, int) and len_s > 32000: + return None + + _, TILE_R = ( + 1, + len_c + if len_c > 1 + else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1), + ) + LOAD_V_VEC = min(get_max_factor(TILE_R, [1, 2, 4, 8]), LOAD_V_VEC) + VEC_LOAD = 1 + + return apply( + sch, + gemv=block, + TAG_S=TAG_S, + TAG_R=TAG_R, + TS=TS, + TR=TR, + SCALE_PACK=SCALE_PACK, + DEC_PACK=DEC_PACK, + VEC_LOAD=VEC_LOAD, + VEC_C=VEC_C, + LOAD_V_SHARED=LOAD_V_SHARED, + LOAD_V_VEC=LOAD_V_VEC, + UNROLL=UNROLL, + LOAD_V_TILE=LOAD_V_TILE, + ) + + def sch_outer_reduction_fallback( # pylint: disable=too-many-arguments, invalid-name, unused-argument + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + vector_input_buffers: List[tir.Buffer], + epilogue_info: Optional[BlockInfo], + ): + """Schedule the outer reduction block.""" # NOTE: Only Android is supported so far if not (target.kind.name == "opencl" and "android" in str(target.host)): return None diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index ed81b7f6881f..f4ef1f50448b 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -777,7 +777,7 @@ def get_configs(self, target: Target) -> Config: elif target.kind.name == "opencl" and "android" in str(target.host): return Matmul.Config( block_size_x=8, - block_size_y=8, + block_size_y=16, vthread_x=1, vthread_y=1, micro_size_x=8, diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py index 0fd7f791599f..4aae617654d2 100644 --- a/tests/python/dlight/test_gpu_gemv.py +++ b/tests/python/dlight/test_gpu_gemv.py @@ -732,77 +732,331 @@ def expected( T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): var_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") - lv574_local = T.alloc_buffer((1, 1, 11008), "float16", scope="local") - for u_fused in T.thread_binding(1, thread="blockIdx.y"): - for ax0_fused_0 in T.thread_binding(32, thread="blockIdx.x"): - for ax0_fused_1 in T.thread_binding( - 64, - thread="threadIdx.x", - annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}, - ): - for ax0_fused_2_init in T.vectorized(2): - with T.block("matmul_init"): + var_matmul_intermediate_rf_local = T.alloc_buffer( + (32, 1, 1, 4096), "float16", scope="local" + ) + var_matmul_intermediate_rf_local_1 = T.alloc_buffer( + (4, 1, 1, 4096), "float16", scope="local" + ) + lv576_local = T.alloc_buffer((344, 4096), "float16", scope="local") + lv575_local = T.alloc_buffer((1376, 4096), "uint32", scope="local") + for u_fused_ax0_fused_fused_0 in T.thread_binding(64, thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + for ( + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init + ) in T.thread_binding(4, thread="threadIdx.y"): + for ( + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init + ) in T.vectorized(8): + with T.block("matmul_rf_init"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial( + 32, + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * 8 + + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init, + ) v0 = T.axis.spatial( - 4096, ax0_fused_0 * 128 + ax0_fused_1 * 2 + ax0_fused_2_init + 4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 ) T.reads() - T.writes(var_matmul_intermediate_local[0, 0, v0]) - var_matmul_intermediate_local[0, 0, v0] = T.float16(0) - for ax1_0_fused_0, ax1_0_fused_1 in T.grid(344, 4): + T.writes( + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, + 0, + 0, + v0, + ] + ) + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0 + ] = T.float16(0) + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in T.thread_binding( + 4, thread="threadIdx.y" + ): + for ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1 in T.grid(86, 1): for ax0, ax1 in T.grid(1, 1): - for ax2 in T.vectorized(8): - with T.block("lv574_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - 11008, ax1_0_fused_0 * 32 + ax1_0_fused_1 * 8 + ax2 + with T.block("lv576_local"): + v0 = T.axis.spatial( + 344, + ax1_0_fused_ax1_1_fused_0 * 4 + + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + + ax0, + ) + v1 = T.axis.spatial( + 4096, + u_fused_ax0_fused_fused_0 * 64 + + u_fused_ax0_fused_fused_1 + + ax1, + ) + T.reads(lv576[v0, v1]) + T.writes(lv576_local[v0, v1]) + lv576_local[v0, v1] = lv576[v0, v1] + for ax1_0_fused_ax1_1_fused_3 in range(4): + for ax0, ax1 in T.grid(1, 1): + with T.block("lv575_local"): + v0 = T.axis.spatial( + 1376, + ax1_0_fused_ax1_1_fused_0 * 16 + + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + * 4 + + ax1_0_fused_ax1_1_fused_3 + + ax0, + ) + v1 = T.axis.spatial( + 4096, + u_fused_ax0_fused_fused_0 * 64 + + u_fused_ax0_fused_fused_1 + + ax1, + ) + T.reads(lv575[v0, v1]) + T.writes(lv575_local[v0, v1]) + lv575_local[v0, v1] = lv575[v0, v1] + for ( + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 + ) in T.vectorized(8): + with T.block("matmul_rf_update"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial( + 32, + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + * 8 + + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, ) - T.reads(lv574[v0, v1, v2]) - T.writes(lv574_local[v0, v1, v2]) - lv574_local[v0, v1, v2] = lv574[v0, v1, v2] - for ax1_1 in range(8): - for ax0_fused_2 in T.vectorized(2): - with T.block("matmul_update"): v0 = T.axis.spatial( - 4096, ax0_fused_0 * 128 + ax0_fused_1 * 2 + ax0_fused_2 + 4096, + u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1, ) - v1 = T.axis.reduce( - 11008, ax1_0_fused_0 * 32 + ax1_0_fused_1 * 8 + ax1_1 + ( + vax1_0_fused_ax1_1_fused_0, + vax1_0_fused_ax1_1_fused_1, + vax1_0_fused_ax1_1_fused_3, + ) = T.axis.remap( + "RRR", + [ + ax1_0_fused_ax1_1_fused_0, + ax1_0_fused_ax1_1_fused_1, + ax1_0_fused_ax1_1_fused_3, + ], ) T.reads( - var_matmul_intermediate_local[0, 0, v0], - lv574_local[0, 0, v1], - lv575[v1 // 8, v0], - lv576[v1 // 32, v0], + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, + 0, + 0, + v0, + ], + lv574[ + 0, + 0, + vax1_0_fused_ax1_1_fused_0 * 128 + + vax1_0_fused_ax1_1_fused_1 * 128 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + * 32 + + vax1_0_fused_ax1_1_fused_3 * 8 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + % 8, + ], + lv575_local[ + vax1_0_fused_ax1_1_fused_0 * 16 + + vax1_0_fused_ax1_1_fused_1 * 16 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + * 4 + + vax1_0_fused_ax1_1_fused_3, + v0, + ], + lv576_local[ + vax1_0_fused_ax1_1_fused_0 * 4 + + vax1_0_fused_ax1_1_fused_1 * 4 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + + vax1_0_fused_ax1_1_fused_3 // 4, + v0, + ], + ) + T.writes( + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, + 0, + 0, + v0, + ], ) - T.writes(var_matmul_intermediate_local[0, 0, v0]) - var_matmul_intermediate_local[ - 0, 0, v0 - ] = var_matmul_intermediate_local[0, 0, v0] + lv574_local[ - 0, 0, v1 + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, + 0, + 0, + v0, + ] = var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, + 0, + 0, + v0, + ] + lv574[ + 0, + 0, + vax1_0_fused_ax1_1_fused_0 * 128 + + vax1_0_fused_ax1_1_fused_1 * 128 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + * 32 + + vax1_0_fused_ax1_1_fused_3 * 8 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + % 8, ] * ( ( T.Cast( "float16", T.bitwise_and( T.shift_right( - lv575[v1 // 8, v0], - T.Cast("uint32", v1 % 8) * T.uint32(4), + lv575_local[ + vax1_0_fused_ax1_1_fused_0 * 16 + + vax1_0_fused_ax1_1_fused_1 * 16 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + * 4 + + vax1_0_fused_ax1_1_fused_3, + v0, + ], + T.Cast( + "uint32", + ( + vax1_0_fused_ax1_1_fused_0 * 128 + + vax1_0_fused_ax1_1_fused_1 * 128 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + * 32 + + vax1_0_fused_ax1_1_fused_3 * 8 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + % 8 + ) + % 8, + ) + * T.uint32(4), ), T.uint32(15), ), ) - T.float16(7) ) - * lv576[v1 // 32, v0] + * lv576_local[ + vax1_0_fused_ax1_1_fused_0 * 4 + + vax1_0_fused_ax1_1_fused_1 * 4 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + + vax1_0_fused_ax1_1_fused_3 // 4, + v0, + ] ) - for ax0 in range(2): - with T.block("T_add"): - v0 = T.axis.spatial(4096, ax0_fused_0 * 128 + ax0_fused_1 * 2 + ax0) - T.reads(lv570[0, 0, v0], var_matmul_intermediate_local[0, 0, v0]) - T.writes(p_output0_intermediate[0, 0, v0]) - p_output0_intermediate[0, 0, v0] = ( - lv570[0, 0, v0] + var_matmul_intermediate_local[0, 0, v0] + for ax2 in T.thread_binding(64, thread="threadIdx.x"): + for ax0 in T.thread_binding(4, thread="threadIdx.y"): + with T.block("matmul_rf_init"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = ( + T.axis.spatial(4, ax0) + ) + v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2) + T.reads() + T.writes( + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ] + ) + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0 + ] = T.float16(0) + for ax1 in T.serial( + 8, + annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}, + ): + with T.block("matmul_rf_update"): + ( + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, + ) = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2) + T.reads( + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ], + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, + 0, + 0, + v0, + ], + ) + T.writes( + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ] + ) + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ] = ( + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ] + + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, + 0, + 0, + v0, + ] ) + for ax1 in T.thread_binding(64, thread="threadIdx.x"): + for ax0 in T.thread_binding(4, thread="threadIdx.y"): + with T.block("matmul"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = ( + T.axis.reduce(4, ax0) + ) + v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax1) + T.reads( + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ] + ) + T.writes(var_matmul_intermediate_local[0, 0, v0]) + with T.init(): + var_matmul_intermediate_local[0, 0, v0] = T.float16(0) + var_matmul_intermediate_local[0, 0, v0] = ( + var_matmul_intermediate_local[0, 0, v0] + + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ] + ) + for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.x"): + for ax0_fused_1 in range(1): + with T.block("T_add"): + v0 = T.axis.spatial( + 4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1 + ) + T.reads(lv570[0, 0, v0], var_matmul_intermediate_local[0, 0, v0]) + T.writes(p_output0_intermediate[0, 0, v0]) + p_output0_intermediate[0, 0, v0] = ( + lv570[0, 0, v0] + var_matmul_intermediate_local[0, 0, v0] + ) mod = tvm.IRModule({"main": before}) with Target("opencl", host="llvm -mtriple=aarch64-linux-android"): @@ -852,38 +1106,82 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), v)) # with T.block("root"): var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), v), "float16", scope="local") - lv1607_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local") - for u_fused in T.thread_binding(1, thread="blockIdx.y"): - for ax0_fused_0 in T.thread_binding((v + T.int64(63)) // T.int64(64), thread="blockIdx.x"): - for ax0_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): - for ax0_fused_2_init in T.vectorized(T.int64(1)): - with T.block("matmul_init"): - v0 = T.axis.spatial(v, ax0_fused_0 * T.int64(64) + ax0_fused_1 + ax0_fused_2_init) - T.where(ax0_fused_0 * T.int64(64) + ax0_fused_1 + ax0_fused_2_init < v) + var_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), v), "float16", scope="local") + var_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(1), v), "float16", scope="local") + lv613_local = T.alloc_buffer((T.int64(128), v), "float16", scope="local") + lv612_local = T.alloc_buffer((T.int64(512), v), "uint32", scope="local") + for u_fused_ax0_fused_fused_0 in T.thread_binding((v + T.int64(63)) // T.int64(64), thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init in T.vectorized(T.int64(8)): + with T.block("matmul_rf_init"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(T.int64(32), ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * T.int64(8) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 < v) T.reads() - T.writes(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) - var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) - for ax1_0_fused_0, ax1_0_fused_1 in T.grid(T.int64(128), T.int64(4)): + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1 in T.grid(T.int64(32), T.int64(1)): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(8)): - with T.block("lv1607_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), ax1_0_fused_0 * T.int64(32) + ax1_0_fused_1 * T.int64(8) + ax2) - T.reads(lv1607[v0, v1, v2]) - T.writes(lv1607_local[v0, v1, v2]) - lv1607_local[v0, v1, v2] = lv1607[v0, v1, v2] - for ax1_1 in range(T.int64(8)): - for ax0_fused_2 in T.vectorized(T.int64(1)): - with T.block("matmul_update"): - v0 = T.axis.spatial(v, ax0_fused_0 * T.int64(64) + ax0_fused_1 + ax0_fused_2) - v1 = T.axis.reduce(T.int64(4096), ax1_0_fused_0 * T.int64(32) + ax1_0_fused_1 * T.int64(8) + ax1_1) - T.where(ax0_fused_0 * T.int64(64) + ax0_fused_1 + ax0_fused_2 < v) - T.reads(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0], lv1607_local[T.int64(0), T.int64(0), v1], lv612[v1 // T.int64(8), v0], lv613[v1 // T.int64(32), v0]) - T.writes(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) - var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + lv1607_local[T.int64(0), T.int64(0), v1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv612[v1 // T.int64(8), v0], T.Cast("uint32", v1 % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613[v1 // T.int64(32), v0]) + with T.block("lv613_local"): + v0 = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + ax0) + v1 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 + ax1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 < v) + T.reads(lv613[v0, v1]) + T.writes(lv613_local[v0, v1]) + lv613_local[v0, v1] = lv613[v0, v1] + for ax1_0_fused_ax1_1_fused_3 in range(T.int64(4)): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + with T.block("lv612_local"): + v0 = T.axis.spatial(T.int64(512), ax1_0_fused_ax1_1_fused_0 * T.int64(16) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_3 + ax0) + v1 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 + ax1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 < v) + T.reads(lv612[v0, v1]) + T.writes(lv612_local[v0, v1]) + lv612_local[v0, v1] = lv612[v0, v1] + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in T.vectorized(T.int64(8)): + with T.block("matmul_rf_update"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(T.int64(32), ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1) + vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1, ax1_0_fused_ax1_1_fused_3]) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 < v) + T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0], lv1607[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 * T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % T.int64(8)], lv612_local[vax1_0_fused_ax1_1_fused_0 * T.int64(16) + vax1_0_fused_ax1_1_fused_1 * T.int64(16) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(4) + vax1_0_fused_ax1_1_fused_3, v0], lv613_local[vax1_0_fused_ax1_1_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1 * T.int64(4) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) + vax1_0_fused_ax1_1_fused_3 // T.int64(4), v0]) + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] = var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] + lv1607[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 * T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % T.int64(8)] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv612_local[vax1_0_fused_ax1_1_fused_0 * T.int64(16) + vax1_0_fused_ax1_1_fused_1 * T.int64(16) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(4) + vax1_0_fused_ax1_1_fused_3, v0], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 * T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % T.int64(8)) % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613_local[vax1_0_fused_ax1_1_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1 * T.int64(4) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) + vax1_0_fused_ax1_1_fused_3 // T.int64(4), v0]) + for ax2 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + with T.block("matmul_rf_init"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.spatial(T.int64(4), ax0) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + ax2) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax2 < v) + T.reads() + T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0]) + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): + with T.block("matmul_rf_update"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + ax2) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax2 < v) + T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0], var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, T.int64(0), T.int64(0), v0]) + T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0]) + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] = var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, T.int64(0), T.int64(0), v0] + for ax1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + with T.block("matmul"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.reduce(T.int64(4), ax0) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + ax1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax1 < v) + T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0]) + T.writes(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) + with T.init(): + var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) + var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] + for ax0_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_fused_1 in range(T.int64(1)): with T.block("compute"): - v0 = T.axis.spatial(v, ax0_fused_0 * T.int64(64) + ax0_fused_1) - T.where(ax0_fused_0 * T.int64(64) + ax0_fused_1 < v) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + ax0_fused_0 + ax0_fused_1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + (ax0_fused_0 + ax0_fused_1) < v) T.reads(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) T.writes(p_output0_intermediate[T.int64(0), T.int64(0), v0]) p_output0_intermediate[T.int64(0), T.int64(0), v0] = T.Cast("float32", var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index a421d9e6c734..63117073d156 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -634,18 +634,18 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) # with T.block("root"): - matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="local") - for ax0_ax1_0_fused in T.thread_binding((m + T.int64(15)) // T.int64(16), thread="blockIdx.y"): + matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local") + for ax0_ax1_0_fused in T.thread_binding((m + T.int64(31)) // T.int64(32), thread="blockIdx.y"): for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.x"): for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.y"): for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax1_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): for ax1_3_init, ax2_3_0_init in T.grid(T.int64(2), T.int64(1)): for ax2_3_1_init in T.vectorized(T.int64(8)): with T.block("matmul_init"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax0_ax1_0_fused * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3_init) + v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(2) + ax1_3_init) v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(8) + ax2_3_0_init * T.int64(8) + ax2_3_1_init) T.reads() T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2]) @@ -654,7 +654,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), for ax2_3_1 in T.vectorized(T.int64(8)): with T.block("matmul_update"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax0_ax1_0_fused * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3) + v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(2) + ax1_3) v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(8) + ax2_3_0 * T.int64(8) + ax2_3_1) v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1) T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0[T.int64(0), v1, v3], inp1[v3, v2]) @@ -664,7 +664,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), for ax2_1_1 in T.vectorized(T.int64(8)): with T.block("matmul_reindex_pad_local"): v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax0_ax1_0_fused * T.int64(16) + ax1_2 * T.int64(2) + ax1) + v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_2 * T.int64(2) + ax1) v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(8) + ax2_0_1 * T.int64(8) + ax2_1_1) T.reads(matmul_reindex_pad_local[v0, v1, v2]) T.writes(matmul[T.int64(0), v1, v2])