Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def update_config_files(file_path: str, merge_name: str):
return new_file_path


# @functools.lru_cache(maxsize=1)
def get_config_file(env_name, default_file, tuned_file_name):
config_env_file = os.getenv(env_name)
# default_file = f"{AITER_ROOT_DIR}/aiter/configs/{tuned_file_name}.csv"
Expand Down
44 changes: 38 additions & 6 deletions aiter/utility/base_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from operator import itemgetter
import time
from aiter import dtypes
from aiter import core

INVALID_TIME = -1

Expand Down Expand Up @@ -168,6 +167,37 @@ def result_to_df(self, rets):
"""transfer results to dataframe"""
pass

def update_config_files(self, file_path: str, merge_name: str):
path_list = file_path.split(os.pathsep) if file_path else []
if len(path_list) <= 1:
return file_path
df_list = []
## merge config files
##example: AITER_CONFIG_GEMM_A4W4="/path1:/path2"

df_list.append(pd.read_csv(path_list[0]))
for i, path in enumerate(path_list[1:]):
if os.path.exists(path):
df = pd.read_csv(path)
## check columns
assert (
df.columns.tolist() == df_list[0].columns.tolist()
), f"Column mismatch between {path_list[0]} and {path}, {df_list[0].columns.tolist()}, {df.columns.tolist()}"

df_list.append(df)
else:
print(f"path {i+1}: {path} (not exist)")
merge_df = pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame()
##drop_duplicates
merge_df = (
merge_df.sort_values("us")
.drop_duplicates(subset=self.keys, keep="first")
.reset_index(drop=True)
)
new_file_path = f"/tmp/{merge_name}.csv"
merge_df.to_csv(new_file_path, index=False)
return new_file_path

def get_untuned_gemm_list(self, untuned_gemm_file):
assert os.path.exists(
untuned_gemm_file
Expand All @@ -183,7 +213,7 @@ def get_out_file(self, tuned_file):
return path_list[0]

def get_tuned_gemm_list(self, tuned_gemm_file, columns=[]):
all_tuned_file = core.update_config_files(tuned_gemm_file, self.name)
all_tuned_file = self.update_config_files(tuned_gemm_file, self.name)
if os.path.exists(all_tuned_file):
column_order = pd.read_csv(all_tuned_file, nrows=0).columns.tolist()
tunedf = pd.read_csv(all_tuned_file)
Expand Down Expand Up @@ -331,17 +361,19 @@ def tune_summary(self, status):
)
logger.info("Successfully tuned shapes:")
if not self.success.empty:
print(self.success)
print(self.success, flush=True)
logger.info("Failed shapes:")
print(self.failed)
print(self.failed, flush=True)

tunedf_subset = tunedf[self.untunedf.columns].astype(self.untunedf.dtypes)
mask = self.untunedf.apply(tuple, axis=1).isin(
tunedf_subset.apply(tuple, axis=1)
)
self.remain_untuned = self.untunedf[~mask]
logger.info("untuned shapes:")
print(self.remain_untuned)

if not self.remain_untuned.empty:
logger.info("untuned shapes:")
print(self.remain_untuned)

@abstractmethod
def result_to_csv(self, results, file, concat=False):
Expand Down
137 changes: 77 additions & 60 deletions hsa/gfx942/fmoe_2stages/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from aiter import dtypes
from aiter import ActivationType as ActivationType
from aiter.jit.utils.chip_info import get_gfx
from aiter.utility import fp4_utils
import torch.nn.functional as F
from einops import rearrange
from aiter.utility.base_tuner import TunerCommon
from aiter.utility import fp4_utils
from aiter.utility.fp4_utils import moe_mxfp4_sort


Expand Down Expand Up @@ -134,15 +134,6 @@ def ck_moe_stage1_fwd_out(
):
inter_dim = w1_qt_shffle_ck.shape[1] // 2
token_num = a1_qt.shape[0]

a1_scale = moe_mxfp4_sort(
a1_scale,
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token_num,
block_size=blockM,
)

