diff --git a/aiter/ops/triton/_triton_kernels/fp8_mqa_logits.py b/aiter/ops/triton/_triton_kernels/fp8_mqa_logits.py new file mode 100644 index 0000000000..bc58d2421f --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/fp8_mqa_logits.py @@ -0,0 +1,110 @@ +import triton +import triton.language as tl + + +@triton.jit +def _fp8_mqa_logits_kernel( + Q_ptr, # fp8e4m3 [seq_len, H, D] + KV_ptr, # fp8e4m3 [seq_len_kv, D] + kv_scales_ptr, # fp32 [seq_len_kv] + weights_ptr, # fp32 [seq_len, H] + cu_start_ptr, # int32 [seq_len] + cu_end_ptr, # int32 [seq_len] + logits_ptr, # fp32 [seq_len, seq_len_kv] + seq_len, + seq_len_kv, + NUM_HEADS: tl.constexpr, + HEAD_SIZE: tl.constexpr, + # strides + stride_q_s: tl.int64, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_s: tl.int64, + stride_kv_d: tl.constexpr, + stride_w_s: tl.int64, + stride_w_h: tl.constexpr, + stride_logits_s: tl.int64, + stride_logits_k: tl.int64, + # block sizes + BLOCK_KV: tl.constexpr, +): + row_id = tl.program_id(0) + + tl.assume(row_id >= 0) + tl.assume(stride_q_s > 0) + tl.assume(stride_q_h > 0) + tl.assume(stride_q_d > 0) + tl.assume(stride_kv_s > 0) + tl.assume(stride_kv_d > 0) + tl.assume(stride_w_s > 0) + tl.assume(stride_w_h > 0) + + h_inds = tl.arange(0, NUM_HEADS)[:, None] + d_inds = tl.arange(0, HEAD_SIZE) + + # load Q[BLOCK_Q, NUM_HEADS, HEAD_SIZE] + q_ptrs = ( + Q_ptr + row_id * stride_q_s + h_inds * stride_q_h + d_inds[None, :] * stride_q_d + ) + + q_block = tl.load(q_ptrs, cache_modifier=".cg") + w_ptrs = weights_ptr + row_id * stride_w_s + h_inds * stride_w_h + w_block = tl.load(w_ptrs, cache_modifier=".cg").to(tl.float32) + + # Load start/end for each row in this block + start_ind = tl.load(cu_start_ptr + row_id) + end_ind = tl.load(cu_end_ptr + row_id) + + start_ind = tl.maximum(start_ind, 0) + end_ind = tl.minimum(end_ind, seq_len_kv) + unmasked_end_ind = (end_ind // BLOCK_KV) * BLOCK_KV + + logits_row_ptrs = logits_ptr + row_id * stride_logits_s + kv_col_offsets = tl.arange(0, BLOCK_KV) + start_ind + kv_ptrs = ( + KV_ptr + kv_col_offsets[None, :] * stride_kv_s + d_inds[:, None] * stride_kv_d + ) + + kv_scales_ptrs = kv_scales_ptr + kv_col_offsets + + logits_ptrs = logits_row_ptrs + kv_col_offsets * stride_logits_k + + # Loop over KV tiles + for _ in tl.range(start_ind, unmasked_end_ind, BLOCK_KV): + kv_block = tl.load(kv_ptrs) + kv_scales = tl.load(kv_scales_ptrs) + + # [NUM_HEADS, BLOCK_KV] = [NUM_HEADS, HEAD_SIZE] x [HEAD_SIZE, BLOCK_KV] + scores = tl.dot(q_block, kv_block) + # Multiply by kv_scales (broadcast along rows) + scores = scores * kv_scales[None, :] + # ReLU + scores = tl.maximum(scores, 0.0) + scores = scores * w_block + # [NUM_HEADS, BLOCK_KV] -> [BLOCK_KV, ] + scores = tl.sum(scores, axis=0) + tl.store(logits_ptrs, scores) + + kv_ptrs += BLOCK_KV * stride_kv_s + kv_scales_ptrs += BLOCK_KV + logits_ptrs += BLOCK_KV * stride_logits_k + + if unmasked_end_ind != end_ind: + # masked load + kv_col_offsets = tl.arange(0, BLOCK_KV) + unmasked_end_ind + kv_col_mask = kv_col_offsets < seq_len_kv + kv_block = tl.load(kv_ptrs, mask=kv_col_mask[None, :], other=0.0) + kv_scales = tl.load(kv_scales_ptrs, mask=kv_col_mask, other=0.0) + + # [NUM_HEADS, BLOCK_KV] = [NUM_HEADS, HEAD_SIZE] x [HEAD_SIZE, BLOCK_KV] + scores = tl.dot(q_block, kv_block) + # Multiply by kv_scales (broadcast along rows) + scores = scores * kv_scales[None, :] + # ReLU + scores = tl.maximum(scores, 0.0) + scores = scores * w_block + # [NUM_HEADS, BLOCK_KV] -> [BLOCK_KV, ] + scores = tl.sum(scores, axis=0) + # masked store + in_window = (kv_col_offsets >= start_ind) & (kv_col_offsets < end_ind) + tl.store(logits_ptrs, scores, mask=in_window) diff --git a/aiter/ops/triton/_triton_kernels/unified_attention_sparse_mla.py b/aiter/ops/triton/_triton_kernels/unified_attention_sparse_mla.py new file mode 100644 index 0000000000..06ab445bf4 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/unified_attention_sparse_mla.py @@ -0,0 +1,252 @@ +import triton +import triton.language as tl + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.exp(Sdiv) + p2 = tl.exp(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@triton.jit +def find_seq_idx( + query_start_len_ptr, + target_idx, + num_seqs, + BLOCK_Q: tl.constexpr, + use_q_block_mode: tl.constexpr, +): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + + return left - 1 + + +@triton.jit +def _kernel_unified_attention_sparse_mla_2d( + output_ptr, # [num_tokens, num_query_heads, KV_LORA_RANK] + query_ptr, # [num_tokens, num_query_heads, KV_LORA_RANK] + key_cache_ptr, # [num_blks, blk_size, 1, KV_LORA_RANK + ROPE_RANK] + value_cache_ptr, # [num_blks, blk_size, 1, KV_LORA_RANK] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + topk_indices_ptr, # [num_tokens, topk] + seq_lens_ptr, # [num_seqs] + scale, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + topk_count: tl.constexpr, + query_start_len_ptr, # [num_seqs+1] + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + ROPE_RANK: tl.constexpr, + KV_LORA_RANK: tl.constexpr, + TILE_SIZE: tl.constexpr, + ALL_DECODE: tl.constexpr = False, +): + """ + TODO: + -- Masking can be simplified + -- Tests fail when all topk indices are all -1, not likely to be the case in practice + """ + # only one query per program + # these can be removed but keeps the kernel similar to the MHA way + BLOCK_Q: tl.constexpr = 1 + kv_head_idx = 0 # assume there is single kv head + + q_block_global_idx = tl.program_id(0) + q_ind = q_block_global_idx // (num_query_heads // BLOCK_M) + head_ind = q_block_global_idx % (num_query_heads // BLOCK_M) + seq_idx = find_seq_idx(query_start_len_ptr, q_ind, num_seqs, BLOCK_Q, False) + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) + + q_block_local_idx = q_ind - q_block_start_idx + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + offs_m = tl.arange(0, BLOCK_M) + head_ind * BLOCK_M + + # load Q in two parts with different dim offsets + offs_lora = tl.arange(0, KV_LORA_RANK) + offs_rope = tl.arange(KV_LORA_RANK, KV_LORA_RANK + ROPE_RANK) + + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + + query_mask_0 = query_pos < cur_batch_query_len + query_mask_1 = query_offset_1 < num_query_heads + + if ALL_DECODE or BLOCK_M >= num_query_heads: + Q_cache_modifier: tl.constexpr = ".cg" + else: + Q_cache_modifier: tl.constexpr = "" + + # load Q in two parts + # q_pe: (BLOCK_M, ROPE_RANK) + q_rope_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_rope[None, :] + ) + Q_rope = tl.load( + query_ptr + q_rope_offset, + mask=query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + cache_modifier=Q_cache_modifier, + ) + + # q_lora: (BLOCK_M, KV_LORA_RANK) + q_lora_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_lora[None, :] + ) + Q_lora = tl.load( + query_ptr + q_lora_offset, + mask=query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + cache_modifier=Q_cache_modifier, + ) + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, KV_LORA_RANK], dtype=tl.float32) + + block_table_offset = seq_idx * block_table_stride + + # iterate topk indices in tiles of TILE_SIZE + num_tiles = (topk_count + TILE_SIZE - 1) // TILE_SIZE + KV_cache_modifier: tl.constexpr = ".cg" if ALL_DECODE else "" + for t in range(0, num_tiles): + tile_start = t * TILE_SIZE + offs_t = tl.arange(0, TILE_SIZE) + valid_t = (tile_start + offs_t) < topk_count + + # load top-k token positions for this query + topk_row_ptr = topk_indices_ptr + q_ind * topk_count + topk_pos = tl.load(topk_row_ptr + tile_start + offs_t, mask=valid_t, other=0) + # ignore -1, means not valid + valid_t = valid_t & (topk_pos != -1) + + # map positions to block id and in-block offset + physical_block_idx = topk_pos // BLOCK_SIZE + slot = topk_pos % BLOCK_SIZE + # Compute S = scale * (q_rope k_rope + q_lora k_lora) + # q_rope: (BLOCK_M, ROPE_RANK) k_rope: (ROPE_RANK, TILE_SIZE) + # q_lora: (BLOCK_M, KV_LORA_RANK) k_lora: (KV_LORA_RANK, TILE_SIZE) + S = tl.zeros([BLOCK_M, TILE_SIZE], dtype=tl.float32) + # load k in two parts + # K_rope: (ROPE_RANK, TILE_SIZE) + k_rope_ptrs = ( + key_cache_ptr + + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_rope[:, None] * stride_k_cache_3 + + slot[None, :] * stride_k_cache_1 + ) + K_rope = tl.load( + k_rope_ptrs, + mask=valid_t[None, :], + other=0.0, + cache_modifier=KV_cache_modifier, + ) + S += scale * tl.dot(Q_rope, K_rope) + # K_lora: (KV_LORA_RANK, TILE_SIZE) + k_lora_ptrs = ( + key_cache_ptr + + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_lora[:, None] * stride_k_cache_3 + + slot[None, :] * stride_k_cache_1 + ) + K_lora = tl.load( + k_lora_ptrs, + mask=valid_t[None, :], + other=0.0, + cache_modifier=KV_cache_modifier, + ) + + S += scale * tl.dot(Q_lora, K_lora) + + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & valid_t[None, :], + S, + float("-inf"), + ) + + m_j = tl.maximum(M, tl.max(S, axis=1)) + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + P = tl.exp(S - m_j[:, None]) + l_j = tl.sum(P, axis=1) + alpha = tl.exp(M - m_j) + + acc = acc * alpha[:, None] + L = L * alpha + l_j + M = m_j + + # load V with shape (TILE_SIZE, KV_LORA_RANK) + v_lora_ptrs = ( + value_cache_ptr + + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + slot[:, None] * stride_v_cache_1 + + offs_lora[None, :] * stride_v_cache_3 + ) + V_lora = tl.load( + v_lora_ptrs, + mask=valid_t[:, None], + other=0.0, + cache_modifier=KV_cache_modifier, + ) + + acc += tl.dot(P.to(V_lora.dtype), V_lora) + + # epilogue + one_over_L = 1.0 / L[:, None] + acc = acc * one_over_L + + output_offs_lora = ( + query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_lora[None, :] + ) + tl.store( + output_ptr + output_offs_lora, + acc, + mask=query_mask_0[:, None] & query_mask_1[:, None], + ) diff --git a/aiter/ops/triton/fp8_mqa_logits.py b/aiter/ops/triton/fp8_mqa_logits.py new file mode 100644 index 0000000000..512973b853 --- /dev/null +++ b/aiter/ops/triton/fp8_mqa_logits.py @@ -0,0 +1,75 @@ +import torch +import math +import triton + +from aiter.ops.triton._triton_kernels.fp8_mqa_logits import _fp8_mqa_logits_kernel + + +def fp8_mqa_logits( + Q, + KV, + kv_scales, + weights, + cu_starts, + cu_ends, +): + """ + This function computes the logits to be used by a topk function for sparse attention. + + Q: [seq_len, NUM_HEADS, HEAD_SIZE], dtype float8 + KV: [seq_len_kv, HEAD_SIZE], dtype float8 + kv_scales: [seq_len_kv], dtype float32 + weights: [seq_len, NUM_HEADS], dtype float32 + cu_starts: [seq_len], dtype int32, start indices + cu_ends: [seq_len], dtype int32, end indices + + Returns: + logits: [seq_len, seq_len_kv], dtype float32 (must be initialized to -inf, because of causal masking) + """ + BLOCK_KV = 128 + seq_len, num_heads, head_size = Q.shape + seq_len_kv = KV.shape[0] + # TODO: Currently assuming num_heads and head_size is power of 2. + assert num_heads & (num_heads - 1) == 0, "num q. heads should be power of 2." + assert head_size & (head_size - 1) == 0, "head size should be power of 2." + # Initialize with -inf because of causal masking + logits = torch.full( + (seq_len, seq_len_kv), + fill_value=-float("inf"), + dtype=torch.float32, + device=Q.device, + ) + + stride_q_s, stride_q_h, stride_q_d = Q.stride() + stride_kv_s, stride_kv_d = KV.stride() + stride_w_s, stride_w_h = weights.stride() + stride_logits_s, stride_logits_k = logits.stride() + _fp8_mqa_logits_kernel[(seq_len,)]( + Q_ptr=Q, + KV_ptr=KV, + kv_scales_ptr=kv_scales, + weights_ptr=weights, + cu_start_ptr=cu_starts, + cu_end_ptr=cu_ends, + logits_ptr=logits, + seq_len=seq_len, + seq_len_kv=seq_len_kv, + NUM_HEADS=num_heads, + HEAD_SIZE=head_size, + stride_q_s=stride_q_s, + stride_q_h=stride_q_h, + stride_q_d=stride_q_d, + stride_kv_s=stride_kv_s, + stride_kv_d=stride_kv_d, + stride_w_s=stride_w_s, + stride_w_h=stride_w_h, + stride_logits_s=stride_logits_s, + stride_logits_k=stride_logits_k, + BLOCK_KV=BLOCK_KV, + num_warps=4, + num_stages=2, + waves_per_eu=2, + matrix_instr_nonkdim=16, + ) + + return logits diff --git a/aiter/ops/triton/unified_attention_sparse_mla.py b/aiter/ops/triton/unified_attention_sparse_mla.py new file mode 100644 index 0000000000..c2b8e1f8ce --- /dev/null +++ b/aiter/ops/triton/unified_attention_sparse_mla.py @@ -0,0 +1,95 @@ +from aiter.ops.triton._triton_kernels.unified_attention_sparse_mla import ( + _kernel_unified_attention_sparse_mla_2d, +) + + +def unified_attention_sparse_mla( + q, + kv, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + topk_indices, + block_table, + kv_lora_rank, +): + """ + This function computes the sparse attention. + + Note: topk_indices index the KV cache, not block_table. + + Q: [seq_len, NUM_HEADS, kv_lora_rank + rope_rank], dtype bfloat16 + KV: [seq_len_kv, 1, kv_lora_rank + rope_rank], dtype bfloat16 + cu_seqlens_q: [BATCH + 1], dtype int32 + max_seqlen_q: scalar, dtype int32 + max_seqlen_k: scalar, dtype int32 + softmax_scale: scalar, dtype float32 + topk_indices: [seq_len, TOP_K], dtype int32 + block_table: [BATCH, MAX_NUM_BLOCKS_PER_BATCH], dtype int32 + kv_lora_rank: scalar, dtype int32 + + Returns: + out (in-place): [seq_len, NUM_HEADS, kv_lora_rank], dtype bfloat16 + """ + + # TODO: This kernel is not optimized and simplified for initial development. + + block_size = kv.shape[1] + num_seqs = len(seqused_k) + num_query_heads = q.shape[1] + num_kv_heads = 1 + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] + topk_count = topk_indices.shape[1] + k = kv + v = kv[..., :kv_lora_rank] + + BLOCK_M = 16 + + total_num_q_blocks = q.shape[0] * (num_query_heads // BLOCK_M) + ALL_DECODE = max_seqlen_q == 1 + + ROPE_RANK = head_size - kv_lora_rank + KV_LORA_RANK = kv_lora_rank + TILE_SIZE = block_size + num_stages_2d = 1 + num_warps = 4 + _kernel_unified_attention_sparse_mla_2d[(total_num_q_blocks,)]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + topk_indices_ptr=topk_indices, + seq_lens_ptr=seqused_k, + scale=softmax_scale, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + BLOCK_SIZE=block_size, + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + topk_count=topk_count, + query_start_len_ptr=cu_seqlens_q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + ROPE_RANK=ROPE_RANK, + KV_LORA_RANK=KV_LORA_RANK, + TILE_SIZE=TILE_SIZE, + ALL_DECODE=ALL_DECODE, + num_warps=num_warps, + num_stages=num_stages_2d, + ) diff --git a/op_tests/triton_tests/test_fp8_mqa_logits.py b/op_tests/triton_tests/test_fp8_mqa_logits.py new file mode 100644 index 0000000000..c7e69d1757 --- /dev/null +++ b/op_tests/triton_tests/test_fp8_mqa_logits.py @@ -0,0 +1,129 @@ +# tests are adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py +import torch +import pytest +from typing import Tuple +from aiter.ops.triton.utils.types import get_fp8_dtypes +from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits + +e5m2_type, e4m3_type = get_fp8_dtypes() +fp8_info = torch.finfo(e4m3_type) +fp8_max = fp8_info.max + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_custom_dims_cast_to_fp8( + x: torch.Tensor, dims: Tuple, use_ue8m0: bool +) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / fp8_max + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(e4m3_type) + return x_scaled, sf.squeeze() + + +def ref_fp8_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + cost_only: bool = False, +): + seq_len_kv = kv.shape[0] + + if cost_only: + start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) + end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) + count_ones_per_row = (end - start).clamp(min=0) + return count_ones_per_row.sum() + + k = kv + q = q.float() + k = k.float() + + mask_lo = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + ) + mask_hi = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] + ) + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + cost = mask.sum() + return logits, cost + + +def generate_cp_test_data(seq_len, seq_len_kv): + assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0 + chunk_size = seq_len // 2 + cp_size = seq_len_kv // seq_len + # Select an arbitrary CP rank + cp_id = cp_size // 3 + ks = torch.zeros(seq_len, dtype=torch.int, device="cuda") + ke = torch.zeros(seq_len, dtype=torch.int, device="cuda") + for i in range(chunk_size): + ke[i] = cp_id * chunk_size + i + ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i + return ks, ke + + +@pytest.mark.parametrize("s_q", [1, 17, 61, 128, 1024]) +@pytest.mark.parametrize("s_k", [16, 76, 113, 1024, 2048]) +@pytest.mark.parametrize("num_heads", [16, 64]) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("disable_cp", [True, False]) +@torch.inference_mode() +def test_fp8_mqa_logits( + s_q: int, + s_k: int, + num_heads: int, + head_dim: int, + disable_cp: bool, +) -> None: + torch.manual_seed(0) + if s_q > s_k: + pytest.skip() + q = torch.randn(s_q, num_heads, head_dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(s_k, head_dim, device="cuda", dtype=torch.bfloat16) + kv_fp8, scales = per_custom_dims_cast_to_fp8(kv, (0,), False) + kv = (kv_fp8.to(torch.float32) * scales[:, None]).to(torch.bfloat16) + weights = torch.randn(s_q, num_heads, device="cuda", dtype=torch.float32) + # to respect the aseert in generate_cp_test_data + if disable_cp or s_k % s_q != 0 or s_q % 2 != 0: + ks = torch.zeros(s_q, dtype=torch.int, device="cuda") + ke = torch.arange(s_q, dtype=torch.int, device="cuda") + (s_k - s_q) + else: + ks, ke = generate_cp_test_data(s_q, s_k) + + q_fp8 = q.to(e4m3_type) + kv_fp8, scales = per_custom_dims_cast_to_fp8(kv, (0,), False) + + ref_logits, ref_cost = ref_fp8_mqa_logits( + q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke + ) + + logits = fp8_mqa_logits(q_fp8, kv_fp8, scales, weights, ks, ke) + + ref_neginf_mask = ref_logits == float("-inf") + neginf_mask = logits == float("-inf") + assert torch.equal(neginf_mask, ref_neginf_mask) + ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) + logits = logits.masked_fill(neginf_mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" diff --git a/op_tests/triton_tests/test_unified_attention_sparse_mla.py b/op_tests/triton_tests/test_unified_attention_sparse_mla.py new file mode 100644 index 0000000000..27dd949eac --- /dev/null +++ b/op_tests/triton_tests/test_unified_attention_sparse_mla.py @@ -0,0 +1,367 @@ +# test code is adapted from flashMLA: +# https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla_decoding.py +import random +import dataclasses +from typing import Optional, Tuple + +import torch +import pytest +from math import ceil +from aiter.ops.triton.unified_attention_sparse_mla import unified_attention_sparse_mla + + +def cdiv(a, b): + return ceil(a / b) + + +@dataclasses.dataclass +class Param: + b: int # Batch size + s_q: int # Number of queries for one request + s_k: int # Seq len, or mean seq len if varlen == True + is_varlen: bool + is_causal: bool + is_fp8: bool + topk: Optional[int] = None + test_performance: bool = True + is_all_indices_invalid: bool = False + have_zero_seqlen_k: bool = False + block_size: int = 64 + h_q: int = 128 # Number of q heads + h_kv: int = 1 # Number of kv heads + d: int = 576 # Q/K head dim (= dv + RoPE dim) + dv: int = 512 # V head dim + seed: int = 0 + + +def generate_test_data( + t: Param, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + """ + Generate test data from a given configuration + Return: [cache_seqlens, q, block_table, blocked_k] + Pay attention: This function changes the random seed + """ + random.seed(t.seed) + torch.manual_seed(t.seed) + torch.cuda.manual_seed(t.seed) + torch.backends.cudnn.deterministic = True + + assert t.h_q % t.h_kv == 0 + + cache_seqlens_cpu = torch.full((t.b,), t.s_k, dtype=torch.int32, device="cpu") + if t.is_varlen: + for i in range(t.b): + cache_seqlens_cpu[i] = max(random.normalvariate(t.s_k, t.s_k / 2), t.s_q) + + if t.have_zero_seqlen_k: + zeros_mask = torch.randn(t.b, dtype=torch.float32, device="cpu") > 0 + cache_seqlens_cpu[zeros_mask] = 0 + + max_seqlen = cache_seqlens_cpu.max().item() + max_seqlen_pad = cdiv(max_seqlen, 256) * 256 + cache_seqlens = cache_seqlens_cpu.cuda() + + q = torch.randn(t.b, t.s_q, t.h_q, t.d) + q.clamp_(min=-1.0, max=1.0) + + block_table = torch.arange( + t.b * max_seqlen_pad // t.block_size, dtype=torch.int32 + ).view(t.b, max_seqlen_pad // t.block_size) + block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view( + t.b, -1 + ) + blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10 + blocked_k.clamp_(min=-1.0, max=1.0) + + if t.topk is None: + for i in range(t.b): + cur_len = cache_seqlens_cpu[i].item() + cur_num_blocks = cdiv(cur_len, t.block_size) + blocked_k[block_table[i][cur_num_blocks:]] = float("nan") + if cur_len % t.block_size != 0: + blocked_k[block_table[i][cur_num_blocks - 1]][ + cur_len % t.block_size : + ] = float("nan") + block_table[i][cur_num_blocks:] = 2147480000 + return cache_seqlens, q, block_table, blocked_k, None, None + else: + block_table_cpu = block_table.cpu() + abs_indices = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu") + indices_in_kvcache = torch.empty( + t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu" + ) + for i in range(t.b): + # Generate indices + for j in range(t.s_q): + cur_abs_indices = torch.randperm( + int(cache_seqlens_cpu[i].item()), device="cpu" + )[: t.topk] + cur_blocked_indices = block_table_cpu[ + i, cur_abs_indices // t.block_size + ] * t.block_size + (cur_abs_indices % t.block_size) + if len(cur_abs_indices) < t.topk: + pad_len = t.topk - len(cur_abs_indices) + cur_abs_indices = torch.cat( + [cur_abs_indices, torch.full((pad_len,), -1, device="cpu")] + ) + cur_blocked_indices = torch.cat( + [cur_blocked_indices, torch.full((pad_len,), -1, device="cpu")] + ) + + # Mask KV + perm = torch.randperm(t.topk, device="cpu") + cur_abs_indices = cur_abs_indices[perm] + cur_blocked_indices = cur_blocked_indices[perm] + + # Fill it with invalid indices if needed + if t.is_all_indices_invalid: + cur_abs_indices.fill_(-1) + cur_blocked_indices.fill_(-1) + + abs_indices[i, j, :] = cur_abs_indices + indices_in_kvcache[i, j, :] = cur_blocked_indices + + # Mask nonused KV as NaN + all_indices = indices_in_kvcache.flatten().tolist() + all_indices = list(set(all_indices)) + if -1 in all_indices: + all_indices.remove(-1) + all_indices = torch.tensor(all_indices, dtype=torch.int32, device="cpu") + + blocked_k = blocked_k.view(-1, t.h_kv, t.d) + nonused_indices_mask = torch.ones( + blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device="cpu" + ) + nonused_indices_mask[all_indices] = False + blocked_k[nonused_indices_mask, :, :] = float("nan") + blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d) + + abs_indices = abs_indices.to(q.device) + indices_in_kvcache = indices_in_kvcache.to(q.device) + + return cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache + + +def reference_torch( + cache_seqlens: torch.Tensor, # [batch_size] + block_table: torch.Tensor, # [batch_size, ?] + q: torch.Tensor, # [batch_size, s_q, h_q, d] + blocked_k: torch.Tensor, # [?, block_size, h_kv, d] + dv: int, + scale: float, + is_causal: bool, + indices: Optional[torch.Tensor] = None, # [batch_size, s_q, topk] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A reference implementation in PyTorch + """ + + def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor): + mask = torch.zeros(s_q, s_k, dtype=torch.bool) + for i in range(s_q): + cur_indices = indices[i] + valid_indices = cur_indices[cur_indices != -1] + mask[i, valid_indices] = True + return mask + + def scaled_dot_product_attention( + batch_idx: int, + query: torch.Tensor, # [h_q, s_q, d] + kv: torch.Tensor, # [h_kv, s_k, d] + dv: int, + scale: float, + is_causal, + indices: Optional[torch.Tensor], # [s_q, topk] + ) -> Tuple[torch.Tensor, torch.Tensor]: + h_q = query.size(0) + h_kv = kv.size(0) + s_q = query.shape[-2] + s_k = kv.shape[-2] + query = query.float() * scale + kv = kv.float() + if h_kv != 1: + kv = kv.repeat_interleave(h_q // h_kv, dim=0) + kv[kv != kv] = 0.0 + attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] + if (is_causal and query.size(1) > 1) or indices is not None: + mask = torch.ones(s_q, s_k, dtype=torch.bool) + if is_causal: + assert indices is None + mask = mask.tril(diagonal=s_k - s_q) + if indices is not None: + mask &= get_topk_attn_mask(s_q, s_k, indices) + attn_bias = torch.zeros(s_q, s_k, dtype=torch.float) + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + attn_weight += attn_bias.to(q.dtype) + # attn_weight /= math.sqrt(query.size(-1)) + lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv] + # Correct for q tokens which has no attendable k + lonely_q_mask = lse == float("-inf") + output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 + lse[lonely_q_mask] = float("+inf") + + return output + + b, s_q, h_q, d = q.size() + block_size = blocked_k.size(1) + h_kv = blocked_k.size(2) + cache_seqlens_cpu = cache_seqlens.cpu() + out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + for i in range(b): + cur_len = cache_seqlens_cpu[i].item() + cur_num_blocks = cdiv(cur_len, block_size) + cur_block_indices = block_table[i][0:cur_num_blocks] + cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...] + cur_out = scaled_dot_product_attention( + i, + q[i].transpose(0, 1), + cur_kv.transpose(0, 1), + dv, + scale, + is_causal, + indices[i] if indices is not None else None, + ) + out_ref[i] = cur_out.transpose(0, 1) + out_ref = out_ref.to(torch.bfloat16) + return out_ref + + +def chunk_input( + cache_seqlens, + q, + block_table, + blocked_k, + abs_indices, + indices_in_kvcache, + dtype=torch.bfloat16, +): + q_new = q.reshape(-1, q.shape[2], q.shape[3]) + abs_indices = abs_indices.reshape(-1, abs_indices.shape[2]) + indices_in_kvcache = indices_in_kvcache.reshape(-1, indices_in_kvcache.shape[2]) + max_q_len = q.shape[1] + max_kv_len = max(cache_seqlens) + query_lens = [q.shape[1]] * q.shape[0] # B * [q_len,] + cu_query_lens = torch.tensor( + [0] + query_lens, dtype=torch.int32, device="cuda" + ).cumsum(dim=0, dtype=torch.int32) + cache_seqlens = cache_seqlens.to("cuda") + q_new = q_new.to("cuda") + block_table = block_table.to("cuda") + blocked_k = blocked_k.to("cuda") + abs_indices = abs_indices.to("cuda") + indices_in_kvcache = indices_in_kvcache.to("cuda") + return ( + cu_query_lens, + max_q_len, + cache_seqlens, + max_kv_len, + q_new.to(dtype), + block_table, + blocked_k.to(dtype), + abs_indices, + indices_in_kvcache, + ) + + +@pytest.mark.parametrize("batch", [1, 8]) +@pytest.mark.parametrize("s_q", [1, 64, 177]) +@pytest.mark.parametrize("s_k", [1, 64, 177]) +@pytest.mark.parametrize("top_k", [64, 78]) +@pytest.mark.parametrize("num_q_heads", [16, 32]) +@pytest.mark.parametrize("lora_dim", [256, 512]) +@pytest.mark.parametrize( + "rope_dim", + [ + 64, + ], +) +@pytest.mark.parametrize("block_size", [16, 64]) +@torch.inference_mode() +def test_triton_unified_attn( + batch: int, + s_q: int, + s_k: int, + top_k: int, + num_q_heads: int, + lora_dim: int, + rope_dim: int, + block_size: int, +) -> None: + total_dim = lora_dim + rope_dim + softmax_scale = lora_dim**-0.5 + + test_p = Param( + batch, + s_q, + s_k, + d=total_dim, + dv=lora_dim, + h_q=num_q_heads, + block_size=block_size, + is_varlen=True, + is_causal=False, + is_fp8=False, + topk=top_k, + test_performance=False, + ) + (cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache) = ( + generate_test_data(test_p) + ) + ref_output = reference_torch( + cache_seqlens, + block_table, + q, + blocked_k, + lora_dim, + softmax_scale, + False, + abs_indices, + ) + + ( + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + q, + block_table, + blocked_k, + abs_indices, + indices_in_kvcache, + ) = chunk_input( + cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache + ) + + output = torch.empty((*q.shape[:-1], lora_dim), device=q.device, dtype=q.dtype) + + unified_attention_sparse_mla( + q, + blocked_k, + output, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + indices_in_kvcache, + block_table, + lora_dim, + ) + + ref_output = ref_output.to(output.device).to(q.dtype) + output = output.reshape(ref_output.shape) + + atol, rtol = 1.5e-2, 1e-2 + torch.testing.assert_close( + output, ref_output, atol=atol, rtol=rtol + ), f"{torch.max(torch.abs(output - ref_output))}"