diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index 696722c3f016..20911f0e7d9c 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -500,7 +500,7 @@ def apply( sch.set_scope(block, 0, "shared") _, _, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name _, tx = sch.split(sch.fuse(*s), factors=[None, TX]) - sch.bind(tx, "threadIdx.x") + sch.bind(tx, TAG_S) else: sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[3:]) @@ -538,17 +538,16 @@ def apply( else: TS, TR = 16, 32 elif target.kind.name == "metal": - # Note that the following tile size is tuned on M2 Ultra for 7B - TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" - VEC_C = 1 + VEC_C = 4 LOAD_V_SHARED = False LOAD_V_VEC = -1 - UNROLL = 256 + UNROLL = 8 if isinstance(len_S, int): if len_S > len_R: - TS, TR = 2, 32 + TS, TR = 8, 32 else: - TS, TR = 2, 64 + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + TS, TR = 8, 32 elif target.kind.name == "rocm": VEC_C = 4 LOAD_V_SHARED = True diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py b/tests/python/dlight/test_gpu_low_batch_gemv.py index 4b63cfddba3c..6072664b3a45 100644 --- a/tests/python/dlight/test_gpu_low_batch_gemv.py +++ b/tests/python/dlight/test_gpu_low_batch_gemv.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring -import pytest import tvm.testing from tvm import dlight as dl @@ -65,82 +64,83 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T # with T.block("root"): dequantize_intermediate_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(28672)), "float16", scope="local") NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") - NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(64), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") - NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(64), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") + NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") + NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"): - for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(1024), thread="blockIdx.x"): - for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.x"): - for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.y"): + for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(256), thread="blockIdx.x"): + for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.y"): for ax0_1_init, u_fused_ax1_fused_fused_2_init in T.grid(T.int64(4), T.int64(2)): - for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(1)): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_init"): - vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init) - v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2_init) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2_init) T.reads() T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1]) NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] = T.float16(0) - for ax2_fused_u_fused_0 in T.serial(T.int64(56), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_u_fused_0 in T.serial(T.int64(112), annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)): for ax0_1 in T.vectorized(T.int64(1)): with T.block("dequantize"): - v0 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) + ax0_0_1 + ax0_1) - v1 = T.axis.spatial(T.int64(28672), ax2_fused_u_fused_0 * T.int64(512) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(8) + ax1) + v0 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1 * T.int64(2) + ax0_0_1 + ax0_1) + v1 = T.axis.spatial(T.int64(28672), ax2_fused_u_fused_0 * T.int64(256) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(8) + ax1) T.reads(lv429[v0, v1 // T.int64(8)], lv430[v0, v1 // T.int64(32)]) T.writes(dequantize_intermediate_intermediate_local[v0, v1]) dequantize_intermediate_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv429[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv430[v0, v1 // T.int64(32)] - for ax0_1, u_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(8)): - for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(1)): + for ax0_1, u_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_update"): - vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1) - v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2) vax2_fused_u_fused_0, vax2_fused_u_fused_2 = T.axis.remap("RR", [ax2_fused_u_fused_0, ax2_fused_u_fused_2]) - T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1], lv807[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2], dequantize_intermediate_intermediate_local[v1, vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2]) + T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1], lv807[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)], dequantize_intermediate_intermediate_local[v1, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)]) T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1]) - NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, lv807[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2], T.float16(0)) * dequantize_intermediate_intermediate_local[v1, vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2] - for ax3_fused_0 in T.thread_binding(T.int64(2), thread="threadIdx.x"): - for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.y"): - for ax3_fused_1_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, lv807[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)], T.float16(0)) * dequantize_intermediate_intermediate_local[v1, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)] + for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.y"): + for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): for ax2 in range(T.int64(4)): - for ax3_fused_1_1 in T.vectorized(T.int64(2)): + for ax3_fused_2_1 in T.vectorized(T.int64(2)): with T.block("NT_matmul_rf_init"): - vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64), ax0) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) - v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) + ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) T.reads() T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0) - for ax1 in range(T.int64(1)): + for ax1 in range(T.int64(4)): with T.block("NT_matmul_rf_update"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) - v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) + ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1) - T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1]) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) + T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1]) T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) - NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1] - for ax2_fused_1, ax1 in T.grid(T.int64(2), T.int64(4)): - for ax2_fused_0 in T.thread_binding(T.int64(2), thread="threadIdx.x"): - for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.y"): + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1] + for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)): + for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.y"): with T.block("NT_matmul"): - vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64), ax0) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1) - v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax2_fused_0 * T.int64(2) + ax2_fused_1) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2) T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) T.writes(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]) with T.init(): NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = T.float16(0) NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] for ax0 in range(T.int64(4)): - for ax1_fused_0 in T.thread_binding(T.int64(2), thread="threadIdx.x"): - for ax1_fused_1 in range(T.int64(2)): + for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.x"): + for ax1_fused_2 in range(T.int64(2)): with T.block("NT_matmul_intermediate_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) - v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax1_fused_0 * T.int64(2) + ax1_fused_1) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]) T.writes(NT_matmul_intermediate[v0, T.int64(0), v1]) NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + # fmt: on mod = tvm.IRModule({"main": before}) with Target("metal"): @@ -176,70 +176,70 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), T.int64(4096)), "float16") # with T.block("root"): NT_matmul_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") - NT_matmul_pad_rf_local = T.alloc_buffer((T.int64(64), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") - NT_matmul_pad_rf_local_1 = T.alloc_buffer((T.int64(64), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") + NT_matmul_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") + NT_matmul_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"): - for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(1024), thread="blockIdx.x"): - for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.x"): - for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.y"): + for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(256), thread="blockIdx.x"): + for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.y"): for ax0_1_init, u_fused_ax1_fused_fused_2_init in T.grid(T.int64(4), T.int64(2)): - for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(1)): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_init"): - vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init) - v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2_init) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2_init) T.reads() T.writes(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1]) NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] = T.float16(0) - for ax2_fused_u_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax0_1, u_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(8)): - for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(1)): + for ax2_fused_u_fused_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): + for ax0_1, u_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)): with T.block("NT_matmul_rf_update"): - vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1) - v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2) vax2_fused_u_fused_0, vax2_fused_u_fused_2 = T.axis.remap("RR", [ax2_fused_u_fused_0, ax2_fused_u_fused_2]) - T.reads(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1], A[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2], B[v1, vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2]) + T.reads(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1], A[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)], B[v1, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)]) T.writes(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1]) - NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, A[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2], T.float16(0)) * B[v1, vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2] - for ax3_fused_0 in T.thread_binding(T.int64(2), thread="threadIdx.x"): - for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.y"): - for ax3_fused_1_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, A[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)], T.float16(0)) * B[v1, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)] + for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.y"): + for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): for ax2 in range(T.int64(4)): - for ax3_fused_1_1 in T.vectorized(T.int64(2)): + for ax3_fused_2_1 in T.vectorized(T.int64(2)): with T.block("NT_matmul_rf_init"): - vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64), ax0) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) - v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) + ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) T.reads() T.writes(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0) - for ax1 in range(T.int64(1)): + for ax1 in range(T.int64(4)): with T.block("NT_matmul_rf_update"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) - v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) + ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1) - T.reads(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1]) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) + T.reads(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1]) T.writes(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) - NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1] - for ax2_fused_1, ax1 in T.grid(T.int64(2), T.int64(4)): - for ax2_fused_0 in T.thread_binding(T.int64(2), thread="threadIdx.x"): - for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.y"): + NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1] + for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)): + for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.y"): with T.block("NT_matmul"): - vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64), ax0) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1) - v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax2_fused_0 * T.int64(2) + ax2_fused_1) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2) T.reads(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) T.writes(NT_matmul_pad_local[v0, T.int64(0), v1]) with T.init(): NT_matmul_pad_local[v0, T.int64(0), v1] = T.float16(0) NT_matmul_pad_local[v0, T.int64(0), v1] = NT_matmul_pad_local[v0, T.int64(0), v1] + NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] for ax0 in range(T.int64(4)): - for ax1_fused_0 in T.thread_binding(T.int64(2), thread="threadIdx.x"): - for ax1_fused_1 in range(T.int64(2)): + for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(8), thread="threadIdx.x"): + for ax1_fused_2 in range(T.int64(2)): with T.block("NT_matmul_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) - v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax1_fused_0 * T.int64(2) + ax1_fused_1) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) T.reads(NT_matmul_pad_local[v0, T.int64(0), v1]) T.writes(NT_matmul[v0, T.int64(0), v1])