Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel updated 361 files
434 changes: 429 additions & 5 deletions aiter/configs/tuned_fmoe.csv

Large diffs are not rendered by default.

203 changes: 181 additions & 22 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def fused_moe(
intermediate_pad=0,
bias1=None,
bias2=None,
splitk=0,
):
if not block_size_M:
block_size_M = -1
Expand Down Expand Up @@ -217,7 +218,15 @@ def fused_moe_(
quant_type = quant_remap.get(quant_type, quant_type)
q_dtype_w = w1.dtype
q_dtype_a = w1.dtype if w1.dtype != torch.uint32 else dtypes.fp8
q_dtype_a = dtypes.fp4x2 if quant_type == QuantType.per_1x32 else q_dtype_a
bf16_fp8_bound = 512
if quant_type == QuantType.per_1x32:
if activation == ActivationType.Swiglu:
if get_gfx() != "gfx950" or M < bf16_fp8_bound:
q_dtype_a = dtypes.bf16
elif M >= bf16_fp8_bound:
q_dtype_a = dtypes.fp8
else:
q_dtype_a = dtypes.fp4x2

metadata = get_2stage_cfgs(
get_padded_M(M), # consider token_num > 1024 as prefill
Expand All @@ -234,8 +243,6 @@ def fused_moe_(
doweight_stage1,
hidden_pad,
intermediate_pad,
bias1,
bias2,
)

block_size_M = metadata.block_m if block_size_M is None else block_size_M
Expand Down Expand Up @@ -472,6 +479,33 @@ def get_block_size_M(token, topk, expert, inter_dim):
return sorted(tmp, key=lambda x: x[:2])[0][-1]


@functools.lru_cache(maxsize=2048)
def get_ksplit(token, topk, expert, inter_dim, model_dim):
aiter_ksplit = int(os.environ.get("AITER_KSPLIT", "0"))
if aiter_ksplit != 0:
return aiter_ksplit
# only for moe_blk gemm1 a8w8 decode scenario
if token * topk > expert:
return 0
cu_num = get_cu_num()
tileN = 128

tgM = token * topk # decode tile num
tgN = (inter_dim * 2 + tileN - 1) // tileN

tg_num = tgN * tgM
# if all cu already active
if tg_num >= cu_num:
return 0
tilek = 256
split_max = (cu_num + tg_num - 1) // tg_num
# at least split = 2
for i in reversed(range(2, split_max + 1)):
if (model_dim % i == 0) and ((model_dim // i) % tilek == 0):
return i
return 0


cfg_2stages = None
# fmt: off
fused_moe_1stage_dict = {
Expand Down Expand Up @@ -512,7 +546,8 @@ def nextPow2(n):
def get_padded_M(M):
padded_m = M
if M >= 1 and M <= 16:
padded_m = 16
# decoding policy may be changed in the future.
padded_m = nextPow2(padded_m)
elif M < 1024:
padded_m = nextPow2(padded_m)
elif M < 2048:
Expand All @@ -531,6 +566,7 @@ class MOEMetadata:
block_m: int
ksplit: int
run_1stage: bool = False
has_bias: bool = False


@functools.lru_cache(maxsize=2048)
Expand All @@ -549,8 +585,6 @@ def get_2stage_cfgs(
doweight_stage1,
hidden_pad,
intermediate_pad,
bias1,
bias2,
):
def get_cfg_2stages(tune_file):
import pandas as pd
Expand Down Expand Up @@ -620,8 +654,22 @@ def FinalFunc():
)
logger.info("\033[0m")

def use_cfg():
problem_type = (activation, dtype, q_dtype_a, q_dtype_w, q_type)
bypass_type = (
ActivationType.Silu,
dtypes.bf16,
dtypes.fp8,
dtypes.fp8,
QuantType.per_1x128,
)
if problem_type == bypass_type and (token * topk) <= 128: # bypass tuned
aiter.logger.info("bypass tuned results for fp8 blockscale")
return False
return True

# cfg = cfg_2stages.get(keys, None)
cfg = cfg_2stages.get(keys, None) if cfg_2stages else None
cfg = cfg_2stages.get(keys, None) if cfg_2stages and use_cfg() else None
if cfg is None and os.environ.get("AITER_ONLINE_TUNE", "0") == "1":
lock_path = os.path.join(bd_dir, f"lock_fmoe_tune_{keys}")
mp_lock(lock_path, MainFunc=MainFunc, FinalFunc=FinalFunc)
Expand All @@ -630,7 +678,7 @@ def FinalFunc():
cfg = cfg_2stages.get(keys, None) if cfg_2stages else None
if cfg is None:
logger.warning(f"Fmoe tuning not support for {keys}")
if cfg is None:
if cfg is None or int(os.environ.get("AITER_BYPASS_TUNE_CONFIG", "0")):
ksplit = 0
kernelName1 = ""
kernelName2 = ""
Expand All @@ -645,7 +693,7 @@ def FinalFunc():
doweight_stage1,
) in fused_moe_1stage_dict[get_gfx()]:
if q_type == QuantType.per_1x128:
run_1stage = True and (inter_dim % 128 == 0)
run_1stage = token > 32 and (inter_dim % 256 == 0)
elif q_type == QuantType.per_Token and q_dtype_w == dtypes.i8:
run_1stage = token > 32
elif q_type == QuantType.per_Token and q_dtype_w == dtypes.fp8:
Expand All @@ -657,11 +705,23 @@ def FinalFunc():
BLOCK_SIZE_M
if run_1stage
else (
64
(64 if token > 32 else 16)
if q_type == QuantType.per_1x128
else get_block_size_M(token, topk, expert, inter_dim)
)
)
ksplit = (
ksplit
if (run_1stage)
else (
get_ksplit(token, topk, expert, inter_dim, model_dim)
if q_type == QuantType.per_1x128
else ksplit
)
)
aiter.logger.info(
f"run_1stage = {run_1stage}, ksplit = {ksplit} q_type = {q_type}"
)
else:
block_m = cfg["block_m"]
ksplit = cfg["ksplit"]
Expand All @@ -673,6 +733,13 @@ def FinalFunc():
logger.info(
f"[fused_moe] using {'1stage' if run_1stage else '2stage'} {'default' if cfg is None else tag} for {keys} "
)

def get_block_m() -> int:
if q_dtype_a == dtypes.fp8:
return 32
else:
return 16 if token < 2048 else 32 if token < 16384 else 64

if run_1stage:
return MOEMetadata(
functools.partial(
Expand All @@ -696,17 +763,16 @@ def FinalFunc():
cktile_moe_stage1,
n_pad_zeros=intermediate_pad // 64 * 64 * (2 if use_g1u1 else 1),
k_pad_zeros=hidden_pad // 128 * 128,
bias1=bias1,
),
functools.partial(
cktile_moe_stage2,
n_pad_zeros=hidden_pad // 64 * 64,
k_pad_zeros=intermediate_pad // 128 * 128,
bias2=bias2,
),
16 if token < 2048 else 32 if token < 16384 else 64,
get_block_m(),
ksplit,
False,
True,
)
if (
"ck2stages" in kernelName1
Expand All @@ -717,14 +783,16 @@ def FinalFunc():
dtypes.fp16,
torch.uint32,
dtypes.fp4x2,
dtypes.fp8,
]
):
return MOEMetadata(
functools.partial(
aiter.ck_moe_stage1_fwd,
ck_moe_stage1,
kernelName=kernelName1,
activation=activation,
quant_type=q_type,
splitk=ksplit,
),
functools.partial(
aiter.ck_moe_stage2_fwd,
Expand Down Expand Up @@ -812,17 +880,27 @@ def fused_moe_2stages(
doweight_stage1,
hidden_pad,
intermediate_pad,
bias1,
bias2,
)
if (
quant_type == QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and q_dtype_a in [dtypes.bf16, dtypes.fp16]
and w1.dtype == dtypes.fp4x2
and activation == ActivationType.Swiglu
):
a1 = hidden_states.to(dtype)
a1_scale = None
elif (
quant_type == aiter.QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and q_dtype_a == dtypes.fp8
and w1.dtype == dtypes.fp4x2
and activation == aiter.ActivationType.Swiglu
):
a1 = hidden_states.to(dtypes.fp8)
M = sorted_ids.shape[0]
N = a1.shape[-1]
a1_scale = torch.ones([M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device)
elif quant_type == QuantType.per_1x32:
if token_num <= token_num_quant_moe_sort_switch:
a1, a1_scale = fused_dynamic_mxfp4_quant_moe_sort(
Expand Down Expand Up @@ -874,7 +952,17 @@ def fused_moe_2stages(
dtype=dtype,
device=device,
)

extra_stage1_args = {}
extra_stage2_args = {}
if (
not metadata.run_1stage
and metadata.has_bias
and dtype in [dtypes.bf16, dtypes.fp16]
and quant_type == QuantType.per_1x32
and activation == ActivationType.Swiglu
):
extra_stage1_args["bias1"] = bias1
extra_stage2_args["bias2"] = bias2
a2 = metadata.stage1(
a1,
w1,
Expand All @@ -886,17 +974,31 @@ def fused_moe_2stages(
topk,
block_m=block_size_M,
a1_scale=a1_scale,
w1_scale=w1_scale,
w1_scale=(
w1_scale.view(dtypes.fp8_e8m0) if w1.dtype == dtypes.fp4x2 else w1_scale
),
sorted_weights=sorted_weights if doweight_stage1 else None,
dtype=dtype,
**extra_stage1_args,
)

if (
quant_type == QuantType.per_1x32
and dtype in [dtypes.bf16, dtypes.fp16]
and q_dtype_a in [dtypes.bf16, dtypes.fp16]
and w1.dtype == dtypes.fp4x2
and activation == ActivationType.Swiglu
):
a2_scale = None
elif (
quant_type == aiter.QuantType.per_1x32
and dtype in [dtypes.bf16]
and q_dtype_a == dtypes.fp8
and w1.dtype == dtypes.fp4x2
and activation == aiter.ActivationType.Swiglu
):
a2 = a2.to(dtypes.fp8)
a2_scale = a1_scale
elif quant_type == QuantType.per_1x32:
a2 = a2.view(-1, inter_dim)
if token_num <= token_num_quant_moe_sort_switch:
Expand Down Expand Up @@ -952,10 +1054,13 @@ def fused_moe_2stages(
num_valid_ids,
moe_out,
topk,
w2_scale=w2_scale,
w2_scale=(
w2_scale.view(dtypes.fp8_e8m0) if w2.dtype == dtypes.fp4x2 else w2_scale
),
a2_scale=a2_scale,
block_m=block_size_M,
sorted_weights=sorted_weights if not doweight_stage1 else None,
**extra_stage2_args,
)

return moe_out
Expand Down Expand Up @@ -1293,6 +1398,60 @@ def torch_moe_stage2(
return out.sum(1).to(dtype)


def ck_moe_stage1(
hidden_states,
w1, # [E, inter_dim*2, model_dim]
w2, # [E, model_dim, inter_dim]
sorted_token_ids, # [max_num_tokens_padded]
sorted_expert_ids, # [max_num_m_blocks]
num_valid_ids, # [1]
out,
topk,
block_m,
a1_scale,
w1_scale,
kernelName="",
sorted_weights=None,
quant_type=aiter.QuantType.No,
activation=ActivationType.Gelu,
splitk=1,
dtype=None,
):
token_num = hidden_states.shape[0]
tmp_out = (
torch.zeros(
(token_num, topk, w1.shape[1]), dtype=dtypes.fp32, device=out.device
)
if splitk > 1
else out
)
aiter.ck_moe_stage1_fwd(
hidden_states,
w1,
w2,
sorted_token_ids,
sorted_expert_ids,
num_valid_ids,
tmp_out,
topk,
kernelName,
w1_scale,
a1_scale,
block_m,
sorted_weights,
quant_type,
activation,
splitk,
out.dtype,
)
if splitk > 1:
if activation == ActivationType.Silu:
aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32).to(out.dtype))
else:
aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32).to(out.dtype))
return out


def cktile_moe_stage1(
hidden_states,
w1,
Expand All @@ -1309,6 +1468,7 @@ def cktile_moe_stage1(
n_pad_zeros=0,
k_pad_zeros=0,
bias1=None,
dtype=torch.bfloat16,
):
token_num = hidden_states.shape[0]
_, n1, k1 = w1.shape
Expand All @@ -1318,9 +1478,8 @@ def cktile_moe_stage1(

if w1.dtype is torch.uint32:
D = D * 8
out = torch.empty(
(token_num, topk, D), dtype=hidden_states.dtype, device=hidden_states.device
)
out = torch.empty((token_num, topk, D), dtype=dtype, device=hidden_states.device)

# print("Run cktile_moe_stage1: M=%d, N(N*2)=%d, K=%d, topk=%d, expert=%d"%(token_num, w1.shape[1], hidden_states.shape[1], topk, w1.shape[0]))
aiter.moe_cktile2stages_gemm1(
hidden_states,
Expand Down
Loading