Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 21 additions & 26 deletions aiter/ops/triton/_triton_kernels/lean_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,11 @@
from ..utils.core import AITER_TRITON_CONFIGS_PATH


LOG_TWO_E = 1.44269504 # log_2(e) value for softmax scaling
# Support tensor in [B, Seqlen, H, d] format. Taking tensors in [B*Seqlen, H, d] as inputs


@functools.lru_cache(maxsize=1024)
def _get_config(
causal: bool,
batch_size: int,
):
def _get_config():
if not hasattr(_get_config, "_config_dict"):
dev = arch_info.get_device()
fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-LEANATTN-DEFAULT.json"
Expand Down Expand Up @@ -106,29 +102,34 @@ def _attention_inner(
m_i,
l_i,
acc,
qk_scale,
causal,
q_start_m,
b_seq_size,
offs_m,
offs_n,
BLOCK_M,
BLOCK_N,
HEAD_DIM_ORIG: tl.constexpr,
HEAD_DIM: tl.constexpr,
local_iter,
local_iter_end,
SM_SCALE: tl.constexpr,
causal: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
HEAD_DIM_ORIG: tl.constexpr,
HEAD_DIM: tl.constexpr,
use_64_indexing: tl.constexpr,
):
"""
Performs attention calculation for an (maybe partial) output tile
"""
RCP_LN2: tl.constexpr = 1.4426950408889634

# Define head-dimension mask for padded dims
offs_k_local = tl.arange(0, HEAD_DIM)
mask_k_cols_local = offs_k_local < HEAD_DIM_ORIG
for l_iter in range(local_iter, local_iter_end):
k = tl.load(k_ptrs, mask=mask_k_cols_local[:, None], other=0.0)
qk = tl.dot(q, k) * qk_scale
qk_scale = SM_SCALE * RCP_LN2
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = qk * qk_scale

