Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 83 additions & 60 deletions aiter/ops/triton/gluon/pa_decode_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,6 @@ def paged_attention_decode_v2_gluon_large_block_dot_kernel(
warps_per_cta=[4, 1],
order=[1, 0],
)
shared_query_layout: gl.constexpr = gl.SwizzledSharedLayout(8, 1, 16, order=[1, 0])

# Key cache layout - optimized for CDNA3 architecture
blocked_key_layout: gl.constexpr = gl.BlockedLayout(
Expand Down Expand Up @@ -798,9 +797,6 @@ def paged_attention_decode_v2_gluon_large_block_dot_kernel(
query_tensor = gl.amd.cdna3.buffer_load(
ptr=query_ptr, offsets=query_offsets_base, mask=query_mask
)
query_shared = gl.allocate_shared_memory(
query_tensor.dtype, query_tensor.shape, shared_query_layout, query_tensor
)

# ==================== Query Quantization Scale Handling ====================
if QUERY_QUANT_MODE == 0:
Expand Down Expand Up @@ -969,7 +965,6 @@ def paged_attention_decode_v2_gluon_large_block_dot_kernel(

# Convert layouts for MFMA operation
query_converted = gl.convert_layout(query_tensor, layout=qk_lhs_layout)
# query_converted = query_shared.load(qk_lhs_layout)
key_converted = gl.convert_layout(key_block, layout=qk_rhs_layout)
query_converted = query_converted.to(COMPUTE_TYPE)
key_converted = key_converted.to(COMPUTE_TYPE)
Expand Down Expand Up @@ -1936,11 +1931,8 @@ def paged_attention_decode_v2_gluon_dot_kernel(
else:
OUTPUT_DTYPE: gl.constexpr = COMPUTE_TYPE
LOG2_E: gl.constexpr = 1.4426950408889634 # log2(e) for exponential conversion
CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD: gl.constexpr = KV_16B_ELEMENT_COUNT

K_HEAD_SIZE_SPLITS: gl.constexpr = (
HEAD_SIZE_POW2 // CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD
)
K_HEAD_SIZE_SPLITS: gl.constexpr = HEAD_SIZE_POW2 // KV_16B_ELEMENT_COUNT
MAX_NUM_KV_BLOCKS_PER_COMPUTE: gl.constexpr = KV_COMPUTE_BLOCK_SIZE // KV_BLOCK_SIZE

# ==================== MEMORY LAYOUT DEFINITIONS ====================
Expand All @@ -1951,16 +1943,31 @@ def paged_attention_decode_v2_gluon_dot_kernel(
warps_per_cta=[4, 1],
order=[1, 0],
)
shared_query_layout: gl.constexpr = gl.SwizzledSharedLayout(8, 1, 16, order=[1, 0])
shared_query_layout: gl.constexpr = gl.SwizzledSharedLayout(
KV_16B_ELEMENT_COUNT, 1, 16, order=[1, 0]
)

# Key cache layout - optimized for block-wise access patterns
blocked_key_layout: gl.constexpr = gl.BlockedLayout(
size_per_thread=[1, 1, 1, CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD],
blocked_key_layout_fp8: gl.constexpr = gl.BlockedLayout(
size_per_thread=[1, 1, 1, KV_16B_ELEMENT_COUNT],
threads_per_warp=[1, 4, 16, 1],
warps_per_cta=[4, 1, 1, 1],
order=[3, 2, 1, 0],
)
key_warps_per_cta_f16: gl.constexpr = (
[4, 1, 1, 1] if KV_BLOCK_SIZE == 16 else [1, 1, 4, 1]
)
blocked_key_layout_f16: gl.constexpr = gl.BlockedLayout(
size_per_thread=[1, 1, 1, KV_16B_ELEMENT_COUNT],
threads_per_warp=[1, 4, 16, 1],
warps_per_cta=key_warps_per_cta_f16,
order=[3, 2, 1, 0],
)
blocked_key_layout: gl.constexpr = (
blocked_key_layout_fp8 if KV_16B_ELEMENT_COUNT == 16 else blocked_key_layout_f16
)

DOT_QK_K_WIDTH: gl.constexpr = KV_16B_ELEMENT_COUNT
# QK Matrix multiplication layout using AMD MFMA instructions
qk_mfma_layout: gl.constexpr = gl.amd.AMDMFMALayout(
version=CDNA_VERSION,
Expand All @@ -1969,10 +1976,10 @@ def paged_attention_decode_v2_gluon_dot_kernel(
warps_per_cta=[1, 4],
)
qk_lhs_operand_layout: gl.constexpr = gl.DotOperandLayout(
operand_index=0, parent=qk_mfma_layout, k_width=16
operand_index=0, parent=qk_mfma_layout, k_width=DOT_QK_K_WIDTH
)
qk_rhs_operand_layout: gl.constexpr = gl.DotOperandLayout(
operand_index=1, parent=qk_mfma_layout, k_width=16
operand_index=1, parent=qk_mfma_layout, k_width=DOT_QK_K_WIDTH
)

# Register allocation configuration based on group size and compute block size
Expand Down Expand Up @@ -2011,15 +2018,29 @@ def paged_attention_decode_v2_gluon_dot_kernel(
# Value cache layout configuration based on transpose flag
if VALUE_TRANSPOSED:
# Transposed value layout for better memory access patterns
blocked_value_layout: gl.constexpr = gl.BlockedLayout(
size_per_thread=[1, 1, 1, CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD],
threads_per_warp=[4, 1, 16, 1],
value_threads_per_warp: gl.constexpr = (
[4, 1, 16, 1] if KV_BLOCK_SIZE == 16 else [1, 4, 16, 1]
)
blocked_value_layout_f16: gl.constexpr = gl.BlockedLayout(
size_per_thread=[1, 1, 1, 8],
threads_per_warp=value_threads_per_warp,
warps_per_cta=[1, 1, 4, 1],
order=[3, 2, 1, 0],
)
blocked_value_layout_fp8: gl.constexpr = gl.BlockedLayout(
size_per_thread=[1, 1, 1, 16],
threads_per_warp=value_threads_per_warp,
warps_per_cta=[1, 1, 4, 1],
order=[3, 2, 1, 0],
)
blocked_value_layout: gl.constexpr = (
blocked_value_layout_fp8
if KV_16B_ELEMENT_COUNT == 16
else blocked_value_layout_f16
)
value_dim1_offsets = gl.arange(
0,
KV_BLOCK_SIZE // CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD,
KV_BLOCK_SIZE // KV_16B_ELEMENT_COUNT,
layout=gl.SliceLayout(
0, gl.SliceLayout(2, gl.SliceLayout(3, blocked_value_layout))
),
Expand All @@ -2033,26 +2054,23 @@ def paged_attention_decode_v2_gluon_dot_kernel(
)
value_dim3_offsets = gl.arange(
0,
CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD,
KV_16B_ELEMENT_COUNT,
layout=gl.SliceLayout(
0, gl.SliceLayout(1, gl.SliceLayout(2, blocked_value_layout))
),
)
else:
# Standard value layout
value_threads_per_warp: gl.constexpr = (
[4, 16, 1] if KV_BLOCK_SIZE == 16 else [1, 16, 4]
)
blocked_value_layout: gl.constexpr = gl.BlockedLayout(
size_per_thread=[1, 1, CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD],
threads_per_warp=[4, 16, 1],
size_per_thread=[1, 1, 16],
threads_per_warp=value_threads_per_warp,
warps_per_cta=[1, 4, 1],
order=[2, 1, 0],
)
# blocked_value_layout: gl.constexpr = gl.DistributedLinearLayout(
# reg_bases=((0,0,1), (0,0,2), (0,0,4), (0,0,8), (4,0,0), (8,0,0), (0,64,0)),
# lane_bases=((0,1,0), (0,2,0), (0,4,0), (0,8,0), (1,0,0), (2,0,0)),
# warp_bases=((0,16,0), (0,32,0)),
# block_bases=[],
# shape=[16, 128, 16],
# )

value_dim1_offsets = gl.arange(
0,
HEAD_SIZE_POW2,
Expand Down Expand Up @@ -2108,7 +2126,7 @@ def paged_attention_decode_v2_gluon_dot_kernel(
)
block_element_offsets = gl.arange(0, KV_BLOCK_SIZE, layout=block_element_layout)
contiguous_kv_element_offsets = gl.arange(
0, CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD, layout=contiguous_kv_elements_layout
0, KV_16B_ELEMENT_COUNT, layout=contiguous_kv_elements_layout
)
qk_row_offsets = gl.arange(
0, QUERY_GROUP_SIZE_POW2, layout=gl.SliceLayout(1, qk_linear_layout)
Expand Down Expand Up @@ -2240,8 +2258,7 @@ def paged_attention_decode_v2_gluon_dot_kernel(
kv_block_numbers[:, None, None, None] * stride_key_block
+ kv_head_idx * stride_key_head
+ head_size_split_offsets[None, :, None, None] * stride_key_head_split
+ block_element_offsets[None, None, :, None]
* CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD
+ block_element_offsets[None, None, :, None] * KV_16B_ELEMENT_COUNT
+ contiguous_kv_element_offsets[None, None, None, :]
)
key_tensor = gl.load(key_cache_ptr + key_block_offsets)
Expand Down Expand Up @@ -2272,6 +2289,39 @@ def paged_attention_decode_v2_gluon_dot_kernel(
key_tensor = gl.permute(key_tensor, [1, 3, 0, 2])
key_tensor = gl.reshape(key_tensor, [HEAD_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE])

# ==================== ATTENTION SCORE COMPUTATION ====================
# Initialize QK accumulator
qk_accumulator = gl.zeros(
(QUERY_GROUP_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE),
dtype=gl.float32,
layout=qk_mfma_layout,
)

# if sequence_idx == 0 \
# and kv_head_idx == 0 \
# and sequence_partition_idx == 0:
# print("query_tensor=", query_tensor.to(tl.float32))
# print("key_tensor=", key_tensor.to(tl.float32))
# if QUERY_QUANT_MODE == 0 and KV_QUANT_MODE == 0:
# print("QKV_per_tensor")
# else:
# print("QKV_per_token")

# Convert layouts for MFMA operation
query_converted = query_shared.load(qk_lhs_operand_layout)
key_converted = gl.convert_layout(key_tensor, layout=qk_rhs_operand_layout)

query_converted = query_converted.to(COMPUTE_TYPE)
key_converted = key_converted.to(COMPUTE_TYPE)

# Compute QK attention scores using MFMA
attention_scores = gl.amd.cdna3.mfma(
query_converted, key_converted, qk_accumulator
)
attention_scores = gl.reshape(
attention_scores, [QUERY_GROUP_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE]
)

# ==================== VALUE LOADING AND PROCESSING ====================
if VALUE_TRANSPOSED:
# Load values from transposed cache layout
Expand All @@ -2285,8 +2335,7 @@ def paged_attention_decode_v2_gluon_dot_kernel(
kv_block_numbers_reshaped[:, None, None, None] * stride_value_block
+ kv_head_idx * stride_value_head
+ value_dim1_offsets[None, :, None, None] * stride_value_head_size
+ value_dim2_offsets[None, None, :, None]
* CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD
+ value_dim2_offsets[None, None, :, None] * KV_16B_ELEMENT_COUNT
+ value_dim3_offsets[None, None, None, :]
)
value_tensor = gl.load(value_cache_ptr + value_block_offsets)
Expand Down Expand Up @@ -2314,29 +2363,6 @@ def paged_attention_decode_v2_gluon_dot_kernel(
value_tensor, [KV_COMPUTE_BLOCK_SIZE, HEAD_SIZE_POW2]
)

# ==================== ATTENTION SCORE COMPUTATION ====================
# Initialize QK accumulator
qk_accumulator = gl.zeros(
(QUERY_GROUP_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE),
dtype=gl.float32,
layout=qk_mfma_layout,
)

# Convert layouts for MFMA operation
query_converted = query_shared.load(qk_lhs_operand_layout)
key_converted = gl.convert_layout(key_tensor, layout=qk_rhs_operand_layout)

query_converted = query_converted.to(COMPUTE_TYPE)
key_converted = key_converted.to(COMPUTE_TYPE)

# Compute QK attention scores using MFMA
attention_scores = gl.amd.cdna3.mfma(
query_converted, key_converted, qk_accumulator
)
attention_scores = gl.reshape(
attention_scores, [QUERY_GROUP_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE]
)

# Apply quantization scaling to attention scores
if KV_QUANT_MODE >= 0:
if KV_QUANT_MODE == 1:
Expand Down Expand Up @@ -2524,8 +2550,6 @@ def paged_attention_decode_v2_reduce_kernel(
Various stride parameters for tensor access
Compile-time constants for kernel configuration (no MAX_CONTEXT_PARTITION_NUM needed)
"""
# Mathematical constant for exponential calculations
LOG2_E: tl.constexpr = 1.4426950408889634
MAX_CONTEXT_PARTITION_NUM: tl.constexpr = 16

# ==================== INITIALIZATION ====================
Expand Down Expand Up @@ -2749,10 +2773,9 @@ def _paged_attention_decode_v2_with_dot_kernel_reshape_wrapper(
parameters for Triton compilation and execution.
"""
HEAD_SIZE_POW2 = triton.next_power_of_2(HEAD_SIZE)
# Production path - select and launch appropriate kernel
waves_per_eu = 1
QUERY_GROUP_SIZE = QUERY_SEQ_LEN * QUERY_GROUP_SIZE_ORIGINAL
KV_COMPUTE_BLOCK_SIZE = CONTEXT_PARTITION_SIZE
waves_per_eu = 2
if QUERY_GROUP_SIZE < 16:
QUERY_GROUP_SIZE_POW2 = 16
else:
Expand Down
Loading