diff --git a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py index e092b3fd0b..8b7ce9f54d 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py @@ -124,7 +124,7 @@ def _gemm_afp4_wfp4_kernel( for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): a_scales = tl.load(a_scale_ptrs) - b_scales = tl.load(b_scale_ptrs) + b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. @@ -297,14 +297,36 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): - a_scales = tl.load(a_scale_ptrs) - b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) - if BLOCK_SIZE_M >= 32: - a_scales = tl.reshape( - a_scales, (BLOCK_SIZE_M, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + if BLOCK_SIZE_M < 32: + a_scales = tl.load(a_scale_ptrs) + else: + a_scales = ( + tl.load(a_scale_ptrs) + .reshape( + BLOCK_SIZE_M // 32, + BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8, + 4, + 16, + 2, + 2, + 1, + ) + .permute(0, 5, 3, 1, 4, 2, 6) + .reshape(BLOCK_SIZE_M, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + b_scales = ( + tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + .reshape( + BLOCK_SIZE_N // 32, + BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8, + 4, + 16, + 2, + 2, + 1, ) - b_scales = tl.reshape( - b_scales, (BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + .permute(0, 5, 3, 1, 4, 2, 6) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE) ) # Load the next block of A and B, generate a mask by checking the K dimension. @@ -346,6 +368,222 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales( tl.store(c_ptrs, c, mask=c_mask, cache_modifier=".wt") +@triton.heuristics( + { + "EVEN_K": lambda args: (args["K"] % (args["BLOCK_SIZE_K"] // 2) == 0) + and (args["SPLITK_BLOCK_SIZE"] % args["BLOCK_SIZE_K"] == 0) + and (args["K"] % (args["SPLITK_BLOCK_SIZE"] // 2) == 0), + "GRID_MN": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"]) + * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), + } +) +@triton.jit +def _gemm_afp4_wfp4_kernel_preshuffled_weight_scales( + a_ptr, + b_ptr, + c_ptr, + a_scales_ptr, + b_scales_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_ck, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bsn, + stride_bsk, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, + EVEN_K: tl.constexpr, + GRID_MN: tl.constexpr, + cache_modifier: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A and B inputs are in the microscale fp4 (mxfp4) format. + A_scales and B_scales are in e8m0 format. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_asm > 0) + tl.assume(stride_ask > 0) + tl.assume(stride_bsk > 0) + tl.assume(stride_bsn > 0) + + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid_unified = tl.program_id(axis=0) + pid_unified = remap_xcd(pid_unified, GRID_MN * NUM_KSPLIT, NUM_XCDS=8) + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + if NUM_KSPLIT == 1: + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + # We assume 32 elements along K share the same scale. + SCALE_GROUP_SIZE: tl.constexpr = 32 + + if (pid_k * SPLITK_BLOCK_SIZE // 2) < K: + + num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE // 2, BLOCK_SIZE_K // 2) + + # Create pointers for first block of A and B input matrices + # The BLOCK sizes are of the elements and in fp4 we pack 2 per uint8 container. + offs_k = tl.arange(0, BLOCK_SIZE_K // 2) + offs_k_shuffle_arr = tl.arange(0, (BLOCK_SIZE_K // 2) * 16) + offs_k_split = pid_k * (SPLITK_BLOCK_SIZE // 2) + offs_k + offs_k_shuffle = pid_k * (SPLITK_BLOCK_SIZE // 2) * 16 + offs_k_shuffle_arr + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * (BLOCK_SIZE_N // 16) + tl.arange(0, BLOCK_SIZE_N // 16)) % N + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak + ) + b_ptrs = b_ptr + ( + # offs_k_split[:, None] * stride_bk + offs_bn[None, :] * stride_bn + offs_bn[:, None] * stride_bn + + offs_k_shuffle[None, :] * stride_bk + ) + # Create pointers for the first block of A and B scales + + offs_asn = ( + pid_n * (BLOCK_SIZE_N // 32) + tl.arange(0, (BLOCK_SIZE_N // 32)) + ) % N + offs_ks = (pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) * 32) + tl.arange( + 0, BLOCK_SIZE_K // SCALE_GROUP_SIZE * 32 + ) + # B scales are N x K even though B operand is K x N. + b_scale_ptrs = ( + b_scales_ptr + + offs_asn[:, None] * stride_bsn + + offs_ks[None, :] * stride_bsk + ) + + if BLOCK_SIZE_M < 32: + offs_ks_non_shufl = ( + pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) + ) + tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + a_scale_ptrs = ( + a_scales_ptr + + offs_am[:, None] * stride_asm + + offs_ks_non_shufl[None, :] * stride_ask + ) + else: + offs_asm = ( + pid_m * (BLOCK_SIZE_M // 32) + tl.arange(0, (BLOCK_SIZE_M // 32)) + ) % M + a_scale_ptrs = ( + a_scales_ptr + + offs_asm[:, None] * stride_asm + + offs_ks[None, :] * stride_ask + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + if BLOCK_SIZE_M < 32: + a_scales = tl.load(a_scale_ptrs) + else: + a_scales = ( + tl.load(a_scale_ptrs) + .reshape( + BLOCK_SIZE_M // 32, + BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8, + 4, + 16, + 2, + 2, + 1, + ) + .permute(0, 5, 3, 1, 4, 2, 6) + .reshape(BLOCK_SIZE_M, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + + b_scales = ( + tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + .reshape( + BLOCK_SIZE_N // 32, + BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8, + 4, + 16, + 2, + 2, + 1, + ) + .permute(0, 5, 3, 1, 4, 2, 6) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs, cache_modifier=cache_modifier) + + b = ( + b.reshape( + 1, + BLOCK_SIZE_N // 16, + BLOCK_SIZE_K // 64, + 2, + 16, + 16, + ) + .permute(0, 1, 4, 2, 3, 5) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // 2) + .trans(1, 0) + ) + + accumulator += tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1") + + # Advance the ptrs to the next K block. + a_ptrs += (BLOCK_SIZE_K // 2) * stride_ak + b_ptrs += (BLOCK_SIZE_K // 2) * 16 * stride_bk + if BLOCK_SIZE_M < 32: + a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask + else: + a_scale_ptrs += BLOCK_SIZE_K * stride_ask + b_scale_ptrs += BLOCK_SIZE_K * stride_bsk + + c = accumulator.to(c_ptr.type.element_ty) + + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + c_ptrs = ( + c_ptr + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + + pid_k * stride_ck + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask, cache_modifier=".wt") + + @triton.jit def _gemm_afp4_wfp4_reduce_kernel( c_in_ptr, @@ -398,29 +636,34 @@ def _get_config( M: int, N: int, K: int, + shuffle: bool = False, ): - if not hasattr(_get_config, "_config_dict"): + shuffle_filename_suffix = "" if not shuffle else "_PRESHUFFLED" + if not hasattr(_get_config, "_config_dict") or not hasattr( + _get_config._config_dict, f"default{shuffle_filename_suffix}" + ): dev = arch_info.get_device() _get_config._config_dict = {} - fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-AFP4WFP4.json" + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-AFP4WFP4{shuffle_filename_suffix}.json" with open(fpath, "r") as file: config = json.load(file) - _get_config._config_dict["default"] = config + _get_config._config_dict[f"default{shuffle_filename_suffix}"] = config - key = f"{N}_{K}" + key = f"{N}_{K}{shuffle_filename_suffix}" if key not in _get_config._config_dict.keys(): dev = arch_info.get_device() - fpath = ( - f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-AFP4WFP4-N={N}-K={2*K}.json" - ) + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-AFP4WFP4{shuffle_filename_suffix}-N={N}-K={2*K}.json" if os.path.exists(fpath): with open(fpath, "r") as file: config = json.load(file) _get_config._config_dict[key] = config else: - key = "default" # fall back to default config + key = f"default{shuffle_filename_suffix}" # fall back to default config if M < 32: + BLK_M = triton.next_power_of_2(M) + if BLK_M >= 16 and "small_M16" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["small_M16"] return _get_config._config_dict[key]["small"] elif M <= 128: BLK_M = triton.next_power_of_2(M) diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=106496-K=16384.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=106496-K=16384.json new file mode 100644 index 0000000000..9895a4dc51 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=106496-K=16384.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=16384-K=16384.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=16384-K=16384.json new file mode 100644 index 0000000000..7174492af2 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=16384-K=16384.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=16384-K=53248.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=16384-K=53248.json new file mode 100644 index 0000000000..6e56e82027 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=16384-K=53248.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=18432-K=16384.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=18432-K=16384.json new file mode 100644 index 0000000000..13fe8985da --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=18432-K=16384.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED.json new file mode 100644 index 0000000000..eebf5eff50 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED.json @@ -0,0 +1,87 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } + +} diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index c879a91b9e..a239076ae0 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -10,9 +10,11 @@ from aiter.ops.triton._triton_kernels.gemm_afp4wfp4 import ( _gemm_afp4_wfp4_kernel, _gemm_afp4_wfp4_kernel_preshuffled_scales, + _gemm_afp4_wfp4_kernel_preshuffled_weight_scales, _gemm_afp4_wfp4_reduce_kernel, _get_config, ) +from .utils.core import AITER_TRITON_CONFIGS_PATH _LOGGER = AiterTritonLogger() @@ -202,7 +204,7 @@ def gemm_afp4wfp4_preshuffled_scales( Key parameters: - - X: Matrix X with shape (M, K). + - X: Matrix X with shape (M, K). M >= 32 is required - W: Matrix W with shape (N, K). - X_scales: Matrix with shape (M // 32, K) - W_scales: Matrix with shape (N // 32, K) @@ -219,6 +221,8 @@ def gemm_afp4wfp4_preshuffled_scales( # Transpose w w = w.T + assert M >= 32, f"M >= 32 is required, but got {M=}" + if y is None: y = torch.empty((M, N), dtype=dtype, device=x.device) @@ -308,7 +312,175 @@ def gemm_afp4wfp4_preshuffled_scales( REDUCE_BLOCK_SIZE_M, REDUCE_BLOCK_SIZE_N, ACTUAL_KSPLIT, - config["NUM_KSPLIT"], + triton.next_power_of_2(config["NUM_KSPLIT"]), + ) + + return y + + +def gemm_afp4wfp4_preshuffled_weight_scales( + x, + w, + x_scales, + w_scales, + dtype: Optional[float] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, +): + """ + Computes the matmul Y = X x W + X and W are e2m1 fp4 tensors. + x_scales and w_scales are e8m0 tensors. + Every 32 elements in the K dimension share one e8m0 scale. + + + Key parameters: + - X: Matrix X with shape (M, K). + - W: Matrix W with shape (N, K). + - X_scales: Matrix with shape (M // 32, K) + - W_scales: Matrix with shape (N // 32, K) + + Returns: + - Y: The output matrix with shape (M, N). + """ + + assert arch_info.is_fp4_avail(), "MXFP4 is not available on your device" + + M, K = x.shape + N, K = w.shape + N = N * 16 + K = K // 16 + + if y is None: + y = torch.empty((M, N), dtype=dtype, device=x.device) + + if config is None: + config = _get_config(M, N, K, True) + + if config["NUM_KSPLIT"] > 1: + SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( + K, config["BLOCK_SIZE_K"], config["NUM_KSPLIT"] + ) + + config["SPLITK_BLOCK_SIZE"] = SPLITK_BLOCK_SIZE + config["BLOCK_SIZE_K"] = BLOCK_SIZE_K + config["NUM_KSPLIT"] = NUM_KSPLIT + + if _USE_GEMM_SPLITK_BF16: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), dtype=y.dtype, device=y.device + ) + else: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), dtype=torch.float32, device=y.device + ) + else: + config["SPLITK_BLOCK_SIZE"] = 2 * K + y_pp = None + + if config["BLOCK_SIZE_K"] >= 2 * K: + config["BLOCK_SIZE_K"] = triton.next_power_of_2(2 * K) + config["SPLITK_BLOCK_SIZE"] = 2 * K + + config["BLOCK_SIZE_N"] = max(config["BLOCK_SIZE_N"], 32) + if M < 32: + assert ( + config["BLOCK_SIZE_M"] <= 16 + ), "for M < 32, BLOCK_SIZE_M must be 16 or less as x_scale are assumed to be un-shuffled" + else: + assert ( + config["BLOCK_SIZE_M"] >= 32 + ), "for M >= 32, BLOCK_SIZE_M must be 32 or more as x_scale are assumed to be preshuffled" + + grid = lambda META: ( # noqa: E731 + ( + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]) + ), + ) + + import os + from aiter.utility.triton.triton_metadata_redirect import AOTMetadataContext + + metadata_pth = f"{AITER_TRITON_CONFIGS_PATH}/gemm/aot/{_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.fn.__name__}_M={M}-N={N}-K={K*2}" + if os.path.exists(metadata_pth): + with AOTMetadataContext( + _gemm_afp4_wfp4_kernel_preshuffled_weight_scales.fn.__name__, + f"{metadata_pth}", + ): + _gemm_afp4_wfp4_kernel_preshuffled_weight_scales[grid]( + x, + w, + y if config["NUM_KSPLIT"] == 1 else y_pp, + x_scales, + w_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w.stride(0), + w.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + x_scales.stride(0), + x_scales.stride(1), + w_scales.stride(0), + w_scales.stride(1), + **config, + ) + else: + _gemm_afp4_wfp4_kernel_preshuffled_weight_scales[grid]( + x, + w, + y if config["NUM_KSPLIT"] == 1 else y_pp, + x_scales, + w_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w.stride(0), + w.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + x_scales.stride(0), + x_scales.stride(1), + w_scales.stride(0), + w_scales.stride(1), + **config, + ) + + if config["NUM_KSPLIT"] > 1: + REDUCE_BLOCK_SIZE_M = 16 + # TODO: Need to debug - REDUCE_BLOCK_SIZE_N=128 with fp32 partials fails + # NOTE: REDUCE_BLOCK_SIZE_N=16 gives best perf with fp32 partials and + # REDUCE_BLOCK_SIZE_N=128 gives best perf with bf16 partials + REDUCE_BLOCK_SIZE_N = 128 if _USE_GEMM_SPLITK_BF16 else 64 + ACTUAL_KSPLIT = triton.cdiv(K, (config["SPLITK_BLOCK_SIZE"] // 2)) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N, REDUCE_BLOCK_SIZE_N), + ) + _gemm_afp4_wfp4_reduce_kernel[grid_reduce]( + y_pp, + y, + M, + N, + y_pp.stride(0), + y_pp.stride(1), + y_pp.stride(2), + y.stride(0), + y.stride(1), + REDUCE_BLOCK_SIZE_M, + REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT, + triton.next_power_of_2(config["NUM_KSPLIT"]), ) return y diff --git a/op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py b/op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py index 4d988e89e4..88f91841f1 100644 --- a/op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py +++ b/op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py @@ -6,6 +6,7 @@ from aiter.ops.triton.gemm_afp4wfp4 import ( gemm_afp4wfp4, gemm_afp4wfp4_preshuffled_scales, + gemm_afp4wfp4_preshuffled_weight_scales, ) from op_tests.triton_tests.test_gemm_afp4wfp4 import generate_gemm_afp4wfp4_inputs from op_tests.op_benchmarks.triton.utils.argparse import ( @@ -21,20 +22,18 @@ ) import aiter.ops.triton.utils._triton.arch_info as arch_info -TRITON_HIP_PRESHUFFLE_SCALES = ( - os.environ.get("TRITON_HIP_PRESHUFFLE_SCALES", "0") == "1" -) - -def bench_gemm_fn(M: int, N: int, K: int, metric: str, layout: str): +def bench_gemm_fn(M: int, N: int, K: int, metric: str, layout: str, shuffle: bool): c_dtype = torch.bfloat16 - x, w, _, _, x_scale, w_scale, _, y = generate_gemm_afp4wfp4_inputs( + x, _, w, _, _, x_scale, w_scale, _, y = generate_gemm_afp4wfp4_inputs( M, N, K, c_dtype, layout=layout, output=True, + shuffle_scales_fg=shuffle, + shuffle_weight_fg=shuffle, ) # flops flops = 2.0 * M * N * K @@ -46,11 +45,10 @@ def bench_gemm_fn(M: int, N: int, K: int, metric: str, layout: str): ) mem_write = (M * N) * 2 # TODO: Fix for c_dtype != bf16 mem = mem_read + mem_write - - if TRITON_HIP_PRESHUFFLE_SCALES: + if shuffle: ms = triton.testing.do_bench( - lambda: gemm_afp4wfp4_preshuffled_scales( - x, w, x_scale, w_scale, c_dtype, y + lambda: gemm_afp4wfp4_preshuffled_weight_scales( + x, w, x_scale, w_scale, c_dtype, y # , config=config ), warmup=25, rep=100, @@ -101,7 +99,7 @@ def run_benchmark(args, defaults): def run_model_benchmark(args): - benchmark = get_model_benchmark_object(get_caller_name_no_ext(), args) + benchmark = get_model_benchmark_object("GEMM MXFP4 x MXFP4 Benchmark", args) @triton.testing.perf_report([benchmark]) def bench_gemm_afp4wfp4( @@ -119,17 +117,17 @@ def bench_gemm_afp4wfp4( # Divide K by tensor parallel K = math.ceil(K / args.tp) - return bench_gemm_fn(M, N, K, metric, args.layout) + return bench_gemm_fn(M, N, K, metric, args.layout, args.shuffle) bench_gemm_afp4wfp4.run(save_path="." if args.o else None, print_data=True) def run_shape_benchmark(args): - benchmark = get_shape_benchmark_object(get_caller_name_no_ext(), args) + benchmark = get_shape_benchmark_object("GEMM MXFP4 x MXFP4 Benchmark", args) @triton.testing.perf_report([benchmark]) def bench_gemm_afp4wfp4(M, N, K, metric, model_name=None, **kwargs): - return bench_gemm_fn(M, N, K, metric, args.layout) + return bench_gemm_fn(M, N, K, metric, args.layout, args.shuffle) bench_gemm_afp4wfp4.run(save_path="." if args.o else None, print_data=True) @@ -137,6 +135,9 @@ def bench_gemm_afp4wfp4(M, N, K, metric, model_name=None, **kwargs): def parse_args(): parser = get_parser("MXFP4 x MXFP4 GEMM") parser = add_argparse_ff(parser) + parser.add_argument( + "--shuffle", action="store_true", help="Preshuffle weight and scales" + ) return get_ff_args(parser) @@ -149,7 +150,7 @@ def main(): if args.print_vgpr: print("Retrieving VGPR usage for Triton kernels...") fun = lambda: run_benchmark(args, defaults) # noqa: E731 - print_vgpr(fun, get_caller_name_no_ext()) + print_vgpr(fun, "GEMM") return 0 run_benchmark(args, defaults) diff --git a/op_tests/triton_tests/test_gemm_afp4wfp4.py b/op_tests/triton_tests/test_gemm_afp4wfp4.py index e517144656..bddccfb135 100644 --- a/op_tests/triton_tests/test_gemm_afp4wfp4.py +++ b/op_tests/triton_tests/test_gemm_afp4wfp4.py @@ -6,13 +6,11 @@ from aiter.ops.triton.gemm_afp4wfp4 import ( gemm_afp4wfp4, gemm_afp4wfp4_preshuffled_scales, + gemm_afp4wfp4_preshuffled_weight_scales, ) import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton.utils.types import str_to_torch_dtype - -TRITON_HIP_PRESHUFFLE_SCALES = ( - os.environ.get("TRITON_HIP_PRESHUFFLE_SCALES", "0") == "1" -) +from aiter.ops.shuffle import shuffle_weight def shuffle_scales(scales: torch.Tensor): @@ -28,7 +26,21 @@ def shuffle_scales(scales: torch.Tensor): SCALE_GROUP_SIZE = 32 -def generate_gemm_afp4wfp4_inputs(M, N, K, dtype, layout="TN", output=True): +def generate_gemm_afp4wfp4_inputs( + M, + N, + K, + dtype, + layout="TN", + output=True, + shuffle_weight_fg=False, + shuffle_scales_fg=False, +): + if shuffle_weight_fg: + assert ( + shuffle_scales_fg + ), "weight shuffling is only supported with scale shuffling" + torch.manual_seed(5) if isinstance(dtype, str): dtype = str_to_torch_dtype[dtype] @@ -55,7 +67,7 @@ def generate_gemm_afp4wfp4_inputs(M, N, K, dtype, layout="TN", output=True): w = w_low | w_high << 4 # Scale of 1.0 in e8m0, bias 127. - if M >= 32 and TRITON_HIP_PRESHUFFLE_SCALES: + if M >= 32 and shuffle_scales_fg: M_pad = (M + 255) // 256 * 256 else: M_pad = M @@ -67,7 +79,7 @@ def generate_gemm_afp4wfp4_inputs(M, N, K, dtype, layout="TN", output=True): ) x_scales = x_scales.T w_scales = w_scales.T - if TRITON_HIP_PRESHUFFLE_SCALES: + if shuffle_scales_fg: if M >= 32: x_scales_shuffled = shuffle_scales(x_scales) else: @@ -77,9 +89,21 @@ def generate_gemm_afp4wfp4_inputs(M, N, K, dtype, layout="TN", output=True): x_scales_shuffled = x_scales w_scales_shuffled = w_scales + if shuffle_weight_fg: + use_int4 = False + weight_shuffle_layout = (16, 16) + w_shuffed = shuffle_weight( + w, layout=weight_shuffle_layout, use_int4=use_int4 + ).reshape( + w.shape[0] // weight_shuffle_layout[0], + w.shape[1] * weight_shuffle_layout[0], + ) + else: + w_shuffed = w + y = None if output: - y = torch.empty((M, N), dtype=dtype, device="cuda") + y = torch.empty((M, N), dtype=dtype).cuda() out_dtype = (None,) else: out_dtype = dtype @@ -87,6 +111,7 @@ def generate_gemm_afp4wfp4_inputs(M, N, K, dtype, layout="TN", output=True): return ( x, w, + w_shuffed, x_scales[:M], w_scales, x_scales_shuffled, @@ -133,6 +158,10 @@ def get_x_vals(): x_vals += [(16, 16384, 3328 * 2), (128, 16384, 3328 * 2)] x_vals += [(256, 3584, 2112)] x_vals += [(7, 4608, 7168), (7, 7168, 2304)] + x_vals += [(v, 106496, 16384) for v in [1, 8, 16, 32, 64, 128, 256]] + x_vals += [(v, 16384, 53248) for v in [1, 8, 16, 32, 64, 128, 256]] + x_vals += [(v, 18432, 16384) for v in [1, 8, 16, 32, 64, 128, 256]] + x_vals += [(v, 16384, 16384) for v in [1, 8, 16, 32, 64, 128, 256]] x_vals += [(1, 1, 32)] # minimal case return x_vals @@ -188,48 +217,82 @@ def run_torch(x, w, x_scales, w_scales, dtype): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("layout", ["TN", "TT", "NN", "NT"]) @pytest.mark.parametrize("output", [True, False]) -def test_gemm_afp4_wfp4(M: int, N: int, K: int, dtype, layout, output): +@pytest.mark.parametrize( + "shuffle_scales_fg, shuffle_weight_fg", + # [(False, False), (True, False), (True, True)], + [(True, True)], +) +def test_gemm_afp4_wfp4( + M: int, N: int, K: int, dtype, layout, output, shuffle_scales_fg, shuffle_weight_fg +): if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") - torch.cuda.empty_cache() # Helps avoid hangs in large tests + if shuffle_weight_fg and not shuffle_scales_fg: + pytest.skip("Preshuffling weight without preshuffled scales is not supported") + + if shuffle_weight_fg or shuffle_scales_fg: + if shuffle_scales_fg and not shuffle_weight_fg and M < 32: + pytest.skip("Minimal tile size for preshuffled scales is 32x32x256") - if TRITON_HIP_PRESHUFFLE_SCALES: if N % 32 > 0: pytest.skip( - f"N = {N} is not divisible by 32, skip this test for preshuffled scales tests" + f"N = {N} is not divisible by 32, skip this test for preshuffled weight/scales tests" ) elif K % 256 > 0: pytest.skip( - f"K = {K} is not divisible by 256, skip this test for preshuffled scales tests" + f"K = {K} is not divisible by 256, skip this test for preshuffled weight/scales tests" ) ( x, w, + w_triton, x_scales, w_scales, x_scales_triton, w_scales_triton, out_dtype, y, - ) = generate_gemm_afp4wfp4_inputs(M, N, K, dtype, layout=layout, output=output) + ) = generate_gemm_afp4wfp4_inputs( + M, + N, + K, + dtype, + layout=layout, + output=output, + shuffle_scales_fg=shuffle_scales_fg, + shuffle_weight_fg=shuffle_weight_fg, + ) torch_out = run_torch(x, w, x_scales, w_scales, dtype).to(dtype) - if TRITON_HIP_PRESHUFFLE_SCALES: + if shuffle_scales_fg and shuffle_weight_fg: + if output: + triton_out = gemm_afp4wfp4_preshuffled_weight_scales( + x, w_triton, x_scales_triton, w_scales_triton, dtype, y + ) + else: + triton_out = gemm_afp4wfp4_preshuffled_weight_scales( + x, w_triton, x_scales_triton, w_scales_triton, dtype + ) + elif shuffle_scales_fg and not shuffle_weight_fg: if output: triton_out = gemm_afp4wfp4_preshuffled_scales( - x, w, x_scales_triton, w_scales_triton, dtype, y + x, w_triton, x_scales_triton, w_scales_triton, dtype, y ) else: triton_out = gemm_afp4wfp4_preshuffled_scales( - x, w, x_scales_triton, w_scales_triton, dtype + x, w_triton, x_scales_triton, w_scales_triton, dtype ) else: if output: - triton_out = gemm_afp4wfp4(x, w, x_scales_triton, w_scales_triton, dtype, y) + triton_out = gemm_afp4wfp4( + x, w_triton, x_scales_triton, w_scales_triton, dtype, y + ) else: - triton_out = gemm_afp4wfp4(x, w, x_scales_triton, w_scales_triton, dtype) + triton_out = gemm_afp4wfp4( + x, w_triton, x_scales_triton, w_scales_triton, dtype + ) torch.testing.assert_close(torch_out, triton_out)