diff --git a/aiter/ops/triton/gluon/pa_decode_gluon.py b/aiter/ops/triton/gluon/pa_decode_gluon.py index 37e7ddd268..9c07d4c4ac 100644 --- a/aiter/ops/triton/gluon/pa_decode_gluon.py +++ b/aiter/ops/triton/gluon/pa_decode_gluon.py @@ -1463,6 +1463,7 @@ def paged_attention_decode_sliding_window( * stride_output_head + output_head_size_offsets[None, :] ) + max_logits = gl.full( (QUERY_GROUP_SIZE_POW2,), float("-inf"), @@ -1481,12 +1482,15 @@ def paged_attention_decode_sliding_window( # ==================== SEQUENCE PROCESSING ==================== query_converted = query_shared.load(qk_lhs_operand_layout) - # query_converted = gl.convert_layout(query_tensor, layout=qk_lhs_operand_layout) - sequence_partition_start_idx = ( - context_length - SLIDING_WINDOW - ) // CONTEXT_PARTITION_SIZE + + if SLIDING_WINDOW > 0: + sequence_partition_start_idx = ( + context_length - SLIDING_WINDOW + ) // CONTEXT_PARTITION_SIZE + else: + sequence_partition_start_idx = 0 sequence_partition_end_idx = gl.cdiv(context_length, CONTEXT_PARTITION_SIZE) - # num_iterations = sequence_partition_end_idx - sequence_partition_start_idx + if QUERY_QUANT_MODE < 0 and COMPUTE_TYPE.is_fp8(): # Quantize bf16 query to fp8 # Convert query to float32 for computation @@ -1524,11 +1528,11 @@ def paged_attention_decode_sliding_window( ) # Create mask for valid blocks valid_block_mask = block_indices < num_kv_blocks - # masked_block_indices = gl.where(valid_block_mask, block_indices, 0) + masked_block_indices = gl.where(valid_block_mask, block_indices, 0) block_table_start_ptr = block_tables_ptr + sequence_idx * stride_block_table_seq kv_block_numbers = gl.amd.cdna3.buffer_load( - ptr=block_table_start_ptr + kv_block_start_idx, offsets=block_indices - ).to(gl.uint32) + ptr=block_table_start_ptr + kv_block_start_idx, offsets=masked_block_indices + ).to(gl.int64) # ==================== KEY LOADING AND PROCESSING ==================== # Calculate key cache offsets and load keys @@ -1540,20 +1544,15 @@ def paged_attention_decode_sliding_window( * CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD + contiguous_kv_element_offsets[None, None, None, :] ) - # Optimize: Start key load, then prepare QK MFMA accumulators/query (overlaps with key load) - key_tensor = gl.amd.cdna3.buffer_load( - ptr=key_cache_ptr, - offsets=key_block_offsets, - mask=valid_block_mask[:, None, None, None], - ) + # Optimize: Start key load, then prepare QK MFMA accumulators/query (overlaps with key load) + key_tensor = gl.load(key_cache_ptr + key_block_offsets) # Prepare QK MFMA while key loads (these don't depend on key data) qk_accumulator = gl.zeros( (QUERY_GROUP_SIZE_POW2, CONTEXT_PARTITION_SIZE), dtype=gl.float32, layout=qk_mfma_layout, ) - # Load key quantization scales if needed (overlaps with key tensor load) if KV_QUANT_MODE >= 0: if KV_QUANT_MODE == 0: @@ -1622,11 +1621,7 @@ def paged_attention_decode_sliding_window( * CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD + value_dim3_offsets[None, None, None, :] ) - value_tensor = gl.amd.cdna3.buffer_load( - ptr=value_cache_ptr, - offsets=value_block_offsets, - mask=valid_block_mask[:, None, None, None], - ) + value_tensor = gl.load(value_cache_ptr + value_block_offsets) # Compute QK attention scores using MFMA (overlaps with value load) attention_scores = gl.amd.cdna3.mfma( query_converted, key_converted, qk_accumulator @@ -1655,11 +1650,7 @@ def paged_attention_decode_sliding_window( ) # Schedule: Start value VMEM load, then QK MFMA - value_tensor = gl.amd.cdna3.buffer_load( - ptr=value_cache_ptr, - offsets=value_block_offsets, - mask=valid_block_mask[:, None, None], - ) + value_tensor = gl.load(value_cache_ptr + value_block_offsets) # Compute QK attention scores using MFMA (overlaps with value load) attention_scores = gl.amd.cdna3.mfma( query_converted, key_converted, qk_accumulator @@ -1790,8 +1781,6 @@ def paged_attention_decode_sliding_window( attention_accumulator += attention_output max_logits = new_max_logits - # ==================== OUTPUT NORMALIZATION AND STORING ==================== - # Normalize attention output by softmax denominator if sinks_ptr is not None: sinks_values = gl.load( sinks_ptr + (kv_head_idx * query_group_size + query_group_offsets), @@ -1800,6 +1789,8 @@ def paged_attention_decode_sliding_window( exp_sums += gl.exp( gl.convert_layout(sinks_values, layout=max_logits.type.layout) - max_logits ) + # ==================== OUTPUT NORMALIZATION AND STORING ==================== + # Normalize attention output by softmax denominator exp_sums_reciprocal = 1.0 / exp_sums exp_sums_reciprocal_cvt = gl.convert_layout( @@ -2549,6 +2540,13 @@ def paged_attention_decode_v2_reduce_kernel( head_size_offsets = tl.arange(0, HEAD_SIZE_POW2) # Initialize global accumulation variables + # if USE_SINKS: + # global_max = tl.load( + # sink_token_ptr + (kv_head_idx * query_group_size + query_group_offsets), + # mask=query_group_offsets < query_group_size, + # other=float("-inf"), + # ).to(tl.float32) + # else: global_max = tl.full((QUERY_GROUP_SIZE_POW2,), float("-inf"), dtype=tl.float32) global_max_prev = global_max global_exp_sum = tl.zeros((QUERY_GROUP_SIZE_POW2,), dtype=tl.float32) @@ -2602,7 +2600,6 @@ def paged_attention_decode_v2_reduce_kernel( mask=query_group_offsets < query_group_size, ) global_exp_sum += gl.exp(sink_token_values - global_max) - # ==================== SECOND PASS: COMPUTE RESCALED EXP SUMS AND ACCUMULATE ==================== for iter_idx in range(num_iterations): partition_base = iter_idx * MAX_CONTEXT_PARTITION_NUM @@ -2972,6 +2969,7 @@ def pa_decode_gluon( alibi_slopes: torch.Tensor = None, sinks: torch.Tensor = None, sliding_window: int = 0, + one_shot=None, ) -> None: """ Paged Attention Decode with FP8/BF16/FP16 Support. @@ -3263,7 +3261,8 @@ def pa_decode_gluon( fp8_max_value = torch.finfo(aiter.dtypes.fp8).max # ==================== ATTENTION DECODE KERNEL EXECUTION ==================== - one_shot = sliding_window > 0 + if one_shot is None: + one_shot = sliding_window > 0 _paged_attention_decode_v2_with_dot_kernel_reshape_wrapper( grid, exp_sums,