Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def FinalFunc():
doweight_stage1,
) in fused_moe_1stage_dict[get_gfx()]:
if q_type == QuantType.per_1x128:
run_1stage = True and (inter_dim % 256 == 0)
run_1stage = True and (inter_dim % 128 == 0)
elif q_type == QuantType.per_Token and q_dtype_w == dtypes.i8:
run_1stage = token > 32
elif q_type == QuantType.per_Token and q_dtype_w == dtypes.fp8:
Expand Down
18 changes: 9 additions & 9 deletions csrc/py_itfs_cu/asm_fmoe.cu
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ FMoeKernel* get_heuristic_kernel(
uint32_t tg_num = 0;
uint32_t num_persistent_tgs = 0;
uint32_t round = 0xffffffff;
std::string arch_id = get_gpu_arch();
std::string selectedKl = kernel_name.empty() ? "" : arch_id + kernel_name;
std::string arch_id = get_gpu_arch();
std::string selectedKl = kernel_name.empty() ? "" : arch_id + kernel_name;
int vskip = 1;
static std::unordered_map<std::string, std::unique_ptr<FMoeKernel>> impl_ptr_map;

Expand All @@ -272,8 +272,8 @@ FMoeKernel* get_heuristic_kernel(
{
for(const auto& el : *cfgs)
{
if (el.first.find(arch_id) != 0)
continue;
if(el.first.find(arch_id) != 0)
continue;
const auto& cfg = el.second;
if(cfg.vskip == vskip && cfg.smf == smf)
{
Expand Down Expand Up @@ -675,8 +675,8 @@ void fmoe_g1u1_tkw1(torch::Tensor& out, // [token_cnt, dim]
const int token_cnt = input.size(0);
const int block_m = 32; // fmoe sorting kernel and fmoe kernel only support 32 for now
const int estimated_sub_X_cnt = (token_cnt * topk + block_m - 1) / block_m;
int model_dim = down.size(1);
int inter_dim = down.size(2);
int model_dim = down.size(1);
int inter_dim = down.size(2);
inter_dim *= model_dim / gate.size(2);

if(fc2_smooth_scale.has_value())
Expand Down Expand Up @@ -839,7 +839,7 @@ void fmoe_fp8_blockscale_g1u1(torch::Tensor& out, // [token_cnt, d
int sub_X_cnt = sorted_expert_ids.size(0);
const char* enable_vskip = std::getenv("AITER_ENABLE_VSKIP");

if(out.dtype() == at::ScalarType::BFloat16 && inter_dim % 256 == 0 && fc_scale_blkn == 128 &&
if(out.dtype() == at::ScalarType::BFloat16 && inter_dim % 128 == 0 && fc_scale_blkn == 128 &&
fc_scale_blkk == 128)
{
if(activation == ActivationType::Silu)
Expand All @@ -850,8 +850,8 @@ void fmoe_fp8_blockscale_g1u1(torch::Tensor& out, // [token_cnt, d
TORCH_CHECK(
false, __func__, "Unsupported activation type for fmoe_fp8_blockscale_g1u1");

impl_ptr = get_heuristic_kernel(inter_dim, sorted_expert_ids.size(0), config_map, 0, kernel_name);

impl_ptr =
get_heuristic_kernel(inter_dim, sorted_expert_ids.size(0), config_map, 0, kernel_name);
impl_ptr->launch_kernel<uint8_t, uint16_t, false>(out,
input,
gate,
Expand Down
4 changes: 4 additions & 0 deletions hsa/gfx942/fmoe/gelu/fmoe_bf16_blockscaleFp8_g1u1_gelu.csv
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
knl_name,co_name,atm,vskip,smf,tg_num_perCU,ps,subGU_m,subGU_n
_ZN5aiter47fmoe_bf16_blockscaleFp8_g1u1_vs_gelu_1tg_32x128E,fmoe_bf16_blockscaleFp8_g1u1_vs_gelu_1tg_32x128.co,0,1,0,1,0,32,128
_ZN5aiter47fmoe_bf16_blockscaleFp8_g1u1_vs_gelu_1tg_32x256E,fmoe_bf16_blockscaleFp8_g1u1_vs_gelu_1tg_32x256.co,0,1,0,1,0,32,256
_ZN5aiter50fmoe_bf16_blockscaleFp8_g1u1_vs_gelu_1tg_ps_32x128E,fmoe_bf16_blockscaleFp8_g1u1_vs_gelu_1tg_ps_32x128.co,0,1,0,1,1,32,128
_ZN5aiter50fmoe_bf16_blockscaleFp8_g1u1_vs_gelu_1tg_ps_32x256E,fmoe_bf16_blockscaleFp8_g1u1_vs_gelu_1tg_ps_32x256.co,0,1,0,1,1,32,256
_ZN5aiter49fmoe_bf16_blockscaleFp8_g1u1_novs_gelu_1tg_32x128E,fmoe_bf16_blockscaleFp8_g1u1_novs_gelu_1tg_32x128.co,0,0,0,1,0,32,128
_ZN5aiter49fmoe_bf16_blockscaleFp8_g1u1_novs_gelu_1tg_32x256E,fmoe_bf16_blockscaleFp8_g1u1_novs_gelu_1tg_32x256.co,0,0,0,1,0,32,256
_ZN5aiter52fmoe_bf16_blockscaleFp8_g1u1_novs_gelu_1tg_ps_32x128E,fmoe_bf16_blockscaleFp8_g1u1_novs_gelu_1tg_ps_32x128.co,0,0,0,1,1,32,128
_ZN5aiter52fmoe_bf16_blockscaleFp8_g1u1_novs_gelu_1tg_ps_32x256E,fmoe_bf16_blockscaleFp8_g1u1_novs_gelu_1tg_ps_32x256.co,0,0,0,1,1,32,256
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified hsa/gfx942/fmoe/gelu/fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x128.co
100755 → 100644
Binary file not shown.
Binary file not shown.
Binary file not shown.
4 changes: 4 additions & 0 deletions hsa/gfx942/fmoe/silu/fmoe_bf16_blockscaleFp8_g1u1_silu.csv
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
knl_name,co_name,atm,vskip,smf,tg_num_perCU,ps,subGU_m,subGU_n
_ZN5aiter47fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_32x128E,fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_32x128.co,0,1,0,1,0,32,128
_ZN5aiter47fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_32x256E,fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_32x256.co,0,1,0,1,0,32,256
_ZN5aiter52fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_ps_32x128E,fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_ps_32x128.co,0,0,0,1,1,32,128
_ZN5aiter52fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_ps_32x256E,fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_ps_32x256.co,0,0,0,1,1,32,256
_ZN5aiter50fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_ps_32x128E,fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_ps_32x128.co,0,1,0,1,1,32,128
_ZN5aiter50fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_ps_32x256E,fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_ps_32x256.co,0,1,0,1,1,32,256
_ZN5aiter49fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_32x128E,fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_32x128.co,0,0,0,1,0,32,128
_ZN5aiter49fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_32x256E,fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_32x256.co,0,0,0,1,0,32,256
Binary file not shown.
Binary file not shown.
Binary file modified hsa/gfx942/fmoe/silu/fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x128.co
100755 → 100644
Binary file not shown.