From 59509e4b05e530441f2c6d0ed612691f4bc75e4f Mon Sep 17 00:00:00 2001 From: Valerie Chen Date: Wed, 12 Nov 2025 12:56:27 -0500 Subject: [PATCH 1/4] leanAttn softmax fix for spurious data mismatch test failures --- .../ops/triton/_triton_kernels/lean_atten.py | 36 +++++++++------- aiter/ops/triton/lean_atten.py | 7 +--- op_tests/op_benchmarks/triton/bench_la.py | 2 - op_tests/triton_tests/test_la.py | 42 ++++++++++--------- 4 files changed, 45 insertions(+), 42 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/lean_atten.py b/aiter/ops/triton/_triton_kernels/lean_atten.py index 0ca0959169..78f1010504 100644 --- a/aiter/ops/triton/_triton_kernels/lean_atten.py +++ b/aiter/ops/triton/_triton_kernels/lean_atten.py @@ -28,7 +28,6 @@ from ..utils.core import AITER_TRITON_CONFIGS_PATH -LOG_TWO_E = 1.44269504 # log_2(e) value for softmax scaling # Support tensor in [B, Seqlen, H, d] format. Taking tensors in [B*Seqlen, H, d] as inputs @@ -106,29 +105,34 @@ def _attention_inner( m_i, l_i, acc, - qk_scale, - causal, q_start_m, b_seq_size, offs_m, offs_n, - BLOCK_M, - BLOCK_N, - HEAD_DIM_ORIG: tl.constexpr, - HEAD_DIM: tl.constexpr, local_iter, local_iter_end, + SM_SCALE: tl.constexpr, + causal: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + HEAD_DIM_ORIG: tl.constexpr, + HEAD_DIM: tl.constexpr, use_64_indexing: tl.constexpr, ): """ Performs attention calculation for an (maybe partial) output tile """ + RCP_LN2: tl.constexpr = 1.4426950408889634 + # Define head-dimension mask for padded dims offs_k_local = tl.arange(0, HEAD_DIM) mask_k_cols_local = offs_k_local < HEAD_DIM_ORIG for l_iter in range(local_iter, local_iter_end): k = tl.load(k_ptrs, mask=mask_k_cols_local[:, None], other=0.0) - qk = tl.dot(q, k) * qk_scale + qk_scale = SM_SCALE * RCP_LN2 + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = qk * qk_scale if causal: # Get the starting column index of the current K block @@ -215,7 +219,6 @@ def la_persistent( Q, K, V, - qk_scale, Mp, Lp, Op, @@ -238,6 +241,7 @@ def la_persistent( stride_oph, # total_programs stride_opm, # n_ctx_q stride_opn, # head_dim + SM_SCALE, HEADS_PER_XCD: tl.constexpr, HEAD_DIM_ORIG: tl.constexpr, HEAD_DIM: tl.constexpr, @@ -311,7 +315,6 @@ def la_persistent( Q, K, V, - qk_scale, Mp, Lp, Op, @@ -338,6 +341,7 @@ def la_persistent( current_pid=current_pid, xcd_pid=xcd_pid, xcd_id=xcd_id, + SM_SCALE=SM_SCALE, HEADS_PER_XCD=HEADS_PER_XCD, HEAD_DIM=HEAD_DIM, HEAD_DIM_ORIG=HEAD_DIM_ORIG, @@ -366,7 +370,6 @@ def la_persistent_inner( Q, K, V, - qk_scale, Mp, Lp, Op, @@ -393,7 +396,8 @@ def la_persistent_inner( current_pid, # SOC pid xcd_pid, # XCD pid xcd_id, # The XCD the pid belongs to - HEADS_PER_XCD, + SM_SCALE, + HEADS_PER_XCD: tl.constexpr, HEAD_DIM: tl.constexpr, HEAD_DIM_ORIG: tl.constexpr, BLOCK_M: tl.constexpr, @@ -618,18 +622,18 @@ def la_persistent_inner( m_i, l_i, acc, - qk_scale, - causal, q_start_m, b_seq_size, offs_m, offs_n, + local_iter, + local_iter_end, + SM_SCALE, + causal, BLOCK_M, BLOCK_N, HEAD_DIM_ORIG=HEAD_DIM_ORIG, HEAD_DIM=HEAD_DIM, - local_iter=local_iter, - local_iter_end=local_iter_end, use_64_indexing=use_64_indexing, ) diff --git a/aiter/ops/triton/lean_atten.py b/aiter/ops/triton/lean_atten.py index aca61ab770..8a7e53f054 100644 --- a/aiter/ops/triton/lean_atten.py +++ b/aiter/ops/triton/lean_atten.py @@ -30,7 +30,6 @@ _LOGGER = AiterTritonLogger() -LOG_TWO_E = 1.44269504 # log_2(e) value for softmax scaling # Support tensor in [B, Seqlen, H, d] format. Taking tensors in [B*Seqlen, H, d] as inputs @@ -80,7 +79,6 @@ def persistent_lean_attention( XCD_REMAP=config["XCD_REMAP"], causal=causal, batch_size=batch_size, - sm_scale=sm_scale, RAGGED_BATCH=RAGGED_BATCH, num_warps=config["num_warps"], waves_per_eu=config["waves_per_eu"], @@ -104,7 +102,6 @@ def _persistent_lean_attention( XCD_REMAP: bool, # xcd_remap for spatial causal: bool, # causal masking batch_size: int, - sm_scale: torch.float16, # typically 1 / sqrt(d) RAGGED_BATCH: bool, num_warps: int, waves_per_eu: int, @@ -144,7 +141,7 @@ def _persistent_lean_attention( GQA_GROUP_SIZE = H // H_K HEADS_PER_XCD = H // NUM_XCDS - qk_scale = sm_scale * LOG_TWO_E + sm_scale = q.shape[-1] ** (-0.5) ( num_m_blocks, @@ -271,7 +268,6 @@ def _persistent_lean_attention( q, k, v, - qk_scale, Mp, Lp, Op, @@ -294,6 +290,7 @@ def _persistent_lean_attention( Op.stride(0), # total_programs Op.stride(1), # n_ctx_q Op.stride(2), # head_dim + sm_scale, HEADS_PER_XCD=HEADS_PER_XCD, HEAD_DIM_ORIG=HEAD_DIM_K, HEAD_DIM=HEAD_DIM_K, diff --git a/op_tests/op_benchmarks/triton/bench_la.py b/op_tests/op_benchmarks/triton/bench_la.py index 457f2c2b5b..9f5e402928 100644 --- a/op_tests/op_benchmarks/triton/bench_la.py +++ b/op_tests/op_benchmarks/triton/bench_la.py @@ -340,7 +340,6 @@ def bench_lean_attention( list_sum_block_n.append(len_sum) batch_num_block_n = torch.tensor(list_sum_block_n, device="cuda", dtype=torch.int32) - sm_scale = 0.5 # Allocate Tensors q = torch.empty((n_ctx_q * batch, hq, d), dtype=init_dtype, device="cuda").normal_( @@ -377,7 +376,6 @@ def bench_lean_attention( XCD_REMAP, causal, batch, - sm_scale, RAGGED_BATCH, num_warps, waves_per_eu, diff --git a/op_tests/triton_tests/test_la.py b/op_tests/triton_tests/test_la.py index 946e2c6e97..aed43af3a2 100644 --- a/op_tests/triton_tests/test_la.py +++ b/op_tests/triton_tests/test_la.py @@ -4,6 +4,7 @@ import sys import pytest import torch +import math from typing import Union, List from aiter.ops.triton.lean_atten import ( _persistent_lean_attention, @@ -67,12 +68,13 @@ def get_lean_attn_inputs( return q, k, v, Mp, Lp, Op, locks, batch_num_block_n -def reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal): +def reference_attention(q, k, v, n_ctx, n_ctx_q, causal): # Calculate Pytorch refence output ref_out = torch.empty_like(q, dtype=q.dtype) start = 0 start_q = 0 + d = q.shape[-1] for b in n_ctx: qb = q[start_q : (start_q + int(n_ctx_q)), :, :] @@ -87,7 +89,7 @@ def reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal): group_size = qb_reshaped.shape[0] // kb_reshaped.shape[0] kb_reshaped = kb_reshaped.repeat_interleave(group_size, dim=0) vb_reshaped = vb_reshaped.repeat_interleave(group_size, dim=0) - p = torch.matmul(qb_reshaped, kb_reshaped.transpose(-2, -1)) * sm_scale + p = torch.matmul(qb_reshaped, kb_reshaped.transpose(-2, -1)) / math.sqrt(d) if causal: M = torch.tril(torch.ones((n_ctx_q, b), device="cuda")) mask = M == 0 @@ -325,7 +327,7 @@ def reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal): # (False, 1, 64, 4096, [4096], 128, 304, torch.float16, 128, 16, 3, 4), ], ) -@pytest.mark.skip(reason="This test is temporarily disabled.") + def test_persistent_lean_attention( request, causal, @@ -374,8 +376,6 @@ def test_persistent_lean_attention( list_sum_block_n.append(len_sum) batch_num_block_n = torch.tensor(list_sum_block_n, device="cuda", dtype=torch.int32) - sm_scale = 0.5 - q, k, v, Mp, Lp, Op, locks, batch_num_block_n = get_lean_attn_inputs( batch, n_ctx_q, @@ -406,18 +406,26 @@ def test_persistent_lean_attention( XCD_REMAP, causal, batch, - sm_scale, RAGGED_BATCH, num_warps, waves_per_eu, ) # Calculate Pytorch refence output - ref_out = reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal) + ref_out = reference_attention(q, k, v, n_ctx, n_ctx_q, causal) # Compare result atol = 1.4e-1 if init_dtype == "fp8" else 1e-2 rtol = 1e-2 if init_dtype == "fp8" else 3e-3 - torch.testing.assert_close(ref_out, la_out, atol=atol, rtol=rtol) + #torch.testing.assert_close(ref_out, la_out, atol=atol, rtol=rtol) + # # Compare result + #atol = 1e-2 + #rtol = 1e-2 + try: + torch.testing.assert_close(ref_out, la_out, atol=atol, rtol=rtol) + except AssertionError: + print("Assertion failed! Showing mismatches:") + print_mismatches(ref_out, la_out, atol, rtol) + raise # Re-raise the exception after printing mismatches # NOTE: Tests where the workload < num_sms currently fail. @@ -446,7 +454,6 @@ def test_persistent_lean_attention_outer( ): torch.manual_seed(20) - sm_scale = 0.5 config = _get_config( batch_size=batch, causal=causal, @@ -482,14 +489,13 @@ def test_persistent_lean_attention_outer( locks, batch_num_block_n, batch, - sm_scale, causal=causal, RAGGED_BATCH=RAGGED_BATCH, config=config, ) # Calculate Pytorch refence output - ref_out = reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal) + ref_out = reference_attention(q, k, v, n_ctx, n_ctx_q, causal) # Compare result atol = 1.4e-1 if init_dtype == "fp8" else 1e-2 rtol = 1e-2 if init_dtype == "fp8" else 3e-3 @@ -528,17 +534,17 @@ def print_mismatches(ref_out, la_out, atol=1e-8, rtol=1e-5): def main(): - batch = 3 + batch = 8 causal = False - hq = 128 - hk = 128 + hq = 64 + hk = 64 n_ctx_q = 16 - n_ctx = [4096, 32768, 65536] # [131072] * batch # [16384] #[8192] + n_ctx = [1024, 1024, 2048, 2048, 4096, 4096, 32768, 65536] #[4096, 32768, 65536] # [131072] * batch # [16384] #[8192] d = 128 total_programs = 912 init_dtype = torch.float16 BLOCK_M = 16 - BLOCK_N = 128 + BLOCK_N = 64 XCD_REMAP = True waves_per_eu = 2 num_warps = 4 @@ -565,7 +571,6 @@ def main(): list_sum_block_n.append(len_sum) batch_num_block_n = torch.tensor(list_sum_block_n, device="cuda", dtype=torch.int32) - sm_scale = 0.5 q, k, v, Mp, Lp, Op, locks, batch_num_block_n = get_lean_attn_inputs( batch, @@ -595,14 +600,13 @@ def main(): XCD_REMAP, causal, batch, - sm_scale, RAGGED_BATCH, num_warps, waves_per_eu, ) # print(f"ms={ms}") - ref_out = reference_attention(q, k, v, n_ctx, n_ctx_q, sm_scale, causal) + ref_out = reference_attention(q, k, v, n_ctx, n_ctx_q, causal) # # Compare result atol = 1.4e-1 if init_dtype == "fp8" else 1e-2 From 893a616523f17406f4eb425380bb4e4c6c534d1e Mon Sep 17 00:00:00 2001 From: Valerie Chen Date: Wed, 12 Nov 2025 13:12:35 -0500 Subject: [PATCH 2/4] black fix --- .../ops/triton/_triton_kernels/lean_atten.py | 2 +- op_tests/op_benchmarks/triton/bench_la.py | 1 - op_tests/triton_tests/test_la.py | 19 +++++++++++++------ 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/lean_atten.py b/aiter/ops/triton/_triton_kernels/lean_atten.py index 78f1010504..6e75fd6d0b 100644 --- a/aiter/ops/triton/_triton_kernels/lean_atten.py +++ b/aiter/ops/triton/_triton_kernels/lean_atten.py @@ -131,7 +131,7 @@ def _attention_inner( k = tl.load(k_ptrs, mask=mask_k_cols_local[:, None], other=0.0) qk_scale = SM_SCALE * RCP_LN2 qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) + qk += tl.dot(q, k) qk = qk * qk_scale if causal: diff --git a/op_tests/op_benchmarks/triton/bench_la.py b/op_tests/op_benchmarks/triton/bench_la.py index 9f5e402928..5a92719335 100644 --- a/op_tests/op_benchmarks/triton/bench_la.py +++ b/op_tests/op_benchmarks/triton/bench_la.py @@ -340,7 +340,6 @@ def bench_lean_attention( list_sum_block_n.append(len_sum) batch_num_block_n = torch.tensor(list_sum_block_n, device="cuda", dtype=torch.int32) - # Allocate Tensors q = torch.empty((n_ctx_q * batch, hq, d), dtype=init_dtype, device="cuda").normal_( mean=0.0, std=0.5 diff --git a/op_tests/triton_tests/test_la.py b/op_tests/triton_tests/test_la.py index aed43af3a2..6cf5581992 100644 --- a/op_tests/triton_tests/test_la.py +++ b/op_tests/triton_tests/test_la.py @@ -327,7 +327,6 @@ def reference_attention(q, k, v, n_ctx, n_ctx_q, causal): # (False, 1, 64, 4096, [4096], 128, 304, torch.float16, 128, 16, 3, 4), ], ) - def test_persistent_lean_attention( request, causal, @@ -416,10 +415,10 @@ def test_persistent_lean_attention( # Compare result atol = 1.4e-1 if init_dtype == "fp8" else 1e-2 rtol = 1e-2 if init_dtype == "fp8" else 3e-3 - #torch.testing.assert_close(ref_out, la_out, atol=atol, rtol=rtol) + # torch.testing.assert_close(ref_out, la_out, atol=atol, rtol=rtol) # # Compare result - #atol = 1e-2 - #rtol = 1e-2 + # atol = 1e-2 + # rtol = 1e-2 try: torch.testing.assert_close(ref_out, la_out, atol=atol, rtol=rtol) except AssertionError: @@ -539,7 +538,16 @@ def main(): hq = 64 hk = 64 n_ctx_q = 16 - n_ctx = [1024, 1024, 2048, 2048, 4096, 4096, 32768, 65536] #[4096, 32768, 65536] # [131072] * batch # [16384] #[8192] + n_ctx = [ + 1024, + 1024, + 2048, + 2048, + 4096, + 4096, + 32768, + 65536, + ] # [4096, 32768, 65536] # [131072] * batch # [16384] #[8192] d = 128 total_programs = 912 init_dtype = torch.float16 @@ -571,7 +579,6 @@ def main(): list_sum_block_n.append(len_sum) batch_num_block_n = torch.tensor(list_sum_block_n, device="cuda", dtype=torch.int32) - q, k, v, Mp, Lp, Op, locks, batch_num_block_n = get_lean_attn_inputs( batch, n_ctx_q, From 63dd46bd35cfdab8fa13eae670588b512e7e06f6 Mon Sep 17 00:00:00 2001 From: Valerie Chen Date: Fri, 14 Nov 2025 13:39:54 -0500 Subject: [PATCH 3/4] Remove unused parameters per PR review request --- aiter/ops/triton/_triton_kernels/lean_atten.py | 8 -------- aiter/ops/triton/lean_atten.py | 6 ------ 2 files changed, 14 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/lean_atten.py b/aiter/ops/triton/_triton_kernels/lean_atten.py index 6e75fd6d0b..a743c4a9df 100644 --- a/aiter/ops/triton/_triton_kernels/lean_atten.py +++ b/aiter/ops/triton/_triton_kernels/lean_atten.py @@ -33,8 +33,6 @@ @functools.lru_cache(maxsize=1024) def _get_config( - causal: bool, - batch_size: int, ): if not hasattr(_get_config, "_config_dict"): dev = arch_info.get_device() @@ -261,8 +259,6 @@ def la_persistent( tiles_per_head: tl.constexpr, num_splits: tl.constexpr, max_output_tile_cnt: tl.constexpr, - num_heads_q: tl.constexpr, - num_heads_k: tl.constexpr, gqa_group_size: tl.constexpr, use_64_indexing: tl.constexpr, RAGGED_BATCH: tl.constexpr, @@ -348,8 +344,6 @@ def la_persistent( BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, MASKED_BLOCKS=MASKED_BLOCKS, - XCD_REMAP=XCD_REMAP, - NUM_XCDS=NUM_XCDS, batch_size=batch_size, causal=causal, num_m_blocks=num_m_blocks, @@ -403,8 +397,6 @@ def la_persistent_inner( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, MASKED_BLOCKS: tl.constexpr, - XCD_REMAP: tl.constexpr, - NUM_XCDS: tl.constexpr, batch_size: tl.constexpr, causal: tl.constexpr, num_m_blocks: tl.constexpr, diff --git a/aiter/ops/triton/lean_atten.py b/aiter/ops/triton/lean_atten.py index 8a7e53f054..2bd5ad3466 100644 --- a/aiter/ops/triton/lean_atten.py +++ b/aiter/ops/triton/lean_atten.py @@ -158,7 +158,6 @@ def _persistent_lean_attention( N_CTX_Q, N_CTX_K, H, - H_K, BLOCK_M, BLOCK_N, total_programs, @@ -216,7 +215,6 @@ def _persistent_lean_attention( N_CTX_Q, N_CTX_K, H, - H_K, BLOCK_M, BLOCK_N, total_programs, @@ -314,8 +312,6 @@ def _persistent_lean_attention( num_warps=num_warps, num_stages=1, num_ctas=1, - num_heads_q=H, - num_heads_k=H_K, gqa_group_size=GQA_GROUP_SIZE, use_64_indexing=( (k.stride(0) * N_CTX_K) >= (1 << 31) @@ -347,7 +343,6 @@ def get_num_splits_and_buffer_sizes( max_seqlen_q, max_seqlen_k, num_heads, - num_heads_k, BLOCK_M, BLOCK_N, num_SMs, @@ -367,7 +362,6 @@ def get_num_splits_and_buffer_sizes( # print(f"block_m: {BLOCK_M}, block_n: {BLOCK_N} ") # print(f"num_m_block: {num_m_blocks}, num_n_block: {num_n_blocks} ") # print(f"max_seqlen_q: {max_seqlen_q}, max_seqlen_k: {max_seqlen_k}") - # print(f"num_heads: {num_heads}, num_heads_k: {num_heads_k} ") if max_seqlen_q == 1: causal = False From 520171f55ba55d9ef9748706fcf0a52823908148 Mon Sep 17 00:00:00 2001 From: Valerie Chen Date: Fri, 14 Nov 2025 13:57:15 -0500 Subject: [PATCH 4/4] Black fix --- aiter/ops/triton/_triton_kernels/lean_atten.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/lean_atten.py b/aiter/ops/triton/_triton_kernels/lean_atten.py index a743c4a9df..c71d810642 100644 --- a/aiter/ops/triton/_triton_kernels/lean_atten.py +++ b/aiter/ops/triton/_triton_kernels/lean_atten.py @@ -32,8 +32,7 @@ @functools.lru_cache(maxsize=1024) -def _get_config( -): +def _get_config(): if not hasattr(_get_config, "_config_dict"): dev = arch_info.get_device() fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-LEANATTN-DEFAULT.json"