Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 181 additions & 68 deletions aiter/ops/triton/_triton_kernels/lean_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def la_persistent(
num_heads_k: tl.constexpr,
gqa_group_size: tl.constexpr,
use_64_indexing: tl.constexpr,
RAGGED_BATCH: tl.constexpr,
):
if is_pod:
current_pid = pod_pid
Expand Down Expand Up @@ -356,6 +357,7 @@ def la_persistent(
num_splits=num_splits,
gqa_group_size=gqa_group_size,
use_64_indexing=use_64_indexing,
RAGGED_BATCH=RAGGED_BATCH,
)


Expand Down Expand Up @@ -410,6 +412,7 @@ def la_persistent_inner(
num_splits: tl.constexpr,
gqa_group_size: tl.constexpr,
use_64_indexing: tl.constexpr,
RAGGED_BATCH: tl.constexpr,
):

tl.assume(stride_qm > 0) # n_ctx_q
Expand Down Expand Up @@ -468,23 +471,30 @@ def la_persistent_inner(
tile_head_idx * batch_size + tile_batch_idx
) * num_m_blocks + per_head_tile_idx
else:
tile_idx = (
tile_head_idx * batch_size
) # Output tile idx, 1 output tile per head per batch
tile_iter = tile_head_idx * tiles_per_head
if batch_size == 1:
req_size = tiles_per_head
if not RAGGED_BATCH:
group_size = tiles_per_head // batch_size
tile_batch_idx = (iter % tiles_per_head) // group_size
tile_idx = tile_head_idx * batch_size + tile_batch_idx
tile_iter = tile_head_idx * tiles_per_head + (tile_batch_idx * group_size)
tile_iter_end = tile_iter + group_size
else:
req_size = tl.load(batch_num_block_n)
tile_iter_end = tile_iter + req_size
for b in range(1, batch_size):
next_req_size = tl.load(batch_num_block_n + b)
local_head_iter = iter % tiles_per_head
if (local_head_iter < next_req_size) and (local_head_iter >= req_size):
tile_iter = tile_iter + req_size
tile_idx = tile_idx + b
tile_iter_end = tile_iter + (next_req_size - req_size)
req_size = next_req_size
tile_idx = (
tile_head_idx * batch_size
) # Output tile idx, 1 output tile per head per batch
tile_iter = tile_head_idx * tiles_per_head
if batch_size == 1:
req_size = tiles_per_head
else:
req_size = tl.load(batch_num_block_n)
tile_iter_end = tile_iter + req_size
for b in range(1, batch_size):
next_req_size = tl.load(batch_num_block_n + b)
local_head_iter = iter % tiles_per_head
if (local_head_iter < next_req_size) and (local_head_iter >= req_size):
tile_iter = tile_iter + req_size
tile_idx = tile_idx + b
tile_iter_end = tile_iter + (next_req_size - req_size)
req_size = next_req_size
# Local lean tile ID within a loop of an output tile
local_iter = iter - tile_iter
local_iter_end = tl.minimum(tile_iter_end, cta_end_tile_gid) - tile_iter
Expand All @@ -510,28 +520,52 @@ def la_persistent_inner(
offs_k = tl.arange(0, HEAD_DIM)
mask_k_cols = offs_k < HEAD_DIM_ORIG

if causal:
if causal or not RAGGED_BATCH:
# Prefill or non RAGGED_BATCH
b_seq_size = tile_batch_idx * num_n_blocks
else:
# Decode with RAGGED_BATCH
tile_batch_idx = tile_idx % batch_size
b_seq_size = 0
if tile_batch_idx > 0:
b_seq_size = tl.load(
batch_num_block_n + tile_batch_idx - 1
) # Previous batch size

k_offs = (
(b_seq_size + local_iter) * BLOCK_N * stride_kn
+ tile_khead_idx_global * stride_kh
+ offs_n[None, :] * stride_kn
+ offs_k[:, None] * stride_kk
)
v_offs = (
(b_seq_size + local_iter) * BLOCK_N * stride_vn
+ tile_khead_idx_global * stride_vh
+ offs_n[:, None] * stride_vn
+ offs_k[None, :] * stride_vk
)
if use_64_indexing:
BLOCK_N64 = tl.full((), BLOCK_N, tl.int64)
stride_kn64 = tl.full((), stride_kn, tl.int64)
stride_vn64 = tl.full((), stride_vn, tl.int64)
stride_kh64 = tl.full((), stride_kh, tl.int64)
stride_vh64 = tl.full((), stride_vh, tl.int64)
stride_kk64 = tl.full((), stride_kk, tl.int64)
stride_vk64 = tl.full((), stride_vk, tl.int64)
bn64 = tl.full((), b_seq_size, tl.int64) + tl.full((), local_iter, tl.int64)
k_offs = (
(bn64 * BLOCK_N64) * stride_kn64
+ tl.full((), tile_khead_idx_global, tl.int64) * stride_kh64
+ offs_n[None, :] * stride_kn64
+ offs_k[:, None] * stride_kk64
)
v_offs = (
(bn64 * BLOCK_N64) * stride_vn64
+ tl.full((), tile_khead_idx_global, tl.int64) * stride_vh64
+ offs_n[:, None] * stride_vn64
+ offs_k[None, :] * stride_vk64
)
else:
k_offs = (
(b_seq_size + local_iter) * BLOCK_N * stride_kn
+ tile_khead_idx_global * stride_kh
+ offs_n[None, :] * stride_kn
+ offs_k[:, None] * stride_kk
)
v_offs = (
(b_seq_size + local_iter) * BLOCK_N * stride_vn
+ tile_khead_idx_global * stride_vh
+ offs_n[:, None] * stride_vn
+ offs_k[None, :] * stride_vk
)

k_ptrs = K + k_offs
k_ptrs = tl.multiple_of(k_ptrs, (16, 1))
Expand All @@ -545,12 +579,27 @@ def la_persistent_inner(
q_idx = tile_batch_idx
q_start_m = 0

q_offs = (
q_idx * BLOCK_M * stride_qm
+ tile_head_idx_global * stride_qh
+ offs_m[:, None] * stride_qm
+ offs_k[None, :] * stride_qk
)
if use_64_indexing:
q_idx64 = tl.full((), q_idx, tl.int64)
BLOCK_M64 = tl.full((), BLOCK_M, tl.int64)
stride_qm64 = tl.full((), stride_qm, tl.int64)
stride_qk64 = tl.full((), stride_qk, tl.int64)
th64 = tl.full((), tile_head_idx_global, tl.int64) * tl.full(
(), stride_qh, tl.int64
)
q_offs = (
q_idx64 * BLOCK_M64 * stride_qm64
+ th64
+ offs_m[:, None] * stride_qm64
+ offs_k[None, :] * stride_qk64
)
else:
q_offs = (
q_idx * BLOCK_M * stride_qm
+ tile_head_idx_global * stride_qh
+ offs_m[:, None] * stride_qm
+ offs_k[None, :] * stride_qk
)
q_ptrs = Q + q_offs
q_ptrs = tl.multiple_of(q_ptrs, (1, 16))

Expand Down Expand Up @@ -594,12 +643,27 @@ def la_persistent_inner(
# Update pointers of partial results Mp[cta], Lp[cta], Op[cta]
mp_ptrs = Mp + current_pid * BLOCK_M + offs_m
lp_ptrs = Lp + current_pid * BLOCK_M + offs_m
op_ptrs = (
Op
+ current_pid * stride_oph # stride_oph is total_program dimension
+ offs_m[:, None] * stride_opm
+ offs_k[None, :] * stride_opn
)
if use_64_indexing:
current_pid64 = tl.full((), current_pid, tl.int64)
BLOCK_M64 = tl.full((), BLOCK_M, tl.int64)
stride_oph64 = tl.full((), stride_oph, tl.int64)
stride_opm64 = tl.full((), stride_opm, tl.int64)
stride_opn64 = tl.full((), stride_opn, tl.int64)
offs_m64 = tl.full([BLOCK_M], 0, tl.int64) + tl.cast(offs_m, tl.int64)
offs_k64 = tl.full([HEAD_DIM], 0, tl.int64) + tl.cast(offs_k, tl.int64)
op_ptrs = (
Op
+ current_pid64 * stride_oph64
+ offs_m64[:, None] * stride_opm64
+ offs_k64[None, :] * stride_opn64
)
else:
op_ptrs = (
Op
+ current_pid * stride_oph # stride_oph is total_program dimension
+ offs_m[:, None] * stride_opm
+ offs_k[None, :] * stride_opn
)

tl.store(mp_ptrs, m_i, cache_modifier=".wt")
tl.store(lp_ptrs, l_i, cache_modifier=".wt")
Expand Down Expand Up @@ -705,19 +769,41 @@ def la_persistent_inner(
offs_mplp = temp_pid * BLOCK_M + offs_m
mp_ptrs = Mp + offs_mplp
lp_ptrs = Lp + offs_mplp
op_ptrs0 = (
Op
+ temp_pid * stride_oph
+ offs_m[:, None] * stride_opm
+ tl.arange(0, HEAD_DIM // 2)[None, :] * stride_opn
)
op_ptrs1 = (
Op
+ temp_pid * stride_oph
+ offs_m[:, None] * stride_opm
+ (tl.arange(0, HEAD_DIM // 2)[None, :] + HEAD_DIM // 2)
* stride_opn
)
if use_64_indexing:
temp_pid64 = tl.full((), temp_pid, tl.int64)
stride_oph64 = tl.full((), stride_oph, tl.int64)
stride_opm64 = tl.full((), stride_opm, tl.int64)
stride_opn64 = tl.full((), stride_opn, tl.int64)
offs_m64 = tl.cast(offs_m, tl.int64)
offs0 = tl.arange(0, HEAD_DIM // 2)
offs0_64 = tl.cast(offs0, tl.int64)
offs1_64 = offs0_64 + tl.full((), HEAD_DIM // 2, tl.int64)
op_ptrs0 = (
Op
+ temp_pid64 * stride_oph64
+ offs_m64[:, None] * stride_opm64
+ offs0_64[None, :] * stride_opn64
)
op_ptrs1 = (
Op
+ temp_pid64 * stride_oph64
+ offs_m64[:, None] * stride_opm64
+ offs1_64[None, :] * stride_opn64
)
else:
op_ptrs0 = (
Op
+ temp_pid * stride_oph
+ offs_m[:, None] * stride_opm
+ tl.arange(0, HEAD_DIM // 2)[None, :] * stride_opn
)
op_ptrs1 = (
Op
+ temp_pid * stride_oph
+ offs_m[:, None] * stride_opm
+ (tl.arange(0, HEAD_DIM // 2)[None, :] + HEAD_DIM // 2)
* stride_opn
)

m_cta = tl.load(mp_ptrs, cache_modifier=".cv")
l_cta = tl.load(lp_ptrs, cache_modifier=".cv")
Expand All @@ -744,20 +830,47 @@ def la_persistent_inner(
# host CTA write final result to memory
# acc = acc / l_i[:, None]
# tl.store(o_ptrs, acc.to(Out.type.element_ty))
o_ptrs0 = (
Out
+ q_idx * BLOCK_M * stride_om
+ tile_head_idx_global * stride_oh
+ offs_m[:, None] * stride_om
+ tl.arange(0, HEAD_DIM // 2)[None, :] * stride_on
)
o_ptrs1 = (
Out
+ q_idx * BLOCK_M * stride_om
+ tile_head_idx_global * stride_oh
+ offs_m[:, None] * stride_om
+ (tl.arange(0, HEAD_DIM // 2)[None, :] + HEAD_DIM // 2) * stride_on
)
if use_64_indexing:
q_idx64 = tl.full((), q_idx, tl.int64)
BLOCK_M64 = tl.full((), BLOCK_M, tl.int64)
stride_om64 = tl.full((), stride_om, tl.int64)
stride_on64 = tl.full((), stride_on, tl.int64)
th64 = tl.full((), tile_head_idx_global, tl.int64) * tl.full(
(), stride_oh, tl.int64
)
offs0 = tl.arange(0, HEAD_DIM // 2)
offs0_64 = tl.cast(offs0, tl.int64)
offs1_64 = offs0_64 + tl.full((), HEAD_DIM // 2, tl.int64)

o_ptrs0 = (
Out
+ q_idx64 * BLOCK_M64 * stride_om64
+ th64
+ offs_m[:, None] * stride_om64
+ offs0_64[None, :] * stride_on64
)
o_ptrs1 = (
Out
+ q_idx64 * BLOCK_M64 * stride_om64
+ th64
+ offs_m[:, None] * stride_om64
+ offs1_64[None, :] * stride_on64
)
else:
o_ptrs0 = (
Out
+ q_idx * BLOCK_M * stride_om
+ tile_head_idx_global * stride_oh
+ offs_m[:, None] * stride_om
+ tl.arange(0, HEAD_DIM // 2)[None, :] * stride_on
)
o_ptrs1 = (
Out
+ q_idx * BLOCK_M * stride_om
+ tile_head_idx_global * stride_oh
+ offs_m[:, None] * stride_om
+ (tl.arange(0, HEAD_DIM // 2)[None, :] + HEAD_DIM // 2) * stride_on
)

acc0 = acc0 / l_i[:, None]
acc1 = acc1 / l_i[:, None]
Expand Down
11 changes: 9 additions & 2 deletions aiter/ops/triton/lean_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
from typing import Optional
from bisect import bisect_right
import math
import triton
import triton.language as tl
from aiter.ops.triton._triton_kernels.lean_atten import la_persistent, _get_config
Expand All @@ -45,6 +46,7 @@ def persistent_lean_attention(
batch_size: int,
sm_scale: torch.float16,
causal: bool = True, # causal masking
RAGGED_BATCH: bool = False,
config: Optional[dict] = None,
program_count: Optional[int] = None,
):
Expand Down Expand Up @@ -79,6 +81,7 @@ def persistent_lean_attention(
causal=causal,
batch_size=batch_size,
sm_scale=sm_scale,
RAGGED_BATCH=RAGGED_BATCH,
num_warps=config["num_warps"],
waves_per_eu=config["waves_per_eu"],
config=config,
Expand All @@ -102,6 +105,7 @@ def _persistent_lean_attention(
causal: bool, # causal masking
batch_size: int,
sm_scale: torch.float16, # typically 1 / sqrt(d)
RAGGED_BATCH: bool,
num_warps: int,
waves_per_eu: int,
config: dict = {},
Expand Down Expand Up @@ -187,6 +191,9 @@ def _persistent_lean_attention(
MASKED_BLOCKS=MASKED_BLOCKS,
MODE=CAUSAL_MODE,
)
if not causal:
max_output_tile_cnt = math.ceil((H * batch_size) / total_programs) + 4

if DEBUG:
print(f"max_output_tile_cnt={max_output_tile_cnt}")

Expand Down Expand Up @@ -243,8 +250,6 @@ def _persistent_lean_attention(
f"locks must have length >= total_programs ({total_programs}), got {locks.numel()}"
)

max_output_tile_cnt = max_output_tile_cnt + 4

grid = (total_programs, 1, 1)

o = torch.empty_like(q, dtype=v.dtype)
Expand Down Expand Up @@ -321,7 +326,9 @@ def _persistent_lean_attention(
or (Op.stride(0) * total_programs) >= (1 << 31)
or (Op.stride(1) * N_CTX_Q) >= (1 << 31)
or (o.stride(0) * N_CTX_Q) >= (1 << 31)
or (q.stride(0) * N_CTX_Q) >= (1 << 31)
),
RAGGED_BATCH=RAGGED_BATCH,
**config,
)
"""
Expand Down
Loading