diff --git a/aiter/ops/triton/gluon/pa_decode_gluon.py b/aiter/ops/triton/gluon/pa_decode_gluon.py index 9c07d4c4ac..8c6194cc97 100644 --- a/aiter/ops/triton/gluon/pa_decode_gluon.py +++ b/aiter/ops/triton/gluon/pa_decode_gluon.py @@ -620,7 +620,6 @@ def paged_attention_decode_v2_gluon_large_block_dot_kernel( warps_per_cta=[4, 1], order=[1, 0], ) - shared_query_layout: gl.constexpr = gl.SwizzledSharedLayout(8, 1, 16, order=[1, 0]) # Key cache layout - optimized for CDNA3 architecture blocked_key_layout: gl.constexpr = gl.BlockedLayout( @@ -798,9 +797,6 @@ def paged_attention_decode_v2_gluon_large_block_dot_kernel( query_tensor = gl.amd.cdna3.buffer_load( ptr=query_ptr, offsets=query_offsets_base, mask=query_mask ) - query_shared = gl.allocate_shared_memory( - query_tensor.dtype, query_tensor.shape, shared_query_layout, query_tensor - ) # ==================== Query Quantization Scale Handling ==================== if QUERY_QUANT_MODE == 0: @@ -969,7 +965,6 @@ def paged_attention_decode_v2_gluon_large_block_dot_kernel( # Convert layouts for MFMA operation query_converted = gl.convert_layout(query_tensor, layout=qk_lhs_layout) - # query_converted = query_shared.load(qk_lhs_layout) key_converted = gl.convert_layout(key_block, layout=qk_rhs_layout) query_converted = query_converted.to(COMPUTE_TYPE) key_converted = key_converted.to(COMPUTE_TYPE) @@ -1936,11 +1931,8 @@ def paged_attention_decode_v2_gluon_dot_kernel( else: OUTPUT_DTYPE: gl.constexpr = COMPUTE_TYPE LOG2_E: gl.constexpr = 1.4426950408889634 # log2(e) for exponential conversion - CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD: gl.constexpr = KV_16B_ELEMENT_COUNT - K_HEAD_SIZE_SPLITS: gl.constexpr = ( - HEAD_SIZE_POW2 // CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD - ) + K_HEAD_SIZE_SPLITS: gl.constexpr = HEAD_SIZE_POW2 // KV_16B_ELEMENT_COUNT MAX_NUM_KV_BLOCKS_PER_COMPUTE: gl.constexpr = KV_COMPUTE_BLOCK_SIZE // KV_BLOCK_SIZE # ==================== MEMORY LAYOUT DEFINITIONS ==================== @@ -1951,16 +1943,31 @@ def paged_attention_decode_v2_gluon_dot_kernel( warps_per_cta=[4, 1], order=[1, 0], ) - shared_query_layout: gl.constexpr = gl.SwizzledSharedLayout(8, 1, 16, order=[1, 0]) + shared_query_layout: gl.constexpr = gl.SwizzledSharedLayout( + KV_16B_ELEMENT_COUNT, 1, 16, order=[1, 0] + ) # Key cache layout - optimized for block-wise access patterns - blocked_key_layout: gl.constexpr = gl.BlockedLayout( - size_per_thread=[1, 1, 1, CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD], + blocked_key_layout_fp8: gl.constexpr = gl.BlockedLayout( + size_per_thread=[1, 1, 1, KV_16B_ELEMENT_COUNT], threads_per_warp=[1, 4, 16, 1], warps_per_cta=[4, 1, 1, 1], order=[3, 2, 1, 0], ) + key_warps_per_cta_f16: gl.constexpr = ( + [4, 1, 1, 1] if KV_BLOCK_SIZE == 16 else [1, 1, 4, 1] + ) + blocked_key_layout_f16: gl.constexpr = gl.BlockedLayout( + size_per_thread=[1, 1, 1, KV_16B_ELEMENT_COUNT], + threads_per_warp=[1, 4, 16, 1], + warps_per_cta=key_warps_per_cta_f16, + order=[3, 2, 1, 0], + ) + blocked_key_layout: gl.constexpr = ( + blocked_key_layout_fp8 if KV_16B_ELEMENT_COUNT == 16 else blocked_key_layout_f16 + ) + DOT_QK_K_WIDTH: gl.constexpr = KV_16B_ELEMENT_COUNT # QK Matrix multiplication layout using AMD MFMA instructions qk_mfma_layout: gl.constexpr = gl.amd.AMDMFMALayout( version=CDNA_VERSION, @@ -1969,10 +1976,10 @@ def paged_attention_decode_v2_gluon_dot_kernel( warps_per_cta=[1, 4], ) qk_lhs_operand_layout: gl.constexpr = gl.DotOperandLayout( - operand_index=0, parent=qk_mfma_layout, k_width=16 + operand_index=0, parent=qk_mfma_layout, k_width=DOT_QK_K_WIDTH ) qk_rhs_operand_layout: gl.constexpr = gl.DotOperandLayout( - operand_index=1, parent=qk_mfma_layout, k_width=16 + operand_index=1, parent=qk_mfma_layout, k_width=DOT_QK_K_WIDTH ) # Register allocation configuration based on group size and compute block size @@ -2011,15 +2018,29 @@ def paged_attention_decode_v2_gluon_dot_kernel( # Value cache layout configuration based on transpose flag if VALUE_TRANSPOSED: # Transposed value layout for better memory access patterns - blocked_value_layout: gl.constexpr = gl.BlockedLayout( - size_per_thread=[1, 1, 1, CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD], - threads_per_warp=[4, 1, 16, 1], + value_threads_per_warp: gl.constexpr = ( + [4, 1, 16, 1] if KV_BLOCK_SIZE == 16 else [1, 4, 16, 1] + ) + blocked_value_layout_f16: gl.constexpr = gl.BlockedLayout( + size_per_thread=[1, 1, 1, 8], + threads_per_warp=value_threads_per_warp, + warps_per_cta=[1, 1, 4, 1], + order=[3, 2, 1, 0], + ) + blocked_value_layout_fp8: gl.constexpr = gl.BlockedLayout( + size_per_thread=[1, 1, 1, 16], + threads_per_warp=value_threads_per_warp, warps_per_cta=[1, 1, 4, 1], order=[3, 2, 1, 0], ) + blocked_value_layout: gl.constexpr = ( + blocked_value_layout_fp8 + if KV_16B_ELEMENT_COUNT == 16 + else blocked_value_layout_f16 + ) value_dim1_offsets = gl.arange( 0, - KV_BLOCK_SIZE // CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD, + KV_BLOCK_SIZE // KV_16B_ELEMENT_COUNT, layout=gl.SliceLayout( 0, gl.SliceLayout(2, gl.SliceLayout(3, blocked_value_layout)) ), @@ -2033,26 +2054,23 @@ def paged_attention_decode_v2_gluon_dot_kernel( ) value_dim3_offsets = gl.arange( 0, - CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD, + KV_16B_ELEMENT_COUNT, layout=gl.SliceLayout( 0, gl.SliceLayout(1, gl.SliceLayout(2, blocked_value_layout)) ), ) else: # Standard value layout + value_threads_per_warp: gl.constexpr = ( + [4, 16, 1] if KV_BLOCK_SIZE == 16 else [1, 16, 4] + ) blocked_value_layout: gl.constexpr = gl.BlockedLayout( - size_per_thread=[1, 1, CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD], - threads_per_warp=[4, 16, 1], + size_per_thread=[1, 1, 16], + threads_per_warp=value_threads_per_warp, warps_per_cta=[1, 4, 1], order=[2, 1, 0], ) - # blocked_value_layout: gl.constexpr = gl.DistributedLinearLayout( - # reg_bases=((0,0,1), (0,0,2), (0,0,4), (0,0,8), (4,0,0), (8,0,0), (0,64,0)), - # lane_bases=((0,1,0), (0,2,0), (0,4,0), (0,8,0), (1,0,0), (2,0,0)), - # warp_bases=((0,16,0), (0,32,0)), - # block_bases=[], - # shape=[16, 128, 16], - # ) + value_dim1_offsets = gl.arange( 0, HEAD_SIZE_POW2, @@ -2108,7 +2126,7 @@ def paged_attention_decode_v2_gluon_dot_kernel( ) block_element_offsets = gl.arange(0, KV_BLOCK_SIZE, layout=block_element_layout) contiguous_kv_element_offsets = gl.arange( - 0, CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD, layout=contiguous_kv_elements_layout + 0, KV_16B_ELEMENT_COUNT, layout=contiguous_kv_elements_layout ) qk_row_offsets = gl.arange( 0, QUERY_GROUP_SIZE_POW2, layout=gl.SliceLayout(1, qk_linear_layout) @@ -2240,8 +2258,7 @@ def paged_attention_decode_v2_gluon_dot_kernel( kv_block_numbers[:, None, None, None] * stride_key_block + kv_head_idx * stride_key_head + head_size_split_offsets[None, :, None, None] * stride_key_head_split - + block_element_offsets[None, None, :, None] - * CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD + + block_element_offsets[None, None, :, None] * KV_16B_ELEMENT_COUNT + contiguous_kv_element_offsets[None, None, None, :] ) key_tensor = gl.load(key_cache_ptr + key_block_offsets) @@ -2272,6 +2289,39 @@ def paged_attention_decode_v2_gluon_dot_kernel( key_tensor = gl.permute(key_tensor, [1, 3, 0, 2]) key_tensor = gl.reshape(key_tensor, [HEAD_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE]) + # ==================== ATTENTION SCORE COMPUTATION ==================== + # Initialize QK accumulator + qk_accumulator = gl.zeros( + (QUERY_GROUP_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE), + dtype=gl.float32, + layout=qk_mfma_layout, + ) + + # if sequence_idx == 0 \ + # and kv_head_idx == 0 \ + # and sequence_partition_idx == 0: + # print("query_tensor=", query_tensor.to(tl.float32)) + # print("key_tensor=", key_tensor.to(tl.float32)) + # if QUERY_QUANT_MODE == 0 and KV_QUANT_MODE == 0: + # print("QKV_per_tensor") + # else: + # print("QKV_per_token") + + # Convert layouts for MFMA operation + query_converted = query_shared.load(qk_lhs_operand_layout) + key_converted = gl.convert_layout(key_tensor, layout=qk_rhs_operand_layout) + + query_converted = query_converted.to(COMPUTE_TYPE) + key_converted = key_converted.to(COMPUTE_TYPE) + + # Compute QK attention scores using MFMA + attention_scores = gl.amd.cdna3.mfma( + query_converted, key_converted, qk_accumulator + ) + attention_scores = gl.reshape( + attention_scores, [QUERY_GROUP_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE] + ) + # ==================== VALUE LOADING AND PROCESSING ==================== if VALUE_TRANSPOSED: # Load values from transposed cache layout @@ -2285,8 +2335,7 @@ def paged_attention_decode_v2_gluon_dot_kernel( kv_block_numbers_reshaped[:, None, None, None] * stride_value_block + kv_head_idx * stride_value_head + value_dim1_offsets[None, :, None, None] * stride_value_head_size - + value_dim2_offsets[None, None, :, None] - * CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD + + value_dim2_offsets[None, None, :, None] * KV_16B_ELEMENT_COUNT + value_dim3_offsets[None, None, None, :] ) value_tensor = gl.load(value_cache_ptr + value_block_offsets) @@ -2314,29 +2363,6 @@ def paged_attention_decode_v2_gluon_dot_kernel( value_tensor, [KV_COMPUTE_BLOCK_SIZE, HEAD_SIZE_POW2] ) - # ==================== ATTENTION SCORE COMPUTATION ==================== - # Initialize QK accumulator - qk_accumulator = gl.zeros( - (QUERY_GROUP_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE), - dtype=gl.float32, - layout=qk_mfma_layout, - ) - - # Convert layouts for MFMA operation - query_converted = query_shared.load(qk_lhs_operand_layout) - key_converted = gl.convert_layout(key_tensor, layout=qk_rhs_operand_layout) - - query_converted = query_converted.to(COMPUTE_TYPE) - key_converted = key_converted.to(COMPUTE_TYPE) - - # Compute QK attention scores using MFMA - attention_scores = gl.amd.cdna3.mfma( - query_converted, key_converted, qk_accumulator - ) - attention_scores = gl.reshape( - attention_scores, [QUERY_GROUP_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE] - ) - # Apply quantization scaling to attention scores if KV_QUANT_MODE >= 0: if KV_QUANT_MODE == 1: @@ -2524,8 +2550,6 @@ def paged_attention_decode_v2_reduce_kernel( Various stride parameters for tensor access Compile-time constants for kernel configuration (no MAX_CONTEXT_PARTITION_NUM needed) """ - # Mathematical constant for exponential calculations - LOG2_E: tl.constexpr = 1.4426950408889634 MAX_CONTEXT_PARTITION_NUM: tl.constexpr = 16 # ==================== INITIALIZATION ==================== @@ -2749,10 +2773,9 @@ def _paged_attention_decode_v2_with_dot_kernel_reshape_wrapper( parameters for Triton compilation and execution. """ HEAD_SIZE_POW2 = triton.next_power_of_2(HEAD_SIZE) - # Production path - select and launch appropriate kernel + waves_per_eu = 1 QUERY_GROUP_SIZE = QUERY_SEQ_LEN * QUERY_GROUP_SIZE_ORIGINAL KV_COMPUTE_BLOCK_SIZE = CONTEXT_PARTITION_SIZE - waves_per_eu = 2 if QUERY_GROUP_SIZE < 16: QUERY_GROUP_SIZE_POW2 = 16 else: diff --git a/csrc/cpp_itfs/gluon_aot_tools/compile.py b/csrc/cpp_itfs/gluon_aot_tools/compile.py new file mode 100644 index 0000000000..e6ed6851eb --- /dev/null +++ b/csrc/cpp_itfs/gluon_aot_tools/compile.py @@ -0,0 +1,270 @@ +import binascii +import hashlib +import importlib.util +import sys +from argparse import ArgumentParser +from dataclasses import dataclass +from pathlib import Path +from typing import List + +import triton +import triton.backends + + +@dataclass +class CompileArgs: + """ + A class to contain arguments from command-line parser. + """ + + path: str = "" + kernel_name: str = "" + signature: str = "" + grid: str = "" + target: str | None = None + num_warps: int = 1 + num_stages: int = 3 + out_name: str | None = None + out_path: Path | None = None + + +desc = """ +Triton ahead-of-time compiler: + +This program compiles the kernel with name `kernel-name` in the file at the +provided `path` into self-contained C source-code that embeds the `cubin` +data along with utilities to load, unload and launch the kernel. + +signature is provided as a list of (optionally divisibility-hinted) types +or constexpr values, e.g. + +`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` + +will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. +Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, +and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype. + +The resulting entry point will have signature + +CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2) + +Different such specialized entry points can be combined using the `linker.py` script. + +NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter +used to run this `compile.py` script +""" + + +def main(): + # command-line arguments + parser = ArgumentParser(description=desc) + parser.add_argument( + "path", + help="Path to Python source containing desired kernel in its scope. File will be executed.", + ) + parser.add_argument( + "--kernel-name", + "-n", + type=str, + default="", + help="Name of the kernel to compile", + required=True, + ) + parser.add_argument( + "--target", + "-t", + type=str, + default=None, + help="The target to compile towards, in format of '::'; " + "e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target", + ) + parser.add_argument( + "--num-warps", + "-w", + type=int, + default=1, + help="Number of warps to launch the kernel", + ) + parser.add_argument( + "--num-stages", + "-ns", + type=int, + default=3, + help="Number of stages (meta-parameter of the kernel)", + ) + parser.add_argument( + "--out-name", + "-on", + type=str, + default=None, + help="Out name for the compiled kernel", + ) + parser.add_argument( + "--out-path", "-o", type=Path, default=None, help="Out filename" + ) + parser.add_argument( + "--signature", "-s", type=str, help="Signature of the kernel", required=True + ) + parser.add_argument( + "--grid", "-g", type=str, help="Launch grid of the kernel", required=True + ) + cli_args = parser.parse_args() + args = CompileArgs( + **vars(cli_args) + ) # A sanity check to ensure class CompileArgs is updated as well. + compile_kernel(args) + + +def compile_kernel(args: CompileArgs): + out_name = args.out_name if args.out_name else args.kernel_name + out_path = args.out_path if args.out_path else Path(out_name) + + # execute python sources and extract functions wrapped in JITFunction + arg_path = Path(args.path) + sys.path.insert(0, str(arg_path.parent)) + spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + kernel = getattr(mod, args.kernel_name) + grid = args.grid.split(",") + assert len(grid) == 3 + + # validate and parse signature + signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) + + def hash_signature(signature: List[str]): + m = hashlib.sha256() + m.update(" ".join(signature).encode()) + return m.hexdigest()[:8] + + meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" + sig_hash = hash_signature(signature + [meta_sig]) + + def constexpr(s): + try: + ret = int(s) + return ret + except ValueError: + pass + try: + ret = float(s) + return ret + except ValueError: + pass + return None + + hints = { + (i,): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s + } + hints = {k: v for k, v in hints.items() if v is not None} + constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} + for key, value in hints.items(): + if value == 1: + constants[kernel.arg_names[key[0]]] = value + signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)} + for key in constants: + signature[key] = "constexpr" + const_sig = "x".join([str(v) for v in constants.values()]) + doc_string = [f"{k}={v}" for k, v in constants.items()] + doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] + # compile ast into cubin + for h in hints.values(): + assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" + attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16} + src = triton.compiler.ASTSource( + fn=kernel, constexprs=constants, signature=signature, attrs=attrs + ) + + target = ( + triton.backends.compiler.GPUTarget(*args.target.split(":")) + if args.target + else triton.runtime.driver.active.get_current_target() + ) + backend = triton.compiler.make_backend(target) + kwargs = {"num_warps": args.num_warps, "num_stages": args.num_stages} + options = backend.parse_options(kwargs) + ccinfo = triton.compile(src, target=target, options=options.__dict__) + + if getattr(ccinfo.metadata, "global_scratch_size", 0) > 0: + raise RuntimeError( + "AOT compiling kernels with global scratch requirements is not yet implemented" + ) + if ccinfo.metadata.profile_scratch_size > 0: + raise RuntimeError( + "AOT compiling kernels with profile scratch requirements is not yet implemented" + ) + + arg_names = [] + arg_types = [] + arg_names_not_1 = [] + arg_types_not_1 = [] + for i, arg_name in enumerate(kernel.arg_names): + if arg_name not in constants: + arg_names.append(arg_name) + arg_types.append(signature[arg_name]) + arg_names_not_1.append(arg_name) + arg_types_not_1.append(signature[arg_name]) + elif hints.get((i,), None) == 1: + arg_names.append(arg_name) + arg_types.append("i32") + + # dump C stub code + suffix = "" + for i, ty in enumerate(signature.values()): + suffix += str(i) + if hints.get((i,), None) == 1: + suffix += "c" + if hints.get((i,), None) == 16: + suffix += "d" + func_name = "_".join([out_name, sig_hash, suffix]) + asm = ccinfo.asm[backend.binary_ext] # store binary data once + + hex_ = str(binascii.hexlify(asm))[2:-1] + + ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type + + params = { + "kernel_name": func_name, + "triton_kernel_name": args.kernel_name, + "bin_size": len(asm), + "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), + "signature": ", ".join( + [ + f"{ty_to_cpp(ty)} {name}" + for name, ty in zip(arg_names_not_1, arg_types_not_1) + ] + ), + "full_signature": ", ".join( + [f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)] + ), + "arg_pointers": ", ".join( + [f"&{arg}" for arg in arg_names_not_1] + + ["&global_scratch"] + + ["&profile_scratch"] + ), + "num_args": len(arg_names_not_1) + 2, # +2 for global and profile scratch + "kernel_docstring": doc_string, + "shared": ccinfo.metadata.shared, + "num_warps": args.num_warps, + "algo_info": "_".join([const_sig, meta_sig]), + "gridX": grid[0], + "gridY": grid[1], + "gridZ": grid[2], + "_placeholder": "", + } + output_files = [] + backend_name = target.backend + template_dir = Path(__file__).parent / "extra" / backend_name + for template_path in template_dir.glob("compile.*"): + ext = template_path.suffix + output_file = out_path.with_suffix(f".{sig_hash}_{suffix}{ext}") + with output_file.open("w") as fp: + fp.write(template_path.read_text().format(**params)) + output_files.append(output_file) + + return func_name, output_files + + +if __name__ == "__main__": + main() diff --git a/csrc/cpp_itfs/pa_gluon_aot/pa_decode_gluon_aot.py b/csrc/cpp_itfs/pa_gluon_aot/pa_decode_gluon_aot.py index 20bd6c3942..955745c7ce 100644 --- a/csrc/cpp_itfs/pa_gluon_aot/pa_decode_gluon_aot.py +++ b/csrc/cpp_itfs/pa_gluon_aot/pa_decode_gluon_aot.py @@ -1,50 +1,50 @@ import os -import time import shutil import subprocess +import time from pathlib import Path -from jinja2 import Template -import torch + import aiter import aiter.ops.triton.utils._triton.arch_info as arch_info +import torch import triton import triton.language as tl +from jinja2 import Template -GLUON_AOT_COMPILE_ENABLED = True -try: - from triton.experimental import gluon - from triton.experimental.gluon import language as gl -except ImportError: - print( - "Warning: triton.experimental.gluon or triton.experimental.gluon.language not exists, pa_decode_gluon_aot cannot use compile mode!" - ) - GLUON_AOT_COMPILE_ENABLED = False - -try: - from triton.tools.compile import compile_kernel, CompileArgs -except ImportError: - print("Warning: compile_kernel or CompileArgs is not in triton.tools.compile!") - +from aiter.ops.triton.gluon.pa_decode_gluon import get_cdna_version +from csrc.cpp_itfs.gluon_aot_tools.compile import ( + CompileArgs, + compile_kernel, +) from csrc.cpp_itfs.gluon_aot_tools.compile_gluon import ( - compile_gluon_kernel, CompileGluonArgs, + compile_gluon_kernel, +) +from csrc.cpp_itfs.pa_gluon_aot.transpose_query_output_gluon_aot import ( + transpose_output_gluon_aot, + transpose_query_gluon_aot, ) from csrc.cpp_itfs.torch_utils import torch_to_c_types from csrc.cpp_itfs.utils import ( - BUILD_DIR, AITER_CORE_DIR, - get_default_func_name, + BUILD_DIR, compile_template_op, + get_default_func_name, + logger, mp_lock, not_built, run_lib, - logger, -) -from csrc.cpp_itfs.pa_gluon_aot.transpose_query_output_gluon_aot import ( - transpose_query_gluon_aot, - transpose_output_gluon_aot, ) -from aiter.ops.triton.gluon.pa_decode_gluon import get_cdna_version + +GLUON_AOT_COMPILE_ENABLED = True +try: + from triton.experimental import gluon # noqa: F401 + from triton.experimental.gluon import language as gl # noqa: F401 +except ImportError: + print( + "Warning: triton.experimental.gluon or triton.experimental.gluon.language not exists, pa_decode_gluon_aot cannot use compile mode!" + ) + GLUON_AOT_COMPILE_ENABLED = False MD_NAME = "pa_decode_attention_reduce_kernel" @@ -150,8 +150,8 @@ def compile( "This version triton is not support gluon aot compile, please upgrade to 3.5.0 or higher!" ) - kv_compute_block_size = 256 waves_per_eu = 1 + kv_compute_block_size = context_partition_size # Select kernel implementation based on block size if kv_block_size > context_partition_size: # Use big block kernel for large block sizes