diff --git a/aiter/jit/core.py b/aiter/jit/core.py index 560727795a..41809562fa 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -116,7 +116,6 @@ def update_config_files(file_path: str, merge_name: str): return new_file_path -# @functools.lru_cache(maxsize=1) def get_config_file(env_name, default_file, tuned_file_name): config_env_file = os.getenv(env_name) # default_file = f"{AITER_ROOT_DIR}/aiter/configs/{tuned_file_name}.csv" diff --git a/aiter/utility/base_tuner.py b/aiter/utility/base_tuner.py index 9baec8030e..775a927c63 100644 --- a/aiter/utility/base_tuner.py +++ b/aiter/utility/base_tuner.py @@ -12,7 +12,6 @@ from operator import itemgetter import time from aiter import dtypes -from aiter import core INVALID_TIME = -1 @@ -168,6 +167,37 @@ def result_to_df(self, rets): """transfer results to dataframe""" pass + def update_config_files(self, file_path: str, merge_name: str): + path_list = file_path.split(os.pathsep) if file_path else [] + if len(path_list) <= 1: + return file_path + df_list = [] + ## merge config files + ##example: AITER_CONFIG_GEMM_A4W4="/path1:/path2" + + df_list.append(pd.read_csv(path_list[0])) + for i, path in enumerate(path_list[1:]): + if os.path.exists(path): + df = pd.read_csv(path) + ## check columns + assert ( + df.columns.tolist() == df_list[0].columns.tolist() + ), f"Column mismatch between {path_list[0]} and {path}, {df_list[0].columns.tolist()}, {df.columns.tolist()}" + + df_list.append(df) + else: + print(f"path {i+1}: {path} (not exist)") + merge_df = pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame() + ##drop_duplicates + merge_df = ( + merge_df.sort_values("us") + .drop_duplicates(subset=self.keys, keep="first") + .reset_index(drop=True) + ) + new_file_path = f"/tmp/{merge_name}.csv" + merge_df.to_csv(new_file_path, index=False) + return new_file_path + def get_untuned_gemm_list(self, untuned_gemm_file): assert os.path.exists( untuned_gemm_file @@ -183,7 +213,7 @@ def get_out_file(self, tuned_file): return path_list[0] def get_tuned_gemm_list(self, tuned_gemm_file, columns=[]): - all_tuned_file = core.update_config_files(tuned_gemm_file, self.name) + all_tuned_file = self.update_config_files(tuned_gemm_file, self.name) if os.path.exists(all_tuned_file): column_order = pd.read_csv(all_tuned_file, nrows=0).columns.tolist() tunedf = pd.read_csv(all_tuned_file) @@ -331,17 +361,19 @@ def tune_summary(self, status): ) logger.info("Successfully tuned shapes:") if not self.success.empty: - print(self.success) + print(self.success, flush=True) logger.info("Failed shapes:") - print(self.failed) + print(self.failed, flush=True) tunedf_subset = tunedf[self.untunedf.columns].astype(self.untunedf.dtypes) mask = self.untunedf.apply(tuple, axis=1).isin( tunedf_subset.apply(tuple, axis=1) ) self.remain_untuned = self.untunedf[~mask] - logger.info("untuned shapes:") - print(self.remain_untuned) + + if not self.remain_untuned.empty: + logger.info("untuned shapes:") + print(self.remain_untuned) @abstractmethod def result_to_csv(self, results, file, concat=False): diff --git a/hsa/gfx942/fmoe_2stages/tune.py b/hsa/gfx942/fmoe_2stages/tune.py index 16d2f51441..a7fe782343 100644 --- a/hsa/gfx942/fmoe_2stages/tune.py +++ b/hsa/gfx942/fmoe_2stages/tune.py @@ -26,10 +26,10 @@ from aiter import dtypes from aiter import ActivationType as ActivationType from aiter.jit.utils.chip_info import get_gfx -from aiter.utility import fp4_utils import torch.nn.functional as F from einops import rearrange from aiter.utility.base_tuner import TunerCommon +from aiter.utility import fp4_utils from aiter.utility.fp4_utils import moe_mxfp4_sort @@ -134,15 +134,6 @@ def ck_moe_stage1_fwd_out( ): inter_dim = w1_qt_shffle_ck.shape[1] // 2 token_num = a1_qt.shape[0] - - a1_scale = moe_mxfp4_sort( - a1_scale, - sorted_ids=sorted_ids, - num_valid_ids=num_valid_ids, - token_num=token_num, - block_size=blockM, - ) - out = torch.empty( (token_num, topk, inter_dim), dtype=dtype, @@ -158,7 +149,7 @@ def ck_moe_stage1_fwd_out( out, topk, kernelName, - fp4_utils.e8m0_shuffle(w1_scale), + w1_scale, a1_scale, blockM, sorted_weights, @@ -195,14 +186,6 @@ def ck_moe_stage2_fwd_out( model_dim = w2_qt_shffle_ck.shape[1] token_num = a2_qt.shape[0] - a2_scale = moe_mxfp4_sort( - a2_scale[: token_num * topk, :].view(token_num, topk, -1), - sorted_ids=sorted_ids, - num_valid_ids=num_valid_ids, - token_num=token_num, - block_size=blockM, - ) - out = torch.zeros( (token_num, model_dim), dtype=dtype, @@ -218,9 +201,7 @@ def ck_moe_stage2_fwd_out( out, topk, kernelName, - fp4_utils.e8m0_shuffle( - w2_scale - ), # e8m0_shuffle will do nothing if it's a fp32 + w2_scale, a2_scale, blockM, sorted_weights, @@ -415,14 +396,13 @@ def get_1stage_fmoe_func( quant_type == QuantType.No and activation == ActivationType.Silu and not isG1U1 - or doweight_stage1 or quant_type == QuantType.per_1x32 ): - print("not support No Quant Silu G1U0 1 stage tuning!") + print("not support No Quant Silu G1U0 1 stage or per_1x32 quant tuning!") else: if quant_type == QuantType.per_1x128: fmoe_func = FmoeTuner.run_1stage_fmoe_fp8_blockscale_g1u1 - elif (q_dtype_a == dtypes.fp8) & doweight_stage1: + elif (q_dtype_a == dtypes.fp8) and doweight_stage1: fmoe_func = FmoeTuner.run_1stage_fmoe_g1u1_tkw1 elif isG1U1: fmoe_func = FmoeTuner.run_1stage_fmoe_g1u1 @@ -467,13 +447,24 @@ def generate_data( ) a1_qt = a1_qt.view(token, model_dim) a1_scale = a1_scale.squeeze(-1) + elif ( + q_type == aiter.QuantType.per_1x32 + and (q_dtype_a in [dtypes.bf16, dtypes.fp16]) + and q_dtype_w == dtypes.fp4x2 + ): # a16w4 + a1_qt = input.to(dtype) + a1_scale = None else: torch_quant = aiter.get_torch_quant(q_type) a1_qt, a1_scale = torch_quant(input, quant_dtype=q_dtype_a) del w1, w2, score + if q_dtype_w is not dtypes.fp4x2: + w1_qt_shffle = shuffle_weight(w1_qt, (16, 16)) + w2_qt_shffle = shuffle_weight(w2_qt, (16, 16)) + else: + w1_qt_shffle = w1_qt + w2_qt_shffle = w2_qt - w1_qt_shffle = shuffle_weight(w1_qt, (16, 16)) - w2_qt_shffle = shuffle_weight(w2_qt, (16, 16)) sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = ( moe_sorting(topk_ids, topk_weights, expert, model_dim, dtype, blockM) ) @@ -639,24 +630,39 @@ def generate_data_2stages( else: w1_qt_shffle_ck = w1_qt_shffle w2_qt_shffle_ck = w2_qt_shffle + w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) + w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) if stage == 1: if not doweight_stage1: sorted_weights = None + if q_type == QuantType.per_1x32: + a1_scale_fp4_sort = moe_mxfp4_sort( + a1_scale, # a1_scale[: token * topk, :].view(token, topk, -1), + sorted_ids=sorted_ids, + num_valid_ids=num_valid_ids, + token_num=token, + block_size=blockM, + ) + else: + a1_scale_fp4_sort = a1_scale + return ( - a1_qt, - w1_qt_shffle_ck, - w2_qt_shffle_ck, - a1_scale, - w1_scale, - sorted_ids, - sorted_expert_ids, - sorted_weights, - num_valid_ids, - moe_buf, - w1_qt, - w2_qt, - topk_weights, - topk_ids, + a1_qt, # 0 + w1_qt_shffle_ck, # 1 + w2_qt_shffle_ck, # 2 + a1_scale, # 3 + w1_scale, # 4 + sorted_ids, # 5 + sorted_expert_ids, # 6 + sorted_weights, # 7 + num_valid_ids, # 8 + moe_buf, # 9 + w1_qt, # 10 + w2_qt, # 11 + topk_weights, # 12 + topk_ids, # 13 + a1_scale_fp4_sort, # 14 + w1_scale_aiter, ) elif stage == 2: ref1 = FmoeTuner.run_torch_moe_stage1( @@ -687,21 +693,33 @@ def generate_data_2stages( a2_qt = a2_qt.view(token, topk, -1) if doweight_stage1: sorted_weights = None + if q_type == QuantType.per_1x32: + a2_scale_mxfp4_sort = moe_mxfp4_sort( + a2_scale[: token * topk, :].view(token, topk, -1), + sorted_ids=sorted_ids, + num_valid_ids=num_valid_ids, + token_num=token, + block_size=blockM, + ) + else: + a2_scale_mxfp4_sort = a2_scale return ( - a2_qt, - w1_qt_shffle_ck, - w2_qt_shffle_ck, - a2_scale, - w2_scale, - sorted_ids, - sorted_expert_ids, - sorted_weights, - num_valid_ids, - moe_buf, - w1_qt, - w2_qt, - topk_weights, - topk_ids, + a2_qt, # 0 + w1_qt_shffle_ck, # 1 + w2_qt_shffle_ck, # 2 + a2_scale, # 3 + w2_scale, # 4 + sorted_ids, # 5 + sorted_expert_ids, # 6 + sorted_weights, # 7 + num_valid_ids, # 8 + moe_buf, # 9 + w1_qt, # 10 + w2_qt, # 11 + topk_weights, # 12 + topk_ids, # 13 + a2_scale_mxfp4_sort, # 14 + w2_scale_aiter, ) @staticmethod @@ -766,7 +784,7 @@ def generate_data_1stage( fc1_smooth_scale = None fc2_smooth_scale = None if q_type == QuantType.per_1x32: - a1_scale = fp4_utils.moe_mxfp4_sort( + a1_scale = moe_mxfp4_sort( a1_scale, sorted_ids, num_valid_ids, @@ -1268,8 +1286,6 @@ def get_1stage_file_info(self, q_type, q_dtype_a, doweight_stage1): quantDtype = "" if doweight_stage1: extraInfo_1stage = "_tkw1" - if q_dtype_a == dtypes.fp8: - quantDtype = "Int8" ## tmp solution, need to be updated if q_type == QuantType.No: quantDtype_1stage = "noquant" elif q_type == QuantType.per_1x128: @@ -1559,7 +1575,7 @@ def gen_2stages_task(self, key, blockMs): ), FmoeTuner.ck_moe_stage1_fwd_out, # func ( - [0, 1, 2, 5, 6, 7, 8, 4, 3], + [0, 1, 2, 5, 6, 7, 8, 15, 14], dtype, topk, kernel.name, @@ -1570,6 +1586,7 @@ def gen_2stages_task(self, key, blockMs): {}, FmoeTuner.run_torch_moe_stage1, ( + # [a1_qt, w1_qt, w2_qt, topk_weights, topk_ids, a1_scale, w1_scale] [0, 10, 11, 12, 13, 3, 4], dtype, act_type, @@ -1610,7 +1627,7 @@ def gen_2stages_task(self, key, blockMs): ), FmoeTuner.ck_moe_stage2_fwd_out, # func ( - [0, 1, 2, 5, 6, 7, 8, 4, 3], + [0, 1, 2, 5, 6, 7, 8, 15, 14], dtype, topk, kernel.name, diff --git a/hsa/gfx950/fmoe_2stages/tune.py b/hsa/gfx950/fmoe_2stages/tune.py index c97a75cd74..560dbebc2e 100644 --- a/hsa/gfx950/fmoe_2stages/tune.py +++ b/hsa/gfx950/fmoe_2stages/tune.py @@ -57,8 +57,6 @@ def get_1stage_file_info(self, q_type, q_dtype_a, doweight_stage1): quantDtype = "" if doweight_stage1: extraInfo_1stage = "_tkw1" - if q_dtype_a == dtypes.fp8: - quantDtype = "Int8" ## tmp solution, need to be updated if q_type == QuantType.No: quantDtype_1stage = "noquant" elif q_type == QuantType.per_1x128: