diff --git a/examples/plot_layout/fragment_mfma_load_a.py b/examples/plot_layout/fragment_mfma_load_a.py new file mode 100644 index 000000000..2c3b282a6 --- /dev/null +++ b/examples/plot_layout/fragment_mfma_load_a.py @@ -0,0 +1,133 @@ +import tilelang.language as T +from typing import Literal, Callable +from tvm.tir import IndexMap +from tilelang.intrinsics.utils import get_mma_micro_size + +from tilelang.intrinsics.mfma_layout import ( + shared_16x4_to_local_64x1_layout_A, + shared_16x16_to_local_64x4_layout_A, + shared_16x32_to_local_64x8_layout_A, + shared_16x64_to_local_64x16_layout_A, +) + + +def make_mfma_load_base_layout(dtype: str = "float16", + matrix: Literal["A", "B"] = "A", + k_dim: int = 16, + transposed: bool = False) -> T.Fragment: + """ + Create a layout function for storing MFMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mfma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + dtype : str + The data type of the matrix. + matrix : Literal["A", "B"] + The mfma operand to be loaded. + k_dim : int + The k dimension of the mfma. + transposed : bool + Whether the matrix is transposed, by default False. + + Returns + ------- + T.Fragment + Describes how threads and indices in fragment are laid out. + + """ + + assert matrix in ["A", "B"], "matrix should be either A or B" + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + + if k_dim == 4: + transform_func_sr_a = shared_16x4_to_local_64x1_layout_A + transform_func_sr_b = shared_16x4_to_local_64x1_layout_A + elif k_dim == 16: + transform_func_sr_a = shared_16x16_to_local_64x4_layout_A + transform_func_sr_b = shared_16x16_to_local_64x4_layout_A + elif k_dim == 32: + transform_func_sr_a = shared_16x32_to_local_64x8_layout_A + transform_func_sr_b = shared_16x32_to_local_64x8_layout_A + elif k_dim == 64: + transform_func_sr_a = shared_16x64_to_local_64x16_layout_A + transform_func_sr_b = shared_16x64_to_local_64x16_layout_A + else: + raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") + + is_sr_conditions = [False] + is_sr_conditions.append(matrix == "A" and not transposed) + is_sr_conditions.append(matrix == "B" and transposed) + is_sr_axis_order = any(is_sr_conditions) + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + if matrix == "A": + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( + j, i) + micro_size_s, micro_size_r = micro_size_x, micro_size_k + elif matrix == "B": + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( + j, i) + micro_size_s, micro_size_r = micro_size_k, micro_size_y + else: + raise ValueError(f"Unsupported matrix {matrix}") + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + return base_fragment + + +block_rows = 2 +block_cols = 2 +warp_rows = 2 +warp_cols = 2 +chunk = 2 + +from tilelang.tools import plot_layout + +# ldmatrix layout 16x16 +base_layout = make_mfma_load_base_layout(dtype="float16", matrix="A", transposed=False) +print(base_layout) +plot_layout(base_layout, name="base_layout") + +# warp layout 32x32 +warp_layout = base_layout.repeat([warp_rows, warp_cols], + repeat_on_thread=False, + lower_dim_first=False) +print(warp_layout) +plot_layout(warp_layout, name="warp_layout") + +# block layout 64x32 +block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, + lower_dim_first=True).replicate(block_cols) +print(block_layout) +plot_layout(block_layout, name="block_layout") diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py new file mode 100644 index 000000000..15aa33c8e --- /dev/null +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py @@ -0,0 +1,501 @@ +from tilelang import tvm as tvm +import tilelang.testing + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B) + # T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_ss( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=256, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + latency = profiler.do_bench(profiler.func, warmup=100) + print(f"GEMM SS latency: {latency} ms") + + +def test_gemm_ss(): + # GEMM tests for float16 + run_gemm_ss(1024, 1024, 1024, False, False, "float16", "float16", "float32", 128, 128, 32) + run_gemm_ss(1024, 1024, 1024, False, True, "float16", "float16", "float32", 128, 128, 32) + run_gemm_ss(1024, 1024, 1024, True, False, "float16", "float16", "float32", 128, 128, 32) + run_gemm_ss(1024, 1024, 1024, True, True, "float16", "float16", "float32", 128, 128, 32) + + # GEMM tests for int8 tests + run_gemm_ss(1024, 1024, 1024, False, True, "int8", "int8", "int32", 128, 128, 32) + run_gemm_ss(1024, 1024, 1024, False, False, "int8", "int8", "int32", 128, 128, 32) + run_gemm_ss(1024, 1024, 1024, True, False, "int8", "int8", "int32", 128, 128, 32) + run_gemm_ss(1024, 1024, 1024, True, True, "int8", "int8", "int32", 128, 128, 32) + + +def matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + T.annotate_layout({ + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + }) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rs( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=256, +): + program = matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_gemm_rs(): + # GEMM tests for float16 + run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float16", "float32", 128, 128, 32) + run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float16", "float32", 128, 128, 32) + run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float16", "float32", 128, 128, 32) + run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float16", "float32", 128, 128, 32) + + # GEMM tests for int8 tests + run_gemm_rs(1024, 1024, 1024, False, True, "int8", "int8", "int32", 128, 128, 32) + run_gemm_rs(1024, 1024, 1024, False, False, "int8", "int8", "int32", 128, 128, 32) + run_gemm_rs(1024, 1024, 1024, True, False, "int8", "int8", "int32", 128, 128, 32) + run_gemm_rs(1024, 1024, 1024, True, True, "int8", "int8", "int32", 128, 128, 32) + + +def matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + B_frag_shape = B_shared_shape + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + T.annotate_layout({ + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + }) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(B_shared, B_frag) + T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_sr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=256, +): + program = matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_gemm_sr(): + # GEMM tests for float16 + run_gemm_sr(1024, 1024, 1024, False, False, "float16", "float16", "float32", 128, 128, 32) + run_gemm_sr(1024, 1024, 1024, False, True, "float16", "float16", "float32", 128, 128, 32) + run_gemm_sr(1024, 1024, 1024, True, False, "float16", "float16", "float32", 128, 128, 32) + run_gemm_sr(1024, 1024, 1024, True, True, "float16", "float16", "float32", 128, 128, 32) + + # GEMM tests for int8 tests + run_gemm_sr(1024, 1024, 1024, False, True, "int8", "int8", "int32", 128, 128, 32) + run_gemm_sr(1024, 1024, 1024, False, False, "int8", "int8", "int32", 128, 128, 32) + run_gemm_sr(1024, 1024, 1024, True, False, "int8", "int8", "int32", 128, 128, 32) + run_gemm_sr(1024, 1024, 1024, True, True, "int8", "int8", "int32", 128, 128, 32) + + +def matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + B_frag_shape = B_shared_shape + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + T.annotate_layout({ + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + }) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.copy(B_shared, B_frag) + T.gemm_v2(A_frag, B_frag, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=256, +): + program = matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + print(program) + + print(kernel.get_kernel_source()) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +def test_gemm_rr(): + # GEMM tests for float16 + run_gemm_rr(1024, 1024, 1024, False, False, "float16", "float16", "float32", 128, 128, 32) + run_gemm_rr(1024, 1024, 1024, False, True, "float16", "float16", "float32", 128, 128, 32) + run_gemm_rr(1024, 1024, 1024, True, False, "float16", "float16", "float32", 128, 128, 32) + run_gemm_rr(1024, 1024, 1024, True, True, "float16", "float16", "float32", 128, 128, 32) + + # GEMM tests for int8 tests + run_gemm_rr(1024, 1024, 1024, False, True, "int8", "int8", "int32", 128, 128, 32) + run_gemm_rr(1024, 1024, 1024, False, False, "int8", "int8", "int32", 128, 128, 32) + run_gemm_rr(1024, 1024, 1024, True, False, "int8", "int8", "int32", 128, 128, 32) + run_gemm_rr(1024, 1024, 1024, True, True, "int8", "int8", "int32", 128, 128, 32) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index aa369980f..c1e0c3e9e 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -2,10 +2,32 @@ from tilelang import tvm as tvm import tilelang.language as T from tvm import DataType -from tvm.tir import PrimExpr +from tvm.tir import PrimExpr, IndexMap, Buffer, Var from tvm.runtime import convert from .utils import ( mfma_store_index_map,) +from typing import Literal, Callable + +from tilelang.utils import is_fragment + +from .mfma_layout import ( + shared_16x4_to_local_64x1_layout_A, + shared_4x16_to_local_64x1_layout_B, + shared_16x16_to_local_64x4_layout_A, + shared_16x16_to_local_64x4_layout_B, + shared_16x32_to_local_64x8_layout_A, + shared_16x32_to_local_64x8_layout_B, + shared_16x64_to_local_64x16_layout_A, + shared_16x64_to_local_64x16_layout_B, + thread_id_shared_access_64x1_to_16x4_layout_A, + thread_id_shared_access_64x1_to_4x16_layout_B, + thread_id_shared_access_64x4_to_16x16_layout_A, + thread_id_shared_access_64x4_to_16x16_layout_B, + thread_id_shared_access_64x8_to_16x32_layout_A, + thread_id_shared_access_64x8_to_16x32_layout_B, + thread_id_shared_access_64x16_to_16x64_layout_A, + thread_id_shared_access_64x16_to_16x64_layout_B, +) lift = convert @@ -53,6 +75,7 @@ def __init__( k_pack: int | None = None, is_m_first: bool | None = False, b_preshuffle: bool | None = False, + thread_var: Var | None = None, ): self.a_dtype = a_dtype self.b_dtype = b_dtype @@ -79,6 +102,7 @@ def __init__( self.reduce_k = reduce_k self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) self.num_elems_per_byte = num_elems_per_byte + self.thread_var = thread_var def _initialize_k_dim(self, a_dtype="float16"): if isinstance(a_dtype, str): @@ -147,24 +171,6 @@ def _initialize_b_preshuffle(self, b_preshuffle: bool | None = False): self.b_preshuffle = b_preshuffle def get_ldmatrix_index_map(self, is_b=False): - from .mfma_layout import ( - shared_16x4_to_local_64x1_layout_A, - shared_4x16_to_local_64x1_layout_B, - shared_16x16_to_local_64x4_layout_A, - shared_16x16_to_local_64x4_layout_B, - shared_16x32_to_local_64x8_layout_A, - shared_16x32_to_local_64x8_layout_B, - shared_16x64_to_local_64x16_layout_A, - shared_16x64_to_local_64x16_layout_B, - thread_id_shared_access_64x1_to_16x4_layout_A, - thread_id_shared_access_64x1_to_4x16_layout_B, - thread_id_shared_access_64x4_to_16x16_layout_A, - thread_id_shared_access_64x4_to_16x16_layout_B, - thread_id_shared_access_64x8_to_16x32_layout_A, - thread_id_shared_access_64x8_to_16x32_layout_B, - thread_id_shared_access_64x16_to_16x64_layout_A, - thread_id_shared_access_64x16_to_16x64_layout_B, - ) k_dim = self.k_dim * self.k_pack transposed = self.a_transposed if not is_b else self.b_transposed @@ -200,6 +206,22 @@ def get_ldmatrix_index_map(self, is_b=False): return index_map, reverse_index_map + def get_store_index_map(self, inverse: bool = False) -> IndexMap: + warp_size, local_size_c = self.WARP_SIZE, self.local_size_out + index_map = IndexMap.from_func(mfma_store_index_map, index_dtype="int32") + if not inverse: + return index_map + inverse_index_map = index_map.inverse([warp_size, local_size_c]) + return inverse_index_map + + def get_thread_binding(self): + if self.thread_var is None: + current_frame = T.KernelLaunchFrame.Current() + assert current_frame is not None, "Must be called in a T.Kernel Frame" + return current_frame.get_thread_binding() + else: + return self.thread_var + def extract_thread_binding(self, thread_id, is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: @@ -238,8 +260,7 @@ def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0): local_size_a = self.local_size_a k_pack = self.k_pack is_transposed = self.a_transposed - current_frame = T.KernelLaunchFrame.Current() - thread_binding = current_frame.get_thread_binding() + thread_binding = self.get_thread_binding() _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) @T.macro @@ -279,8 +300,7 @@ def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0): local_size_b = self.local_size_b k_pack = self.k_pack is_transposed = self.b_transposed - current_frame = T.KernelLaunchFrame.Current() - thread_binding = current_frame.get_thread_binding() + thread_binding = self.get_thread_binding() _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) @T.macro @@ -316,7 +336,11 @@ def _warp_ldmatrix_b( return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) - def mfma(self, A_local_buf, B_local_buf, C_local_buf): + def mfma(self, + A_local_buf: Buffer, + B_local_buf: Buffer, + C_local_buf: Buffer, + k_inner: PrimExpr | None = 0): warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -329,8 +353,15 @@ def mfma(self, A_local_buf, B_local_buf, C_local_buf): compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}" compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}" + a_is_fragment = is_fragment(A_local_buf) + b_is_fragment = is_fragment(B_local_buf) + a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0 + b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0 + + print(a_local_stride, b_local_stride) + @T.macro - def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + def _warp_mfma(A_local_buf, B_local_buf, C_local_buf): for kp, i, j in T.grid(k_pack, warp_rows, warp_cols): T.tvm_mfma( mfma_suffix, @@ -340,15 +371,15 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): compute_b_dtype, compute_out_dtype, B_local_buf.data, - ((j * k_pack + kp) * local_size_b) // local_size_b, + (b_local_stride + (j * k_pack + kp) * local_size_b) // local_size_b, A_local_buf.data, - ((i * k_pack + kp) * local_size_a) // local_size_a, + (a_local_stride + (i * k_pack + kp) * local_size_a) // local_size_a, C_local_buf.data, (i * warp_cols * local_size_out + j * local_size_out) // local_size_out, dtype=compute_out_dtype, ) - return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + return _warp_mfma(A_local_buf, B_local_buf, C_local_buf) def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): block_row_warps = self.block_row_warps @@ -356,8 +387,7 @@ def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_out = self.local_size_out - current_frame = T.KernelLaunchFrame.Current() - thread_binding = current_frame.get_thread_binding() + thread_binding = self.get_thread_binding() is_global = pid_m is not None and pid_n is not None BLOCK_M = block_row_warps * warp_rows BLOCK_N = block_col_warps * warp_cols @@ -366,7 +396,7 @@ def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" # STS - # MMA Store must be in simulated instead of TVM Intrins + # MFMA Store must be in simulated instead of TVM Intrins # As TVM Intrins is like a hack that the threadIdx.x should be always # equal to the warp_size @T.macro @@ -400,6 +430,217 @@ def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): thread_binding) if is_global else _warp_stmatrix_shared( C_local_buf, C_buf, thread_binding) + def make_mfma_load_layout(self, + local_buf: Buffer, + matrix: Literal["A", "B"] = "A") -> T.Fragment: + """ + Create a layout function for storing MFMA results into a fragment buffer. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + assert matrix in ["A", "B"], "matrix should be either A or B" + matrix_is_a: bool = matrix == "A" + matrix_is_b: bool = matrix == "B" + transposed = self.a_transposed if matrix_is_a else self.b_transposed + + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + # sr also can represent a non-transposed basic layout + # then rs also can represent a transposed basic layout + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + + k_dim = self.k_dim * self.k_pack + + if k_dim == 4: + transform_func_sr_a = shared_16x4_to_local_64x1_layout_A + transform_func_sr_b = shared_16x4_to_local_64x1_layout_A + elif k_dim == 16: + transform_func_sr_a = shared_16x16_to_local_64x4_layout_A + transform_func_sr_b = shared_16x16_to_local_64x4_layout_A + elif k_dim == 32: + transform_func_sr_a = shared_16x32_to_local_64x8_layout_A + transform_func_sr_b = shared_16x32_to_local_64x8_layout_A + elif k_dim == 64: + transform_func_sr_a = shared_16x64_to_local_64x16_layout_A + transform_func_sr_b = shared_16x64_to_local_64x16_layout_A + else: + raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") + + is_sr_conditions = [False] + is_sr_conditions.append(matrix_is_a and not transposed) + is_sr_conditions.append(matrix_is_b and transposed) + is_sr_axis_order = any(is_sr_conditions) + + transform_func: Callable = None + if matrix_is_a: + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( + j, i) + elif matrix_is_b: + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( + j, i) + else: + raise ValueError(f"Unsupported matrix {matrix}") + + assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" + + if matrix_is_a: + micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k + else: + micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y + + block_row_warps, block_col_warps = ( + self.block_row_warps, + self.block_col_warps, + ) + + inverse_mfma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mfma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mfma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + + warp_rows, warp_cols = self.warp_rows, self.warp_cols + chunk = self.chunk + + warp_s = warp_rows if matrix_is_a else warp_cols + warp_r = chunk // micro_size_r + block_s = block_row_warps if matrix_is_a else block_col_warps + replicate = block_col_warps if matrix_is_a else block_row_warps + + if is_sr_axis_order: + warp_fragment = base_fragment.repeat([warp_s, warp_r], + repeat_on_thread=False, + lower_dim_first=False) + if matrix_is_a: + block_fragment = warp_fragment.repeat([block_s, 1], + repeat_on_thread=True, + lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1], + repeat_on_thread=True, + lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + else: + warp_fragment = base_fragment.repeat([warp_r, warp_s], + repeat_on_thread=False, + lower_dim_first=True) + if matrix_is_a: + block_fragment = warp_fragment.repeat([1, block_s], + repeat_on_thread=True, + lower_dim_first=True).replicate(replicate) + elif matrix_is_b: + block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s], + repeat_on_thread=True, + lower_dim_first=True) + else: + raise ValueError(f"Unsupported matrix type {matrix}") + + return block_fragment + + def make_mfma_store_layout(self, local_buf: Buffer) -> T.Fragment: + """ + Create a layout function for storing MFMA results into a fragment buffer. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + + shape = local_buf.shape + inverse_mfma_store_layout = self.get_store_index_map(inverse=True) + assert is_fragment(local_buf), "local_buf must be a fragment" + micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y + local_size_out = self.local_size_out + block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps + warp_rows, warp_cols = self.warp_rows, self.warp_cols + warp_size = self.WARP_SIZE + is_m_first = self.is_m_first + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a thread index according to `inverse_mfma_store_layout`. + """ + # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y + # the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols + block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols + # upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y + mfma_i, mfma_j = i % micro_size_x, j % micro_size_y + lane_id, _ = inverse_mfma_store_layout.map_indices([mfma_i, mfma_j]) + if is_m_first: + thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id + else: + thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id + return thread_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a local index in a single thread according + to `inverse_mfma_store_layout`. + """ + # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y + # the upper bounds of warp_i and warp_j are warp_rows and warp_cols + warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols + # upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y + mfma_i, mfma_j = i % micro_size_x, j % micro_size_y + _, local_id = inverse_mfma_store_layout.map_indices([mfma_i, mfma_j]) + return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id + + return T.Fragment( + shape, + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 63a999f4d..d0ea704cc 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -8,6 +8,7 @@ from tilelang.ir import GemmWarpPolicy from .gemm_mma import GemmMMA from .gemm_wgmma import GemmWGMMA +from .gemm_mfma import GemmMFMA from tilelang import _ffi_api @@ -28,14 +29,18 @@ def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var): # same definition with src/op/gemm_py.h class GemmInst(IntEnum): MMA = 0 - WGMMMA = 1 - MFMA = 2 + WGMMA = 1 + TCGEN5MMA = 2 + MFMA = 3 def is_mma(self) -> bool: return self == GemmInst.MMA def is_wgmma(self) -> bool: - return self == GemmInst.WGMMMA + return self == GemmInst.WGMMA + + def is_tcgen5mma(self) -> bool: + return self == GemmInst.TCGEN5MMA def is_mfma(self) -> bool: return self == GemmInst.MFMA @@ -115,6 +120,8 @@ def _get_implementation_class(self, gemm_inst: GemmInst): elif gemm_inst.is_wgmma(): return GemmWGMMA elif gemm_inst.is_mfma(): - raise NotImplementedError("MFMA is not implemented") + return GemmMFMA + elif gemm_inst.is_tcgen5mma(): + raise NotImplementedError("TCGEN5MMA is not implemented") else: raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}") diff --git a/tilelang/tileop/gemm/gemm_mfma.py b/tilelang/tileop/gemm/gemm_mfma.py new file mode 100644 index 000000000..76d971317 --- /dev/null +++ b/tilelang/tileop/gemm/gemm_mfma.py @@ -0,0 +1,215 @@ +from .gemm_base import GemmBase +from tilelang.layout import make_swizzled_layout +from tilelang.intrinsics.mfma_macro_generator import ( + MatrixCoreIntrinEmitter,) +from tilelang.utils.language import is_shared, is_fragment +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify + + +class GemmMFMA(GemmBase): + + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, + False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mfma_emitter = MatrixCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + ) + + if self.is_gemm_ss(): + return { + self.A: make_swizzled_layout(self.A), + self.B: make_swizzled_layout(self.B), + self.C: mfma_emitter.make_mfma_store_layout(self.C), + } + elif self.is_gemm_sr(): + return { + self.A: make_swizzled_layout(self.A), + self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"), + self.C: mfma_emitter.make_mfma_store_layout(self.C), + } + elif self.is_gemm_rs(): + return { + self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"), + self.B: make_swizzled_layout(self.B), + self.C: mfma_emitter.make_mfma_store_layout(self.C), + } + elif self.is_gemm_rr(): + return { + self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"), + self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"), + self.C: mfma_emitter.make_mfma_store_layout(self.C), + } + else: + raise ValueError( + f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, + False) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mfma_emitter = MatrixCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + thread_var=thread_var, + ) + + in_dtype = self.in_dtype + warp_rows = mfma_emitter.warp_rows + warp_cols = mfma_emitter.warp_cols + local_size_a = mfma_emitter.local_size_a + local_size_b = mfma_emitter.local_size_b + block_K = mfma_emitter.chunk + micro_size_k = mfma_emitter.micro_size_k + A_shared = self.A + B_shared = self.B + C_local = self.C + + assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + + if self.is_gemm_ss(): + + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Matrix Core mfma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mfma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mfma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mfma_emitter.mfma(A_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True) + elif self.is_gemm_sr(): + B_local = self.B + + @T.prim_func + def _gemm_srr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Matrix Core mfma ops, + accumulating into C_local. + """ + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mfma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Perform Matrix Multiplication + mfma_emitter.mfma(A_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + # alloc_buffers body + # insert into parent block + return _Simplify(_gemm_srr, inline_let=True) + elif self.is_gemm_rs(): + A_local = self.A + + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Matrix Core mfma ops, + accumulating into C_local. + """ + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load B into fragment + mfma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mfma_emitter.mfma(A_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + elif self.is_gemm_rr(): + A_local = self.A + B_local = self.B + + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Matrix Core mfma ops, + accumulating into C_local. + """ + + for ki in T.serial(0, (block_K // micro_size_k)): + # Perform Matrix Multiplication + mfma_emitter.mfma(A_local, B_local, C_local, ki) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + else: + raise ValueError( + f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B)