Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,13 @@ def fused_moe_(
quant_type = quant_remap.get(quant_type, quant_type)
q_dtype_w = w1.dtype
q_dtype_a = w1.dtype if w1.dtype != torch.uint32 else dtypes.fp8
q_dtype_a = dtypes.fp4x2 if quant_type == QuantType.per_1x32 else q_dtype_a
bf16_fp8_bound = 512
if quant_type == QuantType.per_1x32 and M < bf16_fp8_bound:
q_dtype_a = dtypes.bf16
elif quant_type == QuantType.per_1x32 and M >= bf16_fp8_bound:
q_dtype_a = dtypes.fp8
elif quant_type == QuantType.per_1x32:
q_dtype_a = dtypes.fp4x2

metadata = get_2stage_cfgs(
get_padded_M(M), # consider token_num > 1024 as prefill
Expand Down Expand Up @@ -729,6 +735,13 @@ def use_cfg():
logger.info(
f"[fused_moe] using {'1stage' if run_1stage else '2stage'} {'default' if cfg is None else tag} for {keys} "
)

def get_block_m() -> int:
if q_dtype_a == dtypes.fp8:
return 32
else:
return 16 if token < 2048 else 32 if token < 16384 else 64

if run_1stage:
return MOEMetadata(
functools.partial(
Expand Down Expand Up @@ -760,7 +773,7 @@ def use_cfg():
k_pad_zeros=intermediate_pad // 128 * 128,
bias2=bias2,
),
16 if token < 2048 else 32 if token < 16384 else 64,
get_block_m(),
ksplit,
False,
)
Expand Down Expand Up @@ -877,11 +890,23 @@ def fused_moe_2stages(
if (
quant_type == QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and q_dtype_a in [dtypes.bf16, dtypes.fp16]
and w1.dtype == dtypes.fp4x2
and activation == ActivationType.Swiglu
):
a1 = hidden_states.to(dtype)
a1_scale = None
elif (
quant_type == aiter.QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and q_dtype_a == dtypes.fp8
and w1.dtype == dtypes.fp4x2
and activation == aiter.ActivationType.Swiglu
):
a1 = hidden_states.to(dtypes.fp8)
M = sorted_ids.shape[0]
N = a1.shape[-1]
a1_scale = torch.ones([M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device)
elif quant_type == QuantType.per_1x32:
if token_num <= token_num_quant_moe_sort_switch:
a1, a1_scale = fused_dynamic_mxfp4_quant_moe_sort(
Expand Down Expand Up @@ -945,17 +970,29 @@ def fused_moe_2stages(
topk,
block_m=block_size_M,
a1_scale=a1_scale,
w1_scale=w1_scale,
w1_scale=(
w1_scale.view(dtypes.fp8_e8m0) if w1.dtype == dtypes.fp4x2 else w1_scale
),
sorted_weights=sorted_weights if doweight_stage1 else None,
)

if (
quant_type == QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and q_dtype_a in [dtypes.bf16, dtypes.fp16]
and w1.dtype == dtypes.fp4x2
and activation == ActivationType.Swiglu
):
a2_scale = None
elif (
quant_type == aiter.QuantType.per_1x32
and dtype in [dtypes.bf16]
and q_dtype_a == dtypes.fp8
and w1.dtype == dtypes.fp4x2
and activation == aiter.ActivationType.Swiglu
):
a2 = a2.to(dtypes.fp8)
a2_scale = a1_scale
elif quant_type == QuantType.per_1x32:
a2 = a2.view(-1, inter_dim)
if token_num <= token_num_quant_moe_sort_switch:
Expand Down Expand Up @@ -1011,7 +1048,9 @@ def fused_moe_2stages(
num_valid_ids,
moe_out,
topk,
w2_scale=w2_scale,
w2_scale=(
w2_scale.view(dtypes.fp8_e8m0) if w2.dtype == dtypes.fp4x2 else w2_scale
),
a2_scale=a2_scale,
block_m=block_size_M,
sorted_weights=sorted_weights if not doweight_stage1 else None,
Expand Down Expand Up @@ -1431,7 +1470,7 @@ def cktile_moe_stage1(
if w1.dtype is torch.uint32:
D = D * 8
out = torch.empty(
(token_num, topk, D), dtype=hidden_states.dtype, device=hidden_states.device
(token_num, topk, D), dtype=dtypes.bf16, device=hidden_states.device
)
# print("Run cktile_moe_stage1: M=%d, N(N*2)=%d, K=%d, topk=%d, expert=%d"%(token_num, w1.shape[1], hidden_states.shape[1], topk, w1.shape[0]))
aiter.moe_cktile2stages_gemm1(
Expand Down
37 changes: 37 additions & 0 deletions aiter/ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,43 @@ def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False):
return y, scale.view(dtypes.fp8_e8m0)


def per_1x32_f8_scale_f8_quant(
x, scale=None, quant_dtype=dtypes.fp8, scale_type=dtypes.fp32, shuffle=False
):
assert quant_dtype == dtypes.fp8
block_size = 32
F8E8M0_EXP_BIAS = 127
dtypeMax = 448.0
MAX_POW2 = int(torch.log2(torch.tensor(dtypeMax, dtype=torch.float32)).item())
dtypeMax = 2.0**MAX_POW2

shape_original = x.shape
x = x.view(-1, shape_original[-1])

m, n = x.shape
x = x.view(-1, block_size)
max_abs = torch.amax(torch.abs(x.float()), 1)

# fp8e8m0fnu_from_fp32_value
if scale_type == dtypes.fp32:
scale_f32 = max_abs / dtypeMax
scale_e8m0_biased = None
else:
scale_e8m0_biased = fp4_utils.f32_to_e8m0(max_abs / dtypeMax)
scale_f32 = fp4_utils.e8m0_to_f32(scale_e8m0_biased)
# scale_f32 = max_abs / dtypeMax

y = x.float() / scale_f32.view(-1, 1)
y = y.view(*shape_original[:-1], -1)
if scale_type == dtypes.fp32:
scale = scale_f32.view(m, -1)
else:
scale = scale_e8m0_biased.view(m, -1) # .view(torch.uint8)
if shuffle:
scale = fp4_utils.e8m0_shuffle(scale)
return y.to(quant_dtype), scale


def per_tensor_quant(
x, scale=None, scale_dtype=dtypes.fp32, quant_dtype=dtypes.i8, dtypeMax=None
):
Expand Down
195 changes: 194 additions & 1 deletion aiter/ops/triton/fused_mxfp4_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional
from aiter.utility import dtypes
from aiter.ops.triton._triton_kernels.fused_mxfp4_quant import (
_rmsmorm_op,
_fused_rms_mxfp4_quant_kernel,
_fused_flatten_mxfp4_quant,
_fused_reduce_act_mul_and_dynamic_mxfp4_quant_kernel,
Expand Down Expand Up @@ -650,3 +649,197 @@ def fused_dynamic_mxfp4_quant_moe_sort(
x_fp4.view(dtypes.fp4x2),
blockscale_e8m0_sorted.view(dtypes.fp8_e8m0).view(-1, N_o),
)


@triton.jit
def _fused_quant_fp8_sort_kernel(
# Pointers
input_ptr,
sorted_ids_ptr,
num_valid_ids_ptr,
x_fp8_ptr,
scale_sorted_ptr,
# Input/Output strides
stride_input_m: tl.constexpr,
stride_input_n: tl.constexpr,
stride_x_fp8_m: tl.constexpr,
stride_x_fp8_n: tl.constexpr,
stride_scale_o3: tl.constexpr,
stride_scale_o2: tl.constexpr,
stride_scale_o1: tl.constexpr,
stride_scale_o0: tl.constexpr,
# Problem size
M_input: tl.constexpr,
N_input: tl.constexpr,
N_scale_cols: tl.constexpr,
token_num: tl.constexpr,
# Block configuration
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, # quant_block_size / 2
QUANT_BLOCK_SIZE: tl.constexpr,
TOPK: tl.constexpr,
# Quantization parameters
DTYPE_MAX: tl.constexpr,
DTYPE_MIN: tl.constexpr,
):
pid_m = tl.program_id(0) * 2
pid_n = tl.program_id(1) * 2

num_valid_ids = tl.load(num_valid_ids_ptr)
if pid_m * BLOCK_SIZE_M >= num_valid_ids:
return

out = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.uint32)

for i in range(4):
m = i % 2 * BLOCK_SIZE_M # 0 or BLOCK_SIZE_M
n = i // 2 * BLOCK_SIZE_N # 0 or BLOCK_SIZE_N

sorted_ids_offs_m = pid_m * BLOCK_SIZE_M + m + tl.arange(0, BLOCK_SIZE_M)
sorted_ids_mask = sorted_ids_offs_m < num_valid_ids
sorted_ids = tl.load(
sorted_ids_ptr + sorted_ids_offs_m,
mask=sorted_ids_mask,
other=0,
)
topk_ids = sorted_ids >> 24
token_ids = sorted_ids & 0xFFFFFF

if TOPK == 1:
original_m_idx = token_ids
else:
original_m_idx = token_ids * TOPK + topk_ids

input_offs_n = (pid_n * BLOCK_SIZE_N + n) * QUANT_BLOCK_SIZE + tl.arange(
0, BLOCK_SIZE_N * QUANT_BLOCK_SIZE
)
input_offs = (
original_m_idx[:, None] * stride_input_m
+ input_offs_n[None, :] * stride_input_n
)
input_mask = (original_m_idx < M_input)[:, None] & (input_offs_n < N_input)[
None, :
]

x = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0).to(tl.float32)

x_reshaped = x.reshape(BLOCK_SIZE_M * BLOCK_SIZE_N, QUANT_BLOCK_SIZE)

amax = tl.max(tl.abs(x_reshaped), axis=-1, keep_dims=True)

amax = amax.to(tl.int32, bitcast=True)
amax = (amax + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000
amax = amax.to(tl.float32, bitcast=True)

scale_e8m0_unbiased = tl.log2(amax).floor() - tl.log2(DTYPE_MAX).floor()
scale_e8m0_unbiased = tl.clamp(scale_e8m0_unbiased, min=-127, max=127)

quant_scale = tl.exp2(-scale_e8m0_unbiased)
x_fp8 = tl.clamp(x_reshaped * quant_scale, DTYPE_MIN, DTYPE_MAX)
x_fp8 = x_fp8.reshape(BLOCK_SIZE_M, BLOCK_SIZE_N * QUANT_BLOCK_SIZE)

scale_e8m0 = (scale_e8m0_unbiased.to(tl.uint8) + 127).to(tl.uint8)
scale_e8m0 = scale_e8m0.reshape(BLOCK_SIZE_M, BLOCK_SIZE_N) # [BLOCK_SIZE_M]

out_offs_n = (pid_n * BLOCK_SIZE_N + n) * QUANT_BLOCK_SIZE + tl.arange(
0, BLOCK_SIZE_N * QUANT_BLOCK_SIZE
)
out_offs = (
original_m_idx[:, None] * stride_x_fp8_m
+ out_offs_n[None, :] * stride_x_fp8_n
)
out_mask = (original_m_idx < M_input)[:, None] & (out_offs_n < N_input)[None, :]
tl.store(
x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.type.element_ty), mask=out_mask
)

out = out | (scale_e8m0.to(tl.uint32) << (i * 8))

offs_0 = tl.arange(0, BLOCK_SIZE_M)
offs_1 = tl.arange(0, BLOCK_SIZE_N)
offs_2 = pid_n // 2
offs_3 = pid_m // 2
offs = (
offs_0[:, None] * stride_scale_o0
+ offs_1[None, :] * stride_scale_o1
+ offs_2 * stride_scale_o2
+ offs_3 * stride_scale_o3
)
tl.store(scale_sorted_ptr + offs, out)


def fused_quant_fp8_sort(
input: torch.Tensor,
sorted_ids: torch.Tensor,
num_valid_ids: torch.Tensor,
token_num: int,
block_size: int = 32,
quant_block_size: int = 8,
quant_dtype: torch.dtype = dtypes.fp8,
) -> tuple[torch.Tensor, torch.Tensor]:
BLOCK_SIZE_M = block_size
BLOCK_SIZE_N = quant_block_size
BLOCK_SIZE_M_u32 = BLOCK_SIZE_M // 2
BLOCK_SIZE_N_u32 = BLOCK_SIZE_N // 2

M, N = input.shape
assert (
N % quant_block_size == 0
), f"N ({N}) must be multiple of quant_block_size ({quant_block_size})"
assert block_size % 32 == 0, "block_size must be multiple of 32"

M_sorted = sorted_ids.shape[0]
N_blocks = triton.cdiv(N, block_size)

if quant_dtype == dtypes.fp8:
DTYPE_MAX = 448.0
DTYPE_MIN = -448.0
elif quant_dtype == torch.float8_e4m3fn:
DTYPE_MAX = 448.0
DTYPE_MIN = -448.0
else:
DTYPE_MAX = 448.0
DTYPE_MIN = -448.0

x_fp8 = torch.empty_like(input, dtype=quant_dtype, device="cuda")
M_o, N_o = sorted_ids.shape[0], N_blocks

# [M_sorted_blocks/2, N_blocks/2, BLOCK_SIZE_N_u32, BLOCK_SIZE_M_u32]
scale_e8m0_packed = torch.empty(
(
triton.cdiv(M_o, BLOCK_SIZE_M),
triton.cdiv(N_o, BLOCK_SIZE_N),
BLOCK_SIZE_N_u32,
BLOCK_SIZE_M_u32,
),
dtype=torch.uint32,
device=input.device,
)

grid = (
triton.cdiv(M_o, BLOCK_SIZE_M), # 32
triton.cdiv(N_o, BLOCK_SIZE_N), # 8
)

_fused_quant_fp8_sort_kernel[grid](
input,
sorted_ids,
num_valid_ids,
x_fp8,
scale_e8m0_packed,
*input.stride(),
*x_fp8.stride(),
*scale_e8m0_packed.stride(),
M_input=M,
N_input=N,
N_scale_cols=N_blocks,
token_num=token_num,
BLOCK_SIZE_M=BLOCK_SIZE_M // 2,
BLOCK_SIZE_N=BLOCK_SIZE_N // 2,
QUANT_BLOCK_SIZE=32,
TOPK=M // token_num,
DTYPE_MAX=DTYPE_MAX,
DTYPE_MIN=DTYPE_MIN,
)

return x_fp8, scale_e8m0_packed.view(dtypes.fp8_e8m0).view(-1, N_o)
2 changes: 1 addition & 1 deletion aiter/ops/triton/utils/gemm_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Cold start: 290.8928 ms
LRU Cache: ENABLED
Avg per call: 0.110 us
vs
vs
LRU Cache: DISABLED
Avg per call: 2.503 us
"""
Expand Down
Loading