out = torch.empty(
(token_num, topk, inter_dim),
dtype=dtype,
Expand All @@ -158,7 +149,7 @@ def ck_moe_stage1_fwd_out(
out,
topk,
kernelName,
fp4_utils.e8m0_shuffle(w1_scale),
w1_scale,
a1_scale,
blockM,
sorted_weights,
Expand Down Expand Up @@ -195,14 +186,6 @@ def ck_moe_stage2_fwd_out(
model_dim = w2_qt_shffle_ck.shape[1]
token_num = a2_qt.shape[0]

a2_scale = moe_mxfp4_sort(
a2_scale[: token_num * topk, :].view(token_num, topk, -1),
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token_num,
block_size=blockM,
)

out = torch.zeros(
(token_num, model_dim),
dtype=dtype,
Expand All @@ -218,9 +201,7 @@ def ck_moe_stage2_fwd_out(
out,
topk,
kernelName,
fp4_utils.e8m0_shuffle(
w2_scale
), # e8m0_shuffle will do nothing if it's a fp32
w2_scale,
a2_scale,
blockM,
sorted_weights,
Expand Down Expand Up @@ -415,14 +396,13 @@ def get_1stage_fmoe_func(
quant_type == QuantType.No
and activation == ActivationType.Silu
and not isG1U1
or doweight_stage1
or quant_type == QuantType.per_1x32
):
print("not support No Quant Silu G1U0 1 stage tuning!")
print("not support No Quant Silu G1U0 1 stage or per_1x32 quant tuning!")
else:
if quant_type == QuantType.per_1x128:
fmoe_func = FmoeTuner.run_1stage_fmoe_fp8_blockscale_g1u1
elif (q_dtype_a == dtypes.fp8) & doweight_stage1:
elif (q_dtype_a == dtypes.fp8) and doweight_stage1:
fmoe_func = FmoeTuner.run_1stage_fmoe_g1u1_tkw1
elif isG1U1:
fmoe_func = FmoeTuner.run_1stage_fmoe_g1u1
Expand Down Expand Up @@ -467,13 +447,24 @@ def generate_data(
)
a1_qt = a1_qt.view(token, model_dim)
a1_scale = a1_scale.squeeze(-1)
elif (
q_type == aiter.QuantType.per_1x32
and (q_dtype_a in [dtypes.bf16, dtypes.fp16])
and q_dtype_w == dtypes.fp4x2
): # a16w4
a1_qt = input.to(dtype)
a1_scale = None
else:
torch_quant = aiter.get_torch_quant(q_type)
a1_qt, a1_scale = torch_quant(input, quant_dtype=q_dtype_a)
del w1, w2, score
if q_dtype_w is not dtypes.fp4x2:
w1_qt_shffle = shuffle_weight(w1_qt, (16, 16))
w2_qt_shffle = shuffle_weight(w2_qt, (16, 16))
else:
w1_qt_shffle = w1_qt
w2_qt_shffle = w2_qt

w1_qt_shffle = shuffle_weight(w1_qt, (16, 16))
w2_qt_shffle = shuffle_weight(w2_qt, (16, 16))
sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = (
moe_sorting(topk_ids, topk_weights, expert, model_dim, dtype, blockM)
)
Expand Down Expand Up @@ -639,24 +630,39 @@ def generate_data_2stages(
else:
w1_qt_shffle_ck = w1_qt_shffle
w2_qt_shffle_ck = w2_qt_shffle
w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale)
w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale)
if stage == 1:
if not doweight_stage1:
sorted_weights = None
if q_type == QuantType.per_1x32:
a1_scale_fp4_sort = moe_mxfp4_sort(
a1_scale, # a1_scale[: token * topk, :].view(token, topk, -1),
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token,
block_size=blockM,
)
else:
a1_scale_fp4_sort = a1_scale

return (
a1_qt,
w1_qt_shffle_ck,
w2_qt_shffle_ck,
a1_scale,
w1_scale,
sorted_ids,
sorted_expert_ids,
sorted_weights,
num_valid_ids,
moe_buf,
w1_qt,
w2_qt,
topk_weights,
topk_ids,
a1_qt, # 0
w1_qt_shffle_ck, # 1
w2_qt_shffle_ck, # 2
a1_scale, # 3
w1_scale, # 4
sorted_ids, # 5
sorted_expert_ids, # 6
sorted_weights, # 7
num_valid_ids, # 8
moe_buf, # 9
w1_qt, # 10
w2_qt, # 11
topk_weights, # 12
topk_ids, # 13
a1_scale_fp4_sort, # 14
w1_scale_aiter,
)
elif stage == 2:
ref1 = FmoeTuner.run_torch_moe_stage1(
Expand Down Expand Up @@ -687,21 +693,33 @@ def generate_data_2stages(
a2_qt = a2_qt.view(token, topk, -1)
if doweight_stage1:
sorted_weights = None
if q_type == QuantType.per_1x32:
a2_scale_mxfp4_sort = moe_mxfp4_sort(
a2_scale[: token * topk, :].view(token, topk, -1),
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token,
block_size=blockM,
)
else:
a2_scale_mxfp4_sort = a2_scale
return (
a2_qt,
w1_qt_shffle_ck,
w2_qt_shffle_ck,
a2_scale,
w2_scale,
sorted_ids,
sorted_expert_ids,
sorted_weights,
num_valid_ids,
moe_buf,
w1_qt,
w2_qt,
topk_weights,
topk_ids,
a2_qt, # 0
w1_qt_shffle_ck, # 1
w2_qt_shffle_ck, # 2
a2_scale, # 3
w2_scale, # 4
sorted_ids, # 5
sorted_expert_ids, # 6
sorted_weights, # 7
num_valid_ids, # 8
moe_buf, # 9
w1_qt, # 10
w2_qt, # 11
topk_weights, # 12
topk_ids, # 13
a2_scale_mxfp4_sort, # 14
w2_scale_aiter,
)

@staticmethod
Expand Down Expand Up @@ -766,7 +784,7 @@ def generate_data_1stage(
fc1_smooth_scale = None
fc2_smooth_scale = None
if q_type == QuantType.per_1x32:
a1_scale = fp4_utils.moe_mxfp4_sort(
a1_scale = moe_mxfp4_sort(
a1_scale,
sorted_ids,
num_valid_ids,
Expand Down Expand Up @@ -1268,8 +1286,6 @@ def get_1stage_file_info(self, q_type, q_dtype_a, doweight_stage1):
quantDtype = ""
if doweight_stage1:
extraInfo_1stage = "_tkw1"
if q_dtype_a == dtypes.fp8:
quantDtype = "Int8" ## tmp solution, need to be updated
if q_type == QuantType.No:
quantDtype_1stage = "noquant"
elif q_type == QuantType.per_1x128:
Expand Down Expand Up @@ -1559,7 +1575,7 @@ def gen_2stages_task(self, key, blockMs):
),
FmoeTuner.ck_moe_stage1_fwd_out, # func
(
[0, 1, 2, 5, 6, 7, 8, 4, 3],
[0, 1, 2, 5, 6, 7, 8, 15, 14],
dtype,
topk,
kernel.name,
Expand All @@ -1570,6 +1586,7 @@ def gen_2stages_task(self, key, blockMs):
{},
FmoeTuner.run_torch_moe_stage1,
(
# [a1_qt, w1_qt, w2_qt, topk_weights, topk_ids, a1_scale, w1_scale]
[0, 10, 11, 12, 13, 3, 4],
dtype,
act_type,
Expand Down Expand Up @@ -1610,7 +1627,7 @@ def gen_2stages_task(self, key, blockMs):
),
FmoeTuner.ck_moe_stage2_fwd_out, # func
(
[0, 1, 2, 5, 6, 7, 8, 4, 3],
[0, 1, 2, 5, 6, 7, 8, 15, 14],
dtype,
topk,
kernel.name,
Expand Down
2 changes: 0 additions & 2 deletions hsa/gfx950/fmoe_2stages/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def get_1stage_file_info(self, q_type, q_dtype_a, doweight_stage1):
quantDtype = ""
if doweight_stage1:
extraInfo_1stage = "_tkw1"
if q_dtype_a == dtypes.fp8:
quantDtype = "Int8" ## tmp solution, need to be updated
if q_type == QuantType.No:
quantDtype_1stage = "noquant"
elif q_type == QuantType.per_1x128:
Expand Down