Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 28 additions & 29 deletions aiter/ops/triton/gluon/pa_decode_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,7 @@ def paged_attention_decode_sliding_window(
* stride_output_head
+ output_head_size_offsets[None, :]
)

max_logits = gl.full(
(QUERY_GROUP_SIZE_POW2,),
float("-inf"),
Expand All @@ -1481,12 +1482,15 @@ def paged_attention_decode_sliding_window(

# ==================== SEQUENCE PROCESSING ====================
query_converted = query_shared.load(qk_lhs_operand_layout)
# query_converted = gl.convert_layout(query_tensor, layout=qk_lhs_operand_layout)
sequence_partition_start_idx = (
context_length - SLIDING_WINDOW
) // CONTEXT_PARTITION_SIZE

if SLIDING_WINDOW > 0:
sequence_partition_start_idx = (
context_length - SLIDING_WINDOW
) // CONTEXT_PARTITION_SIZE
else:
sequence_partition_start_idx = 0
sequence_partition_end_idx = gl.cdiv(context_length, CONTEXT_PARTITION_SIZE)
# num_iterations = sequence_partition_end_idx - sequence_partition_start_idx

if QUERY_QUANT_MODE < 0 and COMPUTE_TYPE.is_fp8():
# Quantize bf16 query to fp8
# Convert query to float32 for computation
Expand Down Expand Up @@ -1524,11 +1528,11 @@ def paged_attention_decode_sliding_window(
)
# Create mask for valid blocks
valid_block_mask = block_indices < num_kv_blocks
# masked_block_indices = gl.where(valid_block_mask, block_indices, 0)
masked_block_indices = gl.where(valid_block_mask, block_indices, 0)
block_table_start_ptr = block_tables_ptr + sequence_idx * stride_block_table_seq
kv_block_numbers = gl.amd.cdna3.buffer_load(
ptr=block_table_start_ptr + kv_block_start_idx, offsets=block_indices
).to(gl.uint32)
ptr=block_table_start_ptr + kv_block_start_idx, offsets=masked_block_indices
).to(gl.int64)

# ==================== KEY LOADING AND PROCESSING ====================
# Calculate key cache offsets and load keys
Expand All @@ -1540,20 +1544,15 @@ def paged_attention_decode_sliding_window(
* CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD
+ contiguous_kv_element_offsets[None, None, None, :]
)
# Optimize: Start key load, then prepare QK MFMA accumulators/query (overlaps with key load)
key_tensor = gl.amd.cdna3.buffer_load(
ptr=key_cache_ptr,
offsets=key_block_offsets,
mask=valid_block_mask[:, None, None, None],
)

# Optimize: Start key load, then prepare QK MFMA accumulators/query (overlaps with key load)
key_tensor = gl.load(key_cache_ptr + key_block_offsets)
# Prepare QK MFMA while key loads (these don't depend on key data)
qk_accumulator = gl.zeros(
(QUERY_GROUP_SIZE_POW2, CONTEXT_PARTITION_SIZE),
dtype=gl.float32,
layout=qk_mfma_layout,
)

# Load key quantization scales if needed (overlaps with key tensor load)
if KV_QUANT_MODE >= 0:
if KV_QUANT_MODE == 0:
Expand Down Expand Up @@ -1622,11 +1621,7 @@ def paged_attention_decode_sliding_window(
* CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD
+ value_dim3_offsets[None, None, None, :]
)
value_tensor = gl.amd.cdna3.buffer_load(
ptr=value_cache_ptr,
offsets=value_block_offsets,
mask=valid_block_mask[:, None, None, None],
)
value_tensor = gl.load(value_cache_ptr + value_block_offsets)
# Compute QK attention scores using MFMA (overlaps with value load)
attention_scores = gl.amd.cdna3.mfma(
query_converted, key_converted, qk_accumulator
Expand Down Expand Up @@ -1655,11 +1650,7 @@ def paged_attention_decode_sliding_window(
)

# Schedule: Start value VMEM load, then QK MFMA
value_tensor = gl.amd.cdna3.buffer_load(
ptr=value_cache_ptr,
offsets=value_block_offsets,
mask=valid_block_mask[:, None, None],
)
value_tensor = gl.load(value_cache_ptr + value_block_offsets)
# Compute QK attention scores using MFMA (overlaps with value load)
attention_scores = gl.amd.cdna3.mfma(
query_converted, key_converted, qk_accumulator
Expand Down Expand Up @@ -1790,8 +1781,6 @@ def paged_attention_decode_sliding_window(
attention_accumulator += attention_output
max_logits = new_max_logits

# ==================== OUTPUT NORMALIZATION AND STORING ====================
# Normalize attention output by softmax denominator
if sinks_ptr is not None:
sinks_values = gl.load(
sinks_ptr + (kv_head_idx * query_group_size + query_group_offsets),
Expand All @@ -1800,6 +1789,8 @@ def paged_attention_decode_sliding_window(
exp_sums += gl.exp(
gl.convert_layout(sinks_values, layout=max_logits.type.layout) - max_logits
)
# ==================== OUTPUT NORMALIZATION AND STORING ====================
# Normalize attention output by softmax denominator

exp_sums_reciprocal = 1.0 / exp_sums
exp_sums_reciprocal_cvt = gl.convert_layout(
Expand Down Expand Up @@ -2549,6 +2540,13 @@ def paged_attention_decode_v2_reduce_kernel(
head_size_offsets = tl.arange(0, HEAD_SIZE_POW2)

# Initialize global accumulation variables
# if USE_SINKS:
# global_max = tl.load(
# sink_token_ptr + (kv_head_idx * query_group_size + query_group_offsets),
# mask=query_group_offsets < query_group_size,
# other=float("-inf"),
# ).to(tl.float32)
# else:
global_max = tl.full((QUERY_GROUP_SIZE_POW2,), float("-inf"), dtype=tl.float32)
global_max_prev = global_max
global_exp_sum = tl.zeros((QUERY_GROUP_SIZE_POW2,), dtype=tl.float32)
Expand Down Expand Up @@ -2602,7 +2600,6 @@ def paged_attention_decode_v2_reduce_kernel(
mask=query_group_offsets < query_group_size,
)
global_exp_sum += gl.exp(sink_token_values - global_max)

# ==================== SECOND PASS: COMPUTE RESCALED EXP SUMS AND ACCUMULATE ====================
for iter_idx in range(num_iterations):
partition_base = iter_idx * MAX_CONTEXT_PARTITION_NUM
Expand Down Expand Up @@ -2972,6 +2969,7 @@ def pa_decode_gluon(
alibi_slopes: torch.Tensor = None,
sinks: torch.Tensor = None,
sliding_window: int = 0,
one_shot=None,
) -> None:
"""
Paged Attention Decode with FP8/BF16/FP16 Support.
Expand Down Expand Up @@ -3263,7 +3261,8 @@ def pa_decode_gluon(
fp8_max_value = torch.finfo(aiter.dtypes.fp8).max

# ==================== ATTENTION DECODE KERNEL EXECUTION ====================
one_shot = sliding_window > 0
if one_shot is None:
one_shot = sliding_window > 0
_paged_attention_decode_v2_with_dot_kernel_reshape_wrapper(
grid,
exp_sums,
Expand Down
Loading