From 642b0400197bd78d76d3165100d0e31a9ebe4afa Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Thu, 25 Apr 2024 16:27:12 +0530 Subject: [PATCH 1/7] Enable gemv schedule for adreno Enabled new gemv schedule for opencl target, which effectively improves decode performance of mlc-llm LLM models with q4f16_0 format. Few LLM models Decode performance on Snapdragon Gen-3 android. Models Baseline Latest improved Llama-2-7B 10 tok/sec 12.5 tok/sec Qwen-7b 8.5 tok/sec 11 tok/sec --- python/tvm/dlight/gpu/gemv.py | 195 +++++++++++++++++++++++++++++++- python/tvm/dlight/gpu/matmul.py | 2 +- 2 files changed, 195 insertions(+), 2 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index ed32ea77858f..693f4d3eaf0c 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -208,6 +208,13 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- elif is_inner_reduction: self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) return sch + elif target.kind.name == "opencl" and "android" in str(target.host): + ret = self.sch_adreno_inner_reduction( + sch, target, block, vector_input_buffers, epilogue + ) + if ret is None: + return self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) + return sch else: return self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) @@ -486,7 +493,7 @@ def apply( LOAD_V_SHARED = False LOAD_V_VEC = -1 UNROLL = 8 - TS, TR = 2, 32 + TS, TR = 2, 64 elif target.kind.name == "vulkan": VEC_C = 4 LOAD_V_SHARED = True @@ -544,6 +551,192 @@ def apply( SUPPORT_WARP_SHUFFLE=SUPPORT_WARP_SHUFFLE, ) + def sch_adreno_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + vector_input_buffers: List[tir.Buffer], + epilogue_info: Optional[BlockInfo], + ): + """Schedule the inner reduction block.""" + + def get_max_factor(n, factors): + factors = sorted(factors, reverse=True) + for factor in factors: + if n % factor == 0: + return factor + return 1 + + def apply( + sch: tir.Schedule, + gemv, + TAG_S, + TAG_R, + TS, + TR, + SCALE_PACK, + DEC_PACK, + VEC_LOAD, + VEC_C, + LOAD_V_SHARED, + LOAD_V_VEC, + UNROLL, + LOAD_V_TILE, + ): + # rfactor: reduce to tx * vec_c + batch, s, r, c = sch.get_loops(block=gemv) + s = sch.fuse(batch, s) + r = sch.fuse(r, c) + bx, ts = sch.split(s, factors=[None, TS], preserve_unit_iters=True) + r, v_tile, tr, tile_r, vec_c = sch.split( + r, factors=[None, LOAD_V_TILE, TR, SCALE_PACK, DEC_PACK], preserve_unit_iters=True + ) + sch.reorder(bx, ts, r, v_tile, tile_r, tr, vec_c) + tr_vec_c = sch.fuse(tr, vec_c) + rf = sch.rfactor(tr_vec_c, 0) + + # rfactor: reduce to tx + bx, ts, tr_vec_c = sch.get_loops(block=gemv) + tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True) + rf2 = sch.rfactor(tr, 0) + + # bind, vectorize compute + bx, ts, r, v_tile, tile_r, tr_vec_c = sch.get_loops(block=rf) + tr, vec_c = sch.split(tr_vec_c, factors=[TR, DEC_PACK]) + sch.reorder(bx, ts, tr, r, v_tile, tile_r, vec_c) + # sch.bind(batch, "blockIdx.z") + sch.bind(bx, "blockIdx.x") + sch.bind(ts, "threadIdx.x") + sch.bind(tr, "threadIdx.y") + sch.vectorize(vec_c) + + # decompose independent scale read to outer loop + block_rf_stmt = sch.get(rf) + if len(block_rf_stmt.reads) >= 3: + As_local = sch.cache_read(rf, read_buffer_index=2, storage_scope="local") + sch.compute_at(As_local, v_tile, preserve_unit_loops=True) + # *tile_thr, vec_s = sch.get_loops(block=As_local) + # sch.vectorize(vec_s) + + Aq_local = sch.cache_read(rf, read_buffer_index=1, storage_scope="local") + sch.compute_at(Aq_local, tile_r, preserve_unit_loops=True) + # *tile_thr, vec_s = sch.get_loops(block=Aq_local) + # sch.vectorize(vec_s) + + if LOAD_V_SHARED: + V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared") + sch.compute_at(V_shared, r, preserve_unit_loops=True) + l = sch.get_loops(block=V_shared)[-1] + _, v_tile, tx, ty, vec = sch.split( + l, factors=[None, LOAD_V_TILE, TS, TR, LOAD_V_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + # reduce tile_s * tr * vec to tile_s * tr + sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) + tr, vec_c, ts = sch.get_loops(block=rf2)[1:] + sch.reorder(ts, tr, vec_c) + sch.bind(ts, "threadIdx.x") + sch.bind(tr, "threadIdx.y") + + # reduce tile_s * tr to tile_s + sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True) + tr, ts = sch.get_loops(block=gemv)[1:] + sch.reorder(ts, tr) + sch.bind(ts, "threadIdx.x") + sch.bind(tr, "threadIdx.y") + + sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[2]) + sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1]) + + sch.set_scope(rf, buffer_index=0, storage_scope="local") + sch.set_scope(rf2, buffer_index=0, storage_scope="local") + + sch.annotate( + block_or_loop=sch.get_loops(rf2)[3], + ann_key="pragma_auto_unroll_max_step", + ann_val=DEC_PACK, + ) + sch.annotate( + block_or_loop=sch.get_loops(rf2)[3], ann_key="pragma_unroll_explicit", ann_val=1 + ) + + # Schedule epilogue + if epilogue_info is not None: + epilogue = epilogue_info.block_rv + if is_broadcast_epilogue(sch, block, epilogue): + sch.reverse_compute_at(epilogue, bx) + sch.set_scope(block, 0, "shared") + _, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + _, tx = sch.split(sch.fuse(*s), factors=[None, TX]) + sch.bind(tx, "threadIdx.x") + else: + sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) + ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:]) + ts_tile_s = sch.get_loops(epilogue)[-1] + ts, _ = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + sch.bind(ts, "threadIdx.x") + sch.set_scope(block, 0, "local") + return sch + # return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + # Specify the `len_tx` and `len_ty` according to the loop extent + batch, s, r, c = sch.get_loops(block=block) + _, len_s, len_r, len_c = ( + get_extent(sch, batch), + get_extent(sch, s), + get_extent(sch, r), + get_extent(sch, c), + ) + + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + VEC_C = 1 + UNROLL = 4 + TS, TR = 64, 4 + DEC_PACK = 8 + SCALE_PACK = 4 + LOAD_V_SHARED = False + LOAD_V_VEC = 4 + LOAD_V_TILE = 8 + + if LOAD_V_SHARED is False: + LOAD_V_TILE = 1 + + if not isinstance(len_r, int): + return None + + if isinstance(len_s, int) and len_s > 32000: + return None + + _, TILE_R = ( + 1, + len_c + if len_c > 1 + else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1), + ) + LOAD_V_VEC = min(get_max_factor(TILE_R, [1, 2, 4, 8]), LOAD_V_VEC) + VEC_LOAD = 1 + + return apply( + sch, + gemv=block, + TAG_S=TAG_S, + TAG_R=TAG_R, + TS=TS, + TR=TR, + SCALE_PACK=SCALE_PACK, + DEC_PACK=DEC_PACK, + VEC_LOAD=VEC_LOAD, + VEC_C=VEC_C, + LOAD_V_SHARED=LOAD_V_SHARED, + LOAD_V_VEC=LOAD_V_VEC, + UNROLL=UNROLL, + LOAD_V_TILE=LOAD_V_TILE, + ) + def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument self, sch: tir.Schedule, diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index ed81b7f6881f..f4ef1f50448b 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -777,7 +777,7 @@ def get_configs(self, target: Target) -> Config: elif target.kind.name == "opencl" and "android" in str(target.host): return Matmul.Config( block_size_x=8, - block_size_y=8, + block_size_y=16, vthread_x=1, vthread_y=1, micro_size_x=8, From 5038596658ad42f78af844218281b06b480f2329 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Fri, 26 Apr 2024 15:33:31 +0530 Subject: [PATCH 2/7] Modified test case according to dlight schedule update --- tests/python/dlight/test_gpu_gemv.py | 248 ++++++++++++++----------- tests/python/dlight/test_gpu_matmul.py | 12 +- 2 files changed, 149 insertions(+), 111 deletions(-) diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py index 0fd7f791599f..228c1b4a256d 100644 --- a/tests/python/dlight/test_gpu_gemv.py +++ b/tests/python/dlight/test_gpu_gemv.py @@ -722,87 +722,81 @@ def before( ) @T.prim_func(private=True) - def expected( - lv575: T.Buffer((1376, 4096), "uint32"), - lv576: T.Buffer((344, 4096), "float16"), - lv574: T.Buffer((1, 1, 11008), "float16"), - lv570: T.Buffer((1, 1, 4096), "float16"), - p_output0_intermediate: T.Buffer((1, 1, 4096), "float16"), - ): + def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): var_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") - lv574_local = T.alloc_buffer((1, 1, 11008), "float16", scope="local") - for u_fused in T.thread_binding(1, thread="blockIdx.y"): - for ax0_fused_0 in T.thread_binding(32, thread="blockIdx.x"): - for ax0_fused_1 in T.thread_binding( - 64, - thread="threadIdx.x", - annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}, - ): - for ax0_fused_2_init in T.vectorized(2): - with T.block("matmul_init"): - v0 = T.axis.spatial( - 4096, ax0_fused_0 * 128 + ax0_fused_1 * 2 + ax0_fused_2_init - ) + var_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 1, 4096), "float16", scope="local") + var_matmul_intermediate_rf_local_1 = T.alloc_buffer((4, 1, 1, 4096), "float16", scope="local") + lv576_local = T.alloc_buffer((344, 4096), "float16", scope="local") + lv575_local = T.alloc_buffer((1376, 4096), "uint32", scope="local") + for u_fused_ax0_fused_fused_0 in T.thread_binding(64, thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init in T.vectorized(8): + with T.block("matmul_rf_init"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(32, ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * 8 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init) + v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1) T.reads() - T.writes(var_matmul_intermediate_local[0, 0, v0]) - var_matmul_intermediate_local[0, 0, v0] = T.float16(0) - for ax1_0_fused_0, ax1_0_fused_1 in T.grid(344, 4): + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0] = T.float16(0) + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1 in T.grid(86, 1): for ax0, ax1 in T.grid(1, 1): - for ax2 in T.vectorized(8): - with T.block("lv574_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - 11008, ax1_0_fused_0 * 32 + ax1_0_fused_1 * 8 + ax2 - ) - T.reads(lv574[v0, v1, v2]) - T.writes(lv574_local[v0, v1, v2]) - lv574_local[v0, v1, v2] = lv574[v0, v1, v2] - for ax1_1 in range(8): - for ax0_fused_2 in T.vectorized(2): - with T.block("matmul_update"): - v0 = T.axis.spatial( - 4096, ax0_fused_0 * 128 + ax0_fused_1 * 2 + ax0_fused_2 - ) - v1 = T.axis.reduce( - 11008, ax1_0_fused_0 * 32 + ax1_0_fused_1 * 8 + ax1_1 - ) - T.reads( - var_matmul_intermediate_local[0, 0, v0], - lv574_local[0, 0, v1], - lv575[v1 // 8, v0], - lv576[v1 // 32, v0], - ) - T.writes(var_matmul_intermediate_local[0, 0, v0]) - var_matmul_intermediate_local[ - 0, 0, v0 - ] = var_matmul_intermediate_local[0, 0, v0] + lv574_local[ - 0, 0, v1 - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv575[v1 // 8, v0], - T.Cast("uint32", v1 % 8) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - * lv576[v1 // 32, v0] - ) - for ax0 in range(2): - with T.block("T_add"): - v0 = T.axis.spatial(4096, ax0_fused_0 * 128 + ax0_fused_1 * 2 + ax0) - T.reads(lv570[0, 0, v0], var_matmul_intermediate_local[0, 0, v0]) - T.writes(p_output0_intermediate[0, 0, v0]) - p_output0_intermediate[0, 0, v0] = ( - lv570[0, 0, v0] + var_matmul_intermediate_local[0, 0, v0] - ) + with T.block("lv576_local"): + v0 = T.axis.spatial(344, ax1_0_fused_ax1_1_fused_0 * 4 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + ax0) + v1 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax1) + T.reads(lv576[v0, v1]) + T.writes(lv576_local[v0, v1]) + lv576_local[v0, v1] = lv576[v0, v1] + for ax1_0_fused_ax1_1_fused_3 in range(4): + for ax0, ax1 in T.grid(1, 1): + with T.block("lv575_local"): + v0 = T.axis.spatial(1376, ax1_0_fused_ax1_1_fused_0 * 16 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 4 + ax1_0_fused_ax1_1_fused_3 + ax0) + v1 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax1) + T.reads(lv575[v0, v1]) + T.writes(lv575_local[v0, v1]) + lv575_local[v0, v1] = lv575[v0, v1] + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in T.vectorized(8): + with T.block("matmul_rf_update"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(32, ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1) + v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1) + vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1, ax1_0_fused_ax1_1_fused_3]) + T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0], lv574[0, 0, vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1 * 128 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 32 + vax1_0_fused_ax1_1_fused_3 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % 8], lv575_local[vax1_0_fused_ax1_1_fused_0 * 16 + vax1_0_fused_ax1_1_fused_1 * 16 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 4 + vax1_0_fused_ax1_1_fused_3, v0], lv576_local[vax1_0_fused_ax1_1_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1 * 4 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 + vax1_0_fused_ax1_1_fused_3 // 4, v0]) + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0] = var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0] + lv574[0, 0, vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1 * 128 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 32 + vax1_0_fused_ax1_1_fused_3 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % 8] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv575_local[vax1_0_fused_ax1_1_fused_0 * 16 + vax1_0_fused_ax1_1_fused_1 * 16 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 4 + vax1_0_fused_ax1_1_fused_3, v0], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1 * 128 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 32 + vax1_0_fused_ax1_1_fused_3 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % 8) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576_local[vax1_0_fused_ax1_1_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1 * 4 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 + vax1_0_fused_ax1_1_fused_3 // 4, v0]) + for ax2 in T.thread_binding(64, thread="threadIdx.x"): + for ax0 in T.thread_binding(4, thread="threadIdx.y"): + with T.block("matmul_rf_init"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.spatial(4, ax0) + v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2) + T.reads() + T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0]) + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] = T.float16(0) + for ax1 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): + with T.block("matmul_rf_update"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2) + T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0], var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, 0, 0, v0]) + T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0]) + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] = var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, 0, 0, v0] + for ax1 in T.thread_binding(64, thread="threadIdx.x"): + for ax0 in T.thread_binding(4, thread="threadIdx.y"): + with T.block("matmul"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.reduce(4, ax0) + v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax1) + T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0]) + T.writes(var_matmul_intermediate_local[0, 0, v0]) + with T.init(): + var_matmul_intermediate_local[0, 0, v0] = T.float16(0) + var_matmul_intermediate_local[0, 0, v0] = var_matmul_intermediate_local[0, 0, v0] + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] + for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.x"): + for ax0_fused_1 in range(1): + with T.block("T_add"): + v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1) + T.reads(lv570[0, 0, v0], var_matmul_intermediate_local[0, 0, v0]) + T.writes(p_output0_intermediate[0, 0, v0]) + p_output0_intermediate[0, 0, v0] = lv570[0, 0, v0] + var_matmul_intermediate_local[0, 0, v0] mod = tvm.IRModule({"main": before}) with Target("opencl", host="llvm -mtriple=aarch64-linux-android"): @@ -852,38 +846,82 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), v)) # with T.block("root"): var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), v), "float16", scope="local") - lv1607_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local") - for u_fused in T.thread_binding(1, thread="blockIdx.y"): - for ax0_fused_0 in T.thread_binding((v + T.int64(63)) // T.int64(64), thread="blockIdx.x"): - for ax0_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): - for ax0_fused_2_init in T.vectorized(T.int64(1)): - with T.block("matmul_init"): - v0 = T.axis.spatial(v, ax0_fused_0 * T.int64(64) + ax0_fused_1 + ax0_fused_2_init) - T.where(ax0_fused_0 * T.int64(64) + ax0_fused_1 + ax0_fused_2_init < v) + var_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), v), "float16", scope="local") + var_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(1), v), "float16", scope="local") + lv613_local = T.alloc_buffer((T.int64(128), v), "float16", scope="local") + lv612_local = T.alloc_buffer((T.int64(512), v), "uint32", scope="local") + for u_fused_ax0_fused_fused_0 in T.thread_binding((v + T.int64(63)) // T.int64(64), thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init in T.vectorized(T.int64(8)): + with T.block("matmul_rf_init"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(T.int64(32), ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * T.int64(8) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 < v) T.reads() - T.writes(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) - var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) - for ax1_0_fused_0, ax1_0_fused_1 in T.grid(T.int64(128), T.int64(4)): + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1 in T.grid(T.int64(32), T.int64(1)): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(8)): - with T.block("lv1607_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), ax1_0_fused_0 * T.int64(32) + ax1_0_fused_1 * T.int64(8) + ax2) - T.reads(lv1607[v0, v1, v2]) - T.writes(lv1607_local[v0, v1, v2]) - lv1607_local[v0, v1, v2] = lv1607[v0, v1, v2] - for ax1_1 in range(T.int64(8)): - for ax0_fused_2 in T.vectorized(T.int64(1)): - with T.block("matmul_update"): - v0 = T.axis.spatial(v, ax0_fused_0 * T.int64(64) + ax0_fused_1 + ax0_fused_2) - v1 = T.axis.reduce(T.int64(4096), ax1_0_fused_0 * T.int64(32) + ax1_0_fused_1 * T.int64(8) + ax1_1) - T.where(ax0_fused_0 * T.int64(64) + ax0_fused_1 + ax0_fused_2 < v) - T.reads(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0], lv1607_local[T.int64(0), T.int64(0), v1], lv612[v1 // T.int64(8), v0], lv613[v1 // T.int64(32), v0]) - T.writes(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) - var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + lv1607_local[T.int64(0), T.int64(0), v1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv612[v1 // T.int64(8), v0], T.Cast("uint32", v1 % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613[v1 // T.int64(32), v0]) + with T.block("lv613_local"): + v0 = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + ax0) + v1 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 + ax1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 < v) + T.reads(lv613[v0, v1]) + T.writes(lv613_local[v0, v1]) + lv613_local[v0, v1] = lv613[v0, v1] + for ax1_0_fused_ax1_1_fused_3 in range(T.int64(4)): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + with T.block("lv612_local"): + v0 = T.axis.spatial(T.int64(512), ax1_0_fused_ax1_1_fused_0 * T.int64(16) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_3 + ax0) + v1 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 + ax1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 < v) + T.reads(lv612[v0, v1]) + T.writes(lv612_local[v0, v1]) + lv612_local[v0, v1] = lv612[v0, v1] + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in T.vectorized(T.int64(8)): + with T.block("matmul_rf_update"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(T.int64(32), ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1) + vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1, ax1_0_fused_ax1_1_fused_3]) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 < v) + T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0], lv1607[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 * T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % T.int64(8)], lv612_local[vax1_0_fused_ax1_1_fused_0 * T.int64(16) + vax1_0_fused_ax1_1_fused_1 * T.int64(16) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(4) + vax1_0_fused_ax1_1_fused_3, v0], lv613_local[vax1_0_fused_ax1_1_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1 * T.int64(4) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) + vax1_0_fused_ax1_1_fused_3 // T.int64(4), v0]) + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] = var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] + lv1607[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 * T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % T.int64(8)] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv612_local[vax1_0_fused_ax1_1_fused_0 * T.int64(16) + vax1_0_fused_ax1_1_fused_1 * T.int64(16) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(4) + vax1_0_fused_ax1_1_fused_3, v0], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 * T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % T.int64(8)) % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613_local[vax1_0_fused_ax1_1_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1 * T.int64(4) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) + vax1_0_fused_ax1_1_fused_3 // T.int64(4), v0]) + for ax2 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + with T.block("matmul_rf_init"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.spatial(T.int64(4), ax0) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + ax2) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax2 < v) + T.reads() + T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0]) + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): + with T.block("matmul_rf_update"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + ax2) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax2 < v) + T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0], var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, T.int64(0), T.int64(0), v0]) + T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0]) + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] = var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, T.int64(0), T.int64(0), v0] + for ax1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + with T.block("matmul"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.reduce(T.int64(4), ax0) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + ax1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax1 < v) + T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0]) + T.writes(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) + with T.init(): + var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) + var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] + for ax0_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_fused_1 in range(T.int64(1)): with T.block("compute"): - v0 = T.axis.spatial(v, ax0_fused_0 * T.int64(64) + ax0_fused_1) - T.where(ax0_fused_0 * T.int64(64) + ax0_fused_1 < v) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + ax0_fused_0 + ax0_fused_1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + (ax0_fused_0 + ax0_fused_1) < v) T.reads(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) T.writes(p_output0_intermediate[T.int64(0), T.int64(0), v0]) p_output0_intermediate[T.int64(0), T.int64(0), v0] = T.Cast("float32", var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index a421d9e6c734..63117073d156 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -634,18 +634,18 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) # with T.block("root"): - matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="local") - for ax0_ax1_0_fused in T.thread_binding((m + T.int64(15)) // T.int64(16), thread="blockIdx.y"): + matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local") + for ax0_ax1_0_fused in T.thread_binding((m + T.int64(31)) // T.int64(32), thread="blockIdx.y"): for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.x"): for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.y"): for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.y"): + for ax1_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax2_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): for ax1_3_init, ax2_3_0_init in T.grid(T.int64(2), T.int64(1)): for ax2_3_1_init in T.vectorized(T.int64(8)): with T.block("matmul_init"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax0_ax1_0_fused * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3_init) + v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(2) + ax1_3_init) v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(8) + ax2_3_0_init * T.int64(8) + ax2_3_1_init) T.reads() T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2]) @@ -654,7 +654,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), for ax2_3_1 in T.vectorized(T.int64(8)): with T.block("matmul_update"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax0_ax1_0_fused * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3) + v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(2) + ax1_3) v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(8) + ax2_3_0 * T.int64(8) + ax2_3_1) v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1) T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0[T.int64(0), v1, v3], inp1[v3, v2]) @@ -664,7 +664,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), for ax2_1_1 in T.vectorized(T.int64(8)): with T.block("matmul_reindex_pad_local"): v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax0_ax1_0_fused * T.int64(16) + ax1_2 * T.int64(2) + ax1) + v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_2 * T.int64(2) + ax1) v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(8) + ax2_0_1 * T.int64(8) + ax2_1_1) T.reads(matmul_reindex_pad_local[v0, v1, v2]) T.writes(matmul[T.int64(0), v1, v2]) From b8534a8ef2531b166cf10f8a1a1e843acc021a92 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Fri, 26 Apr 2024 16:14:27 +0530 Subject: [PATCH 3/7] Fix lint error --- tests/python/dlight/test_gpu_gemv.py | 328 ++++++++++++++++++++++++--- 1 file changed, 294 insertions(+), 34 deletions(-) diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py index 228c1b4a256d..4aae617654d2 100644 --- a/tests/python/dlight/test_gpu_gemv.py +++ b/tests/python/dlight/test_gpu_gemv.py @@ -722,81 +722,341 @@ def before( ) @T.prim_func(private=True) - def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): + def expected( + lv575: T.Buffer((1376, 4096), "uint32"), + lv576: T.Buffer((344, 4096), "float16"), + lv574: T.Buffer((1, 1, 11008), "float16"), + lv570: T.Buffer((1, 1, 4096), "float16"), + p_output0_intermediate: T.Buffer((1, 1, 4096), "float16"), + ): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): var_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") - var_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 1, 4096), "float16", scope="local") - var_matmul_intermediate_rf_local_1 = T.alloc_buffer((4, 1, 1, 4096), "float16", scope="local") + var_matmul_intermediate_rf_local = T.alloc_buffer( + (32, 1, 1, 4096), "float16", scope="local" + ) + var_matmul_intermediate_rf_local_1 = T.alloc_buffer( + (4, 1, 1, 4096), "float16", scope="local" + ) lv576_local = T.alloc_buffer((344, 4096), "float16", scope="local") lv575_local = T.alloc_buffer((1376, 4096), "uint32", scope="local") for u_fused_ax0_fused_fused_0 in T.thread_binding(64, thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.x"): - for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): - for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init in T.vectorized(8): + for ( + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init + ) in T.thread_binding(4, thread="threadIdx.y"): + for ( + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init + ) in T.vectorized(8): with T.block("matmul_rf_init"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(32, ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * 8 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init) - v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial( + 32, + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * 8 + + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init, + ) + v0 = T.axis.spatial( + 4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ) T.reads() - T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0]) - var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0] = T.float16(0) - for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + T.writes( + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, + 0, + 0, + v0, + ] + ) + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0 + ] = T.float16(0) + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in T.thread_binding( + 4, thread="threadIdx.y" + ): for ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1 in T.grid(86, 1): for ax0, ax1 in T.grid(1, 1): with T.block("lv576_local"): - v0 = T.axis.spatial(344, ax1_0_fused_ax1_1_fused_0 * 4 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + ax0) - v1 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax1) + v0 = T.axis.spatial( + 344, + ax1_0_fused_ax1_1_fused_0 * 4 + + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + + ax0, + ) + v1 = T.axis.spatial( + 4096, + u_fused_ax0_fused_fused_0 * 64 + + u_fused_ax0_fused_fused_1 + + ax1, + ) T.reads(lv576[v0, v1]) T.writes(lv576_local[v0, v1]) lv576_local[v0, v1] = lv576[v0, v1] for ax1_0_fused_ax1_1_fused_3 in range(4): for ax0, ax1 in T.grid(1, 1): with T.block("lv575_local"): - v0 = T.axis.spatial(1376, ax1_0_fused_ax1_1_fused_0 * 16 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 4 + ax1_0_fused_ax1_1_fused_3 + ax0) - v1 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax1) + v0 = T.axis.spatial( + 1376, + ax1_0_fused_ax1_1_fused_0 * 16 + + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + * 4 + + ax1_0_fused_ax1_1_fused_3 + + ax0, + ) + v1 = T.axis.spatial( + 4096, + u_fused_ax0_fused_fused_0 * 64 + + u_fused_ax0_fused_fused_1 + + ax1, + ) T.reads(lv575[v0, v1]) T.writes(lv575_local[v0, v1]) lv575_local[v0, v1] = lv575[v0, v1] - for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in T.vectorized(8): + for ( + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 + ) in T.vectorized(8): with T.block("matmul_rf_update"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(32, ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1) - v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1) - vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1, ax1_0_fused_ax1_1_fused_3]) - T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0], lv574[0, 0, vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1 * 128 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 32 + vax1_0_fused_ax1_1_fused_3 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % 8], lv575_local[vax1_0_fused_ax1_1_fused_0 * 16 + vax1_0_fused_ax1_1_fused_1 * 16 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 4 + vax1_0_fused_ax1_1_fused_3, v0], lv576_local[vax1_0_fused_ax1_1_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1 * 4 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 + vax1_0_fused_ax1_1_fused_3 // 4, v0]) - T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0]) - var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0] = var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, 0, 0, v0] + lv574[0, 0, vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1 * 128 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 32 + vax1_0_fused_ax1_1_fused_3 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % 8] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv575_local[vax1_0_fused_ax1_1_fused_0 * 16 + vax1_0_fused_ax1_1_fused_1 * 16 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 4 + vax1_0_fused_ax1_1_fused_3, v0], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1 * 128 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 * 32 + vax1_0_fused_ax1_1_fused_3 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % 8) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576_local[vax1_0_fused_ax1_1_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1 * 4 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // 8 + vax1_0_fused_ax1_1_fused_3 // 4, v0]) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial( + 32, + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + * 8 + + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, + ) + v0 = T.axis.spatial( + 4096, + u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1, + ) + ( + vax1_0_fused_ax1_1_fused_0, + vax1_0_fused_ax1_1_fused_1, + vax1_0_fused_ax1_1_fused_3, + ) = T.axis.remap( + "RRR", + [ + ax1_0_fused_ax1_1_fused_0, + ax1_0_fused_ax1_1_fused_1, + ax1_0_fused_ax1_1_fused_3, + ], + ) + T.reads( + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, + 0, + 0, + v0, + ], + lv574[ + 0, + 0, + vax1_0_fused_ax1_1_fused_0 * 128 + + vax1_0_fused_ax1_1_fused_1 * 128 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + * 32 + + vax1_0_fused_ax1_1_fused_3 * 8 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + % 8, + ], + lv575_local[ + vax1_0_fused_ax1_1_fused_0 * 16 + + vax1_0_fused_ax1_1_fused_1 * 16 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + * 4 + + vax1_0_fused_ax1_1_fused_3, + v0, + ], + lv576_local[ + vax1_0_fused_ax1_1_fused_0 * 4 + + vax1_0_fused_ax1_1_fused_1 * 4 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + + vax1_0_fused_ax1_1_fused_3 // 4, + v0, + ], + ) + T.writes( + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, + 0, + 0, + v0, + ], + ) + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, + 0, + 0, + v0, + ] = var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, + 0, + 0, + v0, + ] + lv574[ + 0, + 0, + vax1_0_fused_ax1_1_fused_0 * 128 + + vax1_0_fused_ax1_1_fused_1 * 128 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + * 32 + + vax1_0_fused_ax1_1_fused_3 * 8 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + % 8, + ] * ( + ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv575_local[ + vax1_0_fused_ax1_1_fused_0 * 16 + + vax1_0_fused_ax1_1_fused_1 * 16 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + * 4 + + vax1_0_fused_ax1_1_fused_3, + v0, + ], + T.Cast( + "uint32", + ( + vax1_0_fused_ax1_1_fused_0 * 128 + + vax1_0_fused_ax1_1_fused_1 * 128 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + * 32 + + vax1_0_fused_ax1_1_fused_3 * 8 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + % 8 + ) + % 8, + ) + * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) + * lv576_local[ + vax1_0_fused_ax1_1_fused_0 * 4 + + vax1_0_fused_ax1_1_fused_1 * 4 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused + // 8 + + vax1_0_fused_ax1_1_fused_3 // 4, + v0, + ] + ) for ax2 in T.thread_binding(64, thread="threadIdx.x"): for ax0 in T.thread_binding(4, thread="threadIdx.y"): with T.block("matmul_rf_init"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.spatial(4, ax0) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = ( + T.axis.spatial(4, ax0) + ) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2) T.reads() - T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0]) - var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] = T.float16(0) - for ax1 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): + T.writes( + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ] + ) + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0 + ] = T.float16(0) + for ax1 in T.serial( + 8, + annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}, + ): with T.block("matmul_rf_update"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 = T.axis.remap("SR", [ax0, ax1]) + ( + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, + ) = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2) - T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0], var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, 0, 0, v0]) - T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0]) - var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] = var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, 0, 0, v0] + T.reads( + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ], + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, + 0, + 0, + v0, + ], + ) + T.writes( + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ] + ) + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ] = ( + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ] + + var_matmul_intermediate_rf_local[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * 8 + + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, + 0, + 0, + v0, + ] + ) for ax1 in T.thread_binding(64, thread="threadIdx.x"): for ax0 in T.thread_binding(4, thread="threadIdx.y"): with T.block("matmul"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.reduce(4, ax0) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = ( + T.axis.reduce(4, ax0) + ) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax1) - T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0]) + T.reads( + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ] + ) T.writes(var_matmul_intermediate_local[0, 0, v0]) with T.init(): var_matmul_intermediate_local[0, 0, v0] = T.float16(0) - var_matmul_intermediate_local[0, 0, v0] = var_matmul_intermediate_local[0, 0, v0] + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, 0, 0, v0] + var_matmul_intermediate_local[0, 0, v0] = ( + var_matmul_intermediate_local[0, 0, v0] + + var_matmul_intermediate_rf_local_1[ + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, + 0, + 0, + v0, + ] + ) for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.x"): for ax0_fused_1 in range(1): with T.block("T_add"): - v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1) + v0 = T.axis.spatial( + 4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1 + ) T.reads(lv570[0, 0, v0], var_matmul_intermediate_local[0, 0, v0]) T.writes(p_output0_intermediate[0, 0, v0]) - p_output0_intermediate[0, 0, v0] = lv570[0, 0, v0] + var_matmul_intermediate_local[0, 0, v0] + p_output0_intermediate[0, 0, v0] = ( + lv570[0, 0, v0] + var_matmul_intermediate_local[0, 0, v0] + ) mod = tvm.IRModule({"main": before}) with Target("opencl", host="llvm -mtriple=aarch64-linux-android"): From 63d3a676627834ca402bfd76a2fa37fe374b1bc6 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 29 Apr 2024 09:50:25 +0530 Subject: [PATCH 4/7] Updated naming of schedule func --- python/tvm/dlight/gpu/gemv.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 693f4d3eaf0c..84544a68af6a 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -208,15 +208,15 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- elif is_inner_reduction: self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) return sch - elif target.kind.name == "opencl" and "android" in str(target.host): - ret = self.sch_adreno_inner_reduction( + elif target.kind.name == "opencl": + ret = self.sch_outer_reduction( sch, target, block, vector_input_buffers, epilogue ) if ret is None: - return self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) + return self.sch_outer_reduction_fallback(sch, target, block, vector_input_buffers, epilogue) return sch else: - return self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) + return self.sch_outer_reduction_fallback(sch, target, block, vector_input_buffers, epilogue) def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument self, @@ -551,7 +551,7 @@ def apply( SUPPORT_WARP_SHUFFLE=SUPPORT_WARP_SHUFFLE, ) - def sch_adreno_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument + def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument self, sch: tir.Schedule, target: Target, @@ -737,7 +737,7 @@ def apply( LOAD_V_TILE=LOAD_V_TILE, ) - def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument + def sch_outer_reduction_fallback( # pylint: disable=too-many-arguments, invalid-name, unused-argument self, sch: tir.Schedule, target: Target, From bddc0dc24f286555050191ca4b7c91ce379efac0 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 29 Apr 2024 10:08:08 +0530 Subject: [PATCH 5/7] fixed lint error --- python/tvm/dlight/gpu/gemv.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 84544a68af6a..01a2b68865ef 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -209,14 +209,16 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) return sch elif target.kind.name == "opencl": - ret = self.sch_outer_reduction( - sch, target, block, vector_input_buffers, epilogue - ) + ret = self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) if ret is None: - return self.sch_outer_reduction_fallback(sch, target, block, vector_input_buffers, epilogue) + return self.sch_outer_reduction_fallback( + sch, target, block, vector_input_buffers, epilogue + ) return sch else: - return self.sch_outer_reduction_fallback(sch, target, block, vector_input_buffers, epilogue) + return self.sch_outer_reduction_fallback( + sch, target, block, vector_input_buffers, epilogue + ) def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument self, From 3392a164201a92e3e6bce12d76034b7706b2e2c6 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 29 Apr 2024 11:08:28 +0530 Subject: [PATCH 6/7] Corrected comments --- python/tvm/dlight/gpu/gemv.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 01a2b68865ef..44b4edf33d2e 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -561,7 +561,7 @@ def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, un vector_input_buffers: List[tir.Buffer], epilogue_info: Optional[BlockInfo], ): - """Schedule the inner reduction block.""" + """Schedule the outer reduction block.""" def get_max_factor(n, factors): factors = sorted(factors, reverse=True) @@ -683,7 +683,6 @@ def apply( sch.bind(ts, "threadIdx.x") sch.set_scope(block, 0, "local") return sch - # return sch.mod["main"].with_attr("tir.is_scheduled", 1) # Specify the `len_tx` and `len_ty` according to the loop extent batch, s, r, c = sch.get_loops(block=block) From a89de129c885a674d9fbda6e230507186eca4f8f Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 29 Apr 2024 11:27:51 +0530 Subject: [PATCH 7/7] Update gemv.py --- python/tvm/dlight/gpu/gemv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 44b4edf33d2e..cbef6235c098 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -208,7 +208,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- elif is_inner_reduction: self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) return sch - elif target.kind.name == "opencl": + elif target.kind.name == "opencl" and "android" in str(target.host): ret = self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) if ret is None: return self.sch_outer_reduction_fallback(