diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 39a2ae4a2..b5f94ce61 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -5,13 +5,203 @@ import torch from torch import Generator, Tensor - from ..jit.core import CK_DIR, AITER_META_DIR, compile_ops from ..jit.utils.chip_info import get_gfx from ..jit.utils.torch_guard import torch_compile_guard from ..utility import dtypes +def compose_mha_fwd_variant_suffix_and_filter( + dtype: str, + logits_positive: bool, + has_bias: bool, + has_alibi: bool, + use_mask: bool, + return_lse: bool, + dropout_zero: bool, + skip_zero: bool, + has_qscale: bool, +): + dtype_token = f"_{dtype}" + logits_token = "_logits" if logits_positive else "_nlogits" + if has_bias: + bias_token = "_bias" + elif has_alibi: + bias_token = "_alibi" + else: + bias_token = "_nbias" + if use_mask: + mask_token = "_mask" + else: + mask_token = "_nmask" + if return_lse: + lse_token = "_lse" + else: + lse_token = "_nlse" + if dropout_zero: + dropout_token = "_ndropout" + else: + dropout_token = "_dropout" + if skip_zero: + skip_token = "_nskip" + else: + skip_token = "_skip" + if has_qscale: + qscale_token = "_nqscale" + else: + qscale_token = "_pertensor" + + suffix = ( + dtype_token + + logits_token + + bias_token + + mask_token + + lse_token + + dropout_token + + skip_token + + qscale_token + ) + + filt_base = "*" + dtype_pattern = f"{dtype}*" + if logits_positive: + logits_pattern = "_logits*" + else: + logits_pattern = "_nlogits*" + if has_bias: + bias_pattern = "_bias*" + elif has_alibi: + bias_pattern = "_alibi*" + else: + bias_pattern = "_nbias*" + if use_mask: + mask_pattern = "_mask*" + else: + mask_pattern = "_nmask*" + if return_lse: + lse_pattern = "_lse*" + else: + lse_pattern = "_nlse*" + if dropout_zero: + dropout_pattern = "_ndropout*" + else: + dropout_pattern = "_dropout*" + if skip_zero: + skip_pattern = "_nskip*" + else: + skip_pattern = "_skip*" + if has_qscale: + qscale_pattern = "_nqscale*" + else: + qscale_pattern = "_pertensor*" + + filt = ( + filt_base + + dtype_pattern + + logits_pattern + + bias_pattern + + mask_pattern + + lse_pattern + + dropout_pattern + + skip_pattern + + qscale_pattern + ) + return suffix, filt + + +def _parse_mha_varlen_fwd_md_name(md_name: str): + dtype = ( + "bf16" if "_bf16" in md_name else ("fp16" if "_fp16" in md_name else "fp8bf16") + ) + logits_positive = "_logits" in md_name and "_nlogits" not in md_name + has_bias = "_bias" in md_name + has_alibi = "_alibi" in md_name + use_mask = "_mask" in md_name and "_nmask" not in md_name + return_lse = "_lse" in md_name and "_nlse" not in md_name + dropout_zero = "_ndropout" in md_name + skip_zero = "_nskip" in md_name + has_qscale = "_nqscale" in md_name + return ( + dtype, + logits_positive, + has_bias, + has_alibi, + use_mask, + return_lse, + dropout_zero, + skip_zero, + has_qscale, + ) + + +def get_mha_varlen_prebuild_variants_by_names( + md_names, ck_dir: str, receipt: int = 200 +): + variants = [] + for md_name in md_names: + ( + dtype, + logits_positive, + has_bias, + has_alibi, + use_mask, + return_lse, + dropout_zero, + skip_zero, + has_qscale, + ) = _parse_mha_varlen_fwd_md_name(md_name) + suffix, filter_pattern = compose_mha_fwd_variant_suffix_and_filter( + dtype=dtype, + logits_positive=logits_positive, + has_bias=has_bias, + has_alibi=has_alibi, + use_mask=use_mask, + return_lse=return_lse, + dropout_zero=dropout_zero, + skip_zero=skip_zero, + has_qscale=has_qscale, + ) + blob_gen_cmd = [ + f"{ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --receipt {receipt} --filter {filter_pattern} --output_dir {{}}", + f'{ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --receipt {receipt} --filter " @ " --output_dir {{}}', + ] + variants.append( + {"md_name": f"mha_varlen_fwd{suffix}", "blob_gen_cmd": blob_gen_cmd} + ) + return variants + + +# prebuild variants helpers are defined in this file for setup to consume + + +def _parse_mha_varlen_fwd_md_name(md_name: str): + dtype = ( + "bf16" if "_bf16" in md_name else ("fp16" if "_fp16" in md_name else "fp8bf16") + ) + logits_positive = "_logits" in md_name and "_nlogits" not in md_name + has_bias = "_bias" in md_name + has_alibi = "_alibi" in md_name + use_mask = "_mask" in md_name and "_nmask" not in md_name + return_lse = "_lse" in md_name and "_nlse" not in md_name + dropout_zero = "_ndropout" in md_name + skip_zero = "_nskip" in md_name + has_qscale = "_nqscale" in md_name + return ( + dtype, + logits_positive, + has_bias, + has_alibi, + use_mask, + return_lse, + dropout_zero, + skip_zero, + has_qscale, + ) + + +# duplicate removed; single source of truth is the ck_dir-parameterized version above + + def cmdGenFunc_mha_fwd( q: Tensor, k: Tensor, @@ -292,65 +482,39 @@ def cmdGenFunc_mha_varlen_fwd( causal = False md_name = "mha_varlen_fwd" if block_table is None: - filter_fwd = "*" # get_fwd_blobs() if q.dtype == dtypes.fp16: - md_name += "_fp16" - filter_fwd += "fp16*" + dtype_tok = "fp16" elif q.dtype == dtypes.bf16: - md_name += "_bf16" - filter_fwd += "bf16*" + dtype_tok = "bf16" elif q.dtype == dtypes.fp8: if out is None or out.dtype == dtypes.bf16: - md_name += "_fp8bf16" - filter_fwd += "fp8bf16*" + dtype_tok = "fp8bf16" else: raise NotImplementedError("Unsupported output dtype for FP8 MHA") - if 0.0 < logits_soft_cap: - md_name += "_logits" - filter_fwd += "_logits*" else: - md_name += "_nlogits" - filter_fwd += "_nlogits*" - if bias is not None: - md_name += "_bias" - filter_fwd += "_bias*" - elif alibi_slopes is not None: - md_name += "_alibi" - filter_fwd += "_alibi*" - else: - md_name += "_nbias" - filter_fwd += "_nbias*" + raise NotImplementedError("Unsupported dtype") if not causal and window_size_left == -1 and window_size_right == -1: - md_name += "_nmask" - filter_fwd += "_nmask*" + use_mask = False else: - md_name += "_mask" - filter_fwd += "_mask*" - if return_softmax_lse: - md_name += "_lse" - filter_fwd += "_lse*" - else: - md_name += "_nlse" - filter_fwd += "_nlse*" - if dropout_p == 0: - md_name += "_ndropout" - filter_fwd += "_ndropout*" - else: - md_name += "_dropout" - filter_fwd += "_dropout*" - if min_seqlen_q == 0: - md_name += "_nskip" - filter_fwd += "_nskip*" - else: - md_name += "_skip" - filter_fwd += "_skip*" - if q_descale is None or k_descale is None or v_descale is None: - md_name += "_nqscale" - filter_fwd += "_nqscale*" - else: - # only support per-tensor quantization for now - md_name += "_pertensor" - filter_fwd += "_pertensor*" + use_mask = True + has_bias = bias is not None + has_alibi = alibi_slopes is not None + dropout_zero = dropout_p == 0 + skip_zero = min_seqlen_q == 0 + has_qscale = q_descale is None or k_descale is None or v_descale is None + logits_positive = 0.0 < logits_soft_cap + suffix, filter_fwd = compose_mha_fwd_variant_suffix_and_filter( + dtype=dtype_tok, + logits_positive=logits_positive, + has_bias=has_bias, + has_alibi=has_alibi, + use_mask=use_mask, + return_lse=return_softmax_lse, + dropout_zero=dropout_zero, + skip_zero=skip_zero, + has_qscale=has_qscale, + ) + md_name += suffix blob_gen_cmd = [ f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd " "--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd) @@ -967,11 +1131,7 @@ def cmdGenFunc_mha_batch_prefill( return_softmax_lse: bool, return_dropout_randval: bool, out: Optional[Tensor] = None, - bias: Optional[Tensor] = None, alibi_slopes: Optional[Tensor] = None, - q_descale: Optional[Tensor] = None, - k_descale: Optional[Tensor] = None, - v_descale: Optional[Tensor] = None, gen: Optional[Generator] = None, ): # causal=true is the same as causal=false in this case @@ -986,12 +1146,6 @@ def cmdGenFunc_mha_batch_prefill( elif q.dtype == torch.bfloat16: md_name += "_bf16" filter_fwd += "bf16*" - elif q.dtype == dtypes.fp8: - if out is None or out.dtype == dtypes.bf16: - md_name += "_fp8bf16" - filter_fwd += "fp8bf16*" - else: - raise NotImplementedError("Unsupported output dtype for FP8 MHA") if 0.0 < logits_soft_cap: md_name += "_logits" filter_fwd += "_logits*" @@ -1022,17 +1176,11 @@ def cmdGenFunc_mha_batch_prefill( else: md_name += "_dropout" filter_fwd += "_dropout*" - if q_descale is None or k_descale is None or v_descale is None: - md_name += "_nqscale" - filter_fwd += "_nqscale*" - else: - # only support per-tensor quantization for now - md_name += "_pertensor" - filter_fwd += "_pertensor*" blob_gen_cmd = [ f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d batch_prefill " "--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd) ] + return { "md_name": md_name, "blob_gen_cmd": blob_gen_cmd, @@ -2597,9 +2745,6 @@ def mha_batch_prefill_fake_tensors( return_dropout_randval: bool, out: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, gen: Optional[Generator] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -2664,11 +2809,7 @@ def mha_batch_prefill( return_softmax_lse: bool, return_dropout_randval: bool, out: Optional[Tensor] = None, - bias: Optional[Tensor] = None, alibi_slopes: Optional[Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, gen: Optional[Generator] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -2688,15 +2829,11 @@ def _mha_batch_prefill( logits_soft_cap: float = 0.0, window_size_left: int = -1, window_size_right: int = -1, - bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, return_lse: bool = False, return_softmax: bool = False, zero_tensors: bool = False, out: torch.Tensor = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] @@ -2719,11 +2856,8 @@ def _mha_batch_prefill( return_lse, return_softmax, out, - bias, alibi_slopes, - q_descale, - k_descale, - v_descale, + None, # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return out, softmax_lse, S_dmask, rng_state @@ -2748,9 +2882,6 @@ def mha_batch_prefill_func( return_lse=False, return_attn_probs=False, out=None, - q_descale=None, - k_descale=None, - v_descale=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -2780,9 +2911,6 @@ def mha_batch_prefill_func( return_lse=return_lse, return_softmax=return_attn_probs and dropout_p > 0, out=out, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, ) out = out_padded[..., :head_size_v_og] diff --git a/setup.py b/setup.py index c70445943..f8fd7ecfd 100644 --- a/setup.py +++ b/setup.py @@ -99,13 +99,16 @@ def get_exclude_ops(): for module in all_modules: if PREBUILD_KERNELS == 1: - # Exclude mha, _tune, and specific module if ( - "mha" in module - or "_tune" in module + "_tune" in module or module == "module_gemm_mi350_a8w8_blockscale_asm" ): exclude_ops.append(module) + if "mha" in module and module not in [ + "module_fmha_v3_fwd", + "module_fmha_v3_varlen_fwd", + ]: + exclude_ops.append(module) elif PREBUILD_KERNELS == 2: # Exclude _bwd, _tune, and specific module if ( @@ -158,6 +161,50 @@ def get_exclude_ops(): except Exception: pass + if PREBUILD_KERNELS == 1: + base_args = core.get_args_of_build("module_mha_varlen_fwd") + if isinstance(base_args, dict) and base_args.get("srcs"): + + import re + + _mha_path = os.path.join(this_dir, "aiter", "ops", "mha.py") + with open(_mha_path, "r", encoding="utf-8") as f: + _src = f.read() + + def _extract_def(src, name): + pat = re.compile(rf"^def\s+{name}\s*\(.*?\):", re.M | re.S) + m = pat.search(src) + if not m: + raise RuntimeError(f"Failed to extract function: {name}") + start = m.start() + pat_next = re.compile(r"^(def|class)\s+", re.M) + m2 = pat_next.search(src, m.end()) + end = m2.start() if m2 else len(src) + return src[start:end] + + blocks = [] + for fn in [ + "compose_mha_fwd_variant_suffix_and_filter", + "_parse_mha_varlen_fwd_md_name", + "get_mha_varlen_prebuild_variants_by_names", + ]: + blocks.append(_extract_def(_src, fn)) + _ns = {} + exec("\n\n".join(blocks), _ns) + get_variants_by_names = _ns[ + "get_mha_varlen_prebuild_variants_by_names" + ] + + md_names = [ + "mha_varlen_fwd_bf16_nlogits_nbias_mask_nlse_ndropout_nskip_nqscale", + "mha_varlen_fwd_bf16_nlogits_nbias_nmask_lse_ndropout_nskip_nqscale", + ] + for v in get_variants_by_names(md_names, ck_dir): + variant_args = dict(base_args) + variant_args["md_name"] = v["md_name"] + variant_args["blob_gen_cmd"] = v["blob_gen_cmd"] + all_opts_args_build.append(variant_args) + def build_one_module(one_opt_args): flags_cc = list(one_opt_args["flags_extra_cc"]) + [ f"-DPREBUILD_KERNELS={PREBUILD_KERNELS}" @@ -166,12 +213,13 @@ def build_one_module(one_opt_args): f"-DPREBUILD_KERNELS={PREBUILD_KERNELS}" ] + blob_gen_cmd = one_opt_args["blob_gen_cmd"] core.build_module( md_name=one_opt_args["md_name"], srcs=one_opt_args["srcs"], flags_extra_cc=flags_cc, flags_extra_hip=flags_hip, - blob_gen_cmd=one_opt_args["blob_gen_cmd"], + blob_gen_cmd=blob_gen_cmd, extra_include=one_opt_args["extra_include"], extra_ldflags=None, verbose=False, @@ -190,7 +238,6 @@ def build_one_module(one_opt_args): with ThreadPoolExecutor(max_workers=prebuid_thread_num) as executor: list(executor.map(build_one_module, all_opts_args_build)) - else: raise NotImplementedError("Only ROCM is supported")