Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions aiter/ops/triton/_triton_kernels/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,43 @@
It supports page size = 1 and prefill with KV cache (i.e. extend).
"""

from typing import Optional
import functools
import json
import torch
import triton
import triton.language as tl


# from .prefill_attention import context_attention_fwd
from .activation import _tanh
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton.pid_preprocessing import remap_xcd
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils.device_info import get_num_xcds


@triton.jit
from ..utils._triton.kernel_repr import make_kernel_repr


_fwd_kernel_extend_repr = make_kernel_repr(
"_fwd_kernel",
[
"logit_cap",
"Lq",
"Lv",
"BLOCK_DMODEL",
"BLOCK_DPE",
"BLOCK_DV",
"BLOCK_M",
"BLOCK_N",
"USE_CUSTOM_MASK",
"IS_CAUSAL",
"SKIP_PREFIX_CUSTOM_MASK",
"STORE_TRANSPOSE",
"NUM_Q_HEADS",
"NUM_BLOCKS",
"NUM_XCDS",
],
)


@triton.jit(repr=_fwd_kernel_extend_repr)
def _fwd_kernel(
Q_Extend,
K_Extend,
Expand Down Expand Up @@ -74,7 +94,6 @@ def _fwd_kernel(
STORE_TRANSPOSE: tl.constexpr,
NUM_Q_HEADS: tl.constexpr,
NUM_BLOCKS: tl.constexpr,
BATCH: tl.constexpr,
NUM_XCDS: tl.constexpr,
):
workgroup_id = tl.program_id(0) # workgroup index
Expand Down
48 changes: 42 additions & 6 deletions aiter/ops/triton/_triton_kernels/moe_align_block_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,45 @@

import triton
import triton.language as tl


@triton.jit
from ..utils._triton.kernel_repr import make_kernel_repr


_moe_align_block_size_stage1_repr = make_kernel_repr(
"_moe_align_block_size_stage1_kernel",
[
"num_experts",
"numel",
"tokens_per_thread",
],
)

_moe_align_block_size_stage2_repr = make_kernel_repr(
"_moe_align_block_size_stage2_kernel",
[
"num_experts",
],
)

_moe_align_block_size_stage3_repr = make_kernel_repr(
"_moe_align_block_size_stage3_kernel",
[
"num_experts",
"block_size",
],
)

_moe_align_block_size_stage4_repr = make_kernel_repr(
"_moe_align_block_size_stage4_kernel",
[
"num_experts",
"block_size",
"numel",
"tokens_per_thread",
],
)


@triton.jit(repr=_moe_align_block_size_stage1_repr)
def _moe_align_block_size_stage1_kernel(
topk_ids_ptr,
tokens_cnts_ptr,
Expand All @@ -26,7 +62,7 @@ def _moe_align_block_size_stage1_kernel(
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)


@triton.jit
@triton.jit(repr=_moe_align_block_size_stage2_repr)
def _moe_align_block_size_stage2_kernel(
tokens_cnts_ptr,
num_experts: tl.constexpr,
Expand All @@ -40,7 +76,7 @@ def _moe_align_block_size_stage2_kernel(
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)


@triton.jit
@triton.jit(repr=_moe_align_block_size_stage3_repr)
def _moe_align_block_size_stage3_kernel(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
Expand All @@ -57,7 +93,7 @@ def _moe_align_block_size_stage3_kernel(
tl.store(total_tokens_post_pad_ptr, last_cumsum)


@triton.jit
@triton.jit(repr=_moe_align_block_size_stage4_repr)
def _moe_align_block_size_stage4_kernel(
topk_ids_ptr,
sorted_token_ids_ptr,
Expand Down
95 changes: 88 additions & 7 deletions aiter/ops/triton/_triton_kernels/moe_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,102 @@
import triton.language as tl
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton.moe_common import _write_zeros_to_output
from ..utils._triton.kernel_repr import make_kernel_repr


# Source:
# MoE Kernel adapted from VLLM


_fused_moe_kernel_gptq_awq_repr = make_kernel_repr(
"_fused_moe_kernel_gptq_awq",
[
"N",
"K",
"group_size",
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"EVEN_K",
"MUL_ROUTED_WEIGHT",
"top_k",
"compute_type",
"has_zp",
"use_int4_w4a16",
"use_int8_w8a16",
"NUM_XCDS",
],
)

_fused_moe_persistent_kernel_gptq_awq_repr = make_kernel_repr(
"_fused_moe_persistent_kernel_gptq_awq",
[
"N",
"K",
"group_size",
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"EVEN_K",
"NUM_SMS",
"MUL_ROUTED_WEIGHT",
"top_k",
"compute_type",
"has_zp",
"use_int4_w4a16",
"use_int8_w8a16",
"NUM_XCDS",
],
)

_fused_moe_kernel_repr = make_kernel_repr(
"_fused_moe_kernel",
[
"group_n",
"group_k",
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"EVEN_K",
"MUL_ROUTED_WEIGHT",
"top_k",
"compute_type",
"use_fp8_w8a8",
"use_int8_w8a16",
"NUM_XCDS",
],
)

_fused_moe_persistent_kernel_repr = make_kernel_repr(
"_fused_moe_persistent_kernel",
[
"group_n",
"group_k",
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"EVEN_K",
"NUM_SMS",
"MUL_ROUTED_WEIGHT",
"top_k",
"compute_type",
"use_fp8_w8a8",
"use_int8_w8a16",
"NUM_XCDS",
],
)


@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0,
}
)
@triton.jit
@triton.jit(repr=_fused_moe_kernel_gptq_awq_repr)
def _fused_moe_kernel_gptq_awq(
# Pointers to matrices
a_ptr,
Expand Down Expand Up @@ -254,7 +338,7 @@ def _fused_moe_kernel_gptq_awq(
"EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0,
}
)
@triton.jit
@triton.jit(repr=_fused_moe_persistent_kernel_gptq_awq_repr)
def _fused_moe_persistent_kernel_gptq_awq(
# Pointers to matrices
a_ptr,
Expand All @@ -269,7 +353,6 @@ def _fused_moe_persistent_kernel_gptq_awq(
# Matrix dimensions
N: tl.constexpr,
K: tl.constexpr,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
Expand Down Expand Up @@ -483,7 +566,7 @@ def _fused_moe_persistent_kernel_gptq_awq(
"EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0,
}
)
@triton.jit
@triton.jit(repr=_fused_moe_kernel_repr)
def _fused_moe_kernel(
# Pointers to matrices
a_ptr,
Expand All @@ -498,7 +581,6 @@ def _fused_moe_kernel(
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
Expand Down Expand Up @@ -691,7 +773,7 @@ def _fused_moe_kernel(
"EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0,
}
)
@triton.jit
@triton.jit(repr=_fused_moe_persistent_kernel_repr)
def _fused_moe_persistent_kernel(
# Pointers to matrices
a_ptr,
Expand All @@ -706,7 +788,6 @@ def _fused_moe_persistent_kernel(
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
Expand Down
51 changes: 47 additions & 4 deletions aiter/ops/triton/_triton_kernels/moe_op_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,63 @@
import triton.language as tl

from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton.kernel_repr import make_kernel_repr

# Source:
# MoE Kernel adapted from VLLM


_e2e_moe_kernel_repr = make_kernel_repr(
"e2e_moe_kernel",
[
"top_k",
"EM",
"N",
"K",
"EVEN_K",
"MUL_ROUTED_WEIGHT",
"use_fp8_w8a8",
"use_int8_w8a16",
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K1",
"BLOCK_SIZE_K2",
"GROUP_SIZE_M",
"GRID_MN",
"atomic_num_stages",
"dtype",
"NUM_XCDS",
],
)

_e2e_moe_persistent_kernel_repr = make_kernel_repr(
"e2e_moe_persistent_kernel",
[
"top_k",
"N",
"K",
"EVEN_K",
"EVEN_N",
"MUL_ROUTED_WEIGHT",
"use_fp8_w8a8",
"use_int8_w8a16",
"BLOCK_SIZE_M",
"BLOCK_SIZE_N1",
"BLOCK_SIZE_N2",
"BLOCK_SIZE_K1",
"BLOCK_SIZE_K2",
"NUM_SMS",
],
)


@triton.heuristics(
{
"GRID_MN": lambda args: triton.cdiv(args["EM"], args["BLOCK_SIZE_M"])
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"])
}
)
@triton.jit
@triton.jit(repr=_e2e_moe_kernel_repr)
def e2e_moe_kernel(
A,
W1,
Expand Down Expand Up @@ -316,7 +361,7 @@ def e2e_moe_kernel(
# tl.store(out_ptrs + k * BLOCK_SIZE_K2, out, mask=c_mask)


@triton.jit
@triton.jit(repr=_e2e_moe_persistent_kernel_repr)
def e2e_moe_persistent_kernel(
A,
W1,
Expand Down Expand Up @@ -346,7 +391,6 @@ def e2e_moe_persistent_kernel(
expert_ids_ptr,
num_tokens_post_padded_ptr,
num_valid_tokens,
EM: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
EVEN_K: tl.constexpr,
Expand All @@ -360,7 +404,6 @@ def e2e_moe_persistent_kernel(
BLOCK_SIZE_K1: tl.constexpr, # original block_size_k
BLOCK_SIZE_K2: tl.constexpr, # outputs (EM, BLOCK_SIZE_K2)
NUM_SMS: tl.constexpr,
NUM_XCDS: tl.constexpr,
):
start_m = tl.program_id(axis=0)
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
Expand Down
Loading
Loading