if causal:
# Get the starting column index of the current K block
Expand Down Expand Up @@ -215,7 +216,6 @@ def la_persistent(
Q,
K,
V,
qk_scale,
Mp,
Lp,
Op,
Expand All @@ -238,6 +238,7 @@ def la_persistent(
stride_oph, # total_programs
stride_opm, # n_ctx_q
stride_opn, # head_dim
SM_SCALE,
HEADS_PER_XCD: tl.constexpr,
HEAD_DIM_ORIG: tl.constexpr,
HEAD_DIM: tl.constexpr,
Expand All @@ -257,8 +258,6 @@ def la_persistent(
tiles_per_head: tl.constexpr,
num_splits: tl.constexpr,
max_output_tile_cnt: tl.constexpr,
num_heads_q: tl.constexpr,
num_heads_k: tl.constexpr,
gqa_group_size: tl.constexpr,
use_64_indexing: tl.constexpr,
RAGGED_BATCH: tl.constexpr,
Expand Down Expand Up @@ -311,7 +310,6 @@ def la_persistent(
Q,
K,
V,
qk_scale,
Mp,
Lp,
Op,
Expand All @@ -338,14 +336,13 @@ def la_persistent(
current_pid=current_pid,
xcd_pid=xcd_pid,
xcd_id=xcd_id,
SM_SCALE=SM_SCALE,
HEADS_PER_XCD=HEADS_PER_XCD,
HEAD_DIM=HEAD_DIM,
HEAD_DIM_ORIG=HEAD_DIM_ORIG,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
MASKED_BLOCKS=MASKED_BLOCKS,
XCD_REMAP=XCD_REMAP,
NUM_XCDS=NUM_XCDS,
batch_size=batch_size,
causal=causal,
num_m_blocks=num_m_blocks,
Expand All @@ -366,7 +363,6 @@ def la_persistent_inner(
Q,
K,
V,
qk_scale,
Mp,
Lp,
Op,
Expand All @@ -393,14 +389,13 @@ def la_persistent_inner(
current_pid, # SOC pid
xcd_pid, # XCD pid
xcd_id, # The XCD the pid belongs to
HEADS_PER_XCD,
SM_SCALE,
HEADS_PER_XCD: tl.constexpr,
HEAD_DIM: tl.constexpr,
HEAD_DIM_ORIG: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
MASKED_BLOCKS: tl.constexpr,
XCD_REMAP: tl.constexpr,
NUM_XCDS: tl.constexpr,
batch_size: tl.constexpr,
causal: tl.constexpr,
num_m_blocks: tl.constexpr,
Expand Down Expand Up @@ -618,18 +613,18 @@ def la_persistent_inner(
m_i,
l_i,
acc,
qk_scale,
causal,
q_start_m,
b_seq_size,
offs_m,
offs_n,
local_iter,
local_iter_end,
SM_SCALE,
causal,
BLOCK_M,
BLOCK_N,
HEAD_DIM_ORIG=HEAD_DIM_ORIG,
HEAD_DIM=HEAD_DIM,
local_iter=local_iter,
local_iter_end=local_iter_end,
use_64_indexing=use_64_indexing,
)

Expand Down
13 changes: 2 additions & 11 deletions aiter/ops/triton/lean_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

_LOGGER = AiterTritonLogger()

LOG_TWO_E = 1.44269504 # log_2(e) value for softmax scaling
# Support tensor in [B, Seqlen, H, d] format. Taking tensors in [B*Seqlen, H, d] as inputs


Expand Down Expand Up @@ -80,7 +79,6 @@ def persistent_lean_attention(
XCD_REMAP=config["XCD_REMAP"],
causal=causal,
batch_size=batch_size,
sm_scale=sm_scale,
RAGGED_BATCH=RAGGED_BATCH,
num_warps=config["num_warps"],
waves_per_eu=config["waves_per_eu"],
Expand All @@ -104,7 +102,6 @@ def _persistent_lean_attention(
XCD_REMAP: bool, # xcd_remap for spatial
causal: bool, # causal masking
batch_size: int,
sm_scale: torch.float16, # typically 1 / sqrt(d)
RAGGED_BATCH: bool,
num_warps: int,
waves_per_eu: int,
Expand Down Expand Up @@ -144,7 +141,7 @@ def _persistent_lean_attention(
GQA_GROUP_SIZE = H // H_K
HEADS_PER_XCD = H // NUM_XCDS

qk_scale = sm_scale * LOG_TWO_E
sm_scale = q.shape[-1] ** (-0.5)

(
num_m_blocks,
Expand All @@ -161,7 +158,6 @@ def _persistent_lean_attention(
N_CTX_Q,
N_CTX_K,
H,
H_K,
BLOCK_M,
BLOCK_N,
total_programs,
Expand Down Expand Up @@ -219,7 +215,6 @@ def _persistent_lean_attention(
N_CTX_Q,
N_CTX_K,
H,
H_K,
BLOCK_M,
BLOCK_N,
total_programs,
Expand Down Expand Up @@ -271,7 +266,6 @@ def _persistent_lean_attention(
q,
k,
v,
qk_scale,
Mp,
Lp,
Op,
Expand All @@ -294,6 +288,7 @@ def _persistent_lean_attention(
Op.stride(0), # total_programs
Op.stride(1), # n_ctx_q
Op.stride(2), # head_dim
sm_scale,
HEADS_PER_XCD=HEADS_PER_XCD,
HEAD_DIM_ORIG=HEAD_DIM_K,
HEAD_DIM=HEAD_DIM_K,
Expand All @@ -317,8 +312,6 @@ def _persistent_lean_attention(
num_warps=num_warps,
num_stages=1,
num_ctas=1,
num_heads_q=H,
num_heads_k=H_K,
gqa_group_size=GQA_GROUP_SIZE,
use_64_indexing=(
(k.stride(0) * N_CTX_K) >= (1 << 31)
Expand Down Expand Up @@ -350,7 +343,6 @@ def get_num_splits_and_buffer_sizes(
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
BLOCK_M,
BLOCK_N,
num_SMs,
Expand All @@ -370,7 +362,6 @@ def get_num_splits_and_buffer_sizes(
# print(f"block_m: {BLOCK_M}, block_n: {BLOCK_N} ")
# print(f"num_m_block: {num_m_blocks}, num_n_block: {num_n_blocks} ")
# print(f"max_seqlen_q: {max_seqlen_q}, max_seqlen_k: {max_seqlen_k}")
# print(f"num_heads: {num_heads}, num_heads_k: {num_heads_k} ")

if max_seqlen_q == 1:
causal = False
Expand Down
3 changes: 0 additions & 3 deletions op_tests/op_benchmarks/triton/bench_la.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,6 @@ def bench_lean_attention(
list_sum_block_n.append(len_sum)
batch_num_block_n = torch.tensor(list_sum_block_n, device="cuda", dtype=torch.int32)

sm_scale = 0.5

# Allocate Tensors
q = torch.empty((n_ctx_q * batch, hq, d), dtype=init_dtype, device="cuda").normal_(
mean=0.0, std=0.5
Expand Down Expand Up @@ -377,7 +375,6 @@ def bench_lean_attention(
XCD_REMAP,
causal,
batch,
sm_scale,
RAGGED_BATCH,
num_warps,
waves_per_eu,
Expand Down
Loading