diff --git a/.github/workflows/pre-checks.yaml b/.github/workflows/pre-checks.yaml index 6bc1cd6b9a..f5dc4eff35 100644 --- a/.github/workflows/pre-checks.yaml +++ b/.github/workflows/pre-checks.yaml @@ -35,7 +35,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - name: Set up Python environment - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: "3.12" - name: Install dependencies @@ -46,7 +46,16 @@ jobs: env: REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | - ruff check . -e | reviewdog -efm="%f:%l:%c: %m" -diff="git diff FETCH_HEAD" -reporter=github-pr-check -tee + ruff check . \ + --output-format=rdjson \ + --exit-zero \ + --no-fix \ + | reviewdog \ + -f=rdjson \ + -name="ruff" \ + -reporter=github-pr-review \ + -filter-mode=diff_context \ + -fail-on-error=true upload-success-artifact: name: Upload Success Signal diff --git a/.github/workflows/triton-test.yaml b/.github/workflows/triton-test.yaml index 1aa47ac333..0bfb935143 100644 --- a/.github/workflows/triton-test.yaml +++ b/.github/workflows/triton-test.yaml @@ -29,7 +29,7 @@ jobs: GITHUB_SHA: ${{ github.sha }} triton: - runs-on: aiter-mi300-1gpu + runs-on: aiter-1gpu-runner needs: [check-signal] env: DOCKER_IMAGE: "rocm/pytorch:latest" diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 573ace94ac..9bd67c2cf2 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 573ace94aca3b5797d1536670a42052f1a291d48 +Subproject commit 9bd67c2cf2fe8e4479a433bcd6d467e2ea9aedb4 diff --git a/MANIFEST.in b/MANIFEST.in index 0d8f6f5d5d..de0ff3c02c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,6 @@ graft aiter -graft aiter_meta \ No newline at end of file +graft aiter_meta + +# exclude cache and compiled files .pyc / .pyo / .pyd / .pyd +global-exclude *.py[cod] +prune aiter/jit/build \ No newline at end of file diff --git a/aiter/aot/sampling.py b/aiter/aot/sampling.py new file mode 100644 index 0000000000..0d758d5d83 --- /dev/null +++ b/aiter/aot/sampling.py @@ -0,0 +1,89 @@ +from collections import namedtuple +import os +import concurrent.futures +from csrc.cpp_itfs.sampling.top_k_renorm_probs import ( + compile as top_k_renorm_probs_compile, +) +from csrc.cpp_itfs.sampling.top_p_sampling_from_probs import ( + compile as top_p_sampling_from_probs_compile, +) +from csrc.cpp_itfs.sampling.top_k_top_p_sampling_from_probs import ( + compile as top_k_top_p_sampling_from_probs_compile, +) + +TopKRenormConfig = namedtuple( + "TopKRenormConfig", + ["vec_size", "func_name"], +) + +TopPSamplingConfig = namedtuple( + "TopPSamplingConfig", + ["vec_size", "deterministic", "func_name"], +) + +TopKTopPSamplingConfig = namedtuple( + "TopKTopPSamplingConfig", + ["vec_size", "deterministic", "func_name"], +) + + +def process_top_k_renorm_config(config): + return top_k_renorm_probs_compile(config.vec_size) + + +def process_top_p_sampling_config(config): + return top_p_sampling_from_probs_compile(config.vec_size, config.deterministic) + + +def process_top_k_top_p_sampling_config(config): + return top_k_top_p_sampling_from_probs_compile( + config.vec_size, config.deterministic + ) + + +def main(): + # Generate configs for top_k_renorm_probs + top_k_renorm_configs = [] + for vec_size in range(1, 5): + top_k_renorm_configs.append( + TopKRenormConfig( + vec_size=vec_size, + func_name="top_k_renorm_probs", + ) + ) + + # Generate configs for top_p_sampling_from_probs + top_p_sampling_configs = [] + for vec_size in range(1, 5): + for deterministic in [False, True]: + top_p_sampling_configs.append( + TopPSamplingConfig( + vec_size=vec_size, + deterministic=deterministic, + func_name="top_p_sampling_from_probs", + ) + ) + + # Generate configs for top_k_top_p_sampling_from_probs + top_k_top_p_sampling_configs = [] + for vec_size in range(1, 5): + for deterministic in [False, True]: + top_k_top_p_sampling_configs.append( + TopKTopPSamplingConfig( + vec_size=vec_size, + deterministic=deterministic, + func_name="top_k_top_p_sampling_from_probs", + ) + ) + + max_jobs = int(os.environ.get("MAX_JOBS", os.cpu_count() or 16)) + + # Process all configs in parallel + with concurrent.futures.ProcessPoolExecutor(max_workers=max_jobs) as executor: + executor.map(process_top_k_renorm_config, top_k_renorm_configs) + executor.map(process_top_p_sampling_config, top_p_sampling_configs) + executor.map(process_top_k_top_p_sampling_config, top_k_top_p_sampling_configs) + + +if __name__ == "__main__": + main() diff --git a/aiter/configs/a4w4_blockscale_tuned_gemm.csv b/aiter/configs/a4w4_blockscale_tuned_gemm.csv index 51a05157f0..3988c91f18 100644 --- a/aiter/configs/a4w4_blockscale_tuned_gemm.csv +++ b/aiter/configs/a4w4_blockscale_tuned_gemm.csv @@ -921,3 +921,5 @@ cu_num,M,N,K,kernelId,splitK,us,kernelName,tflops,bw,errRatio 256,8,3072,1536,42,0,5.4682,_ZN5aiter42f4gemm_bf16_per1x32Fp4_BpreShuffle_128x128E,13.81,441.57,0.0 256,8,7168,2048,29,0,5.836,_ZN5aiter41f4gemm_bf16_per1x32Fp4_BpreShuffle_64x128E,40.25,1278.77,0.0 256,8,512,7168,29,0,9.6677,_ZN5aiter41f4gemm_bf16_per1x32Fp4_BpreShuffle_64x128E,6.07,193.62,0.0 +256,32768,2112,7168,48,0,293.0219,_ZN5aiter42f4gemm_bf16_per1x32Fp4_BpreShuffle_160x384E,3385.88,898.98,0.0 +256,65536,2112,7168,48,0,575.6528,_ZN5aiter42f4gemm_bf16_per1x32Fp4_BpreShuffle_160x384E,3447.0,902.06,0.0 diff --git a/aiter/configs/a4w4_blockscale_untuned_gemm.csv b/aiter/configs/a4w4_blockscale_untuned_gemm.csv index 3c91c37b07..e78f1eb3f5 100644 --- a/aiter/configs/a4w4_blockscale_untuned_gemm.csv +++ b/aiter/configs/a4w4_blockscale_untuned_gemm.csv @@ -193,3 +193,5 @@ M,N,K 3000, 7168, 2048 3000, 512, 7168 60000, 4096, 512 +32768, 2112, 7168 +65536, 2112, 7168 diff --git a/aiter/configs/tuned_fmoe.csv b/aiter/configs/tuned_fmoe.csv index 275cd8944e..1ce76a65be 100644 --- a/aiter/configs/tuned_fmoe.csv +++ b/aiter/configs/tuned_fmoe.csv @@ -4,6 +4,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,4,2304,1536,8,2,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,17.6606,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.0%,15.126,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.3%,32.7866,0,5.18,2591.37 80,4,2304,1536,8,2,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,17.8008,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,14.5115,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,32.3123,0,5.26,2629.41 80,512,6144,4096,8,2,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,774.6328,moe_ck2stages_gemm1_256x64x128x64_1x4_TypeCast_v3_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,459.0113,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCastExpertWeight_v3_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.3%,1233.6441,0,125.34,989.38 +256,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,130.4639,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,70.3202,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,200.7841,0,7.02,14040.11 +256,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,130.4639,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,70.3202,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,200.7841,0,7.02,14040.11 +256,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,130.4639,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,70.3202,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,200.7841,0,7.02,14040.11 +256,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,130.4639,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,70.3202,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,200.7841,0,7.02,14040.11 256,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,130.4639,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,70.3202,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,200.7841,0,7.02,14040.11 256,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,195.38,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,107.5659,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,302.9459,0,9.3,9306.91 256,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,278.093,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,140.8376,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,418.9306,0,13.46,6732.4 @@ -11,6 +15,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,306.0006,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,170.2105,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,476.2111,0,47.35,5934.16 256,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,309.2402,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,184.9719,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.6%,494.2121,0,91.25,5732.87 256,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,325.0568,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,231.4032,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.6%,556.46,0,162.09,5117.95 +256,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,128.2525,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,72.0121,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,200.2646,0,7.04,14076.53 +256,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,128.2525,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,72.0121,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,200.2646,0,7.04,14076.53 +256,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,128.2525,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,72.0121,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,200.2646,0,7.04,14076.53 +256,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,128.2525,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,72.0121,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,200.2646,0,7.04,14076.53 256,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,128.2525,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,72.0121,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,200.2646,0,7.04,14076.53 256,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,195.9999,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,102.7882,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,298.7881,0,9.43,9436.42 256,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,277.4499,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,139.0861,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,416.536,0,13.53,6771.1 @@ -18,6 +26,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,306.2672,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,164.6962,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,470.9634,0,47.88,6000.28 256,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,309.6434,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,178.4363,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,488.0797,0,92.4,5804.9 256,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,325.7872,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,223.4421,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,549.2293,0,164.22,5185.32 +256,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,67.4265,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,41.189,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,108.6155,0,12.98,12978.17 +256,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,67.4265,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,41.189,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,108.6155,0,12.98,12978.17 +256,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,67.4265,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,41.189,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,108.6155,0,12.98,12978.17 +256,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,67.4265,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,41.189,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,108.6155,0,12.98,12978.17 256,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,67.4265,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,41.189,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,108.6155,0,12.98,12978.17 256,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,102.7345,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,56.8998,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,159.6343,0,17.66,8832.53 256,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,140.8235,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,76.5494,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,217.3729,0,25.93,6489.6 @@ -25,6 +37,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,158.9481,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,92.9698,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,251.91790000000003,0,89.51,5616.08 256,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,161.9427,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,114.4508,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,276.3935,0,163.16,5138.67 256,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,64,0,168.3246,moe_ck2stages_gemm1_256x64x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,205.6813,moe_ck2stages_gemm2_256x64x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,374.0059,0,241.16,3826.96 +256,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,68.0621,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,40.8199,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,108.882,0,12.94,12946.4 +256,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,68.0621,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,40.8199,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,108.882,0,12.94,12946.4 +256,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,68.0621,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,40.8199,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,108.882,0,12.94,12946.4 +256,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,68.0621,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,40.8199,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,108.882,0,12.94,12946.4 256,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,68.0621,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,40.8199,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,108.882,0,12.94,12946.4 256,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,102.8318,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,57.3307,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,160.1625,0,17.6,8803.4 256,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,141.6806,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,77.5578,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,219.2384,0,25.71,6434.38 @@ -32,6 +48,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,159.3862,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,95.0034,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,254.3896,0,88.64,5561.51 256,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,162.5288,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,113.9963,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,276.5251,0,163.09,5136.23 256,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,64,0,168.5532,moe_ck2stages_gemm1_256x64x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,205.3887,moe_ck2stages_gemm2_256x64x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,373.9419,0,241.2,3827.62 +256,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,68.6613,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,46.3816,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.7%,115.0429,0,12.25,12253.08 +256,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,68.6613,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,46.3816,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.7%,115.0429,0,12.25,12253.08 +256,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,68.6613,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,46.3816,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.7%,115.0429,0,12.25,12253.08 +256,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,68.6613,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,46.3816,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.7%,115.0429,0,12.25,12253.08 256,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,68.6613,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,46.3816,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.7%,115.0429,0,12.25,12253.08 256,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,158.0965,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x256E,0.0%,0.0,Null,0,158.0965,1,17.83,8918.44 256,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,215.8536,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x256E,0.0%,0.0,Null,0,215.8536,1,26.12,6535.27 @@ -39,6 +59,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,254.5557,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x256E,0.0%,0.0,Null,0,254.5557,1,88.58,5557.88 256,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,267.5654,_ZN5aiter48fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_ps_32x256E,0.0%,0.0,Null,0,267.5654,1,168.55,5308.22 256,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,366.6991,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x256E,0.0%,0.0,Null,0,366.6991,1,245.96,3903.22 +256,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,68.3263,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,44.1851,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,112.5114,0,12.53,12528.78 +256,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,68.3263,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,44.1851,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,112.5114,0,12.53,12528.78 +256,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,68.3263,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,44.1851,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,112.5114,0,12.53,12528.78 +256,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,68.3263,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,44.1851,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,112.5114,0,12.53,12528.78 256,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,68.3263,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,44.1851,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,112.5114,0,12.53,12528.78 256,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,100.365,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf2E,0.0%,61.0618,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,161.4268,0,17.46,8734.45 256,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,140.407,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf3E,0.0%,78.9057,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,219.3127,0,25.7,6432.2 @@ -46,6 +70,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,159.338,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,102.7582,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,262.0962,0,86.03,5397.98 256,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,161.3644,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf2E,0.0%,132.3204,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,293.6848,0,153.56,4836.12 256,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,64,0,163.9563,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_64x128_2tg_pf3E,0.0%,218.341,moe_ck2stages_gemm2_256x64x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,382.2973,0,235.93,3743.96 +256,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,268.7481,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,135.0723,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,403.82040000000006,0,6.98,13960.67 +256,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,268.7481,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,135.0723,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,403.82040000000006,0,6.98,13960.67 +256,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,268.7481,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,135.0723,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,403.82040000000006,0,6.98,13960.67 +256,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,268.7481,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,135.0723,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,403.82040000000006,0,6.98,13960.67 256,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,268.7481,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,135.0723,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,403.82040000000006,0,6.98,13960.67 256,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,378.5195,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,196.1646,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,574.6841,0,9.81,9810.72 256,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,559.7713,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,271.7302,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,831.5015000000001,0,13.56,6781.68 @@ -53,6 +81,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,612.6749,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,322.9055,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,935.5804,0,48.2,6033.14 256,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,623.7185,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,338.7751,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,962.4936,0,93.71,5872.06 256,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,649.3028,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,368.4383,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,1017.7411,0,177.24,5567.73 +256,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,265.8935,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,135.088,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,400.9815,0,7.03,14059.51 +256,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,265.8935,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,135.088,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,400.9815,0,7.03,14059.51 +256,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,265.8935,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,135.088,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,400.9815,0,7.03,14059.51 +256,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,265.8935,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,135.088,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,400.9815,0,7.03,14059.51 256,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,265.8935,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,135.088,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,400.9815,0,7.03,14059.51 256,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,376.5017,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,196.4837,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,572.9854,0,9.84,9839.8 256,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,556.9744,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,271.6147,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,828.5890999999999,0,13.61,6805.52 @@ -60,6 +92,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,614.0275,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,317.5052,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,931.5327,0,48.41,6059.35 256,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,624.6592,moe_ck2stages_gemm1_256x64x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,332.6196,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,957.2788,0,94.22,5904.05 256,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,644.3248,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,363.2348,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,1007.5596,0,179.04,5623.99 +256,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,139.2785,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,70.4958,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.3%,209.7743,0,13.44,13437.85 +256,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,139.2785,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,70.4958,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.3%,209.7743,0,13.44,13437.85 +256,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,139.2785,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,70.4958,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.3%,209.7743,0,13.44,13437.85 +256,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,139.2785,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,70.4958,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.3%,209.7743,0,13.44,13437.85 256,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,139.2785,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,70.4958,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.3%,209.7743,0,13.44,13437.85 256,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,194.034,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,100.7957,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.4%,294.8297,0,19.12,9562.34 256,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,274.3536,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2E,0.0%,140.3968,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.3%,414.7504,0,27.18,6799.15 @@ -67,6 +103,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,64,0,314.2275,moe_ck2stages_gemm1_256x64x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,169.8802,moe_ck2stages_gemm2_256x64x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.3%,484.1077,0,93.16,5833.57 256,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,318.189,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2E,0.0%,183.5436,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.3%,501.7326000000001,0,179.77,5639.62 256,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,64,0,328.7642,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_64x128_2tg_pf2E,0.0%,226.0569,moe_ck2stages_gemm2_256x64x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.3%,554.8211,0,325.13,5119.83 +256,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,138.6795,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,70.6801,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,209.3596,0,13.46,13464.47 +256,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,138.6795,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,70.6801,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,209.3596,0,13.46,13464.47 +256,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,138.6795,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,70.6801,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,209.3596,0,13.46,13464.47 +256,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,138.6795,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,70.6801,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,209.3596,0,13.46,13464.47 256,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,138.6795,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,70.6801,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,209.3596,0,13.46,13464.47 256,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,193.8469,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,101.2026,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,295.0495,0,19.11,9555.21 256,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,277.7873,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf2E,0.0%,141.146,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,418.9333,0,26.91,6731.26 @@ -74,6 +114,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,314.5026,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,168.2646,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,482.7672,0,93.41,5849.77 256,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,318.2151,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf3E,0.0%,183.4334,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,501.6485,0,179.8,5640.57 256,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,64,0,328.5261,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_64x128_2tg_pf2E,0.0%,225.2001,moe_ck2stages_gemm2_256x64x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,553.7262,0,325.77,5129.96 +256,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,203.1825,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x256E,0.0%,0.0,Null,0,203.1825,1,13.87,13873.81 +256,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,203.1825,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x256E,0.0%,0.0,Null,0,203.1825,1,13.87,13873.81 +256,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,203.1825,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x256E,0.0%,0.0,Null,0,203.1825,1,13.87,13873.81 +256,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,203.1825,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x256E,0.0%,0.0,Null,0,203.1825,1,13.87,13873.81 256,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,203.1825,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x256E,0.0%,0.0,Null,0,203.1825,1,13.87,13873.81 256,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,196.4497,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf2E,0.0%,105.2123,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.4%,301.66200000000003,0,18.69,9345.76 256,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,264.1173,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2E,0.0%,144.1125,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.3%,408.2298,0,27.62,6907.75 @@ -81,6 +125,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,482.7665,_ZN5aiter48fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_ps_32x512E,0.0%,0.0,Null,0,482.7665,1,93.41,5849.78 256,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,494.6598,_ZN5aiter48fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_ps_32x512E,0.0%,0.0,Null,0,494.6598,1,182.34,5720.26 256,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,333.8711,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf2E,0.0%,248.6884,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.3%,582.5595000000001,0,309.65,4876.06 +256,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,129.8926,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf2E,0.0%,73.8599,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,203.7525,0,13.83,13835.0 +256,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,129.8926,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf2E,0.0%,73.8599,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,203.7525,0,13.83,13835.0 +256,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,129.8926,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf2E,0.0%,73.8599,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,203.7525,0,13.83,13835.0 +256,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,129.8926,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf2E,0.0%,73.8599,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,203.7525,0,13.83,13835.0 256,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,129.8926,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf2E,0.0%,73.8599,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,203.7525,0,13.83,13835.0 256,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,196.3192,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf2E,0.0%,102.4978,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,298.817,0,18.86,9434.74 256,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,264.1664,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf2E,0.0%,141.5633,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,405.7297,0,27.79,6950.31 @@ -88,6 +136,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,305.4521,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf2E,0.0%,172.4236,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,477.8757,0,94.37,5909.65 256,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,306.7972,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf2E,0.0%,190.6723,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,497.4695,0,181.31,5687.95 256,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,333.2413,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf2E,0.0%,244.2778,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,577.5191,0,312.35,4918.61 +256,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,89.5023,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,51.4998,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,141.0021,0,8.57,8568.82 +256,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,89.5023,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,51.4998,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,141.0021,0,8.57,8568.82 +256,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,89.5023,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,51.4998,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,141.0021,0,8.57,8568.82 +256,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,89.5023,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,51.4998,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,141.0021,0,8.57,8568.82 256,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,89.5023,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,51.4998,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,141.0021,0,8.57,8568.82 256,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,127.8742,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,68.7529,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,196.6271,0,12.29,6146.07 256,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,136.8058,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,75.6377,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,212.4435,0,22.74,5690.96 @@ -95,11 +147,19 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,140.7161,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,80.4366,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,221.1527,0,87.39,5481.07 256,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,144.981,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,104.8371,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,249.8181,0,154.73,4868.94 256,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,128,0,171.882,moe_ck2stages_gemm1_256x128x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,152.5554,moe_ck2stages_gemm2_256x128x128x64_1x4_TypeCast_v3_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,324.4374,0,238.29,3774.96 +256,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,89.9594,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,51.0022,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,140.9616,0,8.57,8571.28 +256,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,89.9594,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,51.0022,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,140.9616,0,8.57,8571.28 +256,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,89.9594,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,51.0022,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,140.9616,0,8.57,8571.28 +256,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,89.9594,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,51.0022,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,140.9616,0,8.57,8571.28 256,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,89.9594,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,51.0022,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,140.9616,0,8.57,8571.28 256,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,127.4464,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,69.0267,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,196.4731,0,12.3,6150.89 256,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,136.676,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,75.4552,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,212.1312,0,22.78,5699.34 256,128,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,139.2281,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,77.6845,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,216.9126,0,44.55,5578.55 256,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,141.0976,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,80.1778,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,221.2754,0,87.35,5478.03 +256,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,131.8625,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,70.246,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.4%,202.1085,0,6.97,13948.11 +256,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,131.8625,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,70.246,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.4%,202.1085,0,6.97,13948.11 +256,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,131.8625,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,70.246,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.4%,202.1085,0,6.97,13948.11 +256,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,131.8625,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,70.246,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.4%,202.1085,0,6.97,13948.11 256,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,131.8625,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,70.246,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.4%,202.1085,0,6.97,13948.11 256,32,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,198.5347,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,102.7245,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.4%,301.2592,0,9.36,9359.02 256,64,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,277.5506,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,141.6194,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,419.17,0,13.45,6728.55 @@ -107,6 +167,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,307.9132,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,170.5755,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,478.4887,0,47.12,5905.91 256,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,310.8521,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,187.6128,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,498.4649,0,90.47,5683.96 256,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,325.8822,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,228.9235,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,554.8057,0,162.57,5133.21 +256,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,128.4088,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,71.9127,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,200.3215,0,7.04,14072.53 +256,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,128.4088,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,71.9127,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,200.3215,0,7.04,14072.53 +256,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,128.4088,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,71.9127,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,200.3215,0,7.04,14072.53 +256,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,128.4088,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,71.9127,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,200.3215,0,7.04,14072.53 256,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,128.4088,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,71.9127,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,200.3215,0,7.04,14072.53 256,32,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,198.743,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,102.3427,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,301.0857,0,9.36,9364.41 256,64,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,278.5912,moe_ck2stages_gemm1_256x64x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,137.9968,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,416.588,0,13.53,6770.26 @@ -114,6 +178,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,307.5774,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,164.881,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,472.4584,0,47.73,5981.29 256,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,311.7026,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,177.8505,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,489.5531,0,92.12,5787.43 256,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,326.6863,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,223.1283,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,549.8146,0,164.04,5179.81 +256,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,67.999,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,40.5166,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.4%,108.5156,0,12.99,12990.12 +256,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,67.999,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,40.5166,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.4%,108.5156,0,12.99,12990.12 +256,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,67.999,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,40.5166,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.4%,108.5156,0,12.99,12990.12 +256,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,67.999,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,40.5166,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.4%,108.5156,0,12.99,12990.12 256,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,67.999,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,40.5166,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.4%,108.5156,0,12.99,12990.12 256,32,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,102.4854,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,56.7236,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.4%,159.209,0,17.7,8856.12 256,64,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,140.7195,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,75.8797,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,216.5992,0,26.03,6512.78 @@ -121,6 +189,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,158.7735,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,92.6941,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,251.4676,0,89.67,5626.14 256,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,162.9055,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,114.6803,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,277.5858,0,162.46,5116.6 256,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,64,0,171.438,moe_ck2stages_gemm1_256x64x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,205.2641,moe_ck2stages_gemm2_256x64x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,376.7021,0,239.43,3799.57 +256,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,67.7683,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,40.8282,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,108.5965,0,12.98,12980.44 +256,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,67.7683,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,40.8282,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,108.5965,0,12.98,12980.44 +256,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,67.7683,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,40.8282,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,108.5965,0,12.98,12980.44 +256,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,67.7683,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,40.8282,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,108.5965,0,12.98,12980.44 256,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,67.7683,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,40.8282,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,108.5965,0,12.98,12980.44 256,32,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,104.0822,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,57.7224,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,161.8046,0,17.42,8714.06 256,64,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,142.2581,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,77.8633,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,220.1214,0,25.61,6408.57 @@ -128,6 +200,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,158.256,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,94.9442,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,253.2002,0,89.05,5587.64 256,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,162.1092,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,114.1086,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,276.2178,0,163.27,5141.94 256,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,64,0,169.5988,moe_ck2stages_gemm1_256x64x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,205.5691,moe_ck2stages_gemm2_256x64x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,375.1679,0,240.41,3815.11 +256,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,68.2374,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_silu_F8_F8_B16,0.0%,46.5664,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.4%,114.8038,0,12.28,12278.6 +256,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,68.2374,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_silu_F8_F8_B16,0.0%,46.5664,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.4%,114.8038,0,12.28,12278.6 +256,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,68.2374,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_silu_F8_F8_B16,0.0%,46.5664,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.4%,114.8038,0,12.28,12278.6 +256,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,68.2374,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_silu_F8_F8_B16,0.0%,46.5664,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.4%,114.8038,0,12.28,12278.6 256,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,68.2374,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_silu_F8_F8_B16,0.0%,46.5664,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.4%,114.8038,0,12.28,12278.6 256,32,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,100.638,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf2E,0.0%,64.7122,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.3%,165.3502,0,17.05,8527.2 256,64,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,139.9452,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x256_2tg_pf3E,0.0%,84.8694,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.3%,224.8146,0,25.07,6274.78 @@ -135,7 +211,15 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,267.3292,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0,267.3292,1,84.35,5292.32 256,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,272.5758,_ZN5aiter48fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_ps_32x256E,0.0%,0.0,Null,0,272.5758,1,165.45,5210.65 256,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,367.9317,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0,367.9317,1,245.14,3890.14 +256,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,68.7365,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,44.3023,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.0388,0,12.47,12470.32 +256,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,68.7365,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,44.3023,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.0388,0,12.47,12470.32 +256,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,68.7365,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,44.3023,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.0388,0,12.47,12470.32 +256,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,68.7365,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,44.3023,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.0388,0,12.47,12470.32 256,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,68.7365,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,44.3023,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.0388,0,12.47,12470.32 +256,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,268.2034,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,134.7329,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,402.9363,0,7.0,13991.3 +256,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,268.2034,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,134.7329,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,402.9363,0,7.0,13991.3 +256,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,268.2034,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,134.7329,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,402.9363,0,7.0,13991.3 +256,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,268.2034,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,134.7329,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,402.9363,0,7.0,13991.3 256,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,268.2034,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,134.7329,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,402.9363,0,7.0,13991.3 256,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,380.515,moe_ck2stages_gemm1_256x64x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,194.6522,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,575.1672,0,9.8,9802.47 256,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,563.744,moe_ck2stages_gemm1_256x64x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,268.523,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,832.267,0,13.55,6775.45 @@ -143,6 +227,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,613.3676,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,316.2011,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,929.5687,0,48.51,6072.15 256,512,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,621.3253,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,334.5202,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,955.8455,0,94.36,5912.91 256,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,647.0892,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,363.288,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,1010.3772,0,178.54,5608.31 +256,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,138.6611,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,69.9939,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,208.655,0,13.51,13509.94 +256,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,138.6611,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,69.9939,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,208.655,0,13.51,13509.94 +256,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,138.6611,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,69.9939,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,208.655,0,13.51,13509.94 +256,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,138.6611,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,69.9939,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,208.655,0,13.51,13509.94 256,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,138.6611,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,69.9939,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,208.655,0,13.51,13509.94 256,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,194.5186,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,101.0683,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,295.5869,0,19.07,9537.84 256,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,278.8859,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2E,0.0%,141.0514,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,0.0,419,26.85,6715.17 @@ -150,6 +238,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,313.5176,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,168.5062,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,482.0238000000001,0,93.56,5858.79 256,512,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,32,0,316.7569,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2E,0.0%,183.8783,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,500.6352,0,180.16,5651.98 256,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,64,0,328.9474,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_64x256_pf2E,0.0%,225.6663,moe_ck2stages_gemm2_256x64x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,554.6137,0,325.25,5121.75 +256,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,138.8701,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,70.2409,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,209.111,0,13.48,13480.48 +256,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,138.8701,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,70.2409,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,209.111,0,13.48,13480.48 +256,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,138.8701,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,70.2409,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,209.111,0,13.48,13480.48 +256,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,138.8701,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,70.2409,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,209.111,0,13.48,13480.48 256,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,138.8701,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,70.2409,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,209.111,0,13.48,13480.48 256,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,195.414,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,101.3303,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,296.74429999999995,0,19.0,9500.64 256,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,275.5829,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf2E,0.0%,141.8053,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,417.3882,0,27.01,6756.18 @@ -157,6 +249,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,314.331,moe_ck2stages_gemm1_256x32x64x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,169.9851,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,484.3161,0,93.12,5831.06 256,512,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,32,0,315.989,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf2E,0.0%,183.8335,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,499.8225,0,180.45,5661.17 256,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,64,0,328.6203,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_64x256_pf2E,0.0%,227.39,moe_ck2stages_gemm2_256x64x128x256_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,556.0102999999999,0,324.43,5108.88 +256,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,130.3234,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x256_2tg_pf3E,0.0%,74.9639,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.8%,205.2873,0,13.73,13731.57 +256,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,130.3234,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x256_2tg_pf3E,0.0%,74.9639,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.8%,205.2873,0,13.73,13731.57 +256,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,130.3234,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x256_2tg_pf3E,0.0%,74.9639,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.8%,205.2873,0,13.73,13731.57 +256,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,130.3234,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x256_2tg_pf3E,0.0%,74.9639,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.8%,205.2873,0,13.73,13731.57 256,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,130.3234,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x256_2tg_pf3E,0.0%,74.9639,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.8%,205.2873,0,13.73,13731.57 256,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,298.4201,_ZN5aiter48fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_ps_32x512E,0.0%,0.0,Null,0,298.4201,1,18.89,9447.29 256,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,266.0248,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2E,0.0%,146.7599,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.8%,412.7847,0,27.31,6831.52 @@ -164,6 +260,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,304.5903,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2E,0.0%,178.5369,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.8%,483.1272,0,93.34,5845.41 256,512,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,498.7106,_ZN5aiter48fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_ps_32x512E,0.0%,0.0,Null,0,498.7106,1,180.86,5673.8 256,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,333.1281,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf2E,0.0%,249.1231,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.8%,582.2512,0,309.81,4878.64 +256,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,130.2529,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf3E,0.0%,74.4948,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,204.7477,0,13.77,13767.76 +256,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,130.2529,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf3E,0.0%,74.4948,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,204.7477,0,13.77,13767.76 +256,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,130.2529,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf3E,0.0%,74.4948,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,204.7477,0,13.77,13767.76 +256,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,130.2529,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf3E,0.0%,74.4948,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,204.7477,0,13.77,13767.76 256,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,130.2529,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf3E,0.0%,74.4948,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,204.7477,0,13.77,13767.76 256,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,195.8963,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf3E,0.0%,103.7824,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,299.6787,0,18.81,9407.61 256,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,265.4483,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf2E,0.0%,142.3836,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,407.8319,0,27.64,6914.49 @@ -171,6 +271,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,304.4449,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf2E,0.0%,177.3471,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,481.792,0,93.6,5861.61 256,512,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,306.5738,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf2E,0.0%,190.3011,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,496.8749,0,181.52,5694.76 256,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,332.9031,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf2E,0.0%,245.6419,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,578.545,0,311.8,4909.89 +256,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,89.9773,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,50.9117,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.3%,140.889,0,8.57,8575.7 +256,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,89.9773,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,50.9117,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.3%,140.889,0,8.57,8575.7 +256,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,89.9773,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,50.9117,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.3%,140.889,0,8.57,8575.7 +256,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,89.9773,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,50.9117,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.3%,140.889,0,8.57,8575.7 256,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,89.9773,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,50.9117,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.3%,140.889,0,8.57,8575.7 256,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,127.6678,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,69.4848,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.4%,197.1526,0,12.25,6129.69 256,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,136.4991,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,75.3558,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.4%,211.8549,0,22.81,5706.77 @@ -178,6 +282,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,141.1941,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,80.8519,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.4%,222.046,0,87.04,5459.02 256,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,144.5293,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,104.6954,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.4%,249.2247,0,155.1,4880.53 256,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,128,0,169.6864,moe_ck2stages_gemm1_256x128x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,153.1648,moe_ck2stages_gemm2_256x128x128x64_1x4_TypeCast_v3_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.4%,322.8512,0,239.46,3793.5 +256,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,90.339,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,50.8203,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,141.1593,0,8.56,8559.28 +256,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,90.339,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,50.8203,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,141.1593,0,8.56,8559.28 +256,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,90.339,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,50.8203,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,141.1593,0,8.56,8559.28 +256,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,90.339,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,50.8203,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,141.1593,0,8.56,8559.28 256,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,90.339,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,50.8203,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,141.1593,0,8.56,8559.28 256,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,127.4505,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,69.9278,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,197.3783,0,12.24,6122.68 256,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,136.5934,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,76.0763,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,212.6697,0,22.72,5684.91 @@ -185,6 +293,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,140.488,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,80.8556,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,221.3436,0,87.32,5476.34 256,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,64,0,144.855,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,97.8569,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,242.7119,0,159.26,5011.49 256,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,128,0,171.8959,moe_ck2stages_gemm1_256x128x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,144.4775,moe_ck2stages_gemm2_256x128x128x64_1x4_TypeCast_v3_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,316.3734,0,244.36,3871.17 +256,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,50.9507,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,42.2681,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,93.2188,0,12.96,6481.27 +256,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,50.9507,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,42.2681,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,93.2188,0,12.96,6481.27 +256,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,50.9507,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,42.2681,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,93.2188,0,12.96,6481.27 +256,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,50.9507,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,42.2681,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,93.2188,0,12.96,6481.27 256,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,50.9507,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,42.2681,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,93.2188,0,12.96,6481.27 256,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,79.6392,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,54.2772,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,133.9164,0,18.04,4513.06 256,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,83.5661,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,60.2323,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,143.79840000000002,0,33.6,4205.65 @@ -192,6 +304,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,85.0781,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,66.2997,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,151.3778,0,127.68,4010.66 256,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,87.3865,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,82.0796,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,169.46609999999998,0,228.1,3601.14 256,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,92.393,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,125.505,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,217.898,0,354.8,2829.59 +256,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,51.387,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,41.0638,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,92.4508,0,13.07,6535.11 +256,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,51.387,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,41.0638,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,92.4508,0,13.07,6535.11 +256,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,51.387,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,41.0638,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,92.4508,0,13.07,6535.11 +256,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,51.387,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,41.0638,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,92.4508,0,13.07,6535.11 256,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,51.387,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,41.0638,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,92.4508,0,13.07,6535.11 256,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,81.0339,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,51.9652,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,132.9991,0,18.16,4544.19 256,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,84.9852,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,57.2253,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,142.2105,0,33.98,4252.61 @@ -199,6 +315,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,86.6075,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,64.6415,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,151.249,0,127.78,4014.08 256,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,91.2218,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,82.8771,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,174.0989,0,222.03,3505.31 256,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,93.6859,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,125.2007,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,218.8866,0,353.19,2816.81 +256,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,76.0789,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x192E,0.0%,0.0,Null,0,76.0789,1,15.88,7941.44 +256,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,76.0789,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x192E,0.0%,0.0,Null,0,76.0789,1,15.88,7941.44 +256,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,76.0789,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x192E,0.0%,0.0,Null,0,76.0789,1,15.88,7941.44 +256,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,76.0789,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x192E,0.0%,0.0,Null,0,76.0789,1,15.88,7941.44 256,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,76.0789,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x192E,0.0%,0.0,Null,0,76.0789,1,15.88,7941.44 256,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,105.3523,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x192E,0.0%,0.0,Null,0,105.3523,1,22.93,5736.69 256,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,115.4816,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x192E,0.0%,0.0,Null,0,115.4816,1,41.84,5236.91 @@ -206,6 +326,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,120.3527,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x192E,0.0%,0.0,Null,0,120.3527,1,160.59,5044.55 256,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,149.2682,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x384E,0.0%,0.0,Null,0,149.2682,1,258.96,4088.42 256,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,128,0,91.8397,_ZN5aiter45fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf3E,0.0%,159.1852,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.4%,251.0249,0,307.98,2456.18 +256,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,49.7186,_ZN5aiter54fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x128_pf3E,0.0%,64.1778,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.8964,0,10.61,5304.61 +256,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,49.7186,_ZN5aiter54fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x128_pf3E,0.0%,64.1778,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.8964,0,10.61,5304.61 +256,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,49.7186,_ZN5aiter54fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x128_pf3E,0.0%,64.1778,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.8964,0,10.61,5304.61 +256,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,49.7186,_ZN5aiter54fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x128_pf3E,0.0%,64.1778,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.8964,0,10.61,5304.61 256,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,49.7186,_ZN5aiter54fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x128_pf3E,0.0%,64.1778,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.8964,0,10.61,5304.61 256,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,80.6595,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x64_pf3E,0.0%,84.4366,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,165.09609999999998,0,14.63,3660.73 256,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,83.8169,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x64_pf3E,0.0%,88.246,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,172.0629,0,28.08,3514.8 @@ -219,6 +343,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,159.5203,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf3E,0.0%,103.4399,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,262.9602,0,85.75,5380.25 256,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,160.4127,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf2E,0.0%,132.4101,moe_ck2stages_gemm2_256x32x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,292.82280000000003,0,154.01,4850.36 256,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,64,0,162.8098,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_64x128_2tg_pf2E,0.0%,221.1124,moe_ck2stages_gemm2_256x64x128x256_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,383.9222,0,234.93,3728.12 +256,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,269.3232,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,134.7722,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.8%,404.0954,0,6.98,13951.17 +256,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,269.3232,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,134.7722,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.8%,404.0954,0,6.98,13951.17 +256,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,269.3232,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,134.7722,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.8%,404.0954,0,6.98,13951.17 +256,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,269.3232,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,134.7722,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.8%,404.0954,0,6.98,13951.17 256,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,269.3232,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,134.7722,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.8%,404.0954,0,6.98,13951.17 256,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,381.0416,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,195.7302,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.9%,576.7718,0,9.77,9775.2 256,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,562.6212,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,271.5572,moe_ck2stages_gemm2_256x32x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.8%,834.1784,0,13.52,6759.92 @@ -227,6 +355,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,512,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,622.7788,moe_ck2stages_gemm1_256x64x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,341.9723,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.8%,964.7511,0,93.49,5858.32 256,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,64,0,649.2553,moe_ck2stages_gemm1_256x64x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,371.4126,moe_ck2stages_gemm2_256x64x128x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.8%,1020.6679,0,176.74,5551.76 256,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,128,0,171.7102,moe_ck2stages_gemm1_256x128x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,144.3167,moe_ck2stages_gemm2_256x128x128x64_1x4_TypeCast_v3_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,316.0269,0,244.63,3875.42 +256,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,50.0987,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,42.6821,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,92.7808,0,13.02,6511.87 +256,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,50.0987,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,42.6821,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,92.7808,0,13.02,6511.87 +256,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,50.0987,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,42.6821,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,92.7808,0,13.02,6511.87 +256,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,50.0987,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,42.6821,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,92.7808,0,13.02,6511.87 256,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,50.0987,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,42.6821,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,92.7808,0,13.02,6511.87 256,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,78.873,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,54.6731,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,133.5461,0,18.09,4525.58 256,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,82.7604,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,60.2229,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,142.9833,0,33.79,4229.63 @@ -234,6 +366,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,84.7124,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,66.4881,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,151.2005,0,127.83,4015.37 256,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,86.6188,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,83.3428,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,169.96159999999998,0,227.43,3590.64 256,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,0,128,0,92.6237,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,125.9531,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,218.5768,0,353.69,2820.81 +256,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,51.0702,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,41.1541,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,92.2243,0,13.1,6551.16 +256,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,51.0702,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,41.1541,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,92.2243,0,13.1,6551.16 +256,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,51.0702,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,41.1541,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,92.2243,0,13.1,6551.16 +256,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,51.0702,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,41.1541,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,92.2243,0,13.1,6551.16 256,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,51.0702,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,41.1541,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,92.2243,0,13.1,6551.16 256,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,81.6357,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,52.029,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,133.6647,0,18.07,4521.56 256,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,84.8817,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,56.8257,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,141.7074,0,34.1,4267.71 @@ -241,6 +377,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,86.7569,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,64.6359,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,151.39280000000002,0,127.66,4010.27 256,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,88.2869,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,82.4745,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,170.7614,0,226.37,3573.82 256,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Tensor,1,1,128,0,93.1182,moe_ck2stages_gemm1_256x128x64x128_1x4_TypeCast_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,125.1912,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,218.3094,0,354.13,2824.26 +256,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,76.4501,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x192E,0.0%,0.0,Null,0,76.4501,1,15.8,7902.89 +256,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,76.4501,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x192E,0.0%,0.0,Null,0,76.4501,1,15.8,7902.89 +256,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,76.4501,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x192E,0.0%,0.0,Null,0,76.4501,1,15.8,7902.89 +256,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,76.4501,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x192E,0.0%,0.0,Null,0,76.4501,1,15.8,7902.89 256,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,76.4501,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x192E,0.0%,0.0,Null,0,76.4501,1,15.8,7902.89 256,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,104.0048,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x192E,0.0%,0.0,Null,0,104.0048,1,23.23,5811.01 256,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,117.2665,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x192E,0.0%,0.0,Null,0,117.2665,1,41.2,5157.2 @@ -248,6 +388,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,122.2554,_ZN5aiter48fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_ps_32x192E,0.0%,0.0,Null,0,122.2554,1,158.09,4966.04 256,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,32,0,148.475,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_gelu_1tg_32x384E,0.0%,0.0,Null,0,148.475,1,260.34,4110.26 256,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,0,128,0,92.8309,_ZN5aiter45fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf3E,0.0%,159.4309,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.5%,252.2618,0,306.46,2444.14 +256,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,49.6325,_ZN5aiter54fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x128_pf3E,0.0%,64.2871,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.9196,0,10.6,5303.53 +256,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,49.6325,_ZN5aiter54fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x128_pf3E,0.0%,64.2871,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.9196,0,10.6,5303.53 +256,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,49.6325,_ZN5aiter54fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x128_pf3E,0.0%,64.2871,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.9196,0,10.6,5303.53 +256,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,49.6325,_ZN5aiter54fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x128_pf3E,0.0%,64.2871,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.9196,0,10.6,5303.53 256,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,49.6325,_ZN5aiter54fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x128_pf3E,0.0%,64.2871,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,113.9196,0,10.6,5303.53 256,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,80.9233,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x64_pf3E,0.0%,84.6624,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,165.5857,0,14.59,3649.91 256,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,84.5381,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x64_pf3E,0.0%,88.2359,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,172.774,0,27.97,3500.33 @@ -255,6 +399,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,86.7616,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x64_pf3E,0.0%,94.7695,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,181.5311,0,106.47,3344.47 256,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,89.8093,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x64_pf3E,0.0%,111.8215,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,201.6308,0,191.71,3026.68 256,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,128,0,94.4833,_ZN5aiter54fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_128x128_pf3E,0.0%,162.4132,moe_ck2stages_gemm2_256x128x128x128_1x4_TypeCast_v3_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,256.8965,0,300.94,2400.04 +80,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,238.4483,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,155.017,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,393.4653,0,3.58,7164.62 +80,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,238.4483,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,155.017,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,393.4653,0,3.58,7164.62 +80,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,238.4483,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,155.017,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,393.4653,0,3.58,7164.62 +80,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,238.4483,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,155.017,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,393.4653,0,3.58,7164.62 80,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,238.4483,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,155.017,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,393.4653,0,3.58,7164.62 80,32,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,367.326,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,243.7681,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,611.0941,0,4.61,4613.84 80,64,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,485.1197,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,319.3392,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,804.4589000000001,0,7.01,3505.97 @@ -262,6 +410,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,588.4224,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,400.0835,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,988.5059,0,22.81,2858.77 80,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,617.8863,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,423.4145,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,1041.3008,0,43.31,2720.88 80,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,783.2767,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,600.1759,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.3%,1383.4526,0,65.2,2058.57 +80,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,238.507,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,146.2522,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,384.7592,0,3.66,7326.74 +80,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,238.507,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,146.2522,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,384.7592,0,3.66,7326.74 +80,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,238.507,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,146.2522,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,384.7592,0,3.66,7326.74 +80,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,238.507,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,146.2522,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,384.7592,0,3.66,7326.74 80,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,238.507,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,146.2522,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,384.7592,0,3.66,7326.74 80,32,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,376.8336,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,233.1548,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,609.9884,0,4.62,4622.2 80,64,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,484.4582,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,310.4899,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,794.9481,0,7.09,3547.91 @@ -269,6 +421,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,591.6172,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,377.301,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,968.9182,0,23.27,2916.56 80,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,621.3374,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,402.5935,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,1023.9309,0,44.04,2767.03 80,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,784.9765,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,567.3861,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,1352.3626,0,66.69,2105.89 +80,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,121.6233,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,117.3629,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,238.9862,0,5.9,5898.37 +80,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,121.6233,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,117.3629,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,238.9862,0,5.9,5898.37 +80,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,121.6233,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,117.3629,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,238.9862,0,5.9,5898.37 +80,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,121.6233,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,117.3629,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,238.9862,0,5.9,5898.37 80,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,121.6233,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,117.3629,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,238.9862,0,5.9,5898.37 80,32,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,194.485,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,176.724,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,371.209,0,7.59,3798.33 80,64,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,251.2934,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,233.6045,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,484.8979,0,11.63,2909.19 @@ -276,6 +432,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,300.6316,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,279.9782,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,580.6098,0,38.84,2436.73 80,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,312.6244,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,296.8907,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,609.5151,0,73.99,2330.21 80,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,394.9761,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x256_2tg_pf3E,0.0%,427.6178,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.3%,822.5939,0,109.65,1739.99 +80,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,120.6124,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,113.2457,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,233.8581,0,6.03,6027.72 +80,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,120.6124,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,113.2457,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,233.8581,0,6.03,6027.72 +80,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,120.6124,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,113.2457,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,233.8581,0,6.03,6027.72 +80,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,120.6124,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,113.2457,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,233.8581,0,6.03,6027.72 80,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,120.6124,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,113.2457,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,233.8581,0,6.03,6027.72 80,32,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,188.7522,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,170.2119,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,358.96410000000003,0,7.85,3927.9 80,64,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,254.6539,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.1%,220.0918,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,474.7457,0,11.87,2971.41 @@ -283,6 +443,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,299.2146,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,268.1991,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,567.4137000000001,0,39.74,2493.4 80,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,316.0831,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,287.2443,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,603.3274,0,74.75,2354.11 80,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,398.3967,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf3E,0.0%,410.8972,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,809.2939,0,111.45,1768.59 +80,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,239.0773,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,239.0773,1,5.89,5896.13 +80,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,239.0773,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,239.0773,1,5.89,5896.13 +80,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,239.0773,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,239.0773,1,5.89,5896.13 +80,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,239.0773,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,239.0773,1,5.89,5896.13 80,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,239.0773,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,239.0773,1,5.89,5896.13 80,32,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,361.5408,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,361.5408,1,7.8,3899.9 80,64,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,409.1668,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,409.1668,1,13.78,3447.65 @@ -290,6 +454,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,534.8764,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,534.8764,1,42.16,2645.08 80,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,560.9956,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,560.9956,1,80.39,2531.74 80,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,723.0851,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,723.0851,1,124.74,1979.44 +80,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,438.8756,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,256.4253,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.9%,695.3009,0,4.05,8108.15 +80,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,438.8756,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,256.4253,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.9%,695.3009,0,4.05,8108.15 +80,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,438.8756,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,256.4253,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.9%,695.3009,0,4.05,8108.15 +80,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,438.8756,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,256.4253,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.9%,695.3009,0,4.05,8108.15 80,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,438.8756,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,256.4253,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.9%,695.3009,0,4.05,8108.15 80,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,714.79,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,400.2169,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.9%,1115.0069,0,5.06,5056.53 80,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,998.4164,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,561.412,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.9%,1559.8284,0,7.23,3615.13 @@ -300,6 +468,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,504.1748,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf3E,0.0%,378.662,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.9%,882.8368,0,12.77,3194.19 80,128,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,570.6719,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.1%,417.051,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.8%,987.7229,0,22.83,2856.39 80,256,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,597.6775,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,432.5419,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.8%,1030.2194,0,43.77,2741.24 +80,1,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,88.0761,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,66.5013,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,154.5774,0,3.91,3908.99 +80,2,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,88.0761,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,66.5013,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,154.5774,0,3.91,3908.99 +80,4,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,88.0761,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,66.5013,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,154.5774,0,3.91,3908.99 +80,8,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,88.0761,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,66.5013,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,154.5774,0,3.91,3908.99 80,16,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,88.0761,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,66.5013,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,154.5774,0,3.91,3908.99 80,32,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,116.9005,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,89.872,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,206.7725,0,5.84,2923.52 80,64,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,130.2259,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,101.7908,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,232.0167,0,10.41,2607.69 @@ -307,6 +479,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,139.8396,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,107.8712,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,247.7108,0,39.01,2455.18 80,512,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,176.0293,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,152.7806,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,328.8099,0,58.78,1862.38 80,1024,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,283.4477,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,254.2291,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,537.6768,0,71.89,1154.52 +80,1,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,88.5875,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,61.2038,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,149.7913,0,4.03,4033.89 +80,2,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,88.5875,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,61.2038,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,149.7913,0,4.03,4033.89 +80,4,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,88.5875,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,61.2038,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,149.7913,0,4.03,4033.89 +80,8,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,88.5875,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,61.2038,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,149.7913,0,4.03,4033.89 80,16,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,88.5875,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,61.2038,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,149.7913,0,4.03,4033.89 80,32,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,116.7801,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,82.5179,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,199.298,0,6.06,3033.17 80,64,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,130.0463,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,93.2196,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,223.2659,0,10.82,2709.9 @@ -314,6 +490,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,136.6552,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,99.1414,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,235.7966,0,40.98,2579.23 80,512,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,177.8056,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,140.1446,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,317.9502,0,60.79,1925.99 80,1024,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,284.4745,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,232.9485,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,517.423,0,74.71,1199.71 +80,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,160.613,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,97.024,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.3%,257.637,0,4.69,4689.63 +80,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,160.613,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,97.024,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.3%,257.637,0,4.69,4689.63 +80,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,160.613,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,97.024,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.3%,257.637,0,4.69,4689.63 +80,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,160.613,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,97.024,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.3%,257.637,0,4.69,4689.63 80,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,160.613,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,97.024,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.3%,257.637,0,4.69,4689.63 80,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,214.8572,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,129.0874,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.3%,343.94460000000004,0,7.02,3513.6 80,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,238.0028,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,151.5356,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.3%,389.5384,0,12.4,3103.69 @@ -321,6 +501,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,268.9473,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,167.0488,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.4%,435.9961,0,44.33,2780.19 80,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,347.7216,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,219.8661,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.4%,567.5877,0,68.1,2143.01 80,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,523.4769,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,371.6201,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.4%,895.097,0,86.37,1368.27 +80,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,165.7381,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,94.3172,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,260.0553,0,4.65,4646.02 +80,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,165.7381,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,94.3172,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,260.0553,0,4.65,4646.02 +80,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,165.7381,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,94.3172,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,260.0553,0,4.65,4646.02 +80,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,165.7381,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,94.3172,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,260.0553,0,4.65,4646.02 80,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,165.7381,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,94.3172,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,260.0553,0,4.65,4646.02 80,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,216.9447,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,126.1828,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,343.1275,0,7.04,3521.97 80,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,246.22,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,145.7239,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,391.9439,0,12.33,3084.65 @@ -328,6 +512,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,271.8789,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,161.456,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,433.3349,0,44.6,2797.27 80,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,347.2111,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,210.4569,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,557.668,0,69.31,2181.13 80,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,524.4171,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,354.8123,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,879.2293999999999,0,87.93,1392.97 +80,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,82.6308,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,68.157,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,150.7878,0,8.01,4006.8 +80,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,82.6308,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,68.157,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,150.7878,0,8.01,4006.8 +80,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,82.6308,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,68.157,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,150.7878,0,8.01,4006.8 +80,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,82.6308,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,68.157,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,150.7878,0,8.01,4006.8 80,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,82.6308,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,68.157,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,150.7878,0,8.01,4006.8 80,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,109.4253,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,88.2541,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,197.6794,0,12.22,3057.34 80,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,127.2465,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,102.2132,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,229.4597,0,21.06,2635.61 @@ -335,6 +523,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,134.0808,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,109.8052,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,243.886,0,79.25,2489.38 80,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,179.6681,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,155.5374,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,335.20550000000003,0,115.32,1820.59 80,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,269.7206,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,263.4387,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.4%,533.1593,0,145.0,1156.43 +80,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,83.7141,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,64.4684,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,148.1825,0,8.15,4077.25 +80,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,83.7141,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,64.4684,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,148.1825,0,8.15,4077.25 +80,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,83.7141,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,64.4684,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,148.1825,0,8.15,4077.25 +80,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,83.7141,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,64.4684,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,148.1825,0,8.15,4077.25 80,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,83.7141,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,64.4684,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,148.1825,0,8.15,4077.25 80,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,109.0496,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,83.8951,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,192.9447,0,12.52,3132.36 80,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,127.6734,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,97.1704,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,224.8438,0,21.49,2689.72 @@ -342,6 +534,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,134.708,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,105.5023,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,240.2103,0,80.46,2527.47 80,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,178.7076,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,147.3146,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,326.0222,0,118.56,1871.87 80,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,269.4957,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,248.2335,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,517.7292,0,149.32,1190.9 +80,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,82.7173,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.3%,76.402,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.4%,159.1193,0,7.59,3797.0 +80,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,82.7173,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.3%,76.402,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.4%,159.1193,0,7.59,3797.0 +80,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,82.7173,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.3%,76.402,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.4%,159.1193,0,7.59,3797.0 +80,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,82.7173,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.3%,76.402,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.4%,159.1193,0,7.59,3797.0 80,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,82.7173,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.3%,76.402,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.4%,159.1193,0,7.59,3797.0 80,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,109.8209,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight0_silu_F8_F8_B16,0.0%,99.0655,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.4%,208.8864,0,11.57,2893.31 80,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,119.69,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,114.5365,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.4%,234.2265,0,20.63,2581.97 @@ -349,6 +545,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,129.4978,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,126.7846,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.4%,256.2824,0,75.41,2368.97 80,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,174.904,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,184.6529,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.4%,359.5569,0,107.51,1697.29 80,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,272.1056,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight0_silu_F8_F8_B16,0.0%,314.319,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.4%,586.4246,0,131.83,1051.39 +80,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,82.1262,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,74.17,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,156.2962,0,7.73,3865.59 +80,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,82.1262,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,74.17,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,156.2962,0,7.73,3865.59 +80,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,82.1262,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,74.17,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,156.2962,0,7.73,3865.59 +80,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,82.1262,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,74.17,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,156.2962,0,7.73,3865.59 80,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,82.1262,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,74.17,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,156.2962,0,7.73,3865.59 80,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,110.0096,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,95.6992,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,205.7088,0,11.74,2938.0 80,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,125.1757,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_silu_F8_F8_B16,0.0%,111.2889,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,236.4646,0,20.43,2557.53 @@ -356,11 +556,19 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,136.164,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,123.0773,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,259.2413,0,74.55,2341.93 80,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,173.4458,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,178.424,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,351.8698,0,109.86,1734.37 80,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,268.8237,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_silu_F8_F8_B16,0.0%,301.9153,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,570.739,0,135.45,1080.29 +80,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,165.0797,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,96.9624,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,262.0421,0,4.61,4610.79 +80,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,165.0797,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,96.9624,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,262.0421,0,4.61,4610.79 +80,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,165.0797,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,96.9624,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,262.0421,0,4.61,4610.79 +80,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,165.0797,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,96.9624,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,262.0421,0,4.61,4610.79 80,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,165.0797,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,96.9624,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,262.0421,0,4.61,4610.79 80,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,217.0102,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,129.4305,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,346.4407,0,6.97,3488.28 80,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,241.1054,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,150.0225,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,391.1279,0,12.35,3091.08 80,128,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,256.7095,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,157.2573,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,413.9668,0,23.34,2923.08 80,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,265.3977,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,166.4318,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,431.8295,0,44.76,2807.02 +80,1,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,38.52,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,36.2602,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,74.78020000000001,0,4.04,4040.12 +80,2,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,38.52,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,36.2602,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,74.78020000000001,0,4.04,4040.12 +80,4,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,38.52,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,36.2602,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,74.78020000000001,0,4.04,4040.12 +80,8,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,38.52,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,36.2602,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,74.78020000000001,0,4.04,4040.12 80,16,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,38.52,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,36.2602,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,74.78020000000001,0,4.04,4040.12 80,32,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,58.4576,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,49.4407,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,107.8983,0,5.6,2801.27 80,64,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,62.2694,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,55.5013,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,117.7707,0,10.26,2568.67 @@ -368,6 +576,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,66.8949,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,59.7486,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,126.64350000000002,0,38.15,2401.13 80,512,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,94.7614,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,83.103,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,177.8644,0,54.33,1721.45 80,1024,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,143.5168,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,134.0455,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,277.5623,0,69.63,1118.23 +80,1,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,38.6753,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,33.9165,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,72.5918,0,4.16,4161.92 +80,2,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,38.6753,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,33.9165,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,72.5918,0,4.16,4161.92 +80,4,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,38.6753,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,33.9165,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,72.5918,0,4.16,4161.92 +80,8,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,38.6753,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,33.9165,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,72.5918,0,4.16,4161.92 80,16,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,38.6753,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,33.9165,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,72.5918,0,4.16,4161.92 80,32,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,59.5728,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,45.8094,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,105.3822,0,5.73,2868.15 80,64,2048,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,62.2801,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,51.218,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,113.4981,0,10.64,2665.37 @@ -381,6 +593,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,292.8938,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_silu_F8_F8_B16,0.0%,319.6213,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,612.5151000000001,0,36.81,2309.81 80,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,305.4473,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,349.9039,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,655.3512000000001,0,68.81,2167.23 80,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,394.0843,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf3E,0.0%,517.6212,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,911.7055,0,98.93,1569.92 +80,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,241.3413,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,156.5265,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,397.8678,0,3.54,7085.35 +80,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,241.3413,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,156.5265,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,397.8678,0,3.54,7085.35 +80,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,241.3413,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,156.5265,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,397.8678,0,3.54,7085.35 +80,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,241.3413,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,156.5265,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,397.8678,0,3.54,7085.35 80,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,241.3413,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,156.5265,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.7%,397.8678,0,3.54,7085.35 80,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,368.8396,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,244.9596,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.6%,613.7992,0,4.59,4593.51 80,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,486.2218,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,330.374,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.6%,816.5958,0,6.9,3453.86 @@ -388,6 +604,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,587.3979,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,399.9114,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.6%,987.3093,0,22.84,2862.24 80,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,619.5779,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,428.3394,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.6%,1047.9173,0,43.04,2703.7 80,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,785.2244,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,600.8826,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,1.6%,1386.107,0,65.07,2054.63 +80,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,239.8912,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,147.6741,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,387.5653,0,3.64,7273.69 +80,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,239.8912,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,147.6741,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,387.5653,0,3.64,7273.69 +80,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,239.8912,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,147.6741,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,387.5653,0,3.64,7273.69 +80,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,239.8912,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,147.6741,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,387.5653,0,3.64,7273.69 80,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,239.8912,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,147.6741,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,387.5653,0,3.64,7273.69 80,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,368.433,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,237.6243,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,606.0572999999999,0,4.65,4652.18 80,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,493.7026,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,313.5552,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,807.2578000000001,0,6.98,3493.81 @@ -395,7 +615,15 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,593.7672,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,382.7183,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,976.4855,0,23.09,2893.96 80,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,617.3315,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,411.0096,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,1028.3411,0,43.85,2755.17 80,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,781.9082,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,569.3894,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,1351.2976,0,66.75,2107.55 +80,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,118.5389,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_silu_F8_F8_B16,0.0%,118.5512,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,237.0901,0,5.94,5945.55 +80,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,118.5389,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_silu_F8_F8_B16,0.0%,118.5512,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,237.0901,0,5.94,5945.55 +80,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,118.5389,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_silu_F8_F8_B16,0.0%,118.5512,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,237.0901,0,5.94,5945.55 +80,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,118.5389,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_silu_F8_F8_B16,0.0%,118.5512,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,237.0901,0,5.94,5945.55 80,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,118.5389,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_silu_F8_F8_B16,0.0%,118.5512,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,237.0901,0,5.94,5945.55 +80,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,120.5717,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,117.6181,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.7%,238.1898,0,5.92,5918.1 +80,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,120.5717,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,117.6181,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.7%,238.1898,0,5.92,5918.1 +80,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,120.5717,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,117.6181,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.7%,238.1898,0,5.92,5918.1 +80,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,120.5717,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,117.6181,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.7%,238.1898,0,5.92,5918.1 80,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,120.5717,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,117.6181,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.7%,238.1898,0,5.92,5918.1 80,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,188.5396,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,176.4776,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,365.0172,0,7.72,3862.76 80,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,252.7997,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,230.1231,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.7%,482.9228,0,11.67,2921.09 @@ -403,6 +631,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,301.0401,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,279.1768,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,580.2169,0,38.86,2438.38 80,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,314.4244,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,297.6385,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,612.0629,0,73.68,2320.51 80,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,388.4815,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,429.5016,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,1.6%,817.9830999999999,0,110.26,1749.8 +80,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,121.9866,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,113.0687,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,235.0553,0,6.0,5997.02 +80,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,121.9866,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,113.0687,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,235.0553,0,6.0,5997.02 +80,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,121.9866,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,113.0687,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,235.0553,0,6.0,5997.02 +80,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,121.9866,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,113.0687,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,235.0553,0,6.0,5997.02 80,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,121.9866,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,113.0687,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,235.0553,0,6.0,5997.02 80,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,191.3402,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,169.9426,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,361.2828,0,7.8,3902.69 80,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,253.0207,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,220.8733,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,473.894,0,11.9,2976.75 @@ -411,20 +643,36 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,304.03,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,287.1332,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,591.1632,0,76.29,2402.55 80,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,390.4133,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,410.897,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,801.3103,0,112.56,1786.21 80,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,1538.2556,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,1005.4051,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,2543.6607,0,70.92,2227.7 +80,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,237.9312,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,159.3483,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,397.2795,0,7.09,7095.55 +80,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,237.9312,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,159.3483,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,397.2795,0,7.09,7095.55 +80,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,237.9312,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,159.3483,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,397.2795,0,7.09,7095.55 +80,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,237.9312,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,159.3483,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,397.2795,0,7.09,7095.55 80,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,237.9312,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,159.3483,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,397.2795,0,7.09,7095.55 80,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,369.2362,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,241.8314,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.9%,611.0676,0,9.23,4613.66 80,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,510.6484,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf3E,0.0%,335.3164,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,845.9648,0,13.33,3333.41 80,128,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,582.0633,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_silu_F8_F8_B16,0.0%,408.0174,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,990.0807,0,22.77,2849.59 80,256,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,615.4904,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf3E,0.0%,418.179,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,1033.6694,0,43.63,2732.09 80,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,762.665,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf3E,0.0%,620.1682,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,1382.8332,0,130.45,2054.18 +80,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,231.5308,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,154.9071,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,386.4379,0,7.29,7294.62 +80,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,231.5308,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,154.9071,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,386.4379,0,7.29,7294.62 +80,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,231.5308,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,154.9071,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,386.4379,0,7.29,7294.62 +80,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,231.5308,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,154.9071,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,386.4379,0,7.29,7294.62 80,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,231.5308,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,154.9071,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,386.4379,0,7.29,7294.62 80,128,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,579.7023,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.1%,364.33,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,944.0323,0,23.89,2988.59 80,256,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,601.7144,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,402.2826,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,1003.997,0,44.92,2812.83 80,512,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,625.392,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,411.3256,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,1036.7176,0,87.0,2729.37 80,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,763.616,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf3E,0.0%,611.4583,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,1375.0743,0,131.18,2065.77 +80,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,228.7896,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,171.5575,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.9%,400.3471,0,7.04,7041.18 +80,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,228.7896,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,171.5575,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.9%,400.3471,0,7.04,7041.18 +80,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,228.7896,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,171.5575,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.9%,400.3471,0,7.04,7041.18 +80,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,228.7896,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,171.5575,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.9%,400.3471,0,7.04,7041.18 80,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,228.7896,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,171.5575,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.9%,400.3471,0,7.04,7041.18 80,512,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,618.0944,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.1%,476.0445,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.8%,1094.1389,0,82.43,2586.13 80,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,750.5786,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf3E,0.0%,718.1941,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,2.8%,1468.7727,0,122.82,1933.99 +80,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,231.8332,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,169.7631,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,401.5963,0,7.02,7019.28 +80,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,231.8332,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,169.7631,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,401.5963,0,7.02,7019.28 +80,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,231.8332,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,169.7631,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,401.5963,0,7.02,7019.28 +80,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,231.8332,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,169.7631,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,401.5963,0,7.02,7019.28 80,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,231.8332,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,169.7631,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,401.5963,0,7.02,7019.28 80,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,367.8369,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,260.1737,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,628.0106000000001,0,8.98,4489.19 80,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,495.9422,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf3E,0.0%,361.1859,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,857.1281,0,13.15,3290.0 @@ -434,6 +682,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,752.5409,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf3E,0.0%,695.4453,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,1447.9861999999998,0,124.58,1961.75 80,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,351.9236,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,220.5397,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,572.4633,0,67.52,2124.76 80,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,523.5384,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,372.5179,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.5%,896.0563,0,86.28,1366.81 +80,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,161.7277,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,96.3101,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,258.0378,0,4.68,4682.34 +80,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,161.7277,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,96.3101,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,258.0378,0,4.68,4682.34 +80,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,161.7277,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,96.3101,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,258.0378,0,4.68,4682.34 +80,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,161.7277,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,96.3101,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,258.0378,0,4.68,4682.34 80,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,161.7277,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,96.3101,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,258.0378,0,4.68,4682.34 80,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,209.4214,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,126.4843,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,335.9057,0,7.19,3597.69 80,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,240.9494,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,148.9541,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,389.9035,0,12.39,3100.79 @@ -441,6 +693,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,269.3657,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,164.1089,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,433.4746,0,44.59,2796.37 80,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,352.9213,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,213.1932,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,566.1144999999999,0,68.28,2148.59 80,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,524.8862,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,354.9441,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,879.8303000000001,0,87.87,1392.01 +80,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,83.611,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,67.927,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,151.538,0,7.97,3986.96 +80,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,83.611,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,67.927,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,151.538,0,7.97,3986.96 +80,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,83.611,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,67.927,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,151.538,0,7.97,3986.96 +80,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,83.611,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,67.927,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,151.538,0,7.97,3986.96 80,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,83.611,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,67.927,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,151.538,0,7.97,3986.96 80,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,107.6661,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,87.4868,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,195.1529,0,12.38,3096.92 80,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,125.0556,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,101.2628,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,226.3184,0,21.35,2672.19 @@ -448,6 +704,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,135.4179,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,109.748,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,245.1659,0,78.83,2476.39 80,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,177.8313,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,155.4294,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,333.2607,0,115.99,1831.21 80,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,264.4639,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,263.7878,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,0.5%,528.2517,0,146.35,1167.18 +80,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,84.3415,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,65.6237,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,149.96519999999998,0,8.05,4028.78 +80,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,84.3415,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,65.6237,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,149.96519999999998,0,8.05,4028.78 +80,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,84.3415,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,65.6237,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,149.96519999999998,0,8.05,4028.78 +80,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,84.3415,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,65.6237,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,149.96519999999998,0,8.05,4028.78 80,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,84.3415,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,65.6237,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,149.96519999999998,0,8.05,4028.78 80,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,109.5369,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,83.7014,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,193.2383,0,12.5,3127.6 80,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,130.7478,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,96.7112,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,227.459,0,21.24,2658.79 @@ -455,6 +715,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,134.1437,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,106.1464,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,240.2901,0,80.43,2526.64 80,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,179.8699,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,146.8478,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,326.71770000000004,0,118.31,1867.89 80,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,266.6047,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,248.2173,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,514.822,0,150.17,1197.62 +80,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,82.4918,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.0%,76.3624,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.5%,158.8542,0,7.6,3803.34 +80,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,82.4918,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.0%,76.3624,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.5%,158.8542,0,7.6,3803.34 +80,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,82.4918,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.0%,76.3624,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.5%,158.8542,0,7.6,3803.34 +80,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,82.4918,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.0%,76.3624,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.5%,158.8542,0,7.6,3803.34 80,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,82.4918,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.0%,76.3624,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.5%,158.8542,0,7.6,3803.34 80,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,108.8976,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.0%,99.2988,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.5%,208.1964,0,11.6,2902.9 80,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,126.1647,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,114.79,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.5%,240.9547,0,20.05,2509.88 @@ -462,6 +726,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,136.6099,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,127.3005,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.5%,263.9104,0,73.23,2300.5 80,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,174.3438,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,184.6666,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.5%,359.0104,0,107.67,1699.87 80,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,272.9741,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,313.1772,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,0.5%,586.1513,0,131.89,1051.88 +80,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,84.9173,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,74.1307,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,159.048,0,7.59,3798.7 +80,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,84.9173,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,74.1307,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,159.048,0,7.59,3798.7 +80,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,84.9173,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,74.1307,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,159.048,0,7.59,3798.7 +80,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,84.9173,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,74.1307,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,159.048,0,7.59,3798.7 80,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,84.9173,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,74.1307,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,159.048,0,7.59,3798.7 80,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,110.1287,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,96.0451,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,206.1738,0,11.72,2931.38 80,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,126.6253,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,111.3536,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,237.9789,0,20.3,2541.26 @@ -469,6 +737,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,130.1837,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,123.1176,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,253.3013,0,76.3,2396.85 80,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,177.7956,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,178.1446,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,355.9402,0,108.6,1714.53 80,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,271.6424,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,302.5119,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,574.1543,0,134.65,1073.86 +80,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,120.3829,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.0%,134.4734,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.6%,254.8563,0,5.53,5531.08 +80,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,120.3829,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.0%,134.4734,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.6%,254.8563,0,5.53,5531.08 +80,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,120.3829,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.0%,134.4734,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.6%,254.8563,0,5.53,5531.08 +80,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,120.3829,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.0%,134.4734,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.6%,254.8563,0,5.53,5531.08 80,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,120.3829,_ZN5aiter47fmoe_stage1_bf16_pertokenFp8_g1u1_32x64_4tg_pf3E,0.0%,134.4734,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.6%,254.8563,0,5.53,5531.08 80,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,188.6312,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,205.9065,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.6%,394.5377,0,7.14,3573.74 80,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,249.8568,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf2E,0.0%,269.7859,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.6%,519.6427,0,10.85,2714.68 @@ -476,6 +748,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,302.8986,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,328.0652,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.6%,630.9638,0,35.74,2242.27 80,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,305.7627,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.1%,359.7434,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.6%,665.5061000000001,0,67.76,2134.16 80,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,390.1172,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x256_2tg_pf3E,0.0%,531.1275,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,1.6%,921.2447,0,97.9,1553.67 +80,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,119.4832,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,131.8606,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,251.3438,0,5.61,5608.37 +80,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,119.4832,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,131.8606,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,251.3438,0,5.61,5608.37 +80,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,119.4832,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,131.8606,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,251.3438,0,5.61,5608.37 +80,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,119.4832,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,131.8606,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,251.3438,0,5.61,5608.37 80,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,119.4832,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,131.8606,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,251.3438,0,5.61,5608.37 80,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,192.247,_ZN5aiter56fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x64_4tg_pf3E,0.0%,200.7748,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,393.0218,0,7.17,3587.52 80,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,247.1864,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,262.0145,moe_ck2stages_gemm2_256x32x64x256_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,509.2009,0,11.07,2770.35 @@ -484,6 +760,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,310.6611,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,349.0712,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,659.7322999999999,0,68.36,2152.84 80,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,385.5365,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf3E,0.0%,515.5173,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,901.0538,0,100.1,1588.48 80,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,1545.4564,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,1031.1042,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,2.8%,2576.5606,0,70.01,2199.25 +80,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,442.5097,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,249.4045,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,691.9142,0,4.07,8147.84 +80,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,442.5097,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,249.4045,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,691.9142,0,4.07,8147.84 +80,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,442.5097,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,249.4045,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,691.9142,0,4.07,8147.84 +80,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,442.5097,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,249.4045,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,691.9142,0,4.07,8147.84 80,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,442.5097,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,249.4045,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,691.9142,0,4.07,8147.84 80,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,708.8776,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,395.3054,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,1104.183,0,5.11,5106.09 80,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,1025.0488,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_silu_B16_B16_B16,0.0%,559.4474,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,1584.4962,0,7.12,3558.85 @@ -493,6 +773,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,512,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,625.1681,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x256_2tg_pf2E,0.0%,436.7729,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,2.8%,1061.941,0,84.93,2664.54 80,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,369.4651,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_silu_F8_F8_B16,0.0%,236.5581,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,606.0232,0,9.3,4652.07 80,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,516.6341,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf3E,0.0%,344.0188,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,860.6529,0,13.1,3276.52 +80,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,440.6719,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,255.5291,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.4%,696.201,0,4.05,8097.67 +80,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,440.6719,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,255.5291,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.4%,696.201,0,4.05,8097.67 +80,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,440.6719,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,255.5291,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.4%,696.201,0,4.05,8097.67 +80,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,440.6719,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,255.5291,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.4%,696.201,0,4.05,8097.67 80,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,440.6719,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,255.5291,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.4%,696.201,0,4.05,8097.67 80,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,714.05,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,401.6611,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.4%,1115.7111,0,5.05,5053.34 80,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,985.8866,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,559.8234,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.5%,1545.71,0,7.29,3648.15 @@ -500,6 +784,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,1149.8246,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,670.1762,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,1820.0008,0,24.78,3101.36 80,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,1256.2287,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,729.0988,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,1985.3275,0,45.43,2846.8 80,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,1545.4364,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,1026.4073,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,3.3%,2571.8437000000004,0,70.14,2203.29 +80,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,438.1766,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,247.4425,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,685.6191,0,4.11,8222.65 +80,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,438.1766,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,247.4425,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,685.6191,0,4.11,8222.65 +80,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,438.1766,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,247.4425,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,685.6191,0,4.11,8222.65 +80,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,438.1766,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,247.4425,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,685.6191,0,4.11,8222.65 80,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,438.1766,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,247.4425,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,685.6191,0,4.11,8222.65 80,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,712.3917,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,389.8042,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,1102.1959,0,5.11,5115.3 80,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,979.7129,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,568.5654,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,1548.2783,0,7.28,3642.1 @@ -507,6 +795,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,1171.5851,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,662.7687,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,1834.3538,0,24.58,3077.1 80,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,128,0,27.2341,moe_ck2stages_gemm1_256x128x128x128_1x4_TypeCastExpertWeight_v3_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,1749.6879,moe_ck2stages_gemm2_256x128x128x64_1x4_TypeCast_v3_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,1776.9219999999998,0,50.76,3180.68 80,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,1551.7224,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,998.809,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,2550.5314,0,70.73,2221.7 +80,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,230.2686,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,158.8229,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.4%,389.0915,0,7.24,7244.87 +80,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,230.2686,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,158.8229,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.4%,389.0915,0,7.24,7244.87 +80,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,230.2686,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,158.8229,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.4%,389.0915,0,7.24,7244.87 +80,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,230.2686,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,158.8229,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.4%,389.0915,0,7.24,7244.87 80,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,230.2686,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,158.8229,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.4%,389.0915,0,7.24,7244.87 80,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,372.7495,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.1%,246.5629,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.5%,619.3124,0,9.1,4552.24 80,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,502.1489,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,339.9999,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.4%,842.1488,0,13.39,3348.52 @@ -514,6 +806,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,591.8466,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.1%,398.0473,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.3%,989.8939,0,45.56,2852.91 80,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,626.1058,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x256_2tg_pf2E,0.0%,442.0281,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.3%,1068.1339,0,84.44,2649.09 80,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,0,32,0,761.7135,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf3E,0.0%,621.4747,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight1_F8_F8_B16,3.3%,1383.1882,0,130.42,2053.66 +80,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,227.7593,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,155.545,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,383.3043,0,7.35,7354.25 +80,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,227.7593,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,155.545,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,383.3043,0,7.35,7354.25 +80,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,227.7593,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,155.545,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,383.3043,0,7.35,7354.25 +80,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,227.7593,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,155.545,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,383.3043,0,7.35,7354.25 80,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,227.7593,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant1_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,155.545,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,383.3043,0,7.35,7354.25 80,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,370.0464,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,232.0003,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,602.0467,0,9.36,4682.79 80,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,509.0501,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,326.8678,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,835.9178999999999,0,13.49,3373.48 @@ -521,6 +817,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,610.1172,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,408.2756,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,1018.3928,0,44.28,2773.07 80,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,624.0096,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf2E,0.0%,411.5878,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,1035.5974,0,87.09,2732.32 80,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Tensor,1,1,32,0,759.5362,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf3E,0.0%,596.1779,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant1_MulRoutedWeight0_F8_F8_B16,0.0%,1355.7141,0,133.06,2095.27 +80,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,231.6437,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,173.6438,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.4%,405.2875,0,6.95,6955.35 +80,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,231.6437,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,173.6438,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.4%,405.2875,0,6.95,6955.35 +80,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,231.6437,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,173.6438,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.4%,405.2875,0,6.95,6955.35 +80,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,231.6437,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,173.6438,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.4%,405.2875,0,6.95,6955.35 80,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,231.6437,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,173.6438,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.4%,405.2875,0,6.95,6955.35 80,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,363.3125,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,265.3606,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.4%,628.6731,0,8.97,4484.46 80,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,507.5861,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf3E,0.0%,362.0567,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.4%,869.6428,0,12.96,3242.65 @@ -528,6 +828,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,600.7597,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,432.403,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.3%,1033.1627,0,43.65,2733.43 80,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,620.397,_ZN5aiter48fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf3E,0.0%,467.8512,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.3%,1088.2482,0,82.88,2600.13 80,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,750.417,_ZN5aiter44fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf3E,0.0%,717.4604,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight1_F8_F8_B16,3.3%,1467.8774,0,122.89,1935.17 +80,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,230.2993,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,169.4253,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,399.7246,0,7.05,7052.15 +80,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,230.2993,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,169.4253,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,399.7246,0,7.05,7052.15 +80,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,230.2993,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,169.4253,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,399.7246,0,7.05,7052.15 +80,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,230.2993,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,169.4253,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,399.7246,0,7.05,7052.15 80,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,230.2993,moe_ck2stages_gemm1_256x32x64x256_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,169.4253,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,399.7246,0,7.05,7052.15 80,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,363.9112,moe_ck2stages_gemm1_256x32x64x128_1x4_MulABScale_v1_Nswizzle0_Quant2_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,259.8372,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,623.7484,0,9.04,4519.87 80,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,510.4385,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf3E,0.0%,361.3331,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,871.7716,0,12.93,3234.73 @@ -535,6 +839,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,591.7744,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x128_3tg_pf3E,0.0%,430.4532,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,1022.2276,0,44.12,2762.67 80,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,619.0301,_ZN5aiter57fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x256_2tg_pf2E,0.0%,468.3093,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,1087.3393999999998,0,82.95,2602.3 80,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,1,32,0,749.1329,_ZN5aiter53fmoe_stage1_bf16_pertokenFp8_doweight_g1u1_32x512_pf3E,0.0%,695.2144,moe_ck2stages_gemm2_256x32x64x128_1x4_MulABScaleExpertWeight_v1_Nswizzle0_Quant2_MulRoutedWeight0_F8_F8_B16,0.0%,1444.3473,0,124.89,1966.7 +80,1,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,88.9501,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,66.7261,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,155.6762,0,3.88,3881.4 +80,2,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,88.9501,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,66.7261,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,155.6762,0,3.88,3881.4 +80,4,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,88.9501,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,66.7261,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,155.6762,0,3.88,3881.4 +80,8,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,88.9501,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,66.7261,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,155.6762,0,3.88,3881.4 80,16,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,88.9501,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,66.7261,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,155.6762,0,3.88,3881.4 80,32,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,115.8935,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,90.1026,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,205.9961,0,5.86,2934.54 80,64,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,124.4518,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,102.1206,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,226.5724,0,10.66,2670.35 @@ -542,6 +850,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,139.1257,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,108.5817,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,247.7074,0,39.01,2455.21 80,512,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,177.2587,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,152.9295,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,330.1882,0,58.53,1854.6 80,1024,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,286.2896,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,254.7276,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,541.0172,0,71.45,1147.39 +80,1,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,88.3641,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,61.6683,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,150.0324,0,4.03,4027.41 +80,2,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,88.3641,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,61.6683,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,150.0324,0,4.03,4027.41 +80,4,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,88.3641,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,61.6683,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,150.0324,0,4.03,4027.41 +80,8,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,88.3641,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,61.6683,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,150.0324,0,4.03,4027.41 80,16,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,88.3641,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,61.6683,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,150.0324,0,4.03,4027.41 80,32,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,116.2949,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,83.4229,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,199.7178,0,6.05,3026.79 80,64,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,128.528,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,93.9786,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,222.5066,0,10.86,2719.15 @@ -549,6 +861,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,140.9059,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,100.6281,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,241.534,0,40.01,2517.96 80,512,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,176.6934,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,140.5842,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,317.2776,0,60.92,1930.07 80,1024,4096,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,285.499,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,233.8056,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,519.3046,0,74.44,1195.36 +80,1,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,37.5796,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,36.5426,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,74.12219999999999,0,4.07,4075.98 +80,2,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,37.5796,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,36.5426,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,74.12219999999999,0,4.07,4075.98 +80,4,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,37.5796,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,36.5426,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,74.12219999999999,0,4.07,4075.98 +80,8,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,37.5796,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,36.5426,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,74.12219999999999,0,4.07,4075.98 80,16,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,37.5796,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,36.5426,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,74.12219999999999,0,4.07,4075.98 80,32,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,58.5697,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,49.5953,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,108.165,0,5.58,2794.36 80,64,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,62.6056,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,55.5345,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,118.1401,0,10.22,2560.64 @@ -556,6 +872,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,65.3992,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,59.9428,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,125.34199999999998,0,38.55,2426.06 80,512,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,93.5238,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,83.188,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,176.71179999999998,0,54.69,1732.68 80,1024,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,140.8771,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_gelu_B16_B16_B16,0.0%,134.5253,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.0%,275.4024,0,70.18,1127.0 +80,1,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,38.1892,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,34.1814,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,72.3706,0,4.17,4174.64 +80,2,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,38.1892,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,34.1814,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,72.3706,0,4.17,4174.64 +80,4,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,38.1892,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,34.1814,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,72.3706,0,4.17,4174.64 +80,8,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,38.1892,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,34.1814,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,72.3706,0,4.17,4174.64 80,16,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,38.1892,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,34.1814,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,72.3706,0,4.17,4174.64 80,32,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,59.0984,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,45.8963,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,104.9947,0,5.75,2878.74 80,64,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,62.4816,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,51.5263,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,114.0079,0,10.6,2653.45 @@ -563,6 +883,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,65.5874,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,55.6592,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,121.2466,0,39.85,2508.0 80,512,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,94.7864,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,76.6968,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,171.4832,0,56.35,1785.51 80,1024,2048,192,128,8,ActivationType.Gelu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,1,32,0,144.9248,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCastExpertWeight_v1_Nswizzle0_Quant0_MulRoutedWeight1_gelu_B16_B16_B16,0.0%,123.3403,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_B16_B16_B16,0.0%,268.2651,0,72.05,1156.98 +256,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,71.439,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,41.4885,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,112.9275,0,12.48,12482.61 +256,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,71.439,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,41.4885,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,112.9275,0,12.48,12482.61 +256,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,71.439,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,41.4885,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,112.9275,0,12.48,12482.61 +256,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,71.439,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,41.4885,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,112.9275,0,12.48,12482.61 256,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,71.439,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,41.4885,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,112.9275,0,12.48,12482.61 256,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,120.1868,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,59.0038,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,179.1906,0,15.73,7868.57 256,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,141.1318,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,79.7079,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,220.8397,0,25.53,6387.72 @@ -570,6 +894,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,161.4495,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,118.1586,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,279.6081,0,80.64,5059.91 256,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,203.4884,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,212.7151,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,416.2035,0,108.35,3412.5 256,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,328.7665,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,403.3563,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,732.1228,0,123.2,1955.01 +256,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,137.08,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,71.7868,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,16.0%,208.8668,0,13.49,13496.24 +256,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,137.08,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,71.7868,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,16.0%,208.8668,0,13.49,13496.24 +256,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,137.08,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,71.7868,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,16.0%,208.8668,0,13.49,13496.24 +256,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,137.08,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,71.7868,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,16.0%,208.8668,0,13.49,13496.24 256,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,137.08,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,71.7868,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,16.0%,208.8668,0,13.49,13496.24 256,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,205.7119,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,103.7369,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,16.0%,309.4488,0,18.22,9110.59 256,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,282.9681,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,143.4641,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.9%,426.4322,0,26.44,6612.89 @@ -577,6 +905,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,318.0049,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,179.0098,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,497.0147,0,90.74,5682.08 256,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,403.3586,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,227.8796,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.8%,631.2382,0,142.88,4482.59 256,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,553.5171,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,421.0368,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,974.5539,0,185.1,2914.76 +256,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,137.51,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,71.4507,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.8%,208.9607,0,13.49,13490.17 +256,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,137.51,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,71.4507,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.8%,208.9607,0,13.49,13490.17 +256,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,137.51,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,71.4507,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.8%,208.9607,0,13.49,13490.17 +256,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,137.51,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,71.4507,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.8%,208.9607,0,13.49,13490.17 256,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,137.51,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,71.4507,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.8%,208.9607,0,13.49,13490.17 256,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,206.2526,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,103.5784,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.9%,309.831,0,18.19,9099.35 256,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,282.8631,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,144.8538,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.8%,427.7169,0,26.36,6593.03 @@ -584,6 +916,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,315.2045,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,179.0137,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,494.2182,0,91.25,5714.23 256,512,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,408.1562,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,228.0692,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.8%,636.2254,0,141.76,4447.45 256,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,572.8802,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,420.9959,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.8%,993.8761,0,181.5,2858.1 +256,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,54.4576,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,46.027,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,100.4846,0,12.02,6012.63 +256,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,54.4576,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,46.027,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,100.4846,0,12.02,6012.63 +256,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,54.4576,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,46.027,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,100.4846,0,12.02,6012.63 +256,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,54.4576,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,46.027,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,100.4846,0,12.02,6012.63 256,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,54.4576,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,46.027,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,100.4846,0,12.02,6012.63 256,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,78.6177,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,60.2532,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,138.8709,0,17.4,4352.05 256,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,81.5016,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,66.9561,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,148.4577,0,32.55,4073.66 @@ -591,6 +927,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,86.024,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,141.2712,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,227.2952,0,85.03,2671.09 256,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,91.3559,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,248.1618,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,339.5177,0,113.85,1797.47 256,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,124.3946,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,473.4248,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,597.8194,0,129.32,1031.35 +256,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,55.9955,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,45.9382,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,101.9337,0,11.85,5927.15 +256,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,55.9955,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,45.9382,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,101.9337,0,11.85,5927.15 +256,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,55.9955,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,45.9382,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,101.9337,0,11.85,5927.15 +256,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,55.9955,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,45.9382,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,101.9337,0,11.85,5927.15 256,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,55.9955,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,45.9382,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,101.9337,0,11.85,5927.15 256,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,81.7704,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,60.4356,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,142.206,0,16.99,4249.98 256,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,86.8808,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,66.8218,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,153.7026,0,31.44,3934.65 @@ -598,6 +938,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,90.4642,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,141.6691,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,232.1333,0,83.26,2615.42 256,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,95.7813,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,247.6126,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,343.3939,0,112.57,1777.18 256,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,64,0,128.833,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.1%,473.6369,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,602.4699,0,128.32,1023.39 +256,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,71.6188,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,41.3596,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,112.9784,0,12.47,12476.99 +256,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,71.6188,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,41.3596,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,112.9784,0,12.47,12476.99 +256,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,71.6188,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,41.3596,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,112.9784,0,12.47,12476.99 +256,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,71.6188,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,41.3596,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,112.9784,0,12.47,12476.99 256,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,71.6188,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,41.3596,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,112.9784,0,12.47,12476.99 256,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,120.1253,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,58.8622,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,178.9875,0,15.75,7877.5 256,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,141.4796,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,79.0378,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,220.5174,0,25.56,6397.06 @@ -605,6 +949,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,161.8062,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,117.8189,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,279.6251,0,80.64,5059.6 256,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,203.3392,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,212.9202,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,416.2594,0,108.34,3412.05 256,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,329.0517,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,400.2276,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,729.2793,0,123.68,1962.63 +256,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,137.1212,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,71.9542,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,209.0754,0,13.48,13482.77 +256,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,137.1212,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,71.9542,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,209.0754,0,13.48,13482.77 +256,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,137.1212,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,71.9542,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,209.0754,0,13.48,13482.77 +256,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,137.1212,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,71.9542,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,209.0754,0,13.48,13482.77 256,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,137.1212,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,71.9542,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,209.0754,0,13.48,13482.77 256,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,205.6879,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,103.7835,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,309.4714,0,18.22,9109.92 256,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,280.6797,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,145.7947,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,426.4744,0,26.44,6612.23 @@ -612,6 +960,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,316.5168,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,179.9899,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,496.5067,0,90.83,5687.89 256,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,403.1043,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,229.7632,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,632.8675,0,142.52,4471.05 256,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,553.6587,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,425.3929,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,979.0516,0,184.25,2901.37 +256,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,137.2707,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,71.8847,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,209.1554,0,13.48,13477.62 +256,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,137.2707,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,71.8847,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,209.1554,0,13.48,13477.62 +256,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,137.2707,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,71.8847,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,209.1554,0,13.48,13477.62 +256,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,137.2707,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,71.8847,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,209.1554,0,13.48,13477.62 256,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,137.2707,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,71.8847,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,209.1554,0,13.48,13477.62 256,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,205.8602,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,104.2026,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,310.0628,0,18.18,9092.55 256,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,281.4413,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,145.7705,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,427.2118,0,26.39,6600.82 @@ -619,6 +971,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,314.8954,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,179.522,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,494.4174,0,91.21,5711.93 256,512,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,406.507,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,230.6196,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,637.1266,0,141.56,4441.16 256,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,575.7451,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,426.4277,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,1002.1728,0,180.0,2834.43 +256,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,56.2731,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,45.5448,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,101.8179,0,11.86,5933.89 +256,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,56.2731,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,45.5448,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,101.8179,0,11.86,5933.89 +256,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,56.2731,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,45.5448,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,101.8179,0,11.86,5933.89 +256,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,56.2731,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,45.5448,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,101.8179,0,11.86,5933.89 256,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,56.2731,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,45.5448,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,101.8179,0,11.86,5933.89 256,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,80.9637,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,59.914,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,140.8777,0,17.15,4290.05 256,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,84.98,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,66.4795,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,151.4595,0,31.9,3992.92 @@ -632,6 +988,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,160.0658,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,117.833,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,277.8988,0,81.14,5091.03 256,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,205.7498,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,212.5856,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,418.3354,0,107.8,3395.11 256,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,16,0,327.5497,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.1%,402.6535,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,730.2032,0,123.52,1960.15 +256,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,57.2319,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,44.9913,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,102.2232,0,11.82,5910.36 +256,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,57.2319,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,44.9913,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,102.2232,0,11.82,5910.36 +256,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,57.2319,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,44.9913,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,102.2232,0,11.82,5910.36 +256,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,57.2319,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,44.9913,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,102.2232,0,11.82,5910.36 256,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,57.2319,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,44.9913,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,102.2232,0,11.82,5910.36 256,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,83.8327,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,58.83,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,142.6627,0,16.93,4236.38 256,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,88.5384,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,66.7218,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,155.2602,0,31.12,3895.18 @@ -639,6 +999,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,92.6716,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,141.8887,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,234.5603,0,82.4,2588.36 256,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,95.6695,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,247.8364,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,343.5059,0,112.53,1776.6 256,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,1,64,0,133.8828,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.1%,474.349,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,608.2318,0,127.11,1013.7 +80,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,123.5681,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,83.7681,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,207.3362,0,6.8,6798.77 +80,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,123.5681,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,83.7681,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,207.3362,0,6.8,6798.77 +80,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,123.5681,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,83.7681,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,207.3362,0,6.8,6798.77 +80,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,123.5681,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,83.7681,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,207.3362,0,6.8,6798.77 80,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,123.5681,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,83.7681,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,207.3362,0,6.8,6798.77 80,32,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,204.1789,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,127.8428,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,332.0217,0,8.49,4246.63 80,64,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,256.3235,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,173.5099,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,429.8334,0,13.11,3281.88 @@ -646,10 +1010,22 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,307.9267,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,225.0382,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,532.9649,0,42.31,2654.57 80,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,388.2337,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,321.3807,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,709.6144,0,63.55,2001.5 80,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,557.7338,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,556.9381,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,1114.6719,0,80.92,1284.06 +80,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,245.7053,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,130.3948,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.3%,376.1001,0,7.49,7495.12 +80,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,245.7053,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,130.3948,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.3%,376.1001,0,7.49,7495.12 +80,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,245.7053,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,130.3948,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.3%,376.1001,0,7.49,7495.12 +80,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,245.7053,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,130.3948,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.3%,376.1001,0,7.49,7495.12 80,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,245.7053,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,130.3948,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.3%,376.1001,0,7.49,7495.12 80,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,374.0904,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,216.3424,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,590.4328,0,9.55,4774.9 80,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,500.4382,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,283.1229,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.9%,783.5611,0,14.39,3598.89 +80,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,64,0,218.6275,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,307.0939,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,525.7214,0,2.3,1149.23 +80,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,64,0,218.6275,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,307.0939,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,525.7214,0,2.3,1149.23 +80,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,64,0,218.6275,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,307.0939,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,525.7214,0,2.3,1149.23 +80,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,64,0,218.6275,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,307.0939,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,525.7214,0,2.3,1149.23 80,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,64,0,218.6275,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,307.0939,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,525.7214,0,2.3,1149.23 +80,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,124.4194,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,83.1785,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,207.5979,0,6.79,6790.19 +80,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,124.4194,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,83.1785,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,207.5979,0,6.79,6790.19 +80,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,124.4194,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,83.1785,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,207.5979,0,6.79,6790.19 +80,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,124.4194,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,83.1785,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,207.5979,0,6.79,6790.19 80,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,124.4194,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,83.1785,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,207.5979,0,6.79,6790.19 80,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,204.8947,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,128.9358,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.8%,333.8305,0,8.44,4223.62 80,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,252.4238,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,175.0782,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,427.502,0,13.19,3299.78 @@ -667,6 +1043,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,64,0,296.2706,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,517.6107,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.4%,813.8813,0,23.75,745.96 80,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,64,0,324.2019,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,548.1039,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.5%,872.3058,0,44.31,699.61 80,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,64,0,509.2109,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,865.0837,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.4%,1374.2946,0,56.25,448.64 +80,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,244.2105,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,132.0083,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.3%,376.2188,0,7.49,7492.76 +80,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,244.2105,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,132.0083,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.3%,376.2188,0,7.49,7492.76 +80,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,244.2105,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,132.0083,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.3%,376.2188,0,7.49,7492.76 +80,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,244.2105,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,132.0083,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.3%,376.2188,0,7.49,7492.76 80,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,244.2105,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,132.0083,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.3%,376.2188,0,7.49,7492.76 80,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,366.6026,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,203.0742,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.8%,569.6768,0,9.9,4948.88 80,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,517.7397,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,297.5602,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.9%,815.2999,0,13.83,3458.79 @@ -674,6 +1054,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,596.114,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,376.262,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,972.376,0,46.38,2904.31 80,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,703.0772,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,515.3501,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,1218.4273,0,74.03,2322.32 80,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,1032.1656,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_gelu_F8_F8_B16,0.0%,839.6044,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,1871.77,0,96.37,1517.6 +80,1,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,218.3233,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,305.0925,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.9%,523.4158,0,2.31,1154.3 +80,2,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,218.3233,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,305.0925,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.9%,523.4158,0,2.31,1154.3 +80,4,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,218.3233,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,305.0925,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.9%,523.4158,0,2.31,1154.3 +80,8,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,218.3233,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,305.0925,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.9%,523.4158,0,2.31,1154.3 80,16,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,218.3233,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,305.0925,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.9%,523.4158,0,2.31,1154.3 80,32,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,272.2826,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,435.0107,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.9%,707.2933,0,3.42,854.49 80,64,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,275.5136,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,479.9084,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,755.422,0,6.4,800.57 @@ -681,6 +1065,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,293.1722,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,514.6465,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,807.8187,0,23.93,751.56 80,512,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,316.2687,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,545.4219,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,861.6906,0,44.86,708.23 80,1024,4096,384,128,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,495.1237,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,858.9692,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,1354.0929,0,57.09,455.33 +80,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,124.9899,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,81.8591,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,206.849,0,6.81,6814.78 +80,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,124.9899,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,81.8591,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,206.849,0,6.81,6814.78 +80,4,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,124.9899,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,81.8591,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,206.849,0,6.81,6814.78 +80,8,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,124.9899,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,81.8591,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,206.849,0,6.81,6814.78 80,16,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,124.9899,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,81.8591,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,206.849,0,6.81,6814.78 80,32,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,205.3401,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,128.7497,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,334.0898,0,8.44,4220.35 80,64,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,255.6311,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,170.7752,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,426.4063,0,13.22,3308.26 @@ -688,6 +1076,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,301.0373,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,225.1659,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,526.2032,0,42.85,2688.68 80,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,398.4468,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,319.45,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,717.8968,0,62.82,1978.41 80,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,558.3594,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,555.8992,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,1114.2586,0,80.95,1284.54 +80,1,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,124.1564,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,83.5223,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,207.6787,0,6.79,6787.55 +80,2,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,124.1564,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,83.5223,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,207.6787,0,6.79,6787.55 +80,4,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,124.1564,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,83.5223,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,207.6787,0,6.79,6787.55 +80,8,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,124.1564,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,83.5223,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,207.6787,0,6.79,6787.55 80,16,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,124.1564,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,83.5223,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,207.6787,0,6.79,6787.55 80,32,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,198.0641,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,130.4169,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,328.481,0,8.58,4292.41 80,64,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,254.9635,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,170.7009,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,425.6644,0,13.24,3314.02 @@ -695,6 +1087,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,310.1395,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,218.1749,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,528.3144,0,42.68,2677.93 80,512,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,384.4754,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,319.2123,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,703.6877,0,64.09,2018.36 80,1024,7168,256,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,560.9966,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,549.7671,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,1110.7637,0,81.2,1288.58 +80,1,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,241.8455,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,132.6515,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,374.497,0,7.53,7527.21 +80,2,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,241.8455,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,132.6515,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,374.497,0,7.53,7527.21 +80,4,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,241.8455,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,132.6515,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,374.497,0,7.53,7527.21 +80,8,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,241.8455,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,132.6515,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,374.497,0,7.53,7527.21 80,16,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,241.8455,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,132.6515,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,374.497,0,7.53,7527.21 80,32,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,375.1735,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,214.1587,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.6%,589.3322,0,9.57,4783.82 80,64,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,497.6768,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,302.6195,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,800.2963,0,14.09,3523.63 @@ -702,6 +1098,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,578.2908,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,375.1603,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,953.4511,0,47.3,2961.95 80,512,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,727.1819,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,509.1783,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,1236.3602,0,72.95,2288.64 80,1024,7168,512,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,1059.0782,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_silu_F8_F8_B16,0.0%,841.9574,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,1901.0356,0,94.89,1494.23 +80,1,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,221.5506,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,326.2097,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,547.7603,0,2.21,1102.99 +80,2,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,221.5506,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,326.2097,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,547.7603,0,2.21,1102.99 +80,4,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,221.5506,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,326.2097,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,547.7603,0,2.21,1102.99 +80,8,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,221.5506,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,326.2097,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,547.7603,0,2.21,1102.99 80,16,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,221.5506,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,326.2097,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,547.7603,0,2.21,1102.99 80,32,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,276.3545,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,408.058,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,684.4125,0,3.53,883.05 80,64,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,278.4781,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,520.2659,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,798.744,0,6.05,757.15 @@ -709,6 +1109,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,296.1626,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,515.0018,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,811.1644,0,23.83,748.46 80,512,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,319.2415,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,544.9096,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,864.1511,0,44.73,706.21 80,1024,4096,384,128,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,64,0,500.5853,moe_ck2stages_gemm1_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,854.6331,moe_ck2stages_gemm2_256x64x128x128_1x4_MulABScaleExpertWeightA8W8blkscale_v3_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,1355.2184,0,57.05,454.95 +80,1,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,243.5773,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,134.825,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,378.4023,0,7.45,7449.52 +80,2,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,243.5773,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,134.825,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,378.4023,0,7.45,7449.52 +80,4,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,243.5773,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,134.825,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,378.4023,0,7.45,7449.52 +80,8,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,243.5773,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,134.825,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,378.4023,0,7.45,7449.52 80,16,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,243.5773,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,134.825,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,378.4023,0,7.45,7449.52 80,32,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,377.9925,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,208.2557,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,586.2482,0,9.62,4808.99 80,64,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,508.1175,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,301.8424,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.8%,809.9599,0,13.92,3481.59 @@ -717,6 +1121,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,512,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,734.3088,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,509.0543,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,1243.3631,0,72.54,2275.75 80,1024,7168,512,256,8,ActivationType.Gelu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,1,16,0,1040.5956,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_gelu_F8_F8_B16,0.0%,844.24,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_F8_F8_B16,16.7%,1884.8356,0,95.71,1507.08 80,56,6144,4096,8,2,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,32,0,228.7482,_ZN5aiter50fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_ps_32x256E,0.5%,0.0,Null,0.0%,228.7482,1,73.93,2644.88 +80,1,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,245.0416,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,245.0416,1,6.47,5775.08 +80,2,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,245.0416,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,245.0416,1,6.47,5775.08 +80,4,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,245.0416,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,245.0416,1,6.47,5775.08 +80,8,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,245.0416,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,245.0416,1,6.47,5775.08 80,16,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,245.0416,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,245.0416,1,6.47,5775.08 80,32,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,370.7841,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,370.7841,1,8.55,3817.53 80,64,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,428.9409,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,428.9409,1,14.78,3301.54 @@ -724,6 +1132,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,536.7655,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,536.7655,1,47.26,2646.03 80,512,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,560.4425,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,560.4425,1,90.53,2544.06 80,1024,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,827.4898,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,827.4898,1,122.62,1736.35 +80,1,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,274.0603,_ZN5aiter59fmoe_stage1_bf16_pertokenFp8_blockscale_g1u1_16x256_2tg_pf3E,4.9%,150.3324,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,0.3%,424.3927,0,33.21,3425.3 +80,2,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,274.0603,_ZN5aiter59fmoe_stage1_bf16_pertokenFp8_blockscale_g1u1_16x256_2tg_pf3E,4.9%,150.3324,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,0.3%,424.3927,0,33.21,3425.3 +80,4,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,274.0603,_ZN5aiter59fmoe_stage1_bf16_pertokenFp8_blockscale_g1u1_16x256_2tg_pf3E,4.9%,150.3324,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,0.3%,424.3927,0,33.21,3425.3 +80,8,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,274.0603,_ZN5aiter59fmoe_stage1_bf16_pertokenFp8_blockscale_g1u1_16x256_2tg_pf3E,4.9%,150.3324,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,0.3%,424.3927,0,33.21,3425.3 80,16,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,274.0603,_ZN5aiter59fmoe_stage1_bf16_pertokenFp8_blockscale_g1u1_16x256_2tg_pf3E,4.9%,150.3324,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,0.3%,424.3927,0,33.21,3425.3 80,32,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,16,0,359.0112,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.0%,190.8827,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,0.2%,549.8939,0,51.26,2644.17 80,64,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,32,0,631.2833,_ZN5aiter47fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_32x256E,0.9%,0.0,Null,0.0%,631.2833,1,89.3,2304.36 @@ -731,6 +1143,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 80,256,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,32,0,1166.708,_ZN5aiter47fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_32x256E,0.9%,0.0,Null,0.0%,1166.708,1,193.27,1250.38 80,512,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,32,0,2209.3824,_ZN5aiter47fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_32x256E,0.8%,0.0,Null,0.0%,2209.3824,1,204.12,662.78 80,1024,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_1x128,1,0,32,0,4205.8762,_ZN5aiter50fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_ps_32x256E,0.9%,0.0,Null,0.0%,4205.8762,1,214.45,350.78 +80,1,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,442.3731,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,6.8%,0.0,Null,0.0%,442.3731,1,31.86,3286.07 +80,2,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,442.3731,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,6.8%,0.0,Null,0.0%,442.3731,1,31.86,3286.07 +80,4,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,442.3731,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,6.8%,0.0,Null,0.0%,442.3731,1,31.86,3286.07 +80,8,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,442.3731,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,6.8%,0.0,Null,0.0%,442.3731,1,31.86,3286.07 80,16,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,442.3731,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,6.8%,0.0,Null,0.0%,442.3731,1,31.86,3286.07 80,32,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,520.7061,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x512E,5.4%,0.0,Null,0.0%,520.7061,1,54.13,2792.39 80,64,7168,2048,33,10,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fnuz,torch.float8_e4m3fnuz,QuantType.per_Token,1,0,32,0,622.6569,_ZN5aiter45fmoe_bf16_pertokenFp8_g1u1_vs_silu_1tg_32x256E,6.9%,0.0,Null,0.0%,622.6569,1,90.53,2336.28 @@ -744,6 +1160,10 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,128,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,72.9667,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,41.6406,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,114.6073,0,42.16,5288.29 256,64,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,71.4856,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,39.357,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,110.8426,0,21.8,5458.45 256,32,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,62.7201,moe_ck2stages_gemm1_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,35.3504,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,98.0705,0,12.32,6163.97 +256,1,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,47.8381,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,28.9978,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,76.8359,0,7.86,7864.06 +256,2,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,47.8381,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,28.9978,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,76.8359,0,7.86,7864.06 +256,4,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,47.8381,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,28.9978,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,76.8359,0,7.86,7864.06 +256,8,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,47.8381,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,28.9978,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,76.8359,0,7.86,7864.06 256,16,4096,192,128,8,ActivationType.Silu,torch.bfloat16,torch.bfloat16,torch.bfloat16,QuantType.No,1,0,32,0,47.8381,moe_ck2stages_gemm1_256x32x64x128_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight0_silu_B16_B16_B16,0.0%,28.9978,moe_ck2stages_gemm2_256x32x64x64_1x4_TypeCast_v1_Nswizzle0_Quant0_MulRoutedWeight1_B16_B16_B16,0.1%,76.8359,0,7.86,7864.06 256,1,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,45.285,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,9.0945,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,54.3795,0,1.62,25916.16 256,2,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,46.5232,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,11.8082,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.3%,58.3314,0,3.02,24160.73 @@ -756,15 +1176,19 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w, 256,256,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,32,0,249.5786,_ZN5aiter49fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,249.5786,1,90.35,5668.72 256,512,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,32,0,260.9691,_ZN5aiter47fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,260.9691,1,172.81,5442.39 256,1024,7168,256,256,8,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,32,0,359.4797,_ZN5aiter49fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,359.4797,1,250.9,3981.61 +256,1,5120,1024,128,1,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,0.0,_ZN5aiter46fmoe_bf16_pertokenFp8_g1u1_tkw1_silu_1tg_32x64E,0.0%,0.0,Null,0,0.0,1,0.0,0.0 +256,2,5120,1024,128,1,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,0.0,_ZN5aiter46fmoe_bf16_pertokenFp8_g1u1_tkw1_silu_1tg_32x64E,0.0%,0.0,Null,0,0.0,1,0.0,0.0 +256,4,5120,1024,128,1,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,0.0,_ZN5aiter46fmoe_bf16_pertokenFp8_g1u1_tkw1_silu_1tg_32x64E,0.0%,0.0,Null,0,0.0,1,0.0,0.0 +256,8,5120,1024,128,1,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,0.0,_ZN5aiter46fmoe_bf16_pertokenFp8_g1u1_tkw1_silu_1tg_32x64E,0.0%,0.0,Null,0,0.0,1,0.0,0.0 256,16,5120,1024,128,1,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_Token,1,1,32,0,0.0,_ZN5aiter46fmoe_bf16_pertokenFp8_g1u1_tkw1_silu_1tg_32x64E,0.0%,0.0,Null,0,0.0,1,0.0,0.0 -256,1,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,45.285,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,9.0945,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,54.3795,0,1.62,25916.16 -256,2,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,46.5232,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,11.8082,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.3%,58.3314,0,3.02,24160.73 -256,4,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,48.2418,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,17.8498,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.4%,66.0916,0,5.33,21324.53 -256,8,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,53.6435,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,25.7951,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.7%,79.4386,0,8.87,17742.74 +256,1,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,71.1678,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,41.5098,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,112.6776,0,12.51,12510.3 +256,2,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,71.1678,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,41.5098,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,112.6776,0,12.51,12510.3 +256,4,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,71.1678,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,41.5098,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,112.6776,0,12.51,12510.3 +256,8,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,71.1678,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,41.5098,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,112.6776,0,12.51,12510.3 256,16,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,16,0,71.1678,moe_ck2stages_gemm1_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight0_silu_F8_F8_B16,0.1%,41.5098,moe_ck2stages_gemm2_256x16x128x256_1x4_MulABScaleExpertWeightA8W8blkscale_v1_Nswizzle0_Quant4_MulRoutedWeight1_F8_F8_B16,15.6%,112.6776,0,12.51,12510.3 256,32,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,32,0,158.4834,_ZN5aiter49fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,158.4834,1,17.78,8896.67 256,64,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,32,0,212.9873,_ZN5aiter50fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_ps_32x256E,0.0%,0.0,Null,0.0%,212.9873,1,26.47,6623.22 256,128,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,32,0,241.6039,_ZN5aiter50fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_ps_32x256E,0.0%,0.0,Null,0.0%,241.6039,1,46.66,5844.44 256,256,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,32,0,249.5786,_ZN5aiter49fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,249.5786,1,90.35,5668.72 256,512,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,32,0,260.9691,_ZN5aiter47fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,260.9691,1,172.81,5442.39 -256,1024,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,32,0,359.4797,_ZN5aiter49fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,359.4797,1,250.9,3981.61 \ No newline at end of file +256,1024,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float8_e4m3fn,torch.float8_e4m3fn,QuantType.per_1x128,1,0,32,0,359.4797,_ZN5aiter49fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_32x256E,0.0%,0.0,Null,0.0%,359.4797,1,250.9,3981.61 diff --git a/aiter/dist/device_communicators/communicator_cuda.py b/aiter/dist/device_communicators/communicator_cuda.py index fcc7ee05b2..55c3fa3bfc 100644 --- a/aiter/dist/device_communicators/communicator_cuda.py +++ b/aiter/dist/device_communicators/communicator_cuda.py @@ -49,10 +49,16 @@ def __init__( PyNcclCommunicator, ) - self.pynccl_comm = PyNcclCommunicator( - group=self.cpu_group, - device=self.device, - ) + try: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + except Exception as e: + logger.warning( + f"Failed to initialize PyNcclCommunicator for group " + f"{self.unique_name}. Exception: {e}" + ) # if is_symmetric_memory_enabled(): # register_nccl_symmetric_ops(self.pynccl_comm) @@ -149,6 +155,8 @@ def all_reduce( qr_comm is not None and not qr_comm.disabled and qr_comm.should_quick_allreduce(input_) + and (input_.nelement() * input_.element_size()) >= 4*1024*1024 # input shape should be such that quick reduce will show benefits. + # input shape estimated at 2 * max concurrency for now. if performance issues, subject to change ): out = qr_comm.quick_all_reduce(input_) assert out is not None diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index d6c206283b..6c2b593c74 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -104,6 +104,7 @@ def fused_moe( intermediate_pad=0, bias1=None, bias2=None, + splitk=0, ): if not block_size_M: block_size_M = -1 @@ -217,7 +218,15 @@ def fused_moe_( quant_type = quant_remap.get(quant_type, quant_type) q_dtype_w = w1.dtype q_dtype_a = w1.dtype if w1.dtype != torch.uint32 else dtypes.fp8 - q_dtype_a = dtypes.fp4x2 if quant_type == QuantType.per_1x32 else q_dtype_a + bf16_fp8_bound = 512 + if quant_type == QuantType.per_1x32: + if activation == ActivationType.Swiglu: + if get_gfx() != "gfx950" or M < bf16_fp8_bound: + q_dtype_a = dtypes.bf16 + elif M >= bf16_fp8_bound: + q_dtype_a = dtypes.fp8 + else: + q_dtype_a = dtypes.fp4x2 metadata = get_2stage_cfgs( get_padded_M(M), # consider token_num > 1024 as prefill @@ -496,6 +505,33 @@ def get_ksplit(token, topk, expert, inter_dim, model_dim): return 1 +@functools.lru_cache(maxsize=2048) +def get_ksplit(token, topk, expert, inter_dim, model_dim): + aiter_ksplit = int(os.environ.get("AITER_KSPLIT", "0")) + if aiter_ksplit != 0: + return aiter_ksplit + # only for moe_blk gemm1 a8w8 decode scenario + if token * topk > expert: + return 0 + cu_num = get_cu_num() + tileN = 128 + + tgM = token * topk # decode tile num + tgN = (inter_dim * 2 + tileN - 1) // tileN + + tg_num = tgN * tgM + # if all cu already active + if tg_num >= cu_num: + return 0 + tilek = 256 + split_max = (cu_num + tg_num - 1) // tg_num + # at least split = 2 + for i in reversed(range(2, split_max + 1)): + if (model_dim % i == 0) and ((model_dim // i) % tilek == 0): + return i + return 0 + + cfg_2stages = None # fmt: off fused_moe_1stage_dict = { @@ -535,7 +571,10 @@ def nextPow2(n): def get_padded_M(M): padded_m = M - if M < 1024: + if M >= 1 and M <= 16: + # decoding policy may be changed in the future. + padded_m = nextPow2(padded_m) + elif M < 1024: padded_m = nextPow2(padded_m) elif M < 2048: padded_m = 1024 @@ -553,6 +592,7 @@ class MOEMetadata: block_m: int ksplit: int run_1stage: bool = False + has_bias: bool = False @functools.lru_cache(maxsize=2048) @@ -642,8 +682,22 @@ def FinalFunc(): ) logger.info("\033[0m") + def use_cfg(): + problem_type = (activation, dtype, q_dtype_a, q_dtype_w, q_type) + bypass_type = ( + ActivationType.Silu, + dtypes.bf16, + dtypes.fp8, + dtypes.fp8, + QuantType.per_1x128, + ) + if problem_type == bypass_type and (token * topk) <= 128: # bypass tuned + aiter.logger.info("bypass tuned results for fp8 blockscale") + return False + return True + # cfg = cfg_2stages.get(keys, None) - cfg = cfg_2stages.get(keys, None) if cfg_2stages else None + cfg = cfg_2stages.get(keys, None) if cfg_2stages and use_cfg() else None if cfg is None and os.environ.get("AITER_ONLINE_TUNE", "0") == "1": lock_path = os.path.join(bd_dir, f"lock_fmoe_tune_{keys}") mp_lock(lock_path, MainFunc=MainFunc, FinalFunc=FinalFunc) @@ -652,7 +706,7 @@ def FinalFunc(): cfg = cfg_2stages.get(keys, None) if cfg_2stages else None if cfg is None: logger.warning(f"Fmoe tuning not support for {keys}") - if cfg is None: + if cfg is None or int(os.environ.get("AITER_BYPASS_TUNE_CONFIG", "0")): ksplit = 0 kernelName1 = "" kernelName2 = "" @@ -667,7 +721,7 @@ def FinalFunc(): doweight_stage1, ) in fused_moe_1stage_dict[get_gfx()]: if q_type == QuantType.per_1x128: - run_1stage = True and (inter_dim % 256 == 0) + run_1stage = token > 32 and (inter_dim % 256 == 0) elif q_type == QuantType.per_Token and q_dtype_w == dtypes.i8: run_1stage = token > 32 elif q_type == QuantType.per_Token and q_dtype_w == dtypes.fp8: @@ -679,7 +733,7 @@ def FinalFunc(): BLOCK_SIZE_M if run_1stage else ( - 64 + (64 if token > 32 else 16) if q_type == QuantType.per_1x128 else get_block_size_M(token, topk, expert, inter_dim) ) @@ -704,6 +758,13 @@ def FinalFunc(): logger.info( f"[fused_moe] using {'1stage' if run_1stage else '2stage'} {'default' if cfg is None else tag} for {keys} " ) + + def get_block_m() -> int: + if q_dtype_a == dtypes.fp8: + return 32 + else: + return 16 if token < 2048 else 32 if token < 16384 else 64 + if run_1stage: return MOEMetadata( functools.partial( @@ -737,9 +798,10 @@ def FinalFunc(): activation=activation, bias2=bias2, ), - 16 if token < 2048 else 32 if token < 16384 else 64, + get_block_m(), ksplit, False, + True, ) elif ( dtype in [dtypes.bf16, dtypes.fp16] @@ -776,14 +838,16 @@ def FinalFunc(): dtypes.fp16, torch.uint32, dtypes.fp4x2, + dtypes.fp8, ] ): return MOEMetadata( functools.partial( - aiter.ck_moe_stage1_fwd, + ck_moe_stage1, kernelName=kernelName1, activation=activation, quant_type=q_type, + splitk=ksplit, ), functools.partial( aiter.ck_moe_stage2_fwd, @@ -792,7 +856,7 @@ def FinalFunc(): quant_type=q_type, ), block_m, - ksplit, + int(ksplit), run_1stage, ) @@ -877,11 +941,23 @@ def fused_moe_2stages( if ( quant_type == QuantType.per_1x32 and dtype in [dtypes.bf16, dtypes.fp16] + and q_dtype_a in [dtypes.bf16, dtypes.fp16] and w1.dtype == dtypes.fp4x2 and (activation == ActivationType.Swiglu or metadata.ksplit > 1) ): a1 = hidden_states.to(dtype) a1_scale = None + elif ( + quant_type == aiter.QuantType.per_1x32 + and dtype in [dtypes.bf16, dtypes.fp16] + and q_dtype_a == dtypes.fp8 + and w1.dtype == dtypes.fp4x2 + and activation == aiter.ActivationType.Swiglu + ): + a1 = hidden_states.to(dtypes.fp8) + M = sorted_ids.shape[0] + N = a1.shape[-1] + a1_scale = torch.ones([M, N // 32], dtype=dtypes.fp8_e8m0, device=a1.device) elif quant_type == QuantType.per_1x32: if token_num <= token_num_quant_moe_sort_switch: a1, a1_scale = fused_dynamic_mxfp4_quant_moe_sort( @@ -933,7 +1009,17 @@ def fused_moe_2stages( dtype=dtype, device=device, ) - + extra_stage1_args = {} + extra_stage2_args = {} + if ( + not metadata.run_1stage + and metadata.has_bias + and dtype in [dtypes.bf16, dtypes.fp16] + and quant_type == QuantType.per_1x32 + and activation == ActivationType.Swiglu + ): + extra_stage1_args["bias1"] = bias1 + extra_stage2_args["bias2"] = bias2 a2 = metadata.stage1( a1, w1, @@ -945,16 +1031,30 @@ def fused_moe_2stages( topk, block_m=block_size_M, a1_scale=a1_scale, - w1_scale=w1_scale, + w1_scale=( + w1_scale.view(dtypes.fp8_e8m0) if w1.dtype == dtypes.fp4x2 else w1_scale + ), sorted_weights=sorted_weights if doweight_stage1 else None, + dtype=dtype, + **extra_stage1_args, ) if ( quant_type == QuantType.per_1x32 and dtype in [dtypes.bf16, dtypes.fp16] + and q_dtype_a in [dtypes.bf16, dtypes.fp16] and w1.dtype == dtypes.fp4x2 and (activation == ActivationType.Swiglu or metadata.ksplit > 1) ): a2_scale = None + elif ( + quant_type == aiter.QuantType.per_1x32 + and dtype in [dtypes.bf16] + and q_dtype_a == dtypes.fp8 + and w1.dtype == dtypes.fp4x2 + and activation == aiter.ActivationType.Swiglu + ): + a2 = a2.to(dtypes.fp8) + a2_scale = a1_scale elif quant_type == QuantType.per_1x32: a2 = a2.view(-1, inter_dim) if token_num <= token_num_quant_moe_sort_switch: @@ -1010,10 +1110,13 @@ def fused_moe_2stages( num_valid_ids, moe_out, topk, - w2_scale=w2_scale, + w2_scale=( + w2_scale.view(dtypes.fp8_e8m0) if w2.dtype == dtypes.fp4x2 else w2_scale + ), a2_scale=a2_scale, block_m=block_size_M, sorted_weights=sorted_weights if not doweight_stage1 else None, + **extra_stage2_args, ) return moe_out @@ -1085,9 +1188,9 @@ def asm_stage1( ) if ksplit > 0: if activation == ActivationType.Silu: - aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32).to(dtype)) + aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32)) else: - aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32).to(dtype)) + aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32)) return out @@ -1351,6 +1454,60 @@ def torch_moe_stage2( return out.sum(1).to(dtype) +def ck_moe_stage1( + hidden_states, + w1, # [E, inter_dim*2, model_dim] + w2, # [E, model_dim, inter_dim] + sorted_token_ids, # [max_num_tokens_padded] + sorted_expert_ids, # [max_num_m_blocks] + num_valid_ids, # [1] + out, + topk, + block_m, + a1_scale, + w1_scale, + kernelName="", + sorted_weights=None, + quant_type=aiter.QuantType.No, + activation=ActivationType.Gelu, + splitk=1, + dtype=None, +): + token_num = hidden_states.shape[0] + tmp_out = ( + torch.zeros( + (token_num, topk, w1.shape[1]), dtype=dtypes.fp32, device=out.device + ) + if splitk > 1 + else out + ) + aiter.ck_moe_stage1_fwd( + hidden_states, + w1, + w2, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + tmp_out, + topk, + kernelName, + w1_scale, + a1_scale, + block_m, + sorted_weights, + quant_type, + activation, + int(splitk), + out.dtype, + ) + if splitk > 1: + if activation == ActivationType.Silu: + aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32)) + else: + aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32)) + return out + + def cktile_moe_stage1( hidden_states, w1, @@ -1369,6 +1526,7 @@ def cktile_moe_stage1( bias1=None, activation=ActivationType.Silu, split_k=1, + dtype=torch.bfloat16, ): token_num = hidden_states.shape[0] _, n1, k1 = w1.shape @@ -1378,10 +1536,8 @@ def cktile_moe_stage1( if w1.dtype is torch.uint32: D = D * 8 - out = torch.empty( - (token_num, topk, D), dtype=hidden_states.dtype, device=hidden_states.device - ) + out = torch.empty((token_num, topk, D), dtype=dtype, device=hidden_states.device) tmp_out = ( torch.zeros( (token_num, topk, w1.shape[1]), dtype=hidden_states.dtype, device=out.device diff --git a/aiter/jit/core.py b/aiter/jit/core.py index 3b91336d6b..a938c0e287 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -375,12 +375,38 @@ def validate_and_update_archs(): @functools.lru_cache() def hip_flag_checker(flag_hip: str) -> bool: - ret = os.system(f"hipcc {flag_hip} -x hip -E -P /dev/null -o /dev/null") - if ret == 0: - return True - else: - logger.warning(f"{flag_hip} is not supported by hipcc.") + import subprocess + + cmd = ( + ["hipcc"] + + flag_hip.split() + + ["-x", "hip", "-E", "-P", "/dev/null", "-o", "/dev/null"] + ) + try: + subprocess.check_output(cmd, stderr=subprocess.DEVNULL) + except subprocess.CalledProcessError: + logger.warning(f"Current hipcc not support: {flag_hip}, skip it.") return False + return True + + +@functools.lru_cache() +def check_LLVM_MAIN_REVISION(): + # for https://github.com/ROCm/ROCm/issues/5646 and https://github.com/ROCm/composable_kernel/pull/3469 + # ck using following logic... + """#if LLVM_MAIN_REVISION < 554785 + #define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__ + #else + #define CK_TILE_HOST_DEVICE_EXTERN""" + import subprocess + + cmd = """echo "#include +__host__ __device__ void func(){std::tuple t = std::tuple(1, 1);}" | hipcc -x hip -P -c -Wno-unused-command-line-argument -""" + try: + subprocess.check_output(cmd, shell=True, text=True, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError: + return 554785 + return 554785 - 1 def check_and_set_ninja_worker(): @@ -541,6 +567,7 @@ def MainFunc(): "-Wno-macro-redefined", "-Wno-missing-template-arg-list-after-template-kw", "-fgpu-flush-denormals-to-zero", + f"-DDLLVM_MAIN_REVISION={check_LLVM_MAIN_REVISION()}", ] # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214 @@ -864,7 +891,7 @@ def check_args(): pattern = r"([\w\.]+(?:\[[^\]]+\])?)\s*\|\s*None" doc_str = re.sub(pattern, r"Optional[\1]", doc_str) for el in enum_types: - doc_str = re.sub(f" aiter.*{el} ", f" {el} ", doc_str) + doc_str = re.sub(f" (module_)?aiter.*{el} ", f" {el} ", doc_str) namespace = { "List": List, "Optional": Optional, diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index 090a742caa..52e76078bc 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -1036,6 +1036,7 @@ "module_top_k_per_row": { "srcs": [ "f'{AITER_CSRC_DIR}/kernels/topk_per_row_kernels.cu'", + "f'{AITER_CSRC_DIR}/py_itfs_cu/asm_topk_per_row_decode.cu'", "f'{AITER_CSRC_DIR}/pybind/topk_per_row_pybind.cu'" ], "flags_extra_cc": [], @@ -1076,7 +1077,8 @@ "module_topk_plain": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/topk_plain_pybind.cu'", - "f'{AITER_CSRC_DIR}/kernels/topk_plain_kernels.cu'" + "f'{AITER_CSRC_DIR}/kernels/topk_plain_kernels.cu'", + "f'{AITER_CSRC_DIR}/kernels/topk_per_row_kernels.cu'" ], "flags_extra_cc": [], "flags_extra_hip": [], diff --git a/aiter/jit/utils/cpp_extension.py b/aiter/jit/utils/cpp_extension.py index 5799e47205..2bd4c14a67 100644 --- a/aiter/jit/utils/cpp_extension.py +++ b/aiter/jit/utils/cpp_extension.py @@ -1534,7 +1534,20 @@ def _write_ninja_file_to_build_library( extra_ldflags = [flag.strip() for flag in extra_ldflags] extra_include_paths = [flag.strip() for flag in extra_include_paths] # include_paths() gives us the location of torch/extension.h - system_includes = [] if torch_exclude else include_paths(with_cuda) + # system_includes = [] if torch_exclude else include_paths(with_cuda) + import torch + + _TORCH_PATH = os.path.dirname(torch.__file__) + TORCH_INCLUDE_ROOT = os.path.join(_TORCH_PATH, "include") + system_includes = [ + TORCH_INCLUDE_ROOT, + os.path.join(TORCH_INCLUDE_ROOT, "torch/csrc/api/include"), + os.path.join(TORCH_INCLUDE_ROOT, "TH"), + os.path.join(TORCH_INCLUDE_ROOT, "THC"), + ] + if not torch_exclude: + system_includes += include_paths(with_cuda) + system_includes = list(set(system_includes)) # FIXME: build python module excluded with torch, use `pybind11` # But we can't use this now because all aiter op based on torch diff --git a/aiter/mla.py b/aiter/mla.py index 6f4cd2150a..1e09e1bd51 100644 --- a/aiter/mla.py +++ b/aiter/mla.py @@ -162,6 +162,8 @@ def mla_decode_fwd( q_scale=None, kv_scale=None, intra_batch_mode=False, + return_logits=False, + return_lse=False, ): device = q.device assert logit_cap <= 0, f"{logit_cap=} is not support yet" @@ -271,7 +273,7 @@ def mla_decode_fwd( ): # Natively support cases pass - elif nhead in range(32, 128 + 1, 16) and persistent_mode and max_seqlen_q == 1: + elif nhead in range(32, 128 + 1, 16) and persistent_mode: # we use nhead=16 to simulate such cases by customized metadata # metadata also views qo's tensor as shape (total_s * (nhead // 16), 16, ...) total_s = ori_total_s * (ori_nhead // 16) @@ -292,7 +294,11 @@ def mla_decode_fwd( dtype=dtypes.fp32, device=device, ) - final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device) + final_lse = ( + torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device) + if return_lse + else None + ) aiter.mla_decode_stage1_asm_fwd( q, @@ -326,10 +332,9 @@ def mla_decode_fwd( ) if io_transformed: - if persistent_mode: + if return_logits: logits = logits.view(-1, 1, ori_nhead, v_head_dim) - else: - logits = logits.view(ori_total_s, num_kv_splits, ori_nhead, v_head_dim) + q = q.view(ori_total_s, ori_nhead, -1) o = o.view(ori_total_s, ori_nhead, -1) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 20101480eb..a433bd213a 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -629,7 +629,8 @@ def get_mla_metadata_info_v1( max_qo_tiles_per_batch = ( int(math.ceil(max_seqlen_qo * num_head_qo / 128)) - if num_head_qo == 16 or (num_head_qo == 128 and kv_dtype == dtypes.fp8) + if num_head_qo == 16 + or (num_head_qo == 128 and kv_dtype == dtypes.fp8 and q_dtype == dtypes.fp8) else int(math.ceil(max_seqlen_qo * num_head_qo / 16)) ) batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size diff --git a/aiter/ops/gemm_op_a16w16.py b/aiter/ops/gemm_op_a16w16.py index e9f86a5cf9..d83ffd0309 100644 --- a/aiter/ops/gemm_op_a16w16.py +++ b/aiter/ops/gemm_op_a16w16.py @@ -20,6 +20,7 @@ def gen_gemm_a16w16_asm_fake_tensors( A: Tensor, B: Tensor, out: Tensor, + semaphore: Tensor, bias: Optional[Tensor] = None, splitK: Optional[int] = None, kernelName: Optional[str] = None, @@ -37,6 +38,7 @@ def gemm_a16w16_asm( A: Tensor, B: Tensor, out: Tensor, + semaphore: Tensor, bias: Optional[Tensor] = None, splitK: Optional[int] = None, kernelName: Optional[str] = None, @@ -44,6 +46,11 @@ def gemm_a16w16_asm( ) -> Tensor: ... +@functools.lru_cache(maxsize=1) +def get_semaphore_workspace(device: torch.device) -> Tensor: + return torch.zeros((16, 64), dtype=torch.uint32, device=device) + + def gemm_a16w16( A: Tensor, B: Tensor, @@ -52,4 +59,5 @@ def gemm_a16w16( splitK: Optional[int] = None, kernelName: Optional[str] = None, ): - return gemm_a16w16_asm(A, B, out, bias, splitK, kernelName) + sema = get_semaphore_workspace(out.device) + return gemm_a16w16_asm(A, B, out, bias, sema, splitK, kernelName) diff --git a/aiter/ops/gemm_op_a4w4.py b/aiter/ops/gemm_op_a4w4.py index bd3759f98c..752c724645 100644 --- a/aiter/ops/gemm_op_a4w4.py +++ b/aiter/ops/gemm_op_a4w4.py @@ -4,20 +4,17 @@ import functools from typing import Optional -from aiter.jit.utils.torch_guard import torch_compile_guard import pandas as pd import torch from torch import Tensor from aiter import logger +from aiter.jit.utils.torch_guard import torch_compile_guard -from ..jit.core import ( - AITER_CONFIGS, - AITER_LOG_TUNED_CONFIG, - compile_ops, -) +from ..jit.core import AITER_CONFIGS, AITER_LOG_TUNED_CONFIG, compile_ops from ..jit.utils.chip_info import get_cu_num, get_gfx from ..ops.gemm_op_common import get_padded_m +from ..utility import dtypes @functools.lru_cache(maxsize=1024) @@ -66,12 +63,15 @@ def gemm_a4w4_fake( B: Tensor, # B:[N, K/2] f4x2 A_scale: Tensor, # A_scale:[M, K/32] e8m0 paded B_scale: Tensor, # B_scale:[N, K/32] e8m0 paded - out: Tensor, # Out:[M, N] bf16 bias: Optional[Tensor] = None, # bias:[1, N] f32 + dtype: torch.dtype = dtypes.bf16, alpha: Optional[float] = 1.0, beta: Optional[float] = 0.0, bpreshuffle: Optional[bool] = True, ) -> torch.Tensor: + m = A.numel() // A.shape[-1] + n = B.shape[0] + out = torch.empty((m, n), dtype=dtype, device=A.device) return out @@ -81,8 +81,8 @@ def gemm_a4w4( B: Tensor, # B:[N, K/2] f4x2 A_scale: Tensor, # A_scale:[M, K/32] e8m0 paded B_scale: Tensor, # B_scale:[N, K/32] e8m0 paded - out: Tensor, # Out:[M, N] bf16 bias: Optional[Tensor] = None, # bias:[1, N] f32 + dtype: torch.dtype = dtypes.bf16, alpha: Optional[float] = 1.0, beta: Optional[float] = 0.0, bpreshuffle: Optional[bool] = True, @@ -93,9 +93,10 @@ def gemm_a4w4( It is used to perform matrix multiplication with 4-bit quantization. """ # Load the A4W4 GEMM kernel - m = A.shape[0] + m = A.numel() // A.shape[-1] n = B.shape[0] k = A.shape[-1] * 2 + out = torch.empty(((m + 31) // 32 * 32, n), dtype=dtype, device=A.device) gfx_arch = get_gfx() if gfx_arch in ["gfx942"]: raise RuntimeError( @@ -114,12 +115,14 @@ def gemm_a4w4( # or bias is None ): splitK = 0 if splitK is None else splitK - return gemm_a4w4_blockscale(A, B, A_scale, B_scale, out, splitK=splitK) + return gemm_a4w4_blockscale( + A.view(m, k // 2), B, A_scale, B_scale, out, splitK=splitK + )[:m] assert ( out.shape[0] % 32 == 0 ), "Dim0 of gemm_a4w4_asm output needs to be padded to multiples of 32!" - return gemm_a4w4_asm( - A, + gemm_a4w4_asm( + A.view(m, k // 2), B, A_scale, B_scale, @@ -131,6 +134,7 @@ def gemm_a4w4( bpreshuffle, log2_k_split=splitK, ) + return out[:m].view(*A.shape[:-1], n) def gen_gemm_a4w4_asm_fake_tensors( diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py index 906bf0f5d7..71bd2d1e5f 100755 --- a/aiter/ops/moe_op.py +++ b/aiter/ops/moe_op.py @@ -223,17 +223,22 @@ def cmdGenFunc_ck_moe_stage( sorted_weights: Optional[Tensor] = None, quant_type: int = 0, activation: int = 0, + splitk: int = 1, + dst_type: Optional[str] = None, ): mul_routed_weight_stage = 2 if sorted_weights is None else 1 + is_splitk = splitk > 1 + outtype = str2dtype_dict[dst_type] if is_splitk else out.dtype md_name, blob_gen_cmd = get_moe_stage_module( hidden_states.dtype, w1.dtype, - out.dtype, + outtype, activation, quant_type, mul_routed_weight_stage, getattr(w1, "is_shuffled", False), + is_splitk, ) return { "md_name": md_name, @@ -292,6 +297,8 @@ def ck_moe_stage1( sorted_weights: Optional[Tensor] = None, quant_type: int = 0, activation: int = 0, + splitk: int = 1, + dst_type: Optional[str] = None, ) -> None: ... @@ -443,6 +450,11 @@ def moe_cktile2stages_gemm2( torch.int4: "i4", } +str2dtype_dict = { + "f16": dtypes.fp16, + "b16": dtypes.bf16, +} + @functools.lru_cache(maxsize=1024) def get_moe_stage_module( @@ -453,6 +465,7 @@ def get_moe_stage_module( quant_type, mul_routed_weight_stage, preshuffle_mode=False, + is_splitk=False, ): if isinstance(activation, int): activation = ActivationType(activation) @@ -467,6 +480,7 @@ def get_moe_stage_module( if preshuffle_mode and weight_dtype == dtypes.fp4x2: preshuffle_str = "--preshuffle" + splitk_str = "--issplitk" if is_splitk else "" quant_type = ( QuantType.per_1x128 if quant_type == QuantType.per_128x128 else quant_type ) @@ -483,10 +497,11 @@ def get_moe_stage_module( act, quant_type, f"mulWeightStage{mul_routed_weight_stage}", + "splitk" if is_splitk else "", ] ) blob_gen_cmd = [ - f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} {preshuffle_str} -w {{}}" + f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} {preshuffle_str} {splitk_str} -w {{}}" ] return md_name, blob_gen_cmd @@ -508,6 +523,8 @@ def ck_moe_stage1_fwd( sorted_weights: Optional[Tensor] = None, quant_type: QuantType = QuantType.No, activation: ActivationType = ActivationType.Silu, + splitk: Optional[int] = 1, + dst_type: Optional[torch.dtype] = None, ): ck_moe_stage1( hidden_states, @@ -525,6 +542,8 @@ def ck_moe_stage1_fwd( sorted_weights, quant_type.value, activation.value, + splitk, + dtype2str_dict[dst_type], ) return out diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index 0d974fc933..7070a49a42 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -101,6 +101,42 @@ def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False): return y, scale.view(dtypes.fp8_e8m0) +def per_1x32_f8_scale_f8_quant( + x, scale=None, quant_dtype=dtypes.fp8, scale_type=dtypes.fp32, shuffle=False +): + assert quant_dtype == dtypes.fp8 + block_size = 32 + dtypeMax = 448.0 + MAX_POW2 = int(torch.log2(torch.tensor(dtypeMax, dtype=torch.float32)).item()) + dtypeMax = 2.0**MAX_POW2 + + shape_original = x.shape + x = x.view(-1, shape_original[-1]) + + m, n = x.shape + x = x.view(-1, block_size) + max_abs = torch.amax(torch.abs(x.float()), 1) + + # fp8e8m0fnu_from_fp32_value + if scale_type == dtypes.fp32: + scale_f32 = max_abs / dtypeMax + scale_e8m0_biased = None + else: + scale_e8m0_biased = fp4_utils.f32_to_e8m0(max_abs / dtypeMax) + scale_f32 = fp4_utils.e8m0_to_f32(scale_e8m0_biased) + # scale_f32 = max_abs / dtypeMax + + y = x.float() / scale_f32.view(-1, 1) + y = y.view(*shape_original[:-1], -1) + if scale_type == dtypes.fp32: + scale = scale_f32.view(m, -1) + else: + scale = scale_e8m0_biased.view(m, -1) # .view(torch.uint8) + if shuffle: + scale = fp4_utils.e8m0_shuffle(scale) + return y.to(quant_dtype), scale + + def per_tensor_quant( x, scale=None, scale_dtype=dtypes.fp32, quant_dtype=dtypes.i8, dtypeMax=None ): diff --git a/aiter/ops/topk.py b/aiter/ops/topk.py index 1c3666f832..5101a266c2 100755 --- a/aiter/ops/topk.py +++ b/aiter/ops/topk.py @@ -219,3 +219,15 @@ def top_k_per_row_decode( stride0: int, stride1: int, ) -> None: ... + + +@compile_ops("module_top_k_per_row") +def top_k_per_row_decode_fast( + logits: torch.Tensor, + next_n: int, + seqLens: torch.Tensor, + indices: torch.Tensor, + numRows: int, + stride0: int, + stride1: int, +) -> None: ... diff --git a/aiter/ops/topk_plain.py b/aiter/ops/topk_plain.py index dea2c654b7..cd768b01e9 100644 --- a/aiter/ops/topk_plain.py +++ b/aiter/ops/topk_plain.py @@ -13,7 +13,12 @@ def topk_plain( x: torch.Tensor, topk_ids: torch.Tensor, + topk_out: torch.Tensor, topk: int, - largest: bool, + largest: bool = True, + rowStarts: torch.Tensor = None, + rowEnds: torch.Tensor = None, + stride0: int = -1, + stride1: int = 1, ) -> None: pass diff --git a/aiter/ops/triton/batched_gemm_a16wfp4.py b/aiter/ops/triton/batched_gemm_a16wfp4.py index a10cc66bea..ffd8b0ba3d 100755 --- a/aiter/ops/triton/batched_gemm_a16wfp4.py +++ b/aiter/ops/triton/batched_gemm_a16wfp4.py @@ -11,9 +11,11 @@ _get_config, ) from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.utils.common_utils import deserialize_str from aiter.ops.triton.gemm_a16wfp4 import ( get_splitk, ) +from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() @@ -26,17 +28,36 @@ def set_use_gemm_splitk_bf16(value: bool): _USE_GEMM_SPLITK_BF16 = value +def batched_gemm_a16wfp4_fake_tensor( + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[str] = None, + transpose_bm: Optional[bool] = False, + prequant: Optional[bool] = True, + y_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if y is None: + Bx, M, _ = x.shape + _, N, _ = w.shape + return torch.empty((Bx, M, N), dtype=dtype, device=x.device) + return y + + +@torch_compile_guard(gen_fake=batched_gemm_a16wfp4_fake_tensor) def batched_gemm_a16wfp4( - x, - w, - w_scales, - dtype: Optional[float] = torch.bfloat16, + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, + config: Optional[str] = None, transpose_bm: Optional[bool] = False, prequant: Optional[bool] = True, y_scale: Optional[torch.Tensor] = None, -): +) -> torch.Tensor: """ Computes batched FP4 matrix multiplication Y[i] = X[i] @ W[i]^T with active activation quantization. X is quantized to MXFP4 during computation, W is pre-quantized FP4. @@ -72,6 +93,8 @@ def batched_gemm_a16wfp4( if config is None: config = _get_config(M, N, K) + else: + config = deserialize_str(config) if y is None: if transpose_bm: diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py index 01add76bdf..92a8b30256 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py @@ -3,9 +3,8 @@ from typing import Optional import torch -import triton -import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.utils.common_utils import serialize_dict from aiter.ops.triton.batched_gemm_a16wfp4 import ( batched_gemm_a16wfp4, ) @@ -32,6 +31,8 @@ def batched_gemm_afp4wfp4_pre_quant( _LOGGER.info( "batched_gemm_afp4wfp4_pre_quant will be deprecated in future AITER release, please switch to batched_gemm_a16wfp4" ) + + config_hashable = serialize_dict(config) if config else None return batched_gemm_a16wfp4( - x, w, w_scales, dtype, y, config, transpose_bm=False, prequant=True + x, w, w_scales, dtype, y, config_hashable, transpose_bm=False, prequant=True ) diff --git a/aiter/ops/triton/fused_mxfp4_quant.py b/aiter/ops/triton/fused_mxfp4_quant.py index 173c3502bf..0218d385d3 100644 --- a/aiter/ops/triton/fused_mxfp4_quant.py +++ b/aiter/ops/triton/fused_mxfp4_quant.py @@ -5,7 +5,6 @@ from typing import Optional from aiter.utility import dtypes from aiter.ops.triton._triton_kernels.fused_mxfp4_quant import ( - _rmsmorm_op, _fused_rms_mxfp4_quant_kernel, _fused_flatten_mxfp4_quant, _fused_reduce_act_mul_and_dynamic_mxfp4_quant_kernel, @@ -650,3 +649,196 @@ def fused_dynamic_mxfp4_quant_moe_sort( x_fp4.view(dtypes.fp4x2), blockscale_e8m0_sorted.view(dtypes.fp8_e8m0).view(-1, N_o), ) + + +@triton.jit +def _fused_quant_fp8_sort_kernel( + # Pointers + input_ptr, + sorted_ids_ptr, + num_valid_ids_ptr, + x_fp8_ptr, + scale_sorted_ptr, + # Input/Output strides + stride_input_m: tl.constexpr, + stride_input_n: tl.constexpr, + stride_x_fp8_m: tl.constexpr, + stride_x_fp8_n: tl.constexpr, + stride_scale_o3: tl.constexpr, + stride_scale_o2: tl.constexpr, + stride_scale_o1: tl.constexpr, + stride_scale_o0: tl.constexpr, + # Problem size + M_input: tl.constexpr, + N_input: tl.constexpr, + N_scale_cols: tl.constexpr, + token_num: tl.constexpr, + # Block configuration + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, # quant_block_size / 2 + QUANT_BLOCK_SIZE: tl.constexpr, + TOPK: tl.constexpr, + # Quantization parameters + DTYPE_MAX: tl.constexpr, + DTYPE_MIN: tl.constexpr, +): + pid_m = tl.program_id(0) * 2 + pid_n = tl.program_id(1) * 2 + + num_valid_ids = tl.load(num_valid_ids_ptr) + if pid_m * BLOCK_SIZE_M >= num_valid_ids: + return + + out = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.uint32) + + for i in range(4): + m = i % 2 * BLOCK_SIZE_M # 0 or BLOCK_SIZE_M + n = i // 2 * BLOCK_SIZE_N # 0 or BLOCK_SIZE_N + + sorted_ids_offs_m = pid_m * BLOCK_SIZE_M + m + tl.arange(0, BLOCK_SIZE_M) + sorted_ids_mask = sorted_ids_offs_m < num_valid_ids + sorted_ids = tl.load( + sorted_ids_ptr + sorted_ids_offs_m, + mask=sorted_ids_mask, + other=0, + ) + topk_ids = sorted_ids >> 24 + token_ids = sorted_ids & 0xFFFFFF + + if TOPK == 1: + original_m_idx = token_ids + else: + original_m_idx = token_ids * TOPK + topk_ids + + input_offs_n = (pid_n * BLOCK_SIZE_N + n) * QUANT_BLOCK_SIZE + tl.arange( + 0, BLOCK_SIZE_N * QUANT_BLOCK_SIZE + ) + input_offs = ( + original_m_idx[:, None] * stride_input_m + + input_offs_n[None, :] * stride_input_n + ) + input_mask = (original_m_idx < M_input)[:, None] & (input_offs_n < N_input)[ + None, : + ] + + x = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0).to(tl.float32) + + x_reshaped = x.reshape(BLOCK_SIZE_M * BLOCK_SIZE_N, QUANT_BLOCK_SIZE) + + amax = tl.max(tl.abs(x_reshaped), axis=-1, keep_dims=True) + + amax = amax.to(tl.int32, bitcast=True) + amax = (amax + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax = amax.to(tl.float32, bitcast=True) + + scale_e8m0_unbiased = tl.log2(amax).floor() - tl.log2(DTYPE_MAX).floor() + scale_e8m0_unbiased = tl.clamp(scale_e8m0_unbiased, min=-127, max=127) + + quant_scale = tl.exp2(-scale_e8m0_unbiased) + x_fp8 = tl.clamp(x_reshaped * quant_scale, DTYPE_MIN, DTYPE_MAX) + x_fp8 = x_fp8.reshape(BLOCK_SIZE_M, BLOCK_SIZE_N * QUANT_BLOCK_SIZE) + + scale_e8m0 = (scale_e8m0_unbiased.to(tl.uint8) + 127).to(tl.uint8) + scale_e8m0 = scale_e8m0.reshape(BLOCK_SIZE_M, BLOCK_SIZE_N) # [BLOCK_SIZE_M] + + out_offs_n = (pid_n * BLOCK_SIZE_N + n) * QUANT_BLOCK_SIZE + tl.arange( + 0, BLOCK_SIZE_N * QUANT_BLOCK_SIZE + ) + out_offs = ( + original_m_idx[:, None] * stride_x_fp8_m + + out_offs_n[None, :] * stride_x_fp8_n + ) + out_mask = (original_m_idx < M_input)[:, None] & (out_offs_n < N_input)[None, :] + tl.store( + x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.type.element_ty), mask=out_mask + ) + + out = out | (scale_e8m0.to(tl.uint32) << (i * 8)) + + offs_0 = tl.arange(0, BLOCK_SIZE_M) + offs_1 = tl.arange(0, BLOCK_SIZE_N) + offs_2 = pid_n // 2 + offs_3 = pid_m // 2 + offs = ( + offs_0[:, None] * stride_scale_o0 + + offs_1[None, :] * stride_scale_o1 + + offs_2 * stride_scale_o2 + + offs_3 * stride_scale_o3 + ) + tl.store(scale_sorted_ptr + offs, out) + + +def fused_quant_fp8_sort( + input: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + token_num: int, + block_size: int = 32, + quant_block_size: int = 8, + quant_dtype: torch.dtype = dtypes.fp8, +) -> tuple[torch.Tensor, torch.Tensor]: + BLOCK_SIZE_M = block_size + BLOCK_SIZE_N = quant_block_size + BLOCK_SIZE_M_u32 = BLOCK_SIZE_M // 2 + BLOCK_SIZE_N_u32 = BLOCK_SIZE_N // 2 + + M, N = input.shape + assert ( + N % quant_block_size == 0 + ), f"N ({N}) must be multiple of quant_block_size ({quant_block_size})" + assert block_size % 32 == 0, "block_size must be multiple of 32" + + N_blocks = triton.cdiv(N, block_size) + + if quant_dtype == dtypes.fp8: + DTYPE_MAX = 448.0 + DTYPE_MIN = -448.0 + elif quant_dtype == torch.float8_e4m3fn: + DTYPE_MAX = 448.0 + DTYPE_MIN = -448.0 + else: + DTYPE_MAX = 448.0 + DTYPE_MIN = -448.0 + + x_fp8 = torch.empty_like(input, dtype=quant_dtype, device="cuda") + M_o, N_o = sorted_ids.shape[0], N_blocks + + # [M_sorted_blocks/2, N_blocks/2, BLOCK_SIZE_N_u32, BLOCK_SIZE_M_u32] + scale_e8m0_packed = torch.empty( + ( + triton.cdiv(M_o, BLOCK_SIZE_M), + triton.cdiv(N_o, BLOCK_SIZE_N), + BLOCK_SIZE_N_u32, + BLOCK_SIZE_M_u32, + ), + dtype=torch.uint32, + device=input.device, + ) + + grid = ( + triton.cdiv(M_o, BLOCK_SIZE_M), # 32 + triton.cdiv(N_o, BLOCK_SIZE_N), # 8 + ) + + _fused_quant_fp8_sort_kernel[grid]( + input, + sorted_ids, + num_valid_ids, + x_fp8, + scale_e8m0_packed, + *input.stride(), + *x_fp8.stride(), + *scale_e8m0_packed.stride(), + M_input=M, + N_input=N, + N_scale_cols=N_blocks, + token_num=token_num, + BLOCK_SIZE_M=BLOCK_SIZE_M // 2, + BLOCK_SIZE_N=BLOCK_SIZE_N // 2, + QUANT_BLOCK_SIZE=32, + TOPK=M // token_num, + DTYPE_MAX=DTYPE_MAX, + DTYPE_MIN=DTYPE_MIN, + ) + + return x_fp8, scale_e8m0_packed.view(dtypes.fp8_e8m0).view(-1, N_o) diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index 78026c80f0..38341b3efb 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -11,34 +11,34 @@ _get_config, ) from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_str +from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() -def gemm_a16w16_atomic( - x, - w, - dtype: Optional[float] = torch.bfloat16, +def gemm_a16w16_atomic_fake_tensor( + x: torch.Tensor, + w: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, -): - """ - Computes 16 bit matrix multiplication Y = X @ W^T using atomic operations for split-K reduction. - - Args: - x (torch.Tensor): Input matrix with shape (M, K). - w (torch.Tensor): Weight matrix with shape (N, K), internally transposed. - dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). - Note: BF16 atomic aggregation may have slight precision loss. - y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). - Must be zero-initialized for split-K (NUM_KSPLIT > 1). - config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, - BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, cache_modifier). + config: Optional[str] = None, +) -> torch.Tensor: + if y is None: + M, _ = x.shape + _, N = w.shape + return torch.zeros((M, N), dtype=dtype, device=x.device) + return y - Returns: - torch.Tensor: Output with shape (M, N). - """ +@torch_compile_guard(gen_fake=gemm_a16w16_atomic_fake_tensor) +def gemm_a16w16_atomic_( + x: torch.Tensor, + w: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[str] = None, +) -> torch.Tensor: _LOGGER.info( f"GEMM_A16W16_ATOMIC: x.shape={tuple(x.shape)}, w.shape={tuple(w.shape)} " ) @@ -50,6 +50,9 @@ def gemm_a16w16_atomic( if config is None: config = _get_config(M, N, K) + else: + config = deserialize_str(config) + # For compatability reasons, these keys may not exist in the config # TODO: This needs to be embedded in the configs later if "NUM_KSPLIT" not in config: @@ -89,3 +92,30 @@ def gemm_a16w16_atomic( ) return y + + +def gemm_a16w16_atomic( + x: torch.Tensor, + w: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, +) -> torch.Tensor: + """ + Computes 16 bit matrix multiplication Y = X @ W^T using atomic operations for split-K reduction. + + Args: + x (torch.Tensor): Input matrix with shape (M, K). + w (torch.Tensor): Weight matrix with shape (N, K), internally transposed. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + Note: BF16 atomic aggregation may have slight precision loss. + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + Must be zero-initialized for split-K (NUM_KSPLIT > 1). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, cache_modifier). + + Returns: + torch.Tensor: Output with shape (M, N). + """ + config_hashable = serialize_dict(config) if config else None + return gemm_a16w16_atomic_(x, w, dtype, y, config_hashable) diff --git a/aiter/ops/triton/gemm_a16wfp4.py b/aiter/ops/triton/gemm_a16wfp4.py index 40744fba68..2bc0983119 100644 --- a/aiter/ops/triton/gemm_a16wfp4.py +++ b/aiter/ops/triton/gemm_a16wfp4.py @@ -4,10 +4,9 @@ from typing import Optional import torch import triton -import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info -from aiter.ops.triton.quant import _mxfp4_quant_op from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.utils.common_utils import deserialize_str from aiter.ops.triton._triton_kernels.gemm_a16wfp4 import ( _gemm_a16wfp4_kernel, _get_config, @@ -18,20 +17,38 @@ from aiter.ops.triton.gemm_afp4wfp4 import ( get_splitk, ) +from aiter.jit.utils.torch_guard import torch_compile_guard _LOGGER = AiterTritonLogger() +def gemm_a16wfp4_fake_tensor( + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, + atomic_add: bool = False, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[str] = None, +) -> torch.Tensor: + if y is None: + M, _ = x.shape + N, _ = w.shape + return torch.zeros((M, N), dtype=dtype, device=x.device) + return y + + +@torch_compile_guard(gen_fake=gemm_a16wfp4_fake_tensor) def gemm_a16wfp4( - x, - w, - w_scales, + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, atomic_add: bool = False, - dtype: Optional[float] = torch.bfloat16, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, -): + config: Optional[str] = None, +) -> torch.Tensor: """ Computes the matmul Y = X x W W is an e2m1 fp4 tensor and w_scales is an e8m0 tensor. @@ -62,6 +79,8 @@ def gemm_a16wfp4( if config is None: config = _get_config(M, N, K) + else: + config = deserialize_str(config) if y is None: if atomic_add: diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index dec0560f3f..1085dd5d12 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -4,17 +4,17 @@ from typing import Optional import torch import triton -import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_str from aiter.ops.triton._triton_kernels.gemm_afp4wfp4 import ( _gemm_afp4wfp4_kernel, - _gemm_afp4wfp4_kernel_preshuffle_scales, _gemm_afp4wfp4_preshuffle_kernel, _gemm_afp4wfp4_reduce_kernel, _get_config, ) from .utils.core import AITER_TRITON_CONFIGS_PATH +from aiter.jit.utils.torch_guard import torch_compile_guard import os from aiter.utility.triton.triton_metadata_redirect import AOTMetadataContext @@ -63,16 +63,34 @@ def get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int): return SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT -def gemm_afp4wfp4( - x, - w, - x_scales, - w_scales, - dtype: Optional[float] = torch.bfloat16, +def gemm_afp4wfp4_fake_tensor( + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, - config: Optional[dict] = None, + config: Optional[str] = None, skip_reduce: Optional[bool] = False, -): +) -> torch.Tensor: + if y is None: + M, _ = x.shape + N, _ = w.shape + return torch.empty((M, N), dtype=dtype, device=x.device) + return y + + +@torch_compile_guard(gen_fake=gemm_afp4wfp4_fake_tensor) +def gemm_afp4wfp4_( + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[str] = None, + skip_reduce: Optional[bool] = False, +) -> torch.Tensor: """ Computes matrix multiplication Y = X @ W^T with FP4 activations and FP4 weights. @@ -91,7 +109,6 @@ def gemm_afp4wfp4( Returns: torch.Tensor: Output with shape (M, N). """ - _LOGGER.info( f"GEMM_AFPWFP4: x.shape={tuple(x.shape)} w.shape={tuple(w.shape)} x_scale={tuple(x_scales.shape)} w_scale={tuple(w_scales.shape)} " ) @@ -106,6 +123,8 @@ def gemm_afp4wfp4( if config is None: config = _get_config(M, N, K) + else: + config = deserialize_str(config) if config["NUM_KSPLIT"] > 1: SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( @@ -527,3 +546,16 @@ def gemm_afp4wfp4_preshuffled_weight_scales( "gemm_afp4wfp4_preshuffled_weight_scales will be deprecated in future AITER release, please switch to gemm_afp4wfp4_preshuffle" ) return gemm_afp4wfp4_preshuffle(x, w, x_scales, w_scales, dtype, y, config, use_aot) + + +def gemm_afp4wfp4( + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, +) -> torch.Tensor: + config_hashable = serialize_dict(config) if config else None + return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable) diff --git a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py index d3738fd4aa..2d5cbe3e32 100644 --- a/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py +++ b/aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py @@ -3,9 +3,8 @@ from typing import Optional import torch -import triton -import triton.language as tl from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.utils.common_utils import serialize_dict from aiter.ops.triton.gemm_a16wfp4 import ( gemm_a16wfp4, ) @@ -14,9 +13,9 @@ def gemm_afp4wfp4_pre_quant( - x, - w, - w_scales, + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, dtype: Optional[float] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[dict] = None, @@ -24,4 +23,6 @@ def gemm_afp4wfp4_pre_quant( _LOGGER.info( "gemm_afp4wfp4_pre_quant will be deprecated in future AITER release, please switch to gemm_a16wfp4" ) - return gemm_a16wfp4(x, w, w_scales, True, dtype, y, config) + + config_hashable = serialize_dict(config) if config else None + return gemm_a16wfp4(x, w, w_scales, True, dtype, y, config_hashable) diff --git a/aiter/ops/triton/gluon/pa_decode_gluon.py b/aiter/ops/triton/gluon/pa_decode_gluon.py index 37e7ddd268..8c6194cc97 100644 --- a/aiter/ops/triton/gluon/pa_decode_gluon.py +++ b/aiter/ops/triton/gluon/pa_decode_gluon.py @@ -620,7 +620,6 @@ def paged_attention_decode_v2_gluon_large_block_dot_kernel( warps_per_cta=[4, 1], order=[1, 0], ) - shared_query_layout: gl.constexpr = gl.SwizzledSharedLayout(8, 1, 16, order=[1, 0]) # Key cache layout - optimized for CDNA3 architecture blocked_key_layout: gl.constexpr = gl.BlockedLayout( @@ -798,9 +797,6 @@ def paged_attention_decode_v2_gluon_large_block_dot_kernel( query_tensor = gl.amd.cdna3.buffer_load( ptr=query_ptr, offsets=query_offsets_base, mask=query_mask ) - query_shared = gl.allocate_shared_memory( - query_tensor.dtype, query_tensor.shape, shared_query_layout, query_tensor - ) # ==================== Query Quantization Scale Handling ==================== if QUERY_QUANT_MODE == 0: @@ -969,7 +965,6 @@ def paged_attention_decode_v2_gluon_large_block_dot_kernel( # Convert layouts for MFMA operation query_converted = gl.convert_layout(query_tensor, layout=qk_lhs_layout) - # query_converted = query_shared.load(qk_lhs_layout) key_converted = gl.convert_layout(key_block, layout=qk_rhs_layout) query_converted = query_converted.to(COMPUTE_TYPE) key_converted = key_converted.to(COMPUTE_TYPE) @@ -1463,6 +1458,7 @@ def paged_attention_decode_sliding_window( * stride_output_head + output_head_size_offsets[None, :] ) + max_logits = gl.full( (QUERY_GROUP_SIZE_POW2,), float("-inf"), @@ -1481,12 +1477,15 @@ def paged_attention_decode_sliding_window( # ==================== SEQUENCE PROCESSING ==================== query_converted = query_shared.load(qk_lhs_operand_layout) - # query_converted = gl.convert_layout(query_tensor, layout=qk_lhs_operand_layout) - sequence_partition_start_idx = ( - context_length - SLIDING_WINDOW - ) // CONTEXT_PARTITION_SIZE + + if SLIDING_WINDOW > 0: + sequence_partition_start_idx = ( + context_length - SLIDING_WINDOW + ) // CONTEXT_PARTITION_SIZE + else: + sequence_partition_start_idx = 0 sequence_partition_end_idx = gl.cdiv(context_length, CONTEXT_PARTITION_SIZE) - # num_iterations = sequence_partition_end_idx - sequence_partition_start_idx + if QUERY_QUANT_MODE < 0 and COMPUTE_TYPE.is_fp8(): # Quantize bf16 query to fp8 # Convert query to float32 for computation @@ -1524,11 +1523,11 @@ def paged_attention_decode_sliding_window( ) # Create mask for valid blocks valid_block_mask = block_indices < num_kv_blocks - # masked_block_indices = gl.where(valid_block_mask, block_indices, 0) + masked_block_indices = gl.where(valid_block_mask, block_indices, 0) block_table_start_ptr = block_tables_ptr + sequence_idx * stride_block_table_seq kv_block_numbers = gl.amd.cdna3.buffer_load( - ptr=block_table_start_ptr + kv_block_start_idx, offsets=block_indices - ).to(gl.uint32) + ptr=block_table_start_ptr + kv_block_start_idx, offsets=masked_block_indices + ).to(gl.int64) # ==================== KEY LOADING AND PROCESSING ==================== # Calculate key cache offsets and load keys @@ -1540,20 +1539,15 @@ def paged_attention_decode_sliding_window( * CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD + contiguous_kv_element_offsets[None, None, None, :] ) - # Optimize: Start key load, then prepare QK MFMA accumulators/query (overlaps with key load) - key_tensor = gl.amd.cdna3.buffer_load( - ptr=key_cache_ptr, - offsets=key_block_offsets, - mask=valid_block_mask[:, None, None, None], - ) + # Optimize: Start key load, then prepare QK MFMA accumulators/query (overlaps with key load) + key_tensor = gl.load(key_cache_ptr + key_block_offsets) # Prepare QK MFMA while key loads (these don't depend on key data) qk_accumulator = gl.zeros( (QUERY_GROUP_SIZE_POW2, CONTEXT_PARTITION_SIZE), dtype=gl.float32, layout=qk_mfma_layout, ) - # Load key quantization scales if needed (overlaps with key tensor load) if KV_QUANT_MODE >= 0: if KV_QUANT_MODE == 0: @@ -1622,11 +1616,7 @@ def paged_attention_decode_sliding_window( * CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD + value_dim3_offsets[None, None, None, :] ) - value_tensor = gl.amd.cdna3.buffer_load( - ptr=value_cache_ptr, - offsets=value_block_offsets, - mask=valid_block_mask[:, None, None, None], - ) + value_tensor = gl.load(value_cache_ptr + value_block_offsets) # Compute QK attention scores using MFMA (overlaps with value load) attention_scores = gl.amd.cdna3.mfma( query_converted, key_converted, qk_accumulator @@ -1655,11 +1645,7 @@ def paged_attention_decode_sliding_window( ) # Schedule: Start value VMEM load, then QK MFMA - value_tensor = gl.amd.cdna3.buffer_load( - ptr=value_cache_ptr, - offsets=value_block_offsets, - mask=valid_block_mask[:, None, None], - ) + value_tensor = gl.load(value_cache_ptr + value_block_offsets) # Compute QK attention scores using MFMA (overlaps with value load) attention_scores = gl.amd.cdna3.mfma( query_converted, key_converted, qk_accumulator @@ -1790,8 +1776,6 @@ def paged_attention_decode_sliding_window( attention_accumulator += attention_output max_logits = new_max_logits - # ==================== OUTPUT NORMALIZATION AND STORING ==================== - # Normalize attention output by softmax denominator if sinks_ptr is not None: sinks_values = gl.load( sinks_ptr + (kv_head_idx * query_group_size + query_group_offsets), @@ -1800,6 +1784,8 @@ def paged_attention_decode_sliding_window( exp_sums += gl.exp( gl.convert_layout(sinks_values, layout=max_logits.type.layout) - max_logits ) + # ==================== OUTPUT NORMALIZATION AND STORING ==================== + # Normalize attention output by softmax denominator exp_sums_reciprocal = 1.0 / exp_sums exp_sums_reciprocal_cvt = gl.convert_layout( @@ -1945,11 +1931,8 @@ def paged_attention_decode_v2_gluon_dot_kernel( else: OUTPUT_DTYPE: gl.constexpr = COMPUTE_TYPE LOG2_E: gl.constexpr = 1.4426950408889634 # log2(e) for exponential conversion - CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD: gl.constexpr = KV_16B_ELEMENT_COUNT - K_HEAD_SIZE_SPLITS: gl.constexpr = ( - HEAD_SIZE_POW2 // CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD - ) + K_HEAD_SIZE_SPLITS: gl.constexpr = HEAD_SIZE_POW2 // KV_16B_ELEMENT_COUNT MAX_NUM_KV_BLOCKS_PER_COMPUTE: gl.constexpr = KV_COMPUTE_BLOCK_SIZE // KV_BLOCK_SIZE # ==================== MEMORY LAYOUT DEFINITIONS ==================== @@ -1960,16 +1943,31 @@ def paged_attention_decode_v2_gluon_dot_kernel( warps_per_cta=[4, 1], order=[1, 0], ) - shared_query_layout: gl.constexpr = gl.SwizzledSharedLayout(8, 1, 16, order=[1, 0]) + shared_query_layout: gl.constexpr = gl.SwizzledSharedLayout( + KV_16B_ELEMENT_COUNT, 1, 16, order=[1, 0] + ) # Key cache layout - optimized for block-wise access patterns - blocked_key_layout: gl.constexpr = gl.BlockedLayout( - size_per_thread=[1, 1, 1, CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD], + blocked_key_layout_fp8: gl.constexpr = gl.BlockedLayout( + size_per_thread=[1, 1, 1, KV_16B_ELEMENT_COUNT], threads_per_warp=[1, 4, 16, 1], warps_per_cta=[4, 1, 1, 1], order=[3, 2, 1, 0], ) + key_warps_per_cta_f16: gl.constexpr = ( + [4, 1, 1, 1] if KV_BLOCK_SIZE == 16 else [1, 1, 4, 1] + ) + blocked_key_layout_f16: gl.constexpr = gl.BlockedLayout( + size_per_thread=[1, 1, 1, KV_16B_ELEMENT_COUNT], + threads_per_warp=[1, 4, 16, 1], + warps_per_cta=key_warps_per_cta_f16, + order=[3, 2, 1, 0], + ) + blocked_key_layout: gl.constexpr = ( + blocked_key_layout_fp8 if KV_16B_ELEMENT_COUNT == 16 else blocked_key_layout_f16 + ) + DOT_QK_K_WIDTH: gl.constexpr = KV_16B_ELEMENT_COUNT # QK Matrix multiplication layout using AMD MFMA instructions qk_mfma_layout: gl.constexpr = gl.amd.AMDMFMALayout( version=CDNA_VERSION, @@ -1978,10 +1976,10 @@ def paged_attention_decode_v2_gluon_dot_kernel( warps_per_cta=[1, 4], ) qk_lhs_operand_layout: gl.constexpr = gl.DotOperandLayout( - operand_index=0, parent=qk_mfma_layout, k_width=16 + operand_index=0, parent=qk_mfma_layout, k_width=DOT_QK_K_WIDTH ) qk_rhs_operand_layout: gl.constexpr = gl.DotOperandLayout( - operand_index=1, parent=qk_mfma_layout, k_width=16 + operand_index=1, parent=qk_mfma_layout, k_width=DOT_QK_K_WIDTH ) # Register allocation configuration based on group size and compute block size @@ -2020,15 +2018,29 @@ def paged_attention_decode_v2_gluon_dot_kernel( # Value cache layout configuration based on transpose flag if VALUE_TRANSPOSED: # Transposed value layout for better memory access patterns - blocked_value_layout: gl.constexpr = gl.BlockedLayout( - size_per_thread=[1, 1, 1, CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD], - threads_per_warp=[4, 1, 16, 1], + value_threads_per_warp: gl.constexpr = ( + [4, 1, 16, 1] if KV_BLOCK_SIZE == 16 else [1, 4, 16, 1] + ) + blocked_value_layout_f16: gl.constexpr = gl.BlockedLayout( + size_per_thread=[1, 1, 1, 8], + threads_per_warp=value_threads_per_warp, warps_per_cta=[1, 1, 4, 1], order=[3, 2, 1, 0], ) + blocked_value_layout_fp8: gl.constexpr = gl.BlockedLayout( + size_per_thread=[1, 1, 1, 16], + threads_per_warp=value_threads_per_warp, + warps_per_cta=[1, 1, 4, 1], + order=[3, 2, 1, 0], + ) + blocked_value_layout: gl.constexpr = ( + blocked_value_layout_fp8 + if KV_16B_ELEMENT_COUNT == 16 + else blocked_value_layout_f16 + ) value_dim1_offsets = gl.arange( 0, - KV_BLOCK_SIZE // CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD, + KV_BLOCK_SIZE // KV_16B_ELEMENT_COUNT, layout=gl.SliceLayout( 0, gl.SliceLayout(2, gl.SliceLayout(3, blocked_value_layout)) ), @@ -2042,26 +2054,23 @@ def paged_attention_decode_v2_gluon_dot_kernel( ) value_dim3_offsets = gl.arange( 0, - CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD, + KV_16B_ELEMENT_COUNT, layout=gl.SliceLayout( 0, gl.SliceLayout(1, gl.SliceLayout(2, blocked_value_layout)) ), ) else: # Standard value layout + value_threads_per_warp: gl.constexpr = ( + [4, 16, 1] if KV_BLOCK_SIZE == 16 else [1, 16, 4] + ) blocked_value_layout: gl.constexpr = gl.BlockedLayout( - size_per_thread=[1, 1, CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD], - threads_per_warp=[4, 16, 1], + size_per_thread=[1, 1, 16], + threads_per_warp=value_threads_per_warp, warps_per_cta=[1, 4, 1], order=[2, 1, 0], ) - # blocked_value_layout: gl.constexpr = gl.DistributedLinearLayout( - # reg_bases=((0,0,1), (0,0,2), (0,0,4), (0,0,8), (4,0,0), (8,0,0), (0,64,0)), - # lane_bases=((0,1,0), (0,2,0), (0,4,0), (0,8,0), (1,0,0), (2,0,0)), - # warp_bases=((0,16,0), (0,32,0)), - # block_bases=[], - # shape=[16, 128, 16], - # ) + value_dim1_offsets = gl.arange( 0, HEAD_SIZE_POW2, @@ -2117,7 +2126,7 @@ def paged_attention_decode_v2_gluon_dot_kernel( ) block_element_offsets = gl.arange(0, KV_BLOCK_SIZE, layout=block_element_layout) contiguous_kv_element_offsets = gl.arange( - 0, CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD, layout=contiguous_kv_elements_layout + 0, KV_16B_ELEMENT_COUNT, layout=contiguous_kv_elements_layout ) qk_row_offsets = gl.arange( 0, QUERY_GROUP_SIZE_POW2, layout=gl.SliceLayout(1, qk_linear_layout) @@ -2249,8 +2258,7 @@ def paged_attention_decode_v2_gluon_dot_kernel( kv_block_numbers[:, None, None, None] * stride_key_block + kv_head_idx * stride_key_head + head_size_split_offsets[None, :, None, None] * stride_key_head_split - + block_element_offsets[None, None, :, None] - * CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD + + block_element_offsets[None, None, :, None] * KV_16B_ELEMENT_COUNT + contiguous_kv_element_offsets[None, None, None, :] ) key_tensor = gl.load(key_cache_ptr + key_block_offsets) @@ -2281,6 +2289,39 @@ def paged_attention_decode_v2_gluon_dot_kernel( key_tensor = gl.permute(key_tensor, [1, 3, 0, 2]) key_tensor = gl.reshape(key_tensor, [HEAD_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE]) + # ==================== ATTENTION SCORE COMPUTATION ==================== + # Initialize QK accumulator + qk_accumulator = gl.zeros( + (QUERY_GROUP_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE), + dtype=gl.float32, + layout=qk_mfma_layout, + ) + + # if sequence_idx == 0 \ + # and kv_head_idx == 0 \ + # and sequence_partition_idx == 0: + # print("query_tensor=", query_tensor.to(tl.float32)) + # print("key_tensor=", key_tensor.to(tl.float32)) + # if QUERY_QUANT_MODE == 0 and KV_QUANT_MODE == 0: + # print("QKV_per_tensor") + # else: + # print("QKV_per_token") + + # Convert layouts for MFMA operation + query_converted = query_shared.load(qk_lhs_operand_layout) + key_converted = gl.convert_layout(key_tensor, layout=qk_rhs_operand_layout) + + query_converted = query_converted.to(COMPUTE_TYPE) + key_converted = key_converted.to(COMPUTE_TYPE) + + # Compute QK attention scores using MFMA + attention_scores = gl.amd.cdna3.mfma( + query_converted, key_converted, qk_accumulator + ) + attention_scores = gl.reshape( + attention_scores, [QUERY_GROUP_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE] + ) + # ==================== VALUE LOADING AND PROCESSING ==================== if VALUE_TRANSPOSED: # Load values from transposed cache layout @@ -2294,8 +2335,7 @@ def paged_attention_decode_v2_gluon_dot_kernel( kv_block_numbers_reshaped[:, None, None, None] * stride_value_block + kv_head_idx * stride_value_head + value_dim1_offsets[None, :, None, None] * stride_value_head_size - + value_dim2_offsets[None, None, :, None] - * CONTIGUOUS_KV_ELEMENTS_PER_16B_LOAD + + value_dim2_offsets[None, None, :, None] * KV_16B_ELEMENT_COUNT + value_dim3_offsets[None, None, None, :] ) value_tensor = gl.load(value_cache_ptr + value_block_offsets) @@ -2323,29 +2363,6 @@ def paged_attention_decode_v2_gluon_dot_kernel( value_tensor, [KV_COMPUTE_BLOCK_SIZE, HEAD_SIZE_POW2] ) - # ==================== ATTENTION SCORE COMPUTATION ==================== - # Initialize QK accumulator - qk_accumulator = gl.zeros( - (QUERY_GROUP_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE), - dtype=gl.float32, - layout=qk_mfma_layout, - ) - - # Convert layouts for MFMA operation - query_converted = query_shared.load(qk_lhs_operand_layout) - key_converted = gl.convert_layout(key_tensor, layout=qk_rhs_operand_layout) - - query_converted = query_converted.to(COMPUTE_TYPE) - key_converted = key_converted.to(COMPUTE_TYPE) - - # Compute QK attention scores using MFMA - attention_scores = gl.amd.cdna3.mfma( - query_converted, key_converted, qk_accumulator - ) - attention_scores = gl.reshape( - attention_scores, [QUERY_GROUP_SIZE_POW2, KV_COMPUTE_BLOCK_SIZE] - ) - # Apply quantization scaling to attention scores if KV_QUANT_MODE >= 0: if KV_QUANT_MODE == 1: @@ -2533,8 +2550,6 @@ def paged_attention_decode_v2_reduce_kernel( Various stride parameters for tensor access Compile-time constants for kernel configuration (no MAX_CONTEXT_PARTITION_NUM needed) """ - # Mathematical constant for exponential calculations - LOG2_E: tl.constexpr = 1.4426950408889634 MAX_CONTEXT_PARTITION_NUM: tl.constexpr = 16 # ==================== INITIALIZATION ==================== @@ -2549,6 +2564,13 @@ def paged_attention_decode_v2_reduce_kernel( head_size_offsets = tl.arange(0, HEAD_SIZE_POW2) # Initialize global accumulation variables + # if USE_SINKS: + # global_max = tl.load( + # sink_token_ptr + (kv_head_idx * query_group_size + query_group_offsets), + # mask=query_group_offsets < query_group_size, + # other=float("-inf"), + # ).to(tl.float32) + # else: global_max = tl.full((QUERY_GROUP_SIZE_POW2,), float("-inf"), dtype=tl.float32) global_max_prev = global_max global_exp_sum = tl.zeros((QUERY_GROUP_SIZE_POW2,), dtype=tl.float32) @@ -2602,7 +2624,6 @@ def paged_attention_decode_v2_reduce_kernel( mask=query_group_offsets < query_group_size, ) global_exp_sum += gl.exp(sink_token_values - global_max) - # ==================== SECOND PASS: COMPUTE RESCALED EXP SUMS AND ACCUMULATE ==================== for iter_idx in range(num_iterations): partition_base = iter_idx * MAX_CONTEXT_PARTITION_NUM @@ -2752,10 +2773,9 @@ def _paged_attention_decode_v2_with_dot_kernel_reshape_wrapper( parameters for Triton compilation and execution. """ HEAD_SIZE_POW2 = triton.next_power_of_2(HEAD_SIZE) - # Production path - select and launch appropriate kernel + waves_per_eu = 1 QUERY_GROUP_SIZE = QUERY_SEQ_LEN * QUERY_GROUP_SIZE_ORIGINAL KV_COMPUTE_BLOCK_SIZE = CONTEXT_PARTITION_SIZE - waves_per_eu = 2 if QUERY_GROUP_SIZE < 16: QUERY_GROUP_SIZE_POW2 = 16 else: @@ -2972,6 +2992,7 @@ def pa_decode_gluon( alibi_slopes: torch.Tensor = None, sinks: torch.Tensor = None, sliding_window: int = 0, + one_shot=None, ) -> None: """ Paged Attention Decode with FP8/BF16/FP16 Support. @@ -3263,7 +3284,8 @@ def pa_decode_gluon( fp8_max_value = torch.finfo(aiter.dtypes.fp8).max # ==================== ATTENTION DECODE KERNEL EXECUTION ==================== - one_shot = sliding_window > 0 + if one_shot is None: + one_shot = sliding_window > 0 _paged_attention_decode_v2_with_dot_kernel_reshape_wrapper( grid, exp_sums, diff --git a/aiter/ops/triton/utils/common_utils.py b/aiter/ops/triton/utils/common_utils.py index 2da76efe38..4729ccfdb3 100644 --- a/aiter/ops/triton/utils/common_utils.py +++ b/aiter/ops/triton/utils/common_utils.py @@ -5,6 +5,7 @@ import torch import triton +import json def prev_power_of_2(x: int) -> int: @@ -34,3 +35,11 @@ def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor: if x.stride(-1) == 1: return x return x.contiguous() + + +def serialize_dict(d: dict) -> str: + return json.dumps(d) + + +def deserialize_str(s: str) -> dict: + return json.loads(s) diff --git a/aiter/tuned_gemm.py b/aiter/tuned_gemm.py index c7c8b5994f..4465facd34 100644 --- a/aiter/tuned_gemm.py +++ b/aiter/tuned_gemm.py @@ -24,7 +24,14 @@ import torch.nn.functional as F from torch import Tensor -from aiter import dtypes, gemm_a16w16_asm, hipb_create_extension, hipb_mm, logger +from aiter import ( + dtypes, + gemm_a16w16_asm, + get_semaphore_workspace, + hipb_create_extension, + hipb_mm, + logger, +) from aiter.jit.core import AITER_CONFIGS, AITER_LOG_TUNED_CONFIG from aiter.jit.utils.chip_info import get_cu_num, get_gfx from aiter.jit.utils.torch_guard import torch_compile_guard @@ -392,7 +399,10 @@ def asm_gemm( out_asm = torch.empty( inp.shape[0], weights.shape[0], dtype=otype, device=inp.device ) - return gemm_a16w16_asm(inp, weights, out_asm, bias, splitK, KernelName, bpreshuffle) + sema = get_semaphore_workspace(out_asm.device) + return gemm_a16w16_asm( + inp, weights, out_asm, sema, bias, splitK, KernelName, bpreshuffle + ) def triton_gemm( diff --git a/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py b/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py index f98dac5d85..4ca15dfe02 100755 --- a/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py +++ b/csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -import argparse import os import pandas as pd @@ -149,10 +148,10 @@ def get_asm_kernels(self, file): shuffle_df = ( df[df["bpreshuffle"] == 1] .reset_index() - .sort_values(by=["tile_m", "tile_n", "splitK"]) + .sort_values(by=["tile_M", "tile_N", "splitK"]) ) kernel_dict = ( - shuffle_df.groupby(["tile_m", "tile_n", "splitK"])["knl_name"] + shuffle_df.groupby(["tile_M", "tile_N", "splitK"])["knl_name"] .apply(list) .to_dict() ) diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu index a5be94138e..322ea56c71 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu @@ -56,13 +56,23 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token std::optional block_m = 32, std::optional sorted_weights = std::nullopt, int quant_type = 0, - int activation = 0) + int activation = 0, + int splitk = 1, + std::optional dst_type = std::nullopt) { const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(out)); at::hip::getCurrentHIPStream(); - TORCH_CHECK(out.dtype() == at::ScalarType::BFloat16 || out.dtype() == at::ScalarType::Half, - "Out dtype only support BFloat16/Float16!") + if (splitk > 1) + { + TORCH_CHECK(out.dtype() == at::ScalarType::Float, + "Out dtype only support Float when splitk > 1!") + } + else + { + TORCH_CHECK(out.dtype() == at::ScalarType::BFloat16 || out.dtype() == at::ScalarType::Half, + "Out dtype only support BFloat16/Float16!") + } int tokens = hidden_states.size(0); int sorted_size = std::min(int64_t(tokens * topk * block_m.value()), sorted_token_ids.size(0)); @@ -99,7 +109,7 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token kernel(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, - hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); + hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr, splitk); } void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token @@ -116,7 +126,9 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token std::optional block_m = 32, std::optional sorted_weights = std::nullopt, int quant_type = 0, - int activation = 0) + int activation = 0, + int splitk = 1, + std::optional dst_type = std::nullopt) { TORCH_CHECK(out.dtype() == at::ScalarType::BFloat16 || out.dtype() == at::ScalarType::Half, "Out dtype only support BFloat16/Float16!") @@ -155,5 +167,5 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token kernel(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, - inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); + inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr, splitk); } \ No newline at end of file diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.h b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.h index 7c22ab857b..f1ef022159 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.h +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.h @@ -335,6 +335,19 @@ struct MulABScaleExpertWeightA8W8blkscale } }; +struct MulABScaleExpertWeightA8W8blkscaleSplitk +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const; + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d2) const + { + (void)d2; + e = ck::type_convert(c); + } +}; + using MoeKernel = std::function, - std::optional)>; + std::optional, + std::optional)>; template w1_scale = std::nullopt, // [e, 1, n], gate(up) scale - std::optional a1_scale = std::nullopt // [m, 1], token scale + std::optional a1_scale = std::nullopt, // [m, 1], token scale + std::optional splitk = 1 // splitk ); template w2_scale = std::nullopt, // [e, 1, n], gate(up) scale - std::optional a2_scale = std::nullopt // [max_num_tokens_padded, 1], token scale + std::optional a2_scale = std::nullopt, // [max_num_tokens_padded, 1], token scale + std::optional splitk = 1 // splitk ); diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh index 55cf72b4a9..705460baed 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh @@ -1,40 +1,44 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "gemm_moe_ck2stages.h" -#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp" +#include "gemm_moe_ck2stages.h" #include -template < - typename A0DataType, - typename B0DataType, - typename AccDataType, - typename EDataType, - typename CDEElementOp, - PipelineVersion PipelineVer, - int BLOCKSIZE, - int MPerBlock, - int NPerBlock, - int KPerBlock, - int MWaves, - int NWaves, - bool Nswizzle, - bool PerTensorQuant, - bool MulRoutedWeight, - int ActOP> -void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, +template +void ck_moe_stage1_gemm(const hipStream_t& stream, + int tokens, + int sorted_size, + int N, + int K, int topk, - void *&hidden_states, // [m, k], input token - void *&w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - void *&w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) - void *&sorted_token_ids, // [max_num_tokens_padded] - void *&sorted_expert_ids, // [max_num_m_blocks] - void *&sorted_weights, - void *&num_valid_ids, // [1] - void *&out, // [max_num_tokens_padded, inter_dim] - std::optional w1_scale, // [e, 1, n], gate(up) scale - std::optional a1_scale // [m, 1], token scale + void*& hidden_states, // [m, k], input token + void*& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + void*& w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) + void*& sorted_token_ids, // [max_num_tokens_padded] + void*& sorted_expert_ids, // [max_num_m_blocks] + void*& sorted_weights, + void*& num_valid_ids, // [1] + void*& out, // [max_num_tokens_padded, inter_dim] + std::optional w1_scale, // [e, 1, n], gate(up) scale + std::optional a1_scale, // [m, 1], token scale + std::optional splitk // splitk ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things @@ -42,43 +46,47 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, ck::index_t StrideB = K; ck::index_t StrideD = 0; ck::index_t StrideE = N; - ck::index_t KBatch = 1; + ck::index_t KBatch = 1; // using AccDataType = F32; using CShuffleDataType = F32; - using DsDataType = ck::Tuple; + using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; using D0Layout = Row; using D1Layout = Col; - using ELayout = Row; + using ELayout = Row; using D2Layout = ELayout; using DsLayout = ck::Tuple; using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; - using BElementOp = PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - static constexpr ck::index_t MNPerXDL = 16; - static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + static constexpr ck::index_t MNPerXDL = 16; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); - // static constexpr ck::index_t NPerBlock = PipelineVer == ck::BlockGemmPipelineVersion::v1 ? 64 : 128; - static constexpr ck::index_t CShuffleMXDLPerWave = ck::is_same_v ? 2 : MXDLPerWave; - static constexpr ck::index_t CShuffleNXDLPerWave = ck::is_same_v ? 1 : NXDLPerWave; + // static constexpr ck::index_t NPerBlock = PipelineVer == ck::BlockGemmPipelineVersion::v1 ? 64 + // : 128; + static constexpr ck::index_t CShuffleMXDLPerWave = + ck::is_same_v ? 2 : MXDLPerWave; + static constexpr ck::index_t CShuffleNXDLPerWave = + ck::is_same_v ? 1 : NXDLPerWave; // Note: some fp8 instances didn't compile with AK1/BK1=16 - static constexpr ck::index_t K1 = (NPerBlock == 64 && sizeof(A0DataType) == 1 && sizeof(B0DataType) == 1) ? 8 : 16; + static constexpr ck::index_t K1 = + (PipelineVer == ck::BlockGemmPipelineVersion::v3 && NPerBlock == 64 && sizeof(A0DataType) == 1 && sizeof(B0DataType) == 1) ? 8 : 16; static constexpr ck::index_t AK1 = K1 / sizeof(A0DataType); static constexpr ck::index_t BK1 = ck::is_same_v ? 32 : K1 / sizeof(B0DataType); - static constexpr ck::index_t EVec = 16 / sizeof(EDataType); - static constexpr ck::index_t K0_A = KPerBlock / AK1; - static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t EVec = 16 / sizeof(EDataType); + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; static constexpr ck::index_t K0_M_A = BLOCKSIZE / K0_A; static constexpr ck::index_t K0_N_B = BLOCKSIZE / K0_B; - static constexpr ck::index_t D0Vec = 1; - static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; - static constexpr ck::index_t D2Vec = 1; + static constexpr ck::index_t D0Vec = 1; + static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; + static constexpr ck::index_t D2Vec = 1; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -88,7 +96,7 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| ///###### RCR < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, - AElementOp, BElementOp, CDEElementOp, GemmSpec, + AElementOp, BElementOp, CDEElementOp, GemmSpec, BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MNPerXDL, MNPerXDL, @@ -99,45 +107,45 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, ActOP, Nswizzle, true, MulRoutedWeight, !PerTensorQuant, ck::index_t, A0DataType>; // clang-format on - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; auto cde_element_op = CDEElementOp{}; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto I0 = ck::Number<0>{}; - constexpr auto I1 = ck::Number<1>{}; + constexpr auto I0 = ck::Number<0>{}; + constexpr auto I1 = ck::Number<1>{}; static constexpr auto DStride = PerTensorQuant ? I0 : I1; // do GEMM auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(sorted_token_ids, - sorted_expert_ids, - num_valid_ids, - hidden_states, - w1, - std::array{a1_scale.has_value() ? a1_scale.value() : nullptr, - w1_scale.has_value() ? w1_scale.value() : nullptr, - MulRoutedWeight ? sorted_weights : nullptr}, - out, - tokens, - topk, - sorted_size, - N, - K, - StrideA, - StrideB, - std::array{DStride, DStride, I0}, - StrideE, - KBatch, - a_element_op, - b_element_op, - cde_element_op); + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + hidden_states, + w1, + std::array{a1_scale.has_value() ? a1_scale.value() : nullptr, + w1_scale.has_value() ? w1_scale.value() : nullptr, + MulRoutedWeight ? sorted_weights : nullptr}, + out, + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + std::array{DStride, DStride, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); - if (!device_op.IsSupportedArgument(argument)) + if(!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " @@ -147,51 +155,74 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, invoker.Run(argument, StreamConfig{stream}); } -#define CK_MOE_STAGE1_GEMM_DEFINE(BLOCKSIZE, MPerfBlock, NPerBlock, KPerBlock, MWaves, NWaves, PipelineVer) \ - template void ck_moe_stage1_gemm( \ - const hipStream_t &stream, \ - int tokens, int sorted_size, int N, int K, \ - int topk, \ - void *&hidden_states, \ - void *&w1, \ - void *&w2, \ - void *&sorted_token_ids, \ - void *&sorted_expert_ids, \ - void *&sorted_weights, \ - void *&num_valid_ids, \ - void *&out, \ - std::optional w1_scale, \ - std::optional a1_scale); +#define CK_MOE_STAGE1_GEMM_DEFINE( \ + BLOCKSIZE, MPerfBlock, NPerBlock, KPerBlock, MWaves, NWaves, PipelineVer) \ + template void ck_moe_stage1_gemm(const hipStream_t& stream, \ + int tokens, \ + int sorted_size, \ + int N, \ + int K, \ + int topk, \ + void*& hidden_states, \ + void*& w1, \ + void*& w2, \ + void*& sorted_token_ids, \ + void*& sorted_expert_ids, \ + void*& sorted_weights, \ + void*& num_valid_ids, \ + void*& out, \ + std::optional w1_scale, \ + std::optional a1_scale, \ + std::optional splitk); -template < - typename A0DataType, - typename B0DataType, - typename AccDataType, - typename EDataType, - typename CDEElementOp, - PipelineVersion PipelineVer, - int BLOCKSIZE, - int MPerBlock, - int NPerBlock, - int KPerBlock, - int MWaves, - int NWaves, - bool Nswizzle, - bool PerTensorQuant, - bool MulRoutedWeight, - int ActOP = 0> -void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, +template +void ck_moe_stage2_gemm(const hipStream_t& stream, + int tokens, + int sorted_size, + int N, + int K, int topk, - void *&inter_states, // [max_num_tokens_padded, k], input token - void *&w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - void *&w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) - void *&sorted_token_ids, // [max_num_tokens_padded] - void *&sorted_expert_ids, // [max_num_m_blocks] - void *&sorted_weights, // [max_num_tokens_padded] - void *&num_valid_ids, //[1] - void *&out, // [m, out_dim] - std::optional w2_scale, // [e, 1, n], gate(up) scale - std::optional a2_scale // [max_num_tokens_padded, 1], token scale + void*& inter_states, // [max_num_tokens_padded, k], input token + void*& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + void*& w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) + void*& sorted_token_ids, // [max_num_tokens_padded] + void*& sorted_expert_ids, // [max_num_m_blocks] + void*& sorted_weights, // [max_num_tokens_padded] + void*& num_valid_ids, //[1] + void*& out, // [m, out_dim] + std::optional w2_scale, // [e, 1, n], gate(up) scale + std::optional a2_scale, // [max_num_tokens_padded, 1], token scale + std::optional splitk // splitk ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things @@ -199,45 +230,50 @@ void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, ck::index_t StrideB = K; ck::index_t StrideD = 0; ck::index_t StrideE = N; - ck::index_t KBatch = 1; + ck::index_t KBatch = 1; // using AccDataType = F32; using CShuffleDataType = F32; - using DsDataType = ck::Tuple; + using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; - using ELayout = Row; + using ELayout = Row; using D0Layout = Row; using D1Layout = Col; using DsLayout = ck::Tuple; using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; - using BElementOp = PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr ck::index_t BLOCKSIZE = 256; - static constexpr ck::index_t WAVES = BLOCKSIZE / 64; - static constexpr ck::index_t MNPerXDL = 16; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); - static constexpr ck::index_t CShuffleMXDLPerWave = ck::is_same_v ? 2 : MXDLPerWave; - static constexpr ck::index_t CShuffleNXDLPerWave = ck::is_same_v ? 2 : NXDLPerWave; - static constexpr ck::index_t CShuffleNLane = ck::is_same_v ? 32 : NPerBlock / 2 / NXDLPerWave; // 64 + static constexpr ck::index_t CShuffleMXDLPerWave = + ck::is_same_v ? 2 : MXDLPerWave; + static constexpr ck::index_t CShuffleNXDLPerWave = + ck::is_same_v ? 2 : NXDLPerWave; + static constexpr ck::index_t CShuffleNLane = + ck::is_same_v ? 32 : NPerBlock / 2 / NXDLPerWave; // 64 static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; // Note: some fp8 instances didn't compile with AK1/BK1=16 - static constexpr ck::index_t K1 = (KPerBlock == 64 && sizeof(A0DataType) == 1 && sizeof(B0DataType) == 1) ? 8 : 16; + static constexpr ck::index_t K1 = + (KPerBlock == 64 && sizeof(A0DataType) == 1 && sizeof(B0DataType) == 1) ? 8 : 16; static constexpr ck::index_t AK1 = K1 / sizeof(A0DataType); - static constexpr ck::index_t BK1 = ck::is_same_v ? 32 / sizeof(B0DataType) : K1 / sizeof(B0DataType); - static constexpr ck::index_t EVec = 2; + static constexpr ck::index_t BK1 = + ck::is_same_v ? 32 / sizeof(B0DataType) : K1 / sizeof(B0DataType); + static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; static constexpr ck::index_t D2Vec = 1; - static constexpr ck::index_t K0_A = KPerBlock / AK1; - static constexpr ck::index_t K0_B = KPerBlock / BK1; - static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A; - static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B; + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A; + static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -247,7 +283,7 @@ void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, ///#####| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| ///##### RCR < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, - AElementOp, BElementOp, CDEElementOp, GemmSpec, + AElementOp, BElementOp, CDEElementOp, GemmSpec, BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MNPerXDL, MNPerXDL, @@ -319,4 +355,5 @@ void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, void *&num_valid_ids, \ void *&out, \ std::optional w2_scale, \ - std::optional a2_scale); \ No newline at end of file + std::optional a2_scale, \ + std::optional splitk); diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py index f849b42d44..20618c331d 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py @@ -358,6 +358,7 @@ def get_gemm1_kernels_list( ActOP: str, MulRoutedWeight: bool, preshuffle: bool = False, + splitk: bool = False, ) -> list: arch = get_gfx() if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype: @@ -403,7 +404,10 @@ def get_gemm1_kernels_list( if tag == "a8w4": kernel.CDEElementOp = "MulABScaleWint4" elif tag == "a8w8blkscale": - kernel.CDEElementOp = "MulABScaleExpertWeightA8W8blkscale" + if splitk: + kernel.CDEElementOp = "MulABScaleExpertWeightA8W8blkscaleSplitk" + else: + kernel.CDEElementOp = "MulABScaleExpertWeightA8W8blkscale" elif tag == "a8w8" or tag == "a4w4" or tag == "a4w4_bns": kernel.CDEElementOp = "MulABScale" elif tag == "a16w16": diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh index dcd6d096cc..41c918c992 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh @@ -33,7 +33,9 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, void *&num_valid_ids, // [1] void *&out, // [max_num_tokens_padded, inter_dim] std::optional w1_scale, // [e, 1, n], gate(up) scale - std::optional a1_scale // [m, 1], token scale + std::optional a1_scale, // [m, 1], token scale + std::optional splitk // splitk + ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things @@ -45,8 +47,14 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, ck::index_t StrideA = K; ck::index_t StrideB = K; - ck::index_t StrideE = N; - ck::index_t KBatch = 1; + ck::index_t SplitK = splitk.has_value() ? splitk.value() : 1; + + ck::index_t KBatch = SplitK > 1 ? K / (SplitK * KPerBlock) : 1; + if (KBatch > 1){ + TORCH_CHECK((KBatch * KPerBlock * SplitK == K), + "K(", K, ") must be a multiple of KPerBlock(", KPerBlock, ") * splitk(", splitk.value(), ").\n"); + } + ck::index_t StrideE = N * (KBatch > 1 ? 2 : 1); using A0Layout = Row; using B0Layout = Col; @@ -83,6 +91,7 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, static constexpr ck::index_t Scale_Block_M = 1; static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 128; + static constexpr bool IsSplitK = std::is_same_v; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale // clang-format off < Row, Col, DsLayout, ELayout, @@ -96,7 +105,7 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, S, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, MXDLPerWave, NXDLPerWave, S<1, K0_M_A, 1, K0_A>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, ActOP, Nswizzle, true, IsSplitK, MulRoutedWeight, int32_t, A0DataType>; // clang-format on @@ -157,7 +166,8 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, void *&num_valid_ids, \ void *&out, \ std::optional w1_scale, \ - std::optional a1_scale); + std::optional a1_scale, \ + std::optional splitk); template < typename A0DataType, @@ -187,7 +197,8 @@ void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, void *&num_valid_ids, //[1] void *&out, // [m, out_dim] std::optional w2_scale, // [e, 1, n], gate(up) scale - std::optional a2_scale // [max_num_tokens_padded, 1], token scale + std::optional a2_scale, // [max_num_tokens_padded, 1], token scale + std::optional splitk // splitk ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things @@ -252,7 +263,7 @@ void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, S, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, MXDLPerWave, NXDLPerWave, S<1, K0_M, 1, K0_A>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, 0, false, false, false, MulRoutedWeight, int32_t, A0DataType>; @@ -313,4 +324,5 @@ void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, void *&num_valid_ids, \ void *&out, \ std::optional w2_scale, \ - std::optional a2_scale); + std::optional a2_scale, \ + std::optional splitk); diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4.cuh index 9321f5950a..f2315b1dd0 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4.cuh +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4.cuh @@ -38,7 +38,8 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, void*& num_valid_ids, // [1] void*& out, // [max_num_tokens_padded, inter_dim] std::optional w1_scale, // [e, 1, n], gate(up) scale - std::optional a1_scale // [m, 1], token scale + std::optional a1_scale, // [m, 1], token scale + std::optional splitk // splitk ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things @@ -195,7 +196,8 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, void*& num_valid_ids, \ void*& out, \ std::optional w1_scale, \ - std::optional a1_scale); + std::optional a1_scale, \ + std::optional splitk); template w2_scale, // [e, 1, n], gate(up) scale - std::optional a2_scale // [max_num_tokens_padded, 1], token scale + std::optional a2_scale, // [max_num_tokens_padded, 1], token scale + std::optional splitk // splitk ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things @@ -366,4 +369,5 @@ void ck_moe_stage2_gemm(const hipStream_t& stream, void *&num_valid_ids, \ void *&out, \ std::optional w2_scale, \ - std::optional a2_scale); + std::optional a2_scale, \ + std::optional splitk); diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4_bns.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4_bns.cuh index 63d7b29d33..ac9d71be29 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4_bns.cuh +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4_bns.cuh @@ -37,7 +37,8 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, void*& num_valid_ids, // [1] void*& out, // [max_num_tokens_padded, inter_dim] std::optional w1_scale, // [e, 1, n], gate(up) scale - std::optional a1_scale // [m, 1], token scale + std::optional a1_scale, // [m, 1], token scale + std::optional splitk // splitk ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things @@ -96,10 +97,10 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, ///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| ///###### RCR - < Row, Col, DsLayout, ELayout, + < Row, Col, DsLayout, ELayout, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, - AElementOp, BElementOp, CDEElementOp, GemmSpec, - 32, BLOCKSIZE, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 32, BLOCKSIZE, MPerBlock, NPerBlock, 128, AK1, BK1, MNPerXDL, MNPerXDL, @@ -193,7 +194,8 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, void*& num_valid_ids, \ void*& out, \ std::optional w1_scale, \ - std::optional a1_scale); + std::optional a1_scale, \ + std::optional splitk); template w2_scale, // [e, 1, n], gate(up) scale - std::optional a2_scale // [max_num_tokens_padded, 1], token scale + std::optional a2_scale, // [max_num_tokens_padded, 1], token scale + std::optional splitk // splitk ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things @@ -285,10 +288,10 @@ void ck_moe_stage2_gemm(const hipStream_t& stream, ///#####| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///#####| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| ///##### RCR - < Row, Col, DsLayout, ELayout, + < Row, Col, DsLayout, ELayout, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - 32, BLOCKSIZE, + 32, BLOCKSIZE, MPerBlock, NPerBlock, 128, AK1, BK1, MNPerXDL, MNPerXDL, @@ -364,4 +367,5 @@ void ck_moe_stage2_gemm(const hipStream_t& stream, void *&num_valid_ids, \ void *&out, \ std::optional w2_scale, \ - std::optional a2_scale); \ No newline at end of file + std::optional a2_scale, \ + std::optional splitk); diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py index 8062c548d0..38b7826094 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py @@ -715,6 +715,7 @@ def __init__( activation, mul_routed_weight_stage, preshuffle, + splitk, ): self.working_path = working_path self.a_dtype = a_dtype.upper() @@ -725,6 +726,7 @@ def __init__( self.mul_routed_weight_stage = mul_routed_weight_stage self.nswizzle = False self.preshuffle = preshuffle + self.splitk = splitk def generate_instance_and_lookUpTable(self): _, gemm1_kernel_list = get_gemm1_kernels_list( @@ -736,6 +738,7 @@ def generate_instance_and_lookUpTable(self): self.activation, self.mul_routed_weight_stage == 1, self.preshuffle, + self.splitk, ) tag, gemm2_kernel_list = get_gemm2_kernels_list( self.a_dtype, @@ -770,6 +773,9 @@ def generate_instance_and_lookUpTable(self): quanttype = "_mxfp4" else: quanttype = "" + gemm1_fp32 = ( + self.splitk and (kernel.stage == 1) and (quanttype == "_blockscale") + ) if not os.path.exists(f_instance): with open(f_instance, "a") as f_ins: stage_instance = STG_INSTANCE_IMPL.format( @@ -777,7 +783,7 @@ def generate_instance_and_lookUpTable(self): A0DataType=self.a_dtype, B0DataType=self.b_dtype, AccDataType="F32" if self.a_dtype != "I8" else "I32", - EDataType=self.c_dtype, + EDataType="F32" if gemm1_fp32 else self.c_dtype, CDEElementOp=kernel.CDEElementOp, Nswizzle=str(self.nswizzle).lower(), Quant=self.quant_type, @@ -806,7 +812,7 @@ def generate_instance_and_lookUpTable(self): A0DataType=self.a_dtype, B0DataType=self.b_dtype, AccDataType="F32" if self.a_dtype != "I8" else "I32", - EDataType=self.c_dtype, + EDataType="F32" if gemm1_fp32 else self.c_dtype, CDEElementOp=kernel.CDEElementOp, Nswizzle=str(self.nswizzle).lower(), Quant=self.quant_type, @@ -832,11 +838,12 @@ def generate_instance_and_lookUpTable(self): tag ] with open(f_gemm1_heuristic_dispatch, "a") as f_h: + gemm1_fp32 = self.splitk and (quanttype == "_blockscale") gemm1_heuristic_dispatch_str = gemm1_heuristic_dispatch.format( A0DataType=self.a_dtype, B0DataType=self.b_dtype, AccDataType="F32" if self.a_dtype != "I8" else "I32", - EDataType=self.c_dtype, + EDataType="F32" if gemm1_fp32 else self.c_dtype, CDEElementOp=kernel_list[0].CDEElementOp, Nswizzle=str(self.nswizzle).lower(), Quant=self.quant_type, @@ -949,6 +956,12 @@ def generate_instance_and_lookUpTable(self): help="enable pre-shuffle weight mode", ) + parser.add_argument( + "--issplitk", + action="store_true", + help="enable moe_stage1 splitk mode", + ) + args = parser.parse_args() args.quant_type = ( "per_1x128" if args.quant_type == "per_128x128" else args.quant_type @@ -998,13 +1011,15 @@ def generate_instance_and_lookUpTable(self): act, routed_weight, preshuffle_mode, + False, # splitk ) codegen.generate_instance_and_lookUpTable() # blk-quant moe blk_quant_l = ["per_1x128"] - for c_dtype, act, routed_weight, quant in itertools.product( - c_dtypes, acts, routed_weight_l, blk_quant_l + blk_splitk_l = [False, True] + for c_dtype, act, routed_weight, quant, splitk in itertools.product( + c_dtypes, acts, routed_weight_l, blk_quant_l, blk_splitk_l ): codegen = ck_moe_2stage_gemm_codegen( args.working_path, @@ -1015,6 +1030,7 @@ def generate_instance_and_lookUpTable(self): act, routed_weight, preshuffle_mode, + splitk, ) codegen.generate_instance_and_lookUpTable() @@ -1039,6 +1055,7 @@ def generate_instance_and_lookUpTable(self): act, routed_weight, preshuffle_mode, + False, # splitk ) codegen.generate_instance_and_lookUpTable() else: @@ -1053,6 +1070,7 @@ def generate_instance_and_lookUpTable(self): args.activation, args.mul_routed_weight_stage, args.preshuffle, + args.issplitk, ) codegen.generate_instance_and_lookUpTable() diff --git a/csrc/ck_tile_gemm_moe_2stages/gen_instances.py b/csrc/ck_tile_gemm_moe_2stages/gen_instances.py index cf1a1e68e8..59e75de545 100644 --- a/csrc/ck_tile_gemm_moe_2stages/gen_instances.py +++ b/csrc/ck_tile_gemm_moe_2stages/gen_instances.py @@ -8,12 +8,14 @@ import re from moe_cktile2stages_common import ( act_dict, + dtype_dict, kernelInstance, get_gemm1_kernels_list, get_gemm2_kernels_list, get_heuristic_dispatch_template, ) import sys +from chip_info import get_gfx this_dir = os.path.dirname(os.path.abspath(__file__)) AITER_CORE_DIR = os.path.abspath(f"{this_dir}/../../../") @@ -30,7 +32,7 @@ class cktile_moe_2stage_gemm_codegen: def __init__( self, working_path, - ab_dtype, + a_dtype, acc_dtype, c_dtype, quant_type, @@ -42,8 +44,12 @@ def __init__( self.working_path = working_path self.impl_path = os.path.join(working_path, "impl") self.instances_path = os.path.join(working_path, "instances") + self.dispatchers_path = os.path.join(working_path, "dispatchers") + self.manifests_path = os.path.join(working_path, "manifests") self.istune = istune - self.ab_dtype = ab_dtype.lower() + self.kernel_name_list = [] + self.a_dtype = a_dtype + self.b_dtype = "fp4" self.acc_dtype = acc_dtype.lower() self.c_dtype = c_dtype.lower() self.quant_type = quant_type @@ -121,11 +127,12 @@ def gen_instance(self, k: kernelInstance): xptr = "static_cast(x_scale.value().data_ptr())" wptr = "static_cast(w_scale.value().data_ptr())" elif k.QuantType == "1x32": - scaleGranA = "-1" - scaleGranB = "1, 32" + # scaleGranA = "-1" + scaleGranA = "1, 32, ck_tile::e8m0_t" + scaleGranB = "1, 32, ck_tile::e8m0_t" biasGran = "1" - xptr = "nullptr" - wptr = "static_cast(w_scale.value().data_ptr())" + xptr = "x_scale.has_value() ? static_cast(x_scale.value().data_ptr()) : nullptr" + wptr = "static_cast(w_scale.value().data_ptr())" biasptr = "static_cast(exp_bias.has_value() ? exp_bias.value().data_ptr() : nullptr)" if act_dict[k.ActOP] != 2: @@ -240,8 +247,8 @@ def fill_template(name, a_type, b_type, acc_type, c_type): ) ).write_text(intsance) - if (k.QuantType == "1x32") and (self.ab_dtype in ["bf16", "fp16"]): - fill_template(k.name, self.ab_dtype, "pk_fp4", self.acc_dtype, self.c_dtype) + if (k.QuantType == "1x32") and (a_type in ["bf16", "fp16", "fp8"]): + fill_template(k.name, self.a_dtype, "pk_fp4", self.acc_dtype, self.c_dtype) else: for CDtype in ["bf16", "fp16"]: for ABDtype in ["fp8"]: # "bf16", "fp16", @@ -266,6 +273,10 @@ def gen_heuristic_dispatch(self, tag, dict): def validate_and_format(template: str, mapping: dict) -> str: # check all format element in dict. str_mapping = { + '(a_data_type)': dtype_dict[self.a_dtype], + '(b_data_type)': dtype_dict[self.b_dtype], + '(acc_data_type)': dtype_dict[self.acc_dtype], + '(c_data_type)': dtype_dict[self.c_dtype], '(activation)': self.activation, '(has_bias)': 'true' if self.activation == 2 else 'false', '(split_k)': 'true' if self.is_split_k else 'false', @@ -278,8 +289,9 @@ def validate_and_format(template: str, mapping: dict) -> str: # print(placeholders) # print(str_mapping) if missing: - raise KeyError(f"Missing keys in mapping: {missing}") - result = template + for mis in missing: + placeholders.remove(mis) + # result = template # for placeholder in placeholders: # result = result.replace(placeholder, str_mapping[placeholder]) # return result @@ -288,12 +300,14 @@ def validate_and_format(template: str, mapping: dict) -> str: _, k = next(iter(dict.items())) # create heuristic heirarchy with open( - os.path.join(self.working_path, f"{k.dispatch_suffix}_heuristic_dispatch.h"), + os.path.join( + self.dispatchers_path, f"{k.dispatch_suffix}_heuristic_dispatch_{tag}.h" + ), "w", ) as f: f.write(validate_and_format(HEURISTIC_template, dict)) - return f"{k.dispatch_suffix}_heuristic_dispatch.h" + return f"./dispatchers/{k.dispatch_suffix}_heuristic_dispatch_{tag}.h" """generate lookup.h linking MNK/datatype to specific instance""" @@ -341,7 +355,7 @@ def gen_lookup_dict(self, kernels_dict): """generate manifest.h for instance header""" - def gen_manifest_head(self, kernels_dict): + def gen_manifest_head(self, tag, kernels_dict): MAINFEST_head = """#pragma once // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. @@ -353,7 +367,8 @@ def gen_manifest_head(self, kernels_dict): #include """ MAINFEST_template = """ -template +// template +template torch::Tensor {kernel_name}( torch::Tensor& XQ, @@ -378,24 +393,26 @@ def gen_manifest_head(self, kernels_dict): """ _, k0 = next(iter(kernels_dict.items())) with open( - os.path.join(self.working_path, f"{k0.dispatch_suffix}_manifest.h"), "w" + os.path.join(self.manifests_path, f"{k0.dispatch_suffix}_manifest_{tag}.h"), "w" ) as f: f.write(MAINFEST_head) - for mnk, k in kernels_dict.items(): - f.write(MAINFEST_template.format(kernel_name=k.name)) + for k_name in self.kernel_name_list: + f.write(MAINFEST_template.format(kernel_name=k_name)) f.write(MAINFEST_end) - return f"{k0.dispatch_suffix}_manifest.h" + return f"./manifests/{k0.dispatch_suffix}_manifest_{tag}.h" """generate all instances and headers""" def gen_instances(self, tag, kernels_dict): - for mnk, k in kernels_dict.items(): self.gen_instance(k) + if k.name not in self.kernel_name_list: + self.kernel_name_list.append(k.name) self.gen_lookup_dict(kernels_dict) - return self.gen_heuristic_dispatch(tag, kernels_dict), self.gen_manifest_head(kernels_dict) + self.gen_heuristic_dispatch(tag, kernels_dict) + return self.gen_heuristic_dispatch(tag, kernels_dict), self.gen_manifest_head(tag, kernels_dict) # def get_tune_dict(tune_dict_csv): @@ -442,6 +459,46 @@ def generate_common_header(working_path, dispatch_files, manifest_files): f.write(manifest_header) + """genarete heuristic dispatch header for multi dtype""" + +def gen_heuristic_dispatch_header(tags): + HEURISTIC_dispatch_header = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages.h" + +""" + for tag in tags: + HEURISTIC_headers = f"""#include "./dispatchers/moe_cktile2stages_heuristic_dispatch_{tag}.h" +""" + HEURISTIC_dispatch_header += HEURISTIC_headers + + HEURISTIC_function = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_cktile2stages.h" + +template +MoeKernel moe_gemm1_heuristic_dispatch(int M, int N, int K, int block_m); + +template +MoeKernel moe_gemm2_heuristic_dispatch(int M, int N, int K, int block_m); +""" + # create heuristic heirarchy + with open( + os.path.join(self.working_path, "moe_cktile2stages_heuristic_dispatch.h"), + "w", + ) as f: + f.write(HEURISTIC_dispatch_header) + with open( + os.path.join( + self.dispatchers_path, "moe_cktile2stages_heuristic_dispatch_common.h" + ), + "w", + ) as f: + f.write(HEURISTIC_function) + if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -583,6 +640,8 @@ def generate_common_header(working_path, dispatch_files, manifest_files): # quant_type = "per_token" a_types = ["bf16"] + if get_gfx() == "gfx950": + a_types.append("fp8") b_type = "fp4" quant_type = "1x32" @@ -593,6 +652,8 @@ def generate_common_header(working_path, dispatch_files, manifest_files): impl_path = os.path.join(args.working_path, "impl") instances_path = os.path.join(args.working_path, "instances") + dispatchers_path = os.path.join(args.working_path, "dispatchers") + manifests_path = os.path.join(args.working_path, "manifests") if os.path.exists(impl_path): shutil.rmtree(impl_path) @@ -600,14 +661,27 @@ def generate_common_header(working_path, dispatch_files, manifest_files): if os.path.exists(instances_path): shutil.rmtree(instances_path) os.mkdir(instances_path) + if os.path.exists(dispatchers_path): + shutil.rmtree(dispatchers_path) + os.mkdir(dispatchers_path) + if os.path.exists(manifests_path): + shutil.rmtree(manifests_path) + os.mkdir(manifests_path) + gen_dispatch_files = [] gen_manifest_files = [] + tags = [] + kernel_list = [] for a_type, c_dtype, act_type, is_split_k in itertools.product( a_types, c_dtypes, act_types, is_split_k_l ): has_bias = True if act_type == "swiglu" else False + + # a8w8 do not support + if a_type in ["fp8", "bf8"] and is_split_k: + continue codegen = cktile_moe_2stage_gemm_codegen( args.working_path, a_type, acc_type, c_dtype, quant_type, act_type, 2, is_split_k, False ) diff --git a/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh index 44bef46b6e..6ef3297133 100644 --- a/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh +++ b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh @@ -65,8 +65,6 @@ struct MoeFlatmmConfig static constexpr bool TiledMMAPermuteN = false; }; - - template ; // Preshuffle_ - constexpr bool MXFP4_Pipeline = std::is_same_v; + constexpr bool AQUANT_Pipeline = std::is_same_v || + std::is_same_v || + std::is_same_v; + constexpr bool BMXFP4_Pipeline = std::is_same_v; - if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up) + if constexpr(!BMXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up) { static_assert( FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0, @@ -129,11 +130,8 @@ void moe_gemm(const MoeFlatmmHostArgs& args, const ck_stream_config& s) static_assert(sizeof(ComputeDataType) >= sizeof(BDataType), "mixed_prec_flatmm requires ADataType is a wider type than BDataType"); - using GemmPipelineProblem = ck_tile::GemmPipelineProblem; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; @@ -143,11 +141,8 @@ void moe_gemm(const MoeFlatmmHostArgs& args, const ck_stream_config& s) const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - const ck_tile::amd_buffer_coherence_enum b_mem_nt_type = - BaseGemmPipeline::GetBMemNTType( - args.NumTokens, - args.N, - args.K); + const int32_t b_mem_nt_type = + static_cast(BaseGemmPipeline::GetBMemNTType(args.NumTokens, args.N, args.K)); float ave_time{0}; @@ -159,10 +154,21 @@ void moe_gemm(const MoeFlatmmHostArgs& args, const ck_stream_config& s) constexpr auto tail_number_v = tail_number_.value; constexpr auto scheduler = FlatmmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - constexpr auto b_mem_nt_type_v = b_mem_nt_type_.value; - - using CodegenPipelineProblem = - std::conditional_t(b_mem_nt_type_.value); + + using CodegenPipelineProblem = std::conditional_t< + BMXFP4_Pipeline, + std::conditional_t, ck_tile::F16xMXF4FlatmmPipelineProblem, - ck_tile::FlatmmPipelineProblem>; + b_mem_nt_type_v>>, + ck_tile::FlatmmPipelineProblem>; constexpr int BlockedXDLN_PerWarp = - (MXFP4_Pipeline || (moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)) + (BMXFP4_Pipeline || (moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)) ? 2 : 1; // determined by scale shuffle pattern @@ -212,8 +218,11 @@ void moe_gemm(const MoeFlatmmHostArgs& args, const ck_stream_config& s) BlockedXDLN_PerWarp>>; using CodegenFlatmmPipeline = std::conditional_t< - MXFP4_Pipeline, - ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1, + BMXFP4_Pipeline, + std::conditional_t< + AQUANT_Pipeline, + ck_tile::F8xMXF4FlatmmPipelineAGmemBGmemCRegV1, + ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1>, ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1>; @@ -306,26 +315,25 @@ void moe_gemm(const MoeFlatmmHostArgs& args, const ck_stream_config& s) // return ave_time; }; - const auto RunBMem = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - if(b_mem_nt_type == ck_tile::amd_buffer_coherence_enum::WAVE_NT1) - { - Run(has_hot_loop_, - tail_number_, - memory_operation_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - memory_operation_, - ck_tile::integral_constant{}); - } - }; + const auto RunBMem = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + switch(b_mem_nt_type) + { + case 2: { + Run(has_hot_loop_, + tail_number_, + memory_operation_, + ck_tile::integral_constant{}); + } + break; + default: { + Run(has_hot_loop_, + tail_number_, + memory_operation_, + ck_tile::integral_constant{}); + } + } + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu index 00c58b5cdb..80dcc57ed1 100644 --- a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu +++ b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu @@ -213,6 +213,25 @@ torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); // } + if (WQ.dtype() == torch_fp4x2 && Y.dtype() == at::ScalarType::BFloat16) + { + moe_dispatch(M, N, K, MPerBlock, act_op, has_bias, k_batch)( + XQ, + WQ, + Y, + sorted_ids, + sorted_expert_ids, + max_token_ids, + topk, + n_padded_zeros, + k_padded_zeros, + topk_weight, + x_scale, + w_scale, + exp_bias, + act_op, + k_batch); + } } else if((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && (WQ.dtype() == torch_fp4x2)) // a16w4 @@ -224,21 +243,22 @@ torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, // } if(Y.dtype() == at::ScalarType::BFloat16) { - moe_dispatch(M, N, K, MPerBlock, act_op, has_bias, k_batch)(XQ, - WQ, - Y, - sorted_ids, - sorted_expert_ids, - max_token_ids, - topk, - n_padded_zeros, - k_padded_zeros, - topk_weight, - x_scale, - w_scale, - exp_bias, - act_op, - k_batch); + moe_dispatch(M, N, K, MPerBlock, act_op, has_bias, k_batch)( + XQ, + WQ, + Y, + sorted_ids, + sorted_expert_ids, + max_token_ids, + topk, + n_padded_zeros, + k_padded_zeros, + topk_weight, + x_scale, + w_scale, + exp_bias, + act_op, + k_batch); } } else @@ -295,6 +315,25 @@ torch::Tensor cktile_moe_gemm2(torch::Tensor& XQ, // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); // } + if (WQ.dtype() == torch_fp4x2 && Y.dtype() == at::ScalarType::BFloat16) + { + moe_dispatch(M, N, K, MPerBlock, act_op, has_bias, k_batch)( + XQ, + WQ, + Y, + sorted_ids, + sorted_expert_ids, + max_token_ids, + topk, + n_padded_zeros, + k_padded_zeros, + topk_weight, + x_scale, + w_scale, + exp_bias, + act_op, + k_batch); + } } else if((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && (WQ.dtype() == torch_fp4x2)) // a16w4 @@ -306,21 +345,22 @@ torch::Tensor cktile_moe_gemm2(torch::Tensor& XQ, // } if(Y.dtype() == at::ScalarType::BFloat16) { - moe_dispatch(M, N, K, MPerBlock, 0, has_bias, k_batch)(XQ, - WQ, - Y, - sorted_ids, - sorted_expert_ids, - max_token_ids, - topk, - n_padded_zeros, - k_padded_zeros, - topk_weight, - x_scale, - w_scale, - exp_bias, - act_op, - k_batch); + moe_dispatch(M, N, K, MPerBlock, 0, has_bias, k_batch)( + XQ, + WQ, + Y, + sorted_ids, + sorted_expert_ids, + max_token_ids, + topk, + n_padded_zeros, + k_padded_zeros, + topk_weight, + x_scale, + w_scale, + exp_bias, + act_op, + k_batch); } } else @@ -328,4 +368,4 @@ torch::Tensor cktile_moe_gemm2(torch::Tensor& XQ, TORCH_CHECK(false, "Unsupported scales/output dtype!"); } return Y; -} \ No newline at end of file +} diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py index f5c00b1b41..22cecc543e 100644 --- a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py +++ b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py @@ -25,6 +25,14 @@ } +dtype_dict = { + "fp8": "ck_tile::fp8_t", + "bf16": "ck_tile::bf16_t", + "float": "float", + "fp4": "ck_tile::pk_fp4_t", +} + + @dataclass class kernelInstance: stage: int @@ -185,12 +193,36 @@ def dispatch_suffix(self) -> str: # 4: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 32, 1, 4,), } +# gemm1 out:bf16/fp16 AB:fp8/fp4 +a8w4_gemm1_kernels_list_gfx950= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| BlockPerCU| + # 0: kernelInstance( 1, 256, 16, 128, 256, 16, 16, 128, 1, 4, 2,), + # 5: kernelInstance( 2, 256, 16, 512, 256, 16, 16, 32, 1, 4, 4,), + 1: kernelInstance( 1, 256, 32, 256, 256, 16, 16, 128, 1, 4, 2,), + 3: kernelInstance( 1, 256, 64, 256, 256, 16, 16, 128, 1, 4, 1,), + # 4: kernelInstance( 2, 256, 128, 256, 128, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 2, 256, 256, 256, 256, 16, 16, 32, 1, 4,), + # 4: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 32, 1, 4,), +} +# gemm2 out:bf16/fp16 AB:fp8/fp4 +a8w4_gemm2_kernels_list_gfx950= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| BlockPerCU| + # 0: kernelInstance( 2, 256, 16, 128, 256, 16, 16, 128, 1, 4, 2,), + # 5: kernelInstance( 2, 256, 16, 512, 256, 16, 16, 32, 1, 4, 4,), + 1: kernelInstance( 2, 256, 32, 256, 256, 16, 16, 128, 1, 4, 2,), + 3: kernelInstance( 2, 256, 64, 256, 256, 16, 16, 128, 1, 4, 1,), + # 4: kernelInstance( 2, 256, 128, 256, 128, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 2, 256, 256, 256, 256, 16, 16, 32, 1, 4,), + # 4: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 32, 1, 4,), +} + # fmt: on gemm1_kernels_dict = { "a8w8_gfx950": a8w8_gemm1_kernels_list_gfx950, "a8w8": a8w8_gemm1_kernels_list, "a16w4_gfx950": a16w4_gemm1_kernels_list_gfx950, "a16w4": a16w4_gemm1_kernels_list, + "a8w4_gfx950": a8w4_gemm1_kernels_list_gfx950, } gemm2_kernels_dict = { @@ -198,6 +230,7 @@ def dispatch_suffix(self) -> str: "a8w8": a8w8_gemm2_kernels_list, "a16w4_gfx950": a16w4_gemm2_kernels_list_gfx950, "a16w4": a16w4_gemm2_kernels_list, + "a8w4_gfx950": a8w4_gemm2_kernels_list_gfx950, } @@ -205,28 +238,29 @@ def dispatch_suffix(self) -> str: // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "moe_cktile2stages.h" +#include "moe_cktile2stages_heuristic_dispatch_common.h" -template -struct moe_gemm1_heuristic_dispatcher +template <> +struct moe_gemm1_heuristic_dispatcher<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}, {(activation)}, {(has_bias)}, {(split_k)}> {{ static MoeKernel dispatch(int M, int N, int K, int block_m) {{ // Apply shape heuristics to find a suitable kernel implementation. if (block_m == 32) {{ - return {(1, 1)}; + return {(1, 1)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else if (block_m == 64) {{ - return {(1, 2)}; + return {(1, 2)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} //else if (block_m == 128) //{{ - // return {(1, 4)}; + // return {(1, 4)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; //}} //else if (block_m == 256) //{{ - // return {(1, 6)}; + // return {(1, 6)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; //}} else {{ @@ -238,27 +272,27 @@ def dispatch_suffix(self) -> str: }} }}; -template -struct moe_gemm2_heuristic_dispatcher +template <> +struct moe_gemm2_heuristic_dispatcher<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}, {(activation)}, {(has_bias)}, {(split_k)}> {{ static MoeKernel dispatch(int M, int N, int K, int block_m) {{ // Apply shape heuristics to find a suitable kernel implementation. if (block_m == 32) {{ - return {(2, 0)}; + return {(2, 0)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else if (block_m == 64) {{ - return {(2, 1)}; + return {(2, 1)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} //else if (block_m == 128) //{{ - // return {(2, 2)}; + // return {(2, 2)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; //}} //else if (block_m == 256) //{{ - // return {(2, 3)}; + // return {(2, 3)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; //}} else {{ @@ -275,24 +309,25 @@ def dispatch_suffix(self) -> str: // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "moe_cktile2stages.h" +#include "moe_cktile2stages_heuristic_dispatch_common.h" -template -struct moe_gemm1_heuristic_dispatcher +template <> +struct moe_gemm1_heuristic_dispatcher<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}, {(activation)}, {(has_bias)}, {(split_k)}> {{ static MoeKernel dispatch(int M, int N, int K, int block_m) {{ // Apply shape heuristics to find a suitable kernel implementation. if (block_m == 16) {{ - return {(1, 0)}; + return {(1, 0)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else if (block_m == 32) {{ - return {(1, 1)}; + return {(1, 1)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else if (block_m == 64) {{ - return {(1, 3)}; + return {(1, 3)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else {{ @@ -304,23 +339,23 @@ def dispatch_suffix(self) -> str: }} }}; -template -struct moe_gemm2_heuristic_dispatcher +template <> +struct moe_gemm2_heuristic_dispatcher<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}, {(activation)}, {(has_bias)}, {(split_k)}> {{ static MoeKernel dispatch(int M, int N, int K, int block_m) {{ // Apply shape heuristics to find a suitable kernel implementation. if (block_m == 16) {{ - return {(2, 0)}; + return {(2, 0)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else if (block_m == 32) {{ - return {(2, 1)}; + return {(2, 1)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else if (block_m == 64) {{ - return {(2, 3)}; + return {(2, 3)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else {{ @@ -337,24 +372,25 @@ def dispatch_suffix(self) -> str: // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "moe_cktile2stages.h" +#include "moe_cktile2stages_heuristic_dispatch_common.h" -template -struct moe_gemm1_heuristic_dispatcher +template <> +struct moe_gemm1_heuristic_dispatcher<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}, {(activation)}, {(has_bias)}, {(split_k)}> {{ static MoeKernel dispatch(int M, int N, int K, int block_m) {{ // Apply shape heuristics to find a suitable kernel implementation. if (block_m == 16) {{ - return {(1, 0)}; + return {(1, 0)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else if (block_m == 32) {{ - return {(1, 1)}; + return {(1, 1)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else if (block_m == 64) {{ - return {(1, 3)}; + return {(1, 3)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else {{ @@ -366,23 +402,78 @@ def dispatch_suffix(self) -> str: }} }}; -template -struct moe_gemm2_heuristic_dispatcher +template <> +struct moe_gemm2_heuristic_dispatcher<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}, {(activation)}, {(has_bias)}, {(split_k)}> {{ static MoeKernel dispatch(int M, int N, int K, int block_m) {{ // Apply shape heuristics to find a suitable kernel implementation. if (block_m == 16) {{ - return {(2, 0)}; + return {(2, 0)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else if (block_m == 32) {{ - return {(2, 1)}; + return {(2, 1)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else if (block_m == 64) {{ - return {(2, 3)}; + return {(2, 3)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_gemm2 heuristic dispatch: ", + block_m); + }} + }} +}}; +""" + +a8w4_gfx950_heuristic_dispatch = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages.h" +#include "moe_cktile2stages_heuristic_dispatch_common.h" + +template <> +struct moe_gemm1_heuristic_dispatcher<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}, {(activation)}, {(has_bias)}, {(split_k)}> +{{ + static MoeKernel dispatch(int M, int N, int K, int block_m) + {{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 32) + {{ + return {(1, 1)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; + }} + else if (block_m == 64) + {{ + return {(1, 3)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_geem1 heuristic dispatch: ", + block_m); + }} + }} +}}; + +template <> +struct moe_gemm2_heuristic_dispatcher<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}, {(activation)}, {(has_bias)}, {(split_k)}> +{{ + static MoeKernel dispatch(int M, int N, int K, int block_m) + {{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 32) + {{ + return {(2, 1)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; + }} + else if (block_m == 64) + {{ + return {(2, 3)}<{(a_data_type)}, {(b_data_type)}, {(acc_data_type)}, {(c_data_type)}>; }} else {{ @@ -400,6 +491,7 @@ def dispatch_suffix(self) -> str: # "a8w8": a8w8_gemm2_kernels_list, "a16w4_gfx950": a16w4_gfx950_heuristic_dispatch, "a16w4": a16w4_heuristic_dispatch, + "a8w4_gfx950": a8w4_gfx950_heuristic_dispatch, } @@ -429,6 +521,13 @@ def get_gemm1_kernels_list( tag = "a16w4_gfx950" else: tag = "a16w4" + elif Adtype.lower() in bit8_list and Bdtype in bit4_list: + if arch == "gfx950": + tag = "a8w4_gfx950" + else: + raise ValueError( + f"Unsupported data type combination: {Adtype}, {Bdtype} on {arch}" + ) else: raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}") kernels_list = gemm1_kernels_dict[tag] @@ -471,6 +570,13 @@ def get_gemm2_kernels_list( tag = "a16w4_gfx950" else: tag = "a16w4" + elif Adtype.lower() in bit8_list and Bdtype in bit4_list: + if arch == "gfx950": + tag = "a8w4_gfx950" + else: + raise ValueError( + f"Unsupported data type combination: {Adtype}, {Bdtype} on {arch}" + ) else: raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}") kernels_list = gemm2_kernels_dict[tag] diff --git a/csrc/cpp_itfs/gluon_aot_tools/compile.py b/csrc/cpp_itfs/gluon_aot_tools/compile.py new file mode 100644 index 0000000000..e6ed6851eb --- /dev/null +++ b/csrc/cpp_itfs/gluon_aot_tools/compile.py @@ -0,0 +1,270 @@ +import binascii +import hashlib +import importlib.util +import sys +from argparse import ArgumentParser +from dataclasses import dataclass +from pathlib import Path +from typing import List + +import triton +import triton.backends + + +@dataclass +class CompileArgs: + """ + A class to contain arguments from command-line parser. + """ + + path: str = "" + kernel_name: str = "" + signature: str = "" + grid: str = "" + target: str | None = None + num_warps: int = 1 + num_stages: int = 3 + out_name: str | None = None + out_path: Path | None = None + + +desc = """ +Triton ahead-of-time compiler: + +This program compiles the kernel with name `kernel-name` in the file at the +provided `path` into self-contained C source-code that embeds the `cubin` +data along with utilities to load, unload and launch the kernel. + +signature is provided as a list of (optionally divisibility-hinted) types +or constexpr values, e.g. + +`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` + +will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. +Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, +and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype. + +The resulting entry point will have signature + +CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2) + +Different such specialized entry points can be combined using the `linker.py` script. + +NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter +used to run this `compile.py` script +""" + + +def main(): + # command-line arguments + parser = ArgumentParser(description=desc) + parser.add_argument( + "path", + help="Path to Python source containing desired kernel in its scope. File will be executed.", + ) + parser.add_argument( + "--kernel-name", + "-n", + type=str, + default="", + help="Name of the kernel to compile", + required=True, + ) + parser.add_argument( + "--target", + "-t", + type=str, + default=None, + help="The target to compile towards, in format of '::'; " + "e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target", + ) + parser.add_argument( + "--num-warps", + "-w", + type=int, + default=1, + help="Number of warps to launch the kernel", + ) + parser.add_argument( + "--num-stages", + "-ns", + type=int, + default=3, + help="Number of stages (meta-parameter of the kernel)", + ) + parser.add_argument( + "--out-name", + "-on", + type=str, + default=None, + help="Out name for the compiled kernel", + ) + parser.add_argument( + "--out-path", "-o", type=Path, default=None, help="Out filename" + ) + parser.add_argument( + "--signature", "-s", type=str, help="Signature of the kernel", required=True + ) + parser.add_argument( + "--grid", "-g", type=str, help="Launch grid of the kernel", required=True + ) + cli_args = parser.parse_args() + args = CompileArgs( + **vars(cli_args) + ) # A sanity check to ensure class CompileArgs is updated as well. + compile_kernel(args) + + +def compile_kernel(args: CompileArgs): + out_name = args.out_name if args.out_name else args.kernel_name + out_path = args.out_path if args.out_path else Path(out_name) + + # execute python sources and extract functions wrapped in JITFunction + arg_path = Path(args.path) + sys.path.insert(0, str(arg_path.parent)) + spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + kernel = getattr(mod, args.kernel_name) + grid = args.grid.split(",") + assert len(grid) == 3 + + # validate and parse signature + signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) + + def hash_signature(signature: List[str]): + m = hashlib.sha256() + m.update(" ".join(signature).encode()) + return m.hexdigest()[:8] + + meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" + sig_hash = hash_signature(signature + [meta_sig]) + + def constexpr(s): + try: + ret = int(s) + return ret + except ValueError: + pass + try: + ret = float(s) + return ret + except ValueError: + pass + return None + + hints = { + (i,): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s + } + hints = {k: v for k, v in hints.items() if v is not None} + constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} + for key, value in hints.items(): + if value == 1: + constants[kernel.arg_names[key[0]]] = value + signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)} + for key in constants: + signature[key] = "constexpr" + const_sig = "x".join([str(v) for v in constants.values()]) + doc_string = [f"{k}={v}" for k, v in constants.items()] + doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] + # compile ast into cubin + for h in hints.values(): + assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" + attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16} + src = triton.compiler.ASTSource( + fn=kernel, constexprs=constants, signature=signature, attrs=attrs + ) + + target = ( + triton.backends.compiler.GPUTarget(*args.target.split(":")) + if args.target + else triton.runtime.driver.active.get_current_target() + ) + backend = triton.compiler.make_backend(target) + kwargs = {"num_warps": args.num_warps, "num_stages": args.num_stages} + options = backend.parse_options(kwargs) + ccinfo = triton.compile(src, target=target, options=options.__dict__) + + if getattr(ccinfo.metadata, "global_scratch_size", 0) > 0: + raise RuntimeError( + "AOT compiling kernels with global scratch requirements is not yet implemented" + ) + if ccinfo.metadata.profile_scratch_size > 0: + raise RuntimeError( + "AOT compiling kernels with profile scratch requirements is not yet implemented" + ) + + arg_names = [] + arg_types = [] + arg_names_not_1 = [] + arg_types_not_1 = [] + for i, arg_name in enumerate(kernel.arg_names): + if arg_name not in constants: + arg_names.append(arg_name) + arg_types.append(signature[arg_name]) + arg_names_not_1.append(arg_name) + arg_types_not_1.append(signature[arg_name]) + elif hints.get((i,), None) == 1: + arg_names.append(arg_name) + arg_types.append("i32") + + # dump C stub code + suffix = "" + for i, ty in enumerate(signature.values()): + suffix += str(i) + if hints.get((i,), None) == 1: + suffix += "c" + if hints.get((i,), None) == 16: + suffix += "d" + func_name = "_".join([out_name, sig_hash, suffix]) + asm = ccinfo.asm[backend.binary_ext] # store binary data once + + hex_ = str(binascii.hexlify(asm))[2:-1] + + ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type + + params = { + "kernel_name": func_name, + "triton_kernel_name": args.kernel_name, + "bin_size": len(asm), + "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), + "signature": ", ".join( + [ + f"{ty_to_cpp(ty)} {name}" + for name, ty in zip(arg_names_not_1, arg_types_not_1) + ] + ), + "full_signature": ", ".join( + [f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)] + ), + "arg_pointers": ", ".join( + [f"&{arg}" for arg in arg_names_not_1] + + ["&global_scratch"] + + ["&profile_scratch"] + ), + "num_args": len(arg_names_not_1) + 2, # +2 for global and profile scratch + "kernel_docstring": doc_string, + "shared": ccinfo.metadata.shared, + "num_warps": args.num_warps, + "algo_info": "_".join([const_sig, meta_sig]), + "gridX": grid[0], + "gridY": grid[1], + "gridZ": grid[2], + "_placeholder": "", + } + output_files = [] + backend_name = target.backend + template_dir = Path(__file__).parent / "extra" / backend_name + for template_path in template_dir.glob("compile.*"): + ext = template_path.suffix + output_file = out_path.with_suffix(f".{sig_hash}_{suffix}{ext}") + with output_file.open("w") as fp: + fp.write(template_path.read_text().format(**params)) + output_files.append(output_file) + + return func_name, output_files + + +if __name__ == "__main__": + main() diff --git a/csrc/cpp_itfs/mha_bwd.cpp b/csrc/cpp_itfs/mha_bwd.cpp index e2f97f8541..2890795854 100644 --- a/csrc/cpp_itfs/mha_bwd.cpp +++ b/csrc/cpp_itfs/mha_bwd.cpp @@ -127,8 +127,9 @@ std::tuple get_heuristic_kernel(std::stri float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) { + float asm_ret = fmha_v3_bwd(a, s); #if ONLY_FAV3 - return fmha_v3_bwd(a, s); + return asm_ret; #else fmha_bwd_traits traits{a.hdim_q, a.hdim_v, @@ -225,11 +226,11 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) /* drop_seed_offset */ a.drop_seed_offset, }; - float asm_ret = fmha_v3_bwd(a, s); if(asm_ret == -1) { return fmha_bwd(traits, ck_args, s); } + return asm_ret; #endif } @@ -462,9 +463,13 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) if(a.mask_type == 3) { + // Note: sink_size=0 is passed as the 3rd parameter (attention sink not supported in bwd + // yet) + auto sink_size = 0; auto generic_mask = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( a.window_size_left, a.window_size_right, + sink_size, a.seqlen_q, a.seqlen_k, (a.ck_mask_type == static_cast(mask_enum::mask_top_left) || diff --git a/csrc/cpp_itfs/pa_gluon_aot/pa_decode_gluon_aot.py b/csrc/cpp_itfs/pa_gluon_aot/pa_decode_gluon_aot.py index 20bd6c3942..955745c7ce 100644 --- a/csrc/cpp_itfs/pa_gluon_aot/pa_decode_gluon_aot.py +++ b/csrc/cpp_itfs/pa_gluon_aot/pa_decode_gluon_aot.py @@ -1,50 +1,50 @@ import os -import time import shutil import subprocess +import time from pathlib import Path -from jinja2 import Template -import torch + import aiter import aiter.ops.triton.utils._triton.arch_info as arch_info +import torch import triton import triton.language as tl +from jinja2 import Template -GLUON_AOT_COMPILE_ENABLED = True -try: - from triton.experimental import gluon - from triton.experimental.gluon import language as gl -except ImportError: - print( - "Warning: triton.experimental.gluon or triton.experimental.gluon.language not exists, pa_decode_gluon_aot cannot use compile mode!" - ) - GLUON_AOT_COMPILE_ENABLED = False - -try: - from triton.tools.compile import compile_kernel, CompileArgs -except ImportError: - print("Warning: compile_kernel or CompileArgs is not in triton.tools.compile!") - +from aiter.ops.triton.gluon.pa_decode_gluon import get_cdna_version +from csrc.cpp_itfs.gluon_aot_tools.compile import ( + CompileArgs, + compile_kernel, +) from csrc.cpp_itfs.gluon_aot_tools.compile_gluon import ( - compile_gluon_kernel, CompileGluonArgs, + compile_gluon_kernel, +) +from csrc.cpp_itfs.pa_gluon_aot.transpose_query_output_gluon_aot import ( + transpose_output_gluon_aot, + transpose_query_gluon_aot, ) from csrc.cpp_itfs.torch_utils import torch_to_c_types from csrc.cpp_itfs.utils import ( - BUILD_DIR, AITER_CORE_DIR, - get_default_func_name, + BUILD_DIR, compile_template_op, + get_default_func_name, + logger, mp_lock, not_built, run_lib, - logger, -) -from csrc.cpp_itfs.pa_gluon_aot.transpose_query_output_gluon_aot import ( - transpose_query_gluon_aot, - transpose_output_gluon_aot, ) -from aiter.ops.triton.gluon.pa_decode_gluon import get_cdna_version + +GLUON_AOT_COMPILE_ENABLED = True +try: + from triton.experimental import gluon # noqa: F401 + from triton.experimental.gluon import language as gl # noqa: F401 +except ImportError: + print( + "Warning: triton.experimental.gluon or triton.experimental.gluon.language not exists, pa_decode_gluon_aot cannot use compile mode!" + ) + GLUON_AOT_COMPILE_ENABLED = False MD_NAME = "pa_decode_attention_reduce_kernel" @@ -150,8 +150,8 @@ def compile( "This version triton is not support gluon aot compile, please upgrade to 3.5.0 or higher!" ) - kv_compute_block_size = 256 waves_per_eu = 1 + kv_compute_block_size = context_partition_size # Select kernel implementation based on block size if kv_block_size > context_partition_size: # Use big block kernel for large block sizes diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja index f7d0261f9c..1c176062db 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja @@ -24,6 +24,7 @@ void* top_k_arr_ptr, \ int batch_size, \ int top_k_val, \ + int vocab_size, \ void* stream) extern "C" { @@ -32,12 +33,10 @@ FUNCTION_DEFINE; FUNCTION_DEFINE { - constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); - const uint32_t smem_size = sizeof(aiter::sampling::RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(aiter::sampling::BLOCK_THREADS); - auto kernel = aiter::sampling::TopKRenormProbKernel; + auto kernel = aiter::sampling::TopKRenormProbKernel; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(renormed_probs_ptr), reinterpret_cast(top_k_arr_ptr), top_k_val, {{d}}); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(renormed_probs_ptr), reinterpret_cast(top_k_arr_ptr), top_k_val, vocab_size); } \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py index cfc816798f..524285c006 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py @@ -4,7 +4,7 @@ from jinja2 import Template from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR - +import math MD_NAME = "top_k_renorm_probs" @@ -16,7 +16,7 @@ def compile( - d: int, + vec_size: int, folder: str = None, ): return compile_template_op( @@ -27,7 +27,7 @@ def compile( f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", ], - d=d, + vec_size=vec_size, folder=folder, ) @@ -46,16 +46,17 @@ def top_k_renorm_probs( batch_size = probs.size(0) vocab_size = probs.size(1) - + vec_size = math.gcd(16 // probs.element_size(), vocab_size) renorm_probs = torch.empty_like(probs) - func = compile(vocab_size) + func = compile(vec_size) ( probs_ptr, renorm_probs_ptr, top_k_arr_ptr, top_k_val, batch_size, + vocab_size, stream, ) = torch_to_c_types( probs, @@ -63,6 +64,7 @@ def top_k_renorm_probs( maybe_top_k_arr, top_k_val, batch_size, + vocab_size, torch.cuda.current_stream(), ) func( @@ -71,6 +73,7 @@ def top_k_renorm_probs( top_k_arr_ptr, batch_size, top_k_val, + vocab_size, stream, ) return renorm_probs diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja index 301b5c9790..6408c41d96 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja @@ -29,6 +29,7 @@ float top_p_val, \ int philox_seed, \ int philox_offset, \ + int vocab_size, \ void* stream) extern "C" { @@ -37,13 +38,11 @@ FUNCTION_DEFINE; FUNCTION_DEFINE { - constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); - const uint32_t smem_size = sizeof(aiter::sampling::SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(aiter::sampling::BLOCK_THREADS); auto kernel = aiter::sampling::TopKTopPSamplingFromProbKernel; + {{vec_size}}, {{"true" if deterministic else "false"}}, float, int>; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(top_k_arr_ptr), reinterpret_cast(top_p_arr_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), top_k_val, top_p_val, {{d}}, static_cast(philox_seed), static_cast(philox_offset)); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(top_k_arr_ptr), reinterpret_cast(top_p_arr_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), top_k_val, top_p_val, vocab_size, static_cast(philox_seed), static_cast(philox_offset)); } \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py index 48fbe6e6f3..0ac5520dc8 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py @@ -4,6 +4,7 @@ from jinja2 import Template from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool +import math MD_NAME = "top_k_top_p_sampling_from_probs" @@ -16,7 +17,7 @@ def compile( - d: int, + vec_size: int, deterministic: bool, folder: str = None, ): @@ -28,7 +29,7 @@ def compile( f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", ], - d=d, + vec_size=vec_size, deterministic=deterministic, folder=folder, ) @@ -61,8 +62,8 @@ def top_k_top_p_sampling_from_probs( philox_seed = generator.seed() output = torch.empty(batch_size, dtype=torch.int32, device=probs.device) - - func = compile(vocab_size, deterministic) + vec_size = math.gcd(16 // probs.element_size(), vocab_size) + func = compile(vec_size, deterministic) ( probs_ptr, output_ptr, @@ -74,6 +75,7 @@ def top_k_top_p_sampling_from_probs( batch_size, philox_seed, philox_offset, + vocab_size, stream, ) = torch_to_c_types( probs, @@ -86,6 +88,7 @@ def top_k_top_p_sampling_from_probs( batch_size, philox_seed, philox_offset, + vocab_size, torch.cuda.current_stream(), ) func( @@ -99,6 +102,7 @@ def top_k_top_p_sampling_from_probs( top_p_val, philox_seed, philox_offset, + vocab_size, stream, ) return output diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja index 99c23b44e7..020b494ffe 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja @@ -26,6 +26,7 @@ float top_p_val, \ int philox_seed, \ int philox_offset, \ + int vocab_size, \ void* stream) extern "C" { @@ -34,13 +35,12 @@ FUNCTION_DEFINE; FUNCTION_DEFINE { - constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); const uint32_t smem_size = sizeof(aiter::sampling::SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(aiter::sampling::BLOCK_THREADS); auto kernel = aiter::sampling::TopPSamplingFromProbKernel; + {{vec_size}}, {{"true" if deterministic else "false"}}, float, int>; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), reinterpret_cast(top_p_arr_ptr), top_p_val, {{d}}, static_cast(philox_seed), static_cast(philox_offset)); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), reinterpret_cast(top_p_arr_ptr), top_p_val, vocab_size, static_cast(philox_seed), static_cast(philox_offset)); } \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py index 7e1500b231..3c9f5b9af6 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py @@ -4,6 +4,7 @@ from jinja2 import Template from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool +import math MD_NAME = "top_p_sampling_from_probs" @@ -16,7 +17,7 @@ def compile( - d: int, + vec_size: int, deterministic: bool, folder: str = None, ): @@ -28,7 +29,7 @@ def compile( f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", ], - d=d, + vec_size=vec_size, deterministic=deterministic, folder=folder, ) @@ -56,9 +57,9 @@ def top_p_sampling_from_probs( batch_size = probs.size(0) vocab_size = probs.size(1) - + vec_size = math.gcd(16 // probs.element_size(), vocab_size) samples = torch.empty(batch_size, dtype=torch.int32, device=probs.device) - func = compile(vocab_size, deterministic) + func = compile(vec_size, deterministic) ( probs_ptr, samples_ptr, @@ -68,6 +69,7 @@ def top_p_sampling_from_probs( batch_size, philox_seed, philox_offset, + vocab_size, stream, ) = torch_to_c_types( probs, @@ -78,6 +80,7 @@ def top_p_sampling_from_probs( batch_size, philox_seed, philox_offset, + vocab_size, torch.cuda.current_stream(), ) func( @@ -89,6 +92,7 @@ def top_p_sampling_from_probs( top_p_val, philox_seed, philox_offset, + vocab_size, stream, ) return samples diff --git a/csrc/include/asm_gemm_a16w16.h b/csrc/include/asm_gemm_a16w16.h index c7788bb3ec..26a207882c 100644 --- a/csrc/include/asm_gemm_a16w16.h +++ b/csrc/include/asm_gemm_a16w16.h @@ -6,6 +6,7 @@ torch::Tensor gemm_a16w16_asm(torch::Tensor& A, // A:[M, K] bf16 torch::Tensor& B, // B:[N, K] bf16 torch::Tensor& out, // Out:[M, N] f32 + torch::Tensor& semaphore, std::optional bias, std::optional splitK, std::optional kernelName, diff --git a/csrc/include/mha_bwd.h b/csrc/include/mha_bwd.h index 99fc32e87e..e552c01be5 100644 --- a/csrc/include/mha_bwd.h +++ b/csrc/include/mha_bwd.h @@ -10,7 +10,8 @@ namespace aiter { -struct mha_bwd_args { +struct mha_bwd_args +{ // aiter args int mask_type; // 0: no mask 1: top_left_causal 2: bottom_right_causal 3: sliding_window bool use_asm_v3; @@ -146,93 +147,100 @@ struct mha_bwd_args { struct __attribute__((packed)) fmha_bwd_dqdkdv_args { - void* ptr_dq; // 0x00: dq or dq_acc + void* ptr_dq; // 0x00: dq or dq_acc p2 _p0; - void* ptr_dk; // 0x10 + void* ptr_dk; // 0x10 p2 _p1; - void* ptr_dv; // 0x20 + void* ptr_dv; // 0x20 p2 _p2; - const void* ptr_q; // 0x30 + const void* ptr_q; // 0x30 p2 _p3; - const void* ptr_k; // 0x40 + const void* ptr_k; // 0x40 p2 _p4; - const void* ptr_v; // 0x50 + const void* ptr_v; // 0x50 p2 _p5; - const void* ptr_do; // 0x60 + const void* ptr_do; // 0x60 p2 _p6; - const void* ptr_lse; // 0x70 + const void* ptr_lse; // 0x70 p2 _p7; - const void* ptr_d; // 0x80 + const void* ptr_d; // 0x80 p2 _p8; - float scalar; // 0x90 + float scalar; // 0x90 p3 _p9; - float log2e; // 0xa0 + float log2e; // 0xa0 p3 _p10; - unsigned int seqlen_q; // 0xb0: s_seq_len_q + unsigned int seqlen_q; // 0xb0: s_seq_len_q p3 _p11; - unsigned int Ts; // 0xc0: s_Seqs_k*sub_K + unsigned int Ts; // 0xc0: s_Seqs_k*sub_K p3 _p12; - unsigned int Hs_q; // 0xd0: s_Hs_q + unsigned int Hs_q; // 0xd0: s_Hs_q p3 _p13; - unsigned int BAs_q; // 0xe0: s_BAs_q + unsigned int BAs_q; // 0xe0: s_BAs_q p3 _p14; - unsigned int Seqs_q; // 0xf0: s_Seqs_q + unsigned int Seqs_q; // 0xf0: s_Seqs_q p3 _p15; - unsigned int ratio; // 0x100 + unsigned int ratio; // 0x100 p3 _p16; - unsigned int Hs_k; // 0x110: s_Hs_k + unsigned int Hs_k; // 0x110: s_Hs_k p3 _p17; - unsigned int BAs_k; // 0x120: s_BAs_k + unsigned int BAs_k; // 0x120: s_BAs_k p3 _p18; - unsigned int Seqs_k; // 0x130: s_Seqs_k + unsigned int Seqs_k; // 0x130: s_Seqs_k p3 _p19; - unsigned int Seqs_dk; // 0x140: s_Seqs_dk + unsigned int Seqs_dk; // 0x140: s_Seqs_dk p3 _p20; - unsigned int seqlen_k; // 0x150: batch mode + unsigned int seqlen_k; // 0x150: batch mode p3 _p21; - unsigned int head_dim_q; // 0x160: batch&group mode for headdim padding + unsigned int head_dim_q; // 0x160: batch&group mode for headdim padding p3 _p22; - unsigned int head_dim_v; // 0x170: batch&group mode for headdim padding + unsigned int head_dim_v; // 0x170: batch&group mode for headdim padding p3 _p23; - unsigned int nhead_q; // 0x180: batch mode lsed([B,H,S]) addr = batch_idx * nhead_q * seqlen_q * 4 + head_idx * seqlen_q * 4 + unsigned int nhead_q; // 0x180: batch mode lsed([B,H,S]) addr = batch_idx * nhead_q * seqlen_q * + // 4 + head_idx * seqlen_q * 4 p3 _p24; - unsigned int Hs_v; // 0x190: batch&group mode + unsigned int Hs_v; // 0x190: batch&group mode p3 _p25; - unsigned int BAs_v; // 0x1a0: batch mode + unsigned int BAs_v; // 0x1a0: batch mode p3 _p26; - unsigned int Seqs_v; // 0x1b0: batch&group mode + unsigned int Seqs_v; // 0x1b0: batch&group mode p3 _p27; - unsigned int Hs_do; // 0x1c0: batch&group mode + unsigned int Hs_do; // 0x1c0: batch&group mode p3 _p28; - unsigned int BAs_do; // 0x1d0: group mode + unsigned int BAs_do; // 0x1d0: group mode p3 _p29; - unsigned int Seqs_do; // 0x1e0: batch&group mode + unsigned int Seqs_do; // 0x1e0: batch&group mode p3 _p30; - unsigned int Hs_dk; // 0x1f0: batch&group mode + unsigned int Hs_dk; // 0x1f0: batch&group mode p3 _p31; - unsigned int BAs_dk; // 0x200: group mode + unsigned int BAs_dk; // 0x200: group mode p3 _p32; - unsigned int Hs_dv; // 0x210: batch&group mode + unsigned int Hs_dv; // 0x210: batch&group mode p3 _p33; - unsigned int BAs_dv; // 0x220: group mode + unsigned int BAs_dv; // 0x220: group mode p3 _p34; - unsigned int Seqs_dv; // 0x230: batch&group mode + unsigned int Seqs_dv; // 0x230: batch&group mode p3 _p35; - unsigned int Hs_lsed; // 0x240: group mode lsed([H,TotalValid_Q(90)]) addr = seqstart_q[batch_idx] * 4 + head_idx * nhead_stride_lsed(s_Hs_lsed) + unsigned int Hs_lsed; // 0x240: group mode lsed([H,TotalValid_Q(90)]) addr = + // seqstart_q[batch_idx] * 4 + head_idx * nhead_stride_lsed(s_Hs_lsed) p3 _p36; - const void* ptr_qseq; // 0x250: group mode seqstart_q [0, 20, 50, 90] + const void* ptr_qseq; // 0x250: group mode seqstart_q [0, 20, 50, 90] p2 _p37; - const void* ptr_kseq; // 0x260: group mode seqstart_k [0, 50, 110, 180] + const void* ptr_kseq; // 0x260: group mode seqstart_k [0, 50, 110, 180] p2 _p38; - const void* ptr_qseq_padded; // 0x270: group mode seqstart_q_padded [0, 30(20+10), 70(20+10+30+10), 120(20+10+30+10+40+10)] if 10 is padded after each seqlen[30(20+10), 40(30+10), 50(40+10)] + const void* ptr_qseq_padded; // 0x270: group mode seqstart_q_padded [0, 30(20+10), + // 70(20+10+30+10), 120(20+10+30+10+40+10)] if 10 is padded after + // each seqlen[30(20+10), 40(30+10), 50(40+10)] p2 _p39; - const void* ptr_kseq_padded; // 0x280: group mode seqstart_k_padded [0, 60(50+10), 130(50+10+60+10), 200(50+10+60+10+70+10)] if 10 is padded after each seqlen[60(50+10), 70(60+10), 80(70+10)] + const void* ptr_kseq_padded; // 0x280: group mode seqstart_k_padded [0, 60(50+10), + // 130(50+10+60+10), 200(50+10+60+10+70+10)] if 10 is padded after + // each seqlen[60(50+10), 70(60+10), 80(70+10)] p2 _p40; - unsigned int max_seqlen_dq; // 0x290: gorup mode max seqlen q for a16 dq_acc store, padding to 16x + unsigned int + max_seqlen_dq; // 0x290: gorup mode max seqlen q for a16 dq_acc store, padding to 16x p3 _p41; - int mask_x; // 0x2a0 + int mask_x; // 0x2a0 p3 _p42; - int mask_y; // 0x2b0 + int mask_y; // 0x2b0 p3 _p43; }; diff --git a/csrc/include/moe_ck.h b/csrc/include/moe_ck.h index a6d461415b..f9b06dfe9f 100644 --- a/csrc/include/moe_ck.h +++ b/csrc/include/moe_ck.h @@ -18,7 +18,9 @@ void ck_moe_stage1(torch::Tensor& hidden_states, // [m, k], input token std::optional block_m, std::optional sorted_weights, int quant_type, - int activation); + int activation, + int splitk, + std::optional dst_type); void ck_moe_stage2(torch::Tensor& inter_states, // [m, k], input token torch::Tensor& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) @@ -34,4 +36,6 @@ void ck_moe_stage2(torch::Tensor& inter_states, // [m, k], input token std::optional block_m, std::optional sorted_weights, // [max_num_tokens_padded]); int quant_type, - int activation); + int activation, + int splitk, + std::optional dst_type); diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index bc3631e2a2..f2b96e4483 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -907,7 +907,7 @@ template<> OPUS_D float min(const float&a, const float&b) { return template OPUS_D T med3(const T&a, const T&b, const T&c) { auto max_0 = max(a, b); auto min_0 = max(a, b); return max(max_0, max(min_0, c)); } template<> OPUS_D float med3(const float&a, const float&b, const float&c) { return __builtin_amdgcn_fmed3f(a, b, c); } -template<> OPUS_D __fp16 med3<__fp16>(const __fp16&a, const __fp16&b, const __fp16&c) { return __builtin_amdgcn_fmed3h(a, b, c); } +template<> OPUS_D _Float16 med3<_Float16>(const _Float16&a, const _Float16&b, const _Float16&c) { return __builtin_amdgcn_fmed3h(a, b, c); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // buffer load/store related OPUS_D constexpr auto buffer_default_config() { diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index e7d636b832..9f526c81c1 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -3,6 +3,7 @@ #pragma once #include + namespace py = pybind11; #define ACTIVATION_PYBIND \ @@ -455,6 +456,7 @@ namespace py = pybind11; py::arg("A"), \ py::arg("B"), \ py::arg("out"), \ + py::arg("semaphore"), \ py::arg("bias") = std::nullopt, \ py::arg("splitK") = std::nullopt, \ py::arg("kernelName") = std::nullopt, \ @@ -658,6 +660,64 @@ namespace py = pybind11; py::arg("rng_state") = std::nullopt, \ py::arg("gen") = std::nullopt); +#define ROCSOLGEMM_PYBIND \ + m.def("rocb_create_extension", &rocb_create_extension, "create_extension"); \ + m.def("rocb_destroy_extension", &rocb_destroy_extension, "destroy_extension"); \ + m.def("rocb_mm", &RocSolIdxBlas, "mm"); \ + m.def("rocb_findallsols", &RocFindAllSolIdxBlas, "rocblas_find_all_sols"); + +#define HIPBSOLGEMM_PYBIND \ + m.def("hipb_create_extension", &hipb_create_extension, "create_extension"); \ + m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); \ + m.def("hipb_mm", \ + &hipb_mm, \ + "hipb_mm", \ + py::arg("mat1"), \ + py::arg("mat2"), \ + py::arg("solution_index"), \ + py::arg("bias") = std::nullopt, \ + py::arg("out_dtype") = std::nullopt, \ + py::arg("scaleA") = std::nullopt, \ + py::arg("scaleB") = std::nullopt, \ + py::arg("scaleOut") = std::nullopt, \ + py::arg("bpreshuffle") = std::nullopt); \ + m.def("hipb_findallsols", \ + &hipb_findallsols, \ + "hipb_findallsols", \ + py::arg("mat1"), \ + py::arg("mat2"), \ + py::arg("bias") = std::nullopt, \ + py::arg("out_dtype") = std::nullopt, \ + py::arg("scaleA") = std::nullopt, \ + py::arg("scaleB") = std::nullopt, \ + py::arg("scaleC") = std::nullopt, \ + py::arg("bpreshuffle") = false); \ + m.def("getHipblasltKernelName", &getHipblasltKernelName); + +#define LIBMHA_BWD_PYBIND \ + m.def("libmha_bwd", \ + &aiter::torch_itfs::mha_bwd, \ + py::arg("dout"), \ + py::arg("q"), \ + py::arg("k"), \ + py::arg("v"), \ + py::arg("out"), \ + py::arg("softmax_lse"), \ + py::arg("dropout_p"), \ + py::arg("softmax_scale"), \ + py::arg("is_causal"), \ + py::arg("window_size_left"), \ + py::arg("window_size_right"), \ + py::arg("deterministic"), \ + py::arg("dq") = std::nullopt, \ + py::arg("dk") = std::nullopt, \ + py::arg("dv") = std::nullopt, \ + py::arg("dbias") = std::nullopt, \ + py::arg("bias") = std::nullopt, \ + py::arg("alibi_slopes") = std::nullopt, \ + py::arg("rng_state") = std::nullopt, \ + py::arg("gen") = std::nullopt); + #define MHA_VARLEN_BWD_ASM_PYBIND \ m.def("fmha_v3_varlen_bwd", \ &aiter::torch_itfs::fmha_v3_varlen_bwd, \ @@ -756,32 +816,56 @@ namespace py = pybind11; py::arg("v_descale") = std::nullopt, \ py::arg("gen") = std::nullopt); -#define MHA_VARLEN_FWD_ASM_PYBIND \ - m.def("fmha_v3_varlen_fwd", \ - &aiter::torch_itfs::fmha_v3_varlen_fwd, \ - py::arg("q"), \ - py::arg("k"), \ - py::arg("v"), \ - py::arg("cu_seqlens_q"), \ - py::arg("cu_seqlens_k"), \ - py::arg("max_seqlen_q"), \ - py::arg("max_seqlen_k"), \ - py::arg("min_seqlen_q"), \ - py::arg("dropout_p"), \ - py::arg("softmax_scale"), \ - py::arg("logits_soft_cap"), \ - py::arg("zero_tensors"), \ - py::arg("is_causal"), \ - py::arg("window_size_left"), \ - py::arg("window_size_right"), \ - py::arg("return_softmax_lse"), \ - py::arg("return_dropout_randval"), \ - py::arg("how_v3_bf16_cvt"), \ - py::arg("out") = std::nullopt, \ - py::arg("block_table") = std::nullopt, \ - py::arg("bias") = std::nullopt, \ - py::arg("alibi_slopes") = std::nullopt, \ - py::arg("gen") = std::nullopt, \ +#define LIBMHA_FWD_PYBIND \ + m.def("libmha_fwd", \ + &aiter::torch_itfs::mha_fwd, \ + py::arg("q"), \ + py::arg("k"), \ + py::arg("v"), \ + py::arg("dropout_p"), \ + py::arg("softmax_scale"), \ + py::arg("is_causal"), \ + py::arg("window_size_left"), \ + py::arg("window_size_right"), \ + py::arg("sink_size"), \ + py::arg("return_softmax_lse"), \ + py::arg("return_dropout_randval"), \ + py::arg("cu_seqlens_q") = std::nullopt, \ + py::arg("cu_seqlens_kv") = std::nullopt, \ + py::arg("out") = std::nullopt, \ + py::arg("bias") = std::nullopt, \ + py::arg("alibi_slopes") = std::nullopt, \ + py::arg("q_descale") = std::nullopt, \ + py::arg("k_descale") = std::nullopt, \ + py::arg("v_descale") = std::nullopt, \ + py::arg("gen") = std::nullopt); + +#define MHA_VARLEN_FWD_ASM_PYBIND \ + m.def("fmha_v3_varlen_fwd", \ + &aiter::torch_itfs::fmha_v3_varlen_fwd, \ + py::arg("q"), \ + py::arg("k"), \ + py::arg("v"), \ + py::arg("cu_seqlens_q"), \ + py::arg("cu_seqlens_k"), \ + py::arg("max_seqlen_q"), \ + py::arg("max_seqlen_k"), \ + py::arg("min_seqlen_q"), \ + py::arg("dropout_p"), \ + py::arg("softmax_scale"), \ + py::arg("logits_soft_cap"), \ + py::arg("zero_tensors"), \ + py::arg("is_causal"), \ + py::arg("window_size_left"), \ + py::arg("window_size_right"), \ + py::arg("return_softmax_lse"), \ + py::arg("return_dropout_randval"), \ + py::arg("how_v3_bf16_cvt"), \ + py::arg("out") = std::nullopt, \ + py::arg("block_table") = std::nullopt, \ + py::arg("bias") = std::nullopt, \ + py::arg("alibi_slopes") = std::nullopt, \ + py::arg("gen") = std::nullopt, \ py::arg("cu_seqlens_q_padded") = std::nullopt, \ py::arg("cu_seqlens_k_padded") = std::nullopt); @@ -831,7 +915,9 @@ namespace py = pybind11; py::arg("block_m") = 32, \ py::arg("sorted_weights") = std::nullopt, \ py::arg("quant_type") = 0, \ - py::arg("activation") = 0); \ + py::arg("activation") = 0, \ + py::arg("splitk") = 1, \ + py::arg("dst_type") = std::nullopt); \ \ m.def("ck_moe_stage2", \ &ck_moe_stage2, \ @@ -849,7 +935,9 @@ namespace py = pybind11; py::arg("block_m") = 32, \ py::arg("sorted_weights") = std::nullopt, \ py::arg("quant_type") = 0, \ - py::arg("activation") = 0); + py::arg("activation") = 0, \ + py::arg("splitk") = 1, \ + py::arg("dst_type") = std::nullopt); #define MOE_CKTILE_2STAGES_PYBIND \ m.def("cktile_moe_gemm1", \ @@ -1458,25 +1546,34 @@ namespace py = pybind11; #define GEMM_COMMON_PYBIND \ m.def("get_padded_m", &getPaddedM, py::arg("M"), py::arg("N"), py::arg("K"), py::arg("gl")); -#define TOP_K_PER_ROW_PYBIND \ - m.def("top_k_per_row_prefill", \ - &top_k_per_row_prefill, \ - py::arg("logits"), \ - py::arg("rowStarts"), \ - py::arg("rowEnds"), \ - py::arg("indices"), \ - py::arg("values"), \ - py::arg("numRows"), \ - py::arg("stride0"), \ - py::arg("stride1")); \ - m.def("top_k_per_row_decode", \ - &top_k_per_row_decode, \ - py::arg("logits"), \ - py::arg("next_n"), \ - py::arg("seqLens"), \ - py::arg("indices"), \ - py::arg("numRows"), \ - py::arg("stride0"), \ +#define TOP_K_PER_ROW_PYBIND \ + m.def("top_k_per_row_prefill", \ + &top_k_per_row_prefill, \ + py::arg("logits"), \ + py::arg("rowStarts"), \ + py::arg("rowEnds"), \ + py::arg("indices"), \ + py::arg("values"), \ + py::arg("numRows"), \ + py::arg("stride0"), \ + py::arg("stride1")); \ + m.def("top_k_per_row_decode", \ + &top_k_per_row_decode, \ + py::arg("logits"), \ + py::arg("next_n"), \ + py::arg("seqLens"), \ + py::arg("indices"), \ + py::arg("numRows"), \ + py::arg("stride0"), \ + py::arg("stride1")); \ + m.def("top_k_per_row_decode_fast", \ + &top_k_per_row_decode_fast, \ + py::arg("logits"), \ + py::arg("next_n"), \ + py::arg("seqLens"), \ + py::arg("indices"), \ + py::arg("numRows"), \ + py::arg("stride0"), \ py::arg("stride1")); #define MLA_METADATA_PYBIND \ @@ -1542,10 +1639,15 @@ namespace py = pybind11; py::arg("final_output"), \ py::arg("final_lse") = std::nullopt); -#define TOPK_PLAIN_PYBIND \ - m.def("topk_plain", \ - &topk_plain, \ - py::arg("values"), \ - py::arg("topk_ids"), \ - py::arg("topk"), \ - py::arg("largest")); +#define TOPK_PLAIN_PYBIND \ + m.def("topk_plain", \ + &topk_plain, \ + py::arg("values"), \ + py::arg("topk_ids"), \ + py::arg("topk_out"), \ + py::arg("topk"), \ + py::arg("largest") = true, \ + py::arg("rowStarts") = torch::Tensor(), \ + py::arg("rowEnds") = torch::Tensor(), \ + py::arg("stride0") = -1, \ + py::arg("stride1") = 1); diff --git a/csrc/include/topk_per_row.h b/csrc/include/topk_per_row.h index e3bae1887d..dcfcfa565e 100644 --- a/csrc/include/topk_per_row.h +++ b/csrc/include/topk_per_row.h @@ -18,3 +18,11 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t numRows, int64_t stride0, int64_t stride1); + +void top_k_per_row_decode_fast(const torch::Tensor& logits, + int64_t next_n, + const torch::Tensor& seqLens, + torch::Tensor& indices, + int64_t numRows, + int64_t stride0, + int64_t stride1); diff --git a/csrc/include/topk_plain.h b/csrc/include/topk_plain.h index 5a658e491d..087c157196 100644 --- a/csrc/include/topk_plain.h +++ b/csrc/include/topk_plain.h @@ -6,5 +6,10 @@ void topk_plain(torch::Tensor& values, torch::Tensor& topk_ids, - int topk_num, - bool largest); + torch::Tensor& topk_out, + int topk, + bool largest = true, + torch::Tensor rowStarts = torch::Tensor(), + torch::Tensor rowEnds = torch::Tensor(), + int64_t stride0 = -1, + int64_t stride1 = 1); diff --git a/csrc/kernels/activation_kernels.cu b/csrc/kernels/activation_kernels.cu index 3a685ae1e9..f5dc3edcfa 100644 --- a/csrc/kernels/activation_kernels.cu +++ b/csrc/kernels/activation_kernels.cu @@ -21,18 +21,27 @@ using fp8_type = ck_tile::fp8_t; static constexpr int32_t max_vec_size = 8; static constexpr int32_t max_wave_num = 8; +// Type trait for computation type (all compute in native type) + namespace aiter { -// Activation and gating kernel template. -template -__global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d] +// Activation and gating kernel template with flexible input/output types. +// DTYPE_I: input type (fp32/bf16/fp16), DTYPE_O: output type (fp32/bf16/fp16) +// Computes in float, converts to DTYPE_O on output. +template +__global__ void act_and_mul_kernel(DTYPE_O* __restrict__ out, // [..., d] const DTYPE_I* __restrict__ input, // [..., 2, d] const int d) { + // CK Tile buffer addressing constraint: float supports VEC_SIZE <= 16 + static_assert(!(std::is_same_v && VEC_SIZE_I > 16), + "float type only supports VEC_SIZE up to 16"); + const int64_t token_idx = blockIdx.x; auto const* ptr_x = (input + token_idx * 2 * d); auto const* ptr_y = (input + token_idx * 2 * d + d); using vec_i = ck_tile::vec_t; + using vec_o = ck_tile::vec_t; static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); const int32_t oob_i = (d + ooba_i - 1) / ooba_i * ooba_i; auto buffer_x = ck_tile::make_buffer_view(ptr_x, oob_i); @@ -40,15 +49,17 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d buffer_x.init_raw(); buffer_y.init_raw(); - // Output buffer view for wide stores (raw path) - DTYPE_I* __restrict__ out_base = out + token_idx * d; + // Output buffer view (independent type from input) + DTYPE_O* __restrict__ out_base = out + token_idx * d; + static constexpr int32_t ooba_o = 4 / sizeof(DTYPE_O); + const int32_t oob_o = (d + ooba_o - 1) / ooba_o * ooba_o; auto buffer_out = - ck_tile::make_buffer_view(out_base, oob_i); + ck_tile::make_buffer_view(out_base, oob_o); buffer_out.init_raw(); - constexpr int32_t allowed_max = std::is_same::value ? 8 : 16; + constexpr int32_t allowed_max = std::is_same::value ? 8 : 16; - auto store_vec_segmented = [&](int64_t base_idx, const vec_i& v) __device__ { + auto store_vec_segmented = [&](int64_t base_idx, const vec_o& v) __device__ { int64_t off = base_idx; int32_t rem = VEC_SIZE_I; int32_t pos = 0; @@ -56,7 +67,7 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d { if(allowed_max >= 16 && rem >= 16) { - using vec16 = ck_tile::vec_t; + using vec16 = ck_tile::vec_t; vec16 t{}; #pragma unroll for(int i = 0; i < 16; ++i) @@ -68,7 +79,7 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d } else if(rem >= 8) { - using vec8 = ck_tile::vec_t; + using vec8 = ck_tile::vec_t; vec8 t{}; #pragma unroll for(int i = 0; i < 8; ++i) @@ -80,7 +91,7 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d } else if(rem >= 4) { - using vec4 = ck_tile::vec_t; + using vec4 = ck_tile::vec_t; vec4 t{}; #pragma unroll for(int i = 0; i < 4; ++i) @@ -92,7 +103,7 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d } else if(rem >= 2) { - using vec2 = ck_tile::vec_t; + using vec2 = ck_tile::vec_t; vec2 t{}; t[0] = v[pos + 0]; t[1] = v[pos + 1]; @@ -103,7 +114,7 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d } else { - using vec1 = ck_tile::vec_t; + using vec1 = ck_tile::vec_t; vec1 t{}; t[0] = v[pos]; buffer_out.template set(off, 0, true, t); @@ -116,40 +127,40 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d for(int64_t idx = threadIdx.x * VEC_SIZE_I; idx < d; idx += blockDim.x * VEC_SIZE_I) { - vec_i x{}; - vec_i y{}; + vec_i x = buffer_x.template get(idx, 0, true); + vec_i y = buffer_y.template get(idx, 0, true); - x = buffer_x.template get(idx, 0, true); - y = buffer_y.template get(idx, 0, true); - - vec_i r{}; + vec_o r{}; #pragma unroll for(size_t j = 0; j < VEC_SIZE_I; j += 2) { - float ax0 = ACT_FN(x[j]); - float y0 = ck_tile::type_convert(y[j]); + // Call ACT_FN with appropriate type conversion + DTYPE_I x_val0 = x[j]; + float ax0 = ACT_FN(x_val0); + float y0 = ck_tile::type_convert(y[j]); if(j + 1 < VEC_SIZE_I) { - float ax1 = ACT_FN(x[j + 1]); + DTYPE_I x_val1 = x[j + 1]; + float ax1 = ACT_FN(x_val1); float y1 = ck_tile::type_convert(y[j + 1]); ck_tile::fp32x2_t a = {ax0, ax1}; ck_tile::fp32x2_t b = {y0, y1}; ck_tile::fp32x2_t c; asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); - r[j] = ck_tile::type_convert(c.x); - r[j + 1] = ck_tile::type_convert(c.y); + r[j] = ck_tile::type_convert(c.x); + r[j + 1] = ck_tile::type_convert(c.y); } else { - r[j] = ck_tile::type_convert(ax0 * y0); + r[j] = ck_tile::type_convert(ax0 * y0); } } if constexpr(VEC_SIZE_I == 1 || VEC_SIZE_I == 2 || VEC_SIZE_I == 4 || VEC_SIZE_I == 8 || VEC_SIZE_I == 16) { - buffer_out.template set(idx, 0, true, r); + buffer_out.template set(idx, 0, true, r); } else { @@ -158,13 +169,18 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d } } -// Scaled activation and gating kernel template. -template -__global__ void scaled_act_and_mul_kernel(fp8_type* __restrict__ out, // [..., d] +// Scaled activation and gating kernel template with flexible output type. +// DTYPE_I: input type, DTYPE_O: output type (typically fp8 for quantization) +template +__global__ void scaled_act_and_mul_kernel(DTYPE_O* __restrict__ out, // [..., d] const DTYPE_I* __restrict__ input, // [..., 2, d] const int d, const float scale) { + // CK Tile buffer addressing constraint: float supports VEC_SIZE <= 16 + static_assert(!(std::is_same_v && VEC_SIZE_I > 16), + "float type only supports VEC_SIZE up to 16"); + const int64_t token_idx = blockIdx.x; auto const* ptr_x = (input + token_idx * 2 * d); auto const* ptr_y = (input + token_idx * 2 * d + d); @@ -179,17 +195,19 @@ __global__ void scaled_act_and_mul_kernel(fp8_type* __restrict__ out, // for(int64_t idx = threadIdx.x * VEC_SIZE_I; idx < d; idx += blockDim.x * VEC_SIZE_I) { - auto x = buffer_x.template get(idx, 0, true); - auto y = buffer_y.template get(idx, 0, true); + vec_i x = buffer_x.template get(idx, 0, true); + vec_i y = buffer_y.template get(idx, 0, true); for(size_t j = 0; j < VEC_SIZE_I; j += 2) { if(j + 1 < VEC_SIZE_I) { - float act_x0 = ACT_FN(x[j]); - float act_x1 = ACT_FN(x[j + 1]); - float y0 = ck_tile::type_convert(y[j]); - float y1 = ck_tile::type_convert(y[j + 1]); + DTYPE_I x_val0 = x[j]; + DTYPE_I x_val1 = x[j + 1]; + float act_x0 = ACT_FN(x_val0); + float act_x1 = ACT_FN(x_val1); + float y0 = ck_tile::type_convert(y[j]); + float y1 = ck_tile::type_convert(y[j + 1]); float2 act_vals = {act_x0, act_x1}; float2 y_vals = {y0, y1}; @@ -201,13 +219,14 @@ __global__ void scaled_act_and_mul_kernel(fp8_type* __restrict__ out, // : "=v"(result) : "v"(act_vals), "v"(y_vals), "v"(scale_vals)); - out[token_idx * d + idx + j] = ck_tile::type_convert(result.x); - out[token_idx * d + idx + j + 1] = ck_tile::type_convert(result.y); + out[token_idx * d + idx + j] = ck_tile::type_convert(result.x); + out[token_idx * d + idx + j + 1] = ck_tile::type_convert(result.y); } else { - float r = ACT_FN(x[j]) * ck_tile::type_convert(y[j]) * scale; - out[token_idx * d + idx + j] = ck_tile::type_convert(r); + DTYPE_I x_val = x[j]; + float r = ACT_FN(x_val) * ck_tile::type_convert(y[j]) * scale; + out[token_idx * d + idx + j] = ck_tile::type_convert(r); } } } @@ -257,53 +276,142 @@ static constexpr int nextPow2(unsigned int num) return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } -// Launch activation and gating kernel. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - int vec_size = nextPow2(d / 64); \ - vec_size = vec_size < 2 ? 2 : vec_size; \ - vec_size = vec_size > max_vec_size ? max_vec_size : vec_size; \ - int num_wave = nextPow2(d / 64 / vec_size); \ - num_wave = num_wave > max_wave_num ? max_wave_num : num_wave; \ - dim3 grid(num_tokens); \ - dim3 block(num_wave * 64); \ - const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); \ - const hipStream_t stream = at::hip::getCurrentHIPStream(); \ - AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "act_and_mul_kernel", [&] { \ - using input_dtype = typename t2ck::type; \ - AITER_DISPATCH_CASE_VEC_SIZE( \ - vec_size, \ - aiter::act_and_mul_kernel, VEC_SIZE> \ - <<>>(reinterpret_cast(out.data_ptr()), \ - reinterpret_cast(input.data_ptr()), \ - d);) \ - }); -#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - int vec_size = nextPow2(d / 64); \ - vec_size = vec_size < 2 ? 2 : vec_size; \ - vec_size = vec_size > max_vec_size ? max_vec_size : vec_size; \ - int num_wave = nextPow2(d / 64 / vec_size); \ - num_wave = num_wave > max_wave_num ? max_wave_num : num_wave; \ - dim3 grid(num_tokens); \ - dim3 block(num_wave * 64); \ - const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); \ - const hipStream_t stream = at::hip::getCurrentHIPStream(); \ - AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ - using input_dtype = typename t2ck::type; \ - AITER_DISPATCH_CASE_VEC_SIZE( \ - vec_size, \ - aiter::scaled_act_and_mul_kernel, VEC_SIZE> \ - <<>>(reinterpret_cast(out.data_ptr()), \ - reinterpret_cast(input.data_ptr()), \ - d, \ - 1.0f / (*scale.data_ptr()));) \ - }); +// Common kernel launch parameters computation +#define COMPUTE_ACTIVATION_KERNEL_PARAMS \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + int vec_size = nextPow2(d / 64); \ + vec_size = vec_size < 2 ? 2 : vec_size; \ + vec_size = vec_size > max_vec_size ? max_vec_size : vec_size; \ + int num_wave = nextPow2(d / 64 / vec_size); \ + num_wave = num_wave > max_wave_num ? max_wave_num : num_wave; \ + dim3 grid(num_tokens); \ + dim3 block(num_wave * 64); \ + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); \ + const hipStream_t stream = at::hip::getCurrentHIPStream(); + +// Helper macro for fp32 vec_size dispatch (CK Tile only supports VEC_SIZE <= 16 for fp32) +#define DISPATCH_FP32_VEC_SIZE_CASE(VS, KERNEL_NAME, KERNEL, ...) \ + case VS: \ + aiter::KERNEL_NAME, VS> \ + <<>>(__VA_ARGS__); \ + break; + +#define DISPATCH_FP32_KERNEL(KERNEL_NAME, KERNEL, ...) \ + switch(vec_size) \ + { \ + DISPATCH_FP32_VEC_SIZE_CASE(16, KERNEL_NAME, KERNEL, __VA_ARGS__) \ + DISPATCH_FP32_VEC_SIZE_CASE(8, KERNEL_NAME, KERNEL, __VA_ARGS__) \ + DISPATCH_FP32_VEC_SIZE_CASE(4, KERNEL_NAME, KERNEL, __VA_ARGS__) \ + DISPATCH_FP32_VEC_SIZE_CASE(2, KERNEL_NAME, KERNEL, __VA_ARGS__) \ + DISPATCH_FP32_VEC_SIZE_CASE(1, KERNEL_NAME, KERNEL, __VA_ARGS__) \ + } + +#define DISPATCH_FP32_ACT_KERNEL(KERNEL, out_ptr, in_ptr) \ + DISPATCH_FP32_KERNEL(act_and_mul_kernel, KERNEL, out_ptr, in_ptr, d) + +#define DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ + DISPATCH_FP32_KERNEL(scaled_act_and_mul_kernel, KERNEL, out_ptr, in_ptr, d, inv_scale) + +// Helper macro to dispatch scaled kernel with restricted output types (fp8 or int8) +#define DISPATCH_OUTPUT_TYPE_SCALED(KERNEL, in_ptr, inv_scale) \ + if(out.scalar_type() == at::ScalarType::Float8_e4m3fn || \ + out.scalar_type() == at::ScalarType::Float8_e4m3fnuz || \ + out.scalar_type() == at::ScalarType::Float8_e5m2) \ + { \ + using output_dtype = fp8_type; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ + } \ + else if(out.scalar_type() == at::ScalarType::Char) \ + { \ + using output_dtype = ck_tile::int8_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ + } \ + else \ + { \ + TORCH_CHECK(false, "scaled_act_and_mul only supports fp8 or int8 outputs"); \ + } + +// Launch activation and gating kernel with flexible input/output types +// Input and output types are determined by the tensor dtypes passed from Python +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + COMPUTE_ACTIVATION_KERNEL_PARAMS \ + if(input.scalar_type() == at::ScalarType::Float) \ + { \ + /* fp32 input: dispatch based on output type */ \ + using input_dtype = ck_tile::fp32_t; \ + auto* in_ptr = reinterpret_cast(input.data_ptr()); \ + if(out.scalar_type() == at::ScalarType::BFloat16) \ + { \ + using output_dtype = ck_tile::bf16_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_ACT_KERNEL(KERNEL, out_ptr, in_ptr) \ + } \ + else if(out.scalar_type() == at::ScalarType::Half) \ + { \ + using output_dtype = ck_tile::fp16_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_ACT_KERNEL(KERNEL, out_ptr, in_ptr) \ + } \ + else if(out.scalar_type() == at::ScalarType::Float) \ + { \ + using output_dtype = ck_tile::fp32_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_ACT_KERNEL(KERNEL, out_ptr, in_ptr) \ + } \ + else \ + { \ + TORCH_CHECK(false, "Unsupported output type for fp32 input"); \ + } \ + } \ + else \ + { \ + /* bf16/fp16 input: output must match input type */ \ + TORCH_CHECK(input.scalar_type() == out.scalar_type(), \ + "For bf16/fp16 input, output type must match input type"); \ + AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "act_and_mul_kernel", [&] { \ + using input_dtype = typename t2ck::type; \ + using output_dtype = input_dtype; \ + AITER_DISPATCH_CASE_VEC_SIZE( \ + vec_size, \ + aiter:: \ + act_and_mul_kernel, VEC_SIZE> \ + <<>>(reinterpret_cast(out.data_ptr()), \ + reinterpret_cast(input.data_ptr()), \ + d);) \ + }); \ + } + +// Launch scaled activation and gating kernel with flexible input/output types +#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ + COMPUTE_ACTIVATION_KERNEL_PARAMS \ + if(input.scalar_type() == at::ScalarType::Float) \ + { \ + /* fp32 input: dispatch based on output type (fp8/bf16/fp16/fp32) */ \ + using input_dtype = ck_tile::fp32_t; \ + auto* in_ptr = reinterpret_cast(input.data_ptr()); \ + float inv_scale = 1.0f / (*scale.data_ptr()); \ + DISPATCH_OUTPUT_TYPE_SCALED(KERNEL, in_ptr, inv_scale) \ + } \ + else \ + { \ + /* bf16/fp16 input: dispatch based on output type (fp8/bf16/fp16/fp32) */ \ + AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ + using input_dtype = typename t2ck::type; \ + auto* in_ptr = reinterpret_cast(input.data_ptr()); \ + float inv_scale = 1.0f / (*scale.data_ptr()); \ + DISPATCH_OUTPUT_TYPE_SCALED(KERNEL, in_ptr, inv_scale) \ + }); \ + } namespace aiter { +// Flexible type conversion: +// - fp32 input can output as fp32/bf16/fp16 (determined by out.dtype) +// - bf16 input must output as bf16 +// - fp16 input must output as fp16 void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { @@ -392,4 +500,4 @@ void gelu_fast(torch::Tensor& out, // [..., d] LAUNCH_ACTIVATION_KERNEL(aiter::gelu_fast_kernel); } -} // namespace aiter \ No newline at end of file +} // namespace aiter diff --git a/csrc/kernels/mla/metadata/v1_0_device.cuh b/csrc/kernels/mla/metadata/v1_0_device.cuh index a3a9fe2e6f..1c8f1e2f9b 100644 --- a/csrc/kernels/mla/metadata/v1_0_device.cuh +++ b/csrc/kernels/mla/metadata/v1_0_device.cuh @@ -9,7 +9,7 @@ __device__ int32_t get_local_splits(int32_t seqlen_kv, int32_t num_splits, int32_t num_splits_per_cu) { -#if defined(__gfx942__) +#if defined(__gfx942__) return 16; #else int32_t ex_splits = seqlen_kv / 196; // magic num 196. Experiments shows 196 per splits can get better performance. @@ -47,20 +47,19 @@ void kn_get_mla_metadata_v1_0(MlaMetadataV1KernelParameter params) const int32_t bid_ori = bid / params.qk_batch_ratio; const int32_t kv_begin = params.p_seqlens_kv_indptr[bid_ori]; - const int32_t kv_end = params.p_seqlens_kv_indptr[bid_ori + 1]; - int32_t kv_tail = [&](){ if constexpr(DP_MODE) { - // max(*, 0) for cuda graph capture: kvlen < mtp+1 - return max(bid % params.ori_seqlen_qo - params.ori_seqlen_qo + 1, 0); + return bid % params.ori_seqlen_qo - params.ori_seqlen_qo + 1; } else { return 0; } }(); - const int32_t seqlen_kv = kv_end - kv_begin + kv_tail; + const int32_t kv_end = max(params.p_seqlens_kv_indptr[bid_ori + 1] + kv_tail, kv_begin + 1); + + const int32_t seqlen_kv = kv_end - kv_begin; const int32_t num_blocks = integer_divide_ceil_power2( seqlen_kv, params.kv_granularity, params.kv_granularity_log2); @@ -98,19 +97,17 @@ void kn_get_mla_metadata_v1_0(MlaMetadataV1KernelParameter params) const int32_t bid_ori = bid / params.qk_batch_ratio; const int32_t kv_begin = p_lds_kv_seqlen[bid_ori]; - int32_t kv_end = p_lds_kv_seqlen[bid_ori + 1]; int32_t kv_tail = [&](){ if constexpr(DP_MODE) { - // max(*, 0) for cuda graph capture: kvlen < mtp+1 - return max(bid % params.ori_seqlen_qo - params.ori_seqlen_qo + 1, 0); + return bid % params.ori_seqlen_qo - params.ori_seqlen_qo + 1; } else { return 0; } }(); - kv_end += kv_tail; + const int32_t kv_end = max(p_lds_kv_seqlen[bid_ori + 1] + kv_tail, kv_begin + 1); MlaWorkInfo work_info{}; const int32_t split_start = p_lds_shift[bid]; const int32_t split_local = p_lds_split[bid]; diff --git a/csrc/kernels/mla/metadata/v1_1_device.cuh b/csrc/kernels/mla/metadata/v1_1_device.cuh index ef448bd21d..743a99eb79 100644 --- a/csrc/kernels/mla/metadata/v1_1_device.cuh +++ b/csrc/kernels/mla/metadata/v1_1_device.cuh @@ -8,32 +8,30 @@ #define PRINT_DBG 0 -CK_TILE_DEVICE auto get_cost_top( - const int32_t* p_cost_heap, - const int32_t num_clusters) +CK_TILE_DEVICE auto get_cost_top(const int32_t* p_cost_heap, const int32_t num_clusters) { - int32_t cid_min = -1; + int32_t cid_min = -1; int32_t cost_min = 0x7fffffff; // Get local top - for (int32_t cid = ck_tile::get_lane_id(); cid < num_clusters; cid += ck_tile::get_warp_size()) + for(int32_t cid = ck_tile::get_lane_id(); cid < num_clusters; cid += ck_tile::get_warp_size()) { const int32_t cost = p_cost_heap[cid]; - if (cost < cost_min) + if(cost < cost_min) { cost_min = cost; - cid_min = cid; + cid_min = cid; } } - // Get global top - #pragma unroll - for (int32_t offset = (ck_tile::get_warp_size() >> 1); offset > 0; offset >>= 1) +// Get global top +#pragma unroll + for(int32_t offset = (ck_tile::get_warp_size() >> 1); offset > 0; offset >>= 1) { const int32_t srd_lane = (offset ^ ck_tile::get_warp_size()) ^ ck_tile::get_lane_id(); - const int32_t cid_remote = ck_tile::warp_shuffle(cid_min, srd_lane); + const int32_t cid_remote = ck_tile::warp_shuffle(cid_min, srd_lane); const int32_t cost_remote = ck_tile::warp_shuffle(cost_min, srd_lane); - if ((cost_remote < cost_min) || ((cost_remote == cost_min) && (cid_remote < cid_min))) + if((cost_remote < cost_min) || ((cost_remote == cost_min) && (cid_remote < cid_min))) { cost_min = cost_remote; cid_min = cid_remote; @@ -43,23 +41,23 @@ CK_TILE_DEVICE auto get_cost_top( return std::make_tuple(cid_min, cost_min); } -template +template struct MlaMetadataV11Traits { - static constexpr int32_t kPackedQoLenPerWg = kPackedQoLenPerWg_; - static constexpr int32_t kPackedQoLenPerWg_log2 = __builtin_ctz(kPackedQoLenPerWg); - static constexpr int32_t kMaxClusterSize = kMaxClusterSize_; - static constexpr int32_t kSplitTolerance = 16; - static constexpr bool kQoSplits = kQoSplits_; + static constexpr int32_t kPackedQoLenPerWg = kPackedQoLenPerWg_; + static constexpr int32_t kPackedQoLenPerWg_log2 = __builtin_ctz(kPackedQoLenPerWg); + static constexpr int32_t kMaxClusterSize = kMaxClusterSize_; + static constexpr int32_t kSplitTolerance = 16; + static constexpr bool kQoSplits = kQoSplits_; // <= -1: read from seqlens_qo_indptr // == 0: read from MlaMetadataV1KernelParameter::uni_seqlen_QO // >= 1: read from MlaMetadataV11Traits::kUniSeqlenQo - static constexpr int32_t kUniSeqlenQo = kUniSeqlenQo_; - static constexpr int32_t kIsSparse = kIsSparse_; + static constexpr int32_t kUniSeqlenQo = kUniSeqlenQo_; + static constexpr int32_t kIsSparse = kIsSparse_; static constexpr bool kSortBatch = true; }; @@ -72,130 +70,145 @@ struct MlaMetadataV11Coefficients }; // This version just follows Flashinfer -CK_TILE_HOST_DEVICE int32_t cal_workload_limit_global_v0( - const int32_t cum_workload, - const int32_t num_clusters, - const int32_t kv_granularity) +CK_TILE_HOST_DEVICE int32_t cal_workload_limit_global_v0(const int32_t cum_workload, + const int32_t num_clusters, + const int32_t kv_granularity) { int32_t limit; - const int32_t avg_workload = ck_tile::max(ck_tile::integer_divide_ceil(cum_workload, num_clusters), 1); - if (avg_workload <= 8) limit = 32; - else if (avg_workload <= 16) limit = 64; - else if (avg_workload <= 32) limit = 128; - else if (avg_workload <= 64) limit = 192; - else limit = avg_workload; + const int32_t avg_workload = + ck_tile::max(ck_tile::integer_divide_ceil(cum_workload, num_clusters), 1); + if(avg_workload <= 8) + limit = 32; + else if(avg_workload <= 16) + limit = 64; + else if(avg_workload <= 32) + limit = 128; + else if(avg_workload <= 64) + limit = 192; + else + limit = avg_workload; return ck_tile::integer_least_multiple(limit, kv_granularity); } -CK_TILE_HOST_DEVICE int32_t cal_workload_limit_global_v1( - const MlaMetadataV11Coefficients& coefs, - const int32_t num_batches, - const int32_t cum_workload, - const int32_t num_clusters, - const int32_t packed_seqlen_qo, - const int32_t kv_granularity) +CK_TILE_HOST_DEVICE int32_t cal_workload_limit_global_v1(const MlaMetadataV11Coefficients& coefs, + const int32_t num_batches, + const int32_t cum_workload, + const int32_t num_clusters, + const int32_t packed_seqlen_qo, + const int32_t kv_granularity) { - const int32_t split_overhead = 2 * cal_cost(packed_seqlen_qo, 1) - cal_cost(packed_seqlen_qo, 2); + const int32_t split_overhead = + 2 * cal_cost(packed_seqlen_qo, 1) - cal_cost(packed_seqlen_qo, 2); const int32_t fixed_split_overhead = split_overhead * num_batches; int32_t limit; - const int32_t avg_workload = - ck_tile::max(ck_tile::integer_divide_ceil(cum_workload - fixed_split_overhead, num_clusters), 1); - if (avg_workload <= 8) limit = 32; - else if (avg_workload <= 16) limit = 64; - else if (avg_workload <= 32) limit = 128; - else if (avg_workload <= 64) limit = 192; - else limit = avg_workload; - - const float split_amplifier = - num_batches * coefs.workload_limit_global_0 + - avg_workload * coefs.workload_limit_global_1 + - coefs.workload_limit_global_2; + const int32_t avg_workload = ck_tile::max( + ck_tile::integer_divide_ceil(cum_workload - fixed_split_overhead, num_clusters), 1); + if(avg_workload <= 8) + limit = 32; + else if(avg_workload <= 16) + limit = 64; + else if(avg_workload <= 32) + limit = 128; + else if(avg_workload <= 64) + limit = 192; + else + limit = avg_workload; + + const float split_amplifier = num_batches * coefs.workload_limit_global_0 + + avg_workload * coefs.workload_limit_global_1 + + coefs.workload_limit_global_2; return ck_tile::integer_least_multiple( int32_t(cal_cost(packed_seqlen_qo, limit) + split_overhead * split_amplifier), kv_granularity); } template -CK_TILE_DEVICE void generate_work( - const int32_t batch_idx, - const int32_t tile_idx, - const int32_t qo_len, - const int32_t kv_len, - const int32_t qo_tile_len, - const int32_t packed_qo_tile_len, - const int32_t qo_batch_start, - const int32_t kv_batch_start, - const int32_t kv_batch_end, - const int32_t workload_limit_global, - const int32_t num_clusters, - const int32_t kv_granularity, - const int32_t* p_work_indptr, - const int32_t* p_lds_num_qo_clusters_indptr, - int32_t* p_loc_partial_outputs, - int32_t* p_num_partial_outputs, - MlaWorkInfo* p_work_info_set, - MlaPartialTileInfo* p_reduce_final_map, - MlaPartialTileInfo* p_reduce_partial_map, - int32_t* p_cost_heap, - int32_t* p_cluster_work_counter) +CK_TILE_DEVICE void generate_work(const int32_t batch_idx, + const int32_t tile_idx, + const int32_t qo_len, + const int32_t kv_len, + const int32_t qo_tile_len, + const int32_t packed_qo_tile_len, + const int32_t qo_batch_start, + const int32_t kv_batch_start, + const int32_t kv_batch_end, + const int32_t workload_limit_global, + const int32_t num_clusters, + const int32_t kv_granularity, + const int32_t* p_work_indptr, + const int32_t* p_lds_num_qo_clusters_indptr, + int32_t* p_loc_partial_outputs, + int32_t* p_num_partial_outputs, + MlaWorkInfo* p_work_info_set, + MlaPartialTileInfo* p_reduce_final_map, + MlaPartialTileInfo* p_reduce_partial_map, + int32_t* p_cost_heap, + int32_t* p_cluster_work_counter) { int32_t remaining_kv_len = kv_len; - int32_t kv_start_local = 0; + int32_t kv_start_local = 0; - const int32_t kv_len_limit_floor = - ck_tile::integer_least_multiple(ck_tile::integer_divide_ceil(kv_len, num_clusters), kv_granularity); - const auto [cid_top, accum_cost_top] = get_cost_top(p_cost_heap, num_clusters); - const int32_t remaining_capability_top = - ck_tile::max(cal_kv_len(workload_limit_global - accum_cost_top, packed_qo_tile_len), kv_len_limit_floor); + const int32_t kv_len_limit_floor = ck_tile::integer_least_multiple( + ck_tile::integer_divide_ceil(kv_len, num_clusters), kv_granularity); + const auto [cid_top, accum_cost_top] = get_cost_top(p_cost_heap, num_clusters); + const int32_t remaining_capability_top = ck_tile::max( + cal_kv_len(workload_limit_global - accum_cost_top, packed_qo_tile_len), kv_len_limit_floor); const int32_t num_splits_estimated = ck_tile::integer_divide_ceil(remaining_kv_len, remaining_capability_top); - // For the case of #splits==2, make sure that the tailing tile is smaller than Traits::kSplitTolerance. - const bool split_kv = (num_splits_estimated == 2) ? - ((remaining_kv_len - remaining_capability_top) > Traits::kSplitTolerance) : - (num_splits_estimated > 1); + // For the case of #splits==2, make sure that the tailing tile is smaller than + // Traits::kSplitTolerance. + const bool split_kv = + (num_splits_estimated == 2) + ? ((remaining_kv_len - remaining_capability_top) > Traits::kSplitTolerance) + : (num_splits_estimated > 1); do { // Check and update cost_heap auto [cid, accum_cost] = get_cost_top(p_cost_heap, num_clusters); - const int32_t remaining_capability = cal_kv_len(workload_limit_global - accum_cost, packed_qo_tile_len); - const int32_t kv_len_limit_local = - [&]() { + const int32_t remaining_capability = + cal_kv_len(workload_limit_global - accum_cost, packed_qo_tile_len); + const int32_t kv_len_limit_local = [&]() { const int32_t limit_ori = ck_tile::max(remaining_capability, kv_len_limit_floor); - const int32_t tail_size = (remaining_kv_len > limit_ori) ? (remaining_kv_len - limit_ori) : 0x7fffffff; - const int32_t limit_fin = (tail_size <= Traits::kSplitTolerance) ? remaining_kv_len : limit_ori; + const int32_t tail_size = + (remaining_kv_len > limit_ori) ? (remaining_kv_len - limit_ori) : 0x7fffffff; + const int32_t limit_fin = + (tail_size <= Traits::kSplitTolerance) ? remaining_kv_len : limit_ori; return limit_fin; }(); const int32_t kv_len_consuming = ck_tile::min(remaining_kv_len, kv_len_limit_local); - if (ck_tile::get_lane_id() == 0) + if(ck_tile::get_lane_id() == 0) { - const int32_t cost = cal_cost(packed_qo_tile_len, kv_len_consuming); + const int32_t cost = cal_cost(packed_qo_tile_len, kv_len_consuming); const int32_t new_cost = accum_cost + cost; - p_cost_heap[cid] = new_cost; + p_cost_heap[cid] = new_cost; - if constexpr (kOnlyGatherWorkCount == false) + if constexpr(kOnlyGatherWorkCount == false) { // Record work MlaWorkInfo work_info{}; work_info.batch_idx = batch_idx; work_info.qo_start = tile_idx * qo_tile_len + qo_batch_start; - work_info.qo_end = ck_tile::min(work_info.qo_start + qo_tile_len, qo_batch_start + qo_len); + work_info.qo_end = + ck_tile::min(work_info.qo_start + qo_tile_len, qo_batch_start + qo_len); work_info.kv_start = kv_start_local + kv_batch_start; work_info.kv_end = work_info.kv_start + kv_len_consuming; work_info.kv_offset = kv_batch_end - work_info.kv_end; - if (split_kv) + if(split_kv) { - const int32_t global_cluster_q_idx = p_lds_num_qo_clusters_indptr[batch_idx] + tile_idx; + const int32_t global_cluster_q_idx = + p_lds_num_qo_clusters_indptr[batch_idx] + tile_idx; work_info.partial_qo_loc = *p_loc_partial_outputs; - if (p_reduce_partial_map[global_cluster_q_idx].q_start == -1) + if(p_reduce_partial_map[global_cluster_q_idx].q_start == -1) { p_reduce_partial_map[global_cluster_q_idx].q_start = *p_loc_partial_outputs; - p_reduce_final_map[global_cluster_q_idx] = {{ work_info.qo_start, work_info.qo_end }}; + p_reduce_final_map[global_cluster_q_idx] = { + {work_info.qo_start, work_info.qo_end}}; } ++(*p_num_partial_outputs); *p_loc_partial_outputs += (work_info.qo_end - work_info.qo_start); @@ -210,8 +223,14 @@ CK_TILE_DEVICE void generate_work( p_work_info_set[work_info_set_idx] = work_info; #if PRINT_DBG - printf("[metadata] - cost heap updated: work_loc=%d, cid=%d, pre_cost=%d, new_cost=%d, tot_cost=%d, kv_len_cons=%d\n", - work_info_set_idx, cid, accum_cost, cost, accum_cost+cost, kv_len_consuming); + printf("[metadata] - cost heap updated: work_loc=%d, cid=%d, pre_cost=%d, " + "new_cost=%d, tot_cost=%d, kv_len_cons=%d\n", + work_info_set_idx, + cid, + accum_cost, + cost, + accum_cost + cost, + kv_len_consuming); #endif } @@ -221,15 +240,13 @@ CK_TILE_DEVICE void generate_work( // Update state remaining_kv_len -= kv_len_consuming; kv_start_local += kv_len_consuming; - } - while (remaining_kv_len > 0); + } while(remaining_kv_len > 0); } template -__launch_bounds__(ck_tile::get_warp_size(), 1) -__global__ void kn_get_mla_metadata_v1_1( - const MlaMetadataV1KernelParameter params, - const MlaMetadataV11Coefficients coefs) +__launch_bounds__(ck_tile::get_warp_size(), 1) __global__ + void kn_get_mla_metadata_v1_1(const MlaMetadataV1KernelParameter params, + const MlaMetadataV11Coefficients coefs) { extern __shared__ uint8_t p_smem[]; @@ -237,27 +254,33 @@ __global__ void kn_get_mla_metadata_v1_1( // Step.0. Get sequence lengths of query/output and key/value for each batch. int32_t* p_lds_batch_idx = reinterpret_cast(p_smem); - int32_t* p_lds_qo_lens = Traits::kSortBatch ? (p_lds_batch_idx + params.num_batches) : p_lds_batch_idx; - int32_t* p_lds_kv_lens = p_lds_qo_lens + params.num_batches; - for (int32_t bid = lane_idx; bid < params.num_batches; bid += ck_tile::get_warp_size()) + int32_t* p_lds_qo_lens = + Traits::kSortBatch ? (p_lds_batch_idx + params.num_batches) : p_lds_batch_idx; + int32_t* p_lds_kv_lens = p_lds_qo_lens + params.num_batches; + for(int32_t bid = lane_idx; bid < params.num_batches; bid += ck_tile::get_warp_size()) { - const int32_t bid_ori = Traits::kIsSparse ? (bid / params.ori_seqlen_qo / params.qk_batch_ratio) - : (bid / params.qk_batch_ratio); - if constexpr (Traits::kSortBatch) + const int32_t bid_ori = Traits::kIsSparse + ? (bid / params.ori_seqlen_qo / params.qk_batch_ratio) + : (bid / params.qk_batch_ratio); + if constexpr(Traits::kSortBatch) { p_lds_batch_idx[bid] = bid; } - const int32_t raw_seqlen_kv = params.p_seqlens_kv_indptr[bid_ori + 1] - params.p_seqlens_kv_indptr[bid_ori]; - p_lds_kv_lens[bid] = Traits::kIsSparse ? ck_tile::min(raw_seqlen_kv, params.topk) : raw_seqlen_kv; - p_lds_qo_lens[bid] = params.p_seqlens_qo_indptr[bid_ori + 1] - params.p_seqlens_qo_indptr[bid_ori]; + const int32_t raw_seqlen_kv = + params.p_seqlens_kv_indptr[bid_ori + 1] - params.p_seqlens_kv_indptr[bid_ori]; + p_lds_kv_lens[bid] = + Traits::kIsSparse ? ck_tile::min(raw_seqlen_kv, params.topk) : raw_seqlen_kv; + p_lds_qo_lens[bid] = + params.p_seqlens_qo_indptr[bid_ori + 1] - params.p_seqlens_qo_indptr[bid_ori]; } - QoState qo_state(params.uni_seqlen_qo, params.ori_seqlen_qo, p_lds_qo_lens, params.p_seqlens_qo_indptr); + QoState qo_state( + params.uni_seqlen_qo, params.ori_seqlen_qo, p_lds_qo_lens, params.p_seqlens_qo_indptr); - // Step.1. Calculate the size of cluster and some related information. The size is the number of workgroups + // Step.1. Calculate the size of cluster and some related information. The size is the number of + // workgroups // composing each cluster. The size is determined by average packed qo length. - const int32_t sum_qo_len = warp_sum(p_lds_qo_lens, params.num_batches); - const int32_t cluster_size = - [&]() { + const int32_t sum_qo_len = warp_sum(p_lds_qo_lens, params.num_batches); + const int32_t cluster_size = [&]() { const int32_t avg_qo_len = sum_qo_len / params.num_batches; const int32_t cluster_size = ck_tile::integer_divide_ceil(avg_qo_len, Traits::kPackedQoLenPerWg); @@ -268,94 +291,102 @@ __global__ void kn_get_mla_metadata_v1_1( const int32_t cluster_len_q = cluster_size * Traits::kPackedQoLenPerWg; // Step.2. - // a. Get total valid (after causal masking) kv lengths and the maximun workload handled by each cluster - // b. Get a indptr array about #cluster for each batch in direction of qo. + // a. Get total valid (after causal masking) kv lengths and the maximun workload handled by + // each cluster b. Get a indptr array about #cluster for each batch in direction of qo. int32_t* p_lds_num_qo_clusters_indptr = p_lds_kv_lens + params.num_batches; - if (lane_idx == 0) + if(lane_idx == 0) { p_lds_num_qo_clusters_indptr[0] = 0; } - int32_t scan_base = 0; - int32_t workload_sum = 0; - const int32_t num_loop_batch = - integer_divide_ceil_power2(params.num_batches, - ck_tile::get_warp_size(), - __builtin_ctz(ck_tile::get_warp_size())); + int32_t scan_base = 0; + int32_t workload_sum = 0; + const int32_t num_loop_batch = integer_divide_ceil_power2( + params.num_batches, ck_tile::get_warp_size(), __builtin_ctz(ck_tile::get_warp_size())); // lds pointed by p_lds_qo_tiles will be reused by p_lds_sort_workspace later - int32_t* p_lds_qo_tiles = p_lds_num_qo_clusters_indptr + params.num_batches + 1; - for (int32_t loop_idx = 0; loop_idx < num_loop_batch; ++loop_idx) + int32_t* p_lds_qo_tiles = p_lds_num_qo_clusters_indptr + params.num_batches + 1; + for(int32_t loop_idx = 0; loop_idx < num_loop_batch; ++loop_idx) { - const int32_t bid = lane_idx + loop_idx * ck_tile::get_warp_size(); + const int32_t bid = lane_idx + loop_idx * ck_tile::get_warp_size(); int32_t num_qo_tiles = 0; - int32_t workload = 0; + int32_t workload = 0; - if (bid < params.num_batches) + if(bid < params.num_batches) { - const int32_t kv_len = p_lds_kv_lens[bid]; - const int32_t qo_len = qo_state.get_seqlen(bid); + const int32_t kv_len = p_lds_kv_lens[bid]; + const int32_t qo_len = qo_state.get_seqlen(bid); const int32_t packed_qo_len = qo_len * params.num_heads; - num_qo_tiles = ck_tile::integer_divide_ceil(packed_qo_len, cluster_len_q); + num_qo_tiles = ck_tile::integer_divide_ceil(packed_qo_len, cluster_len_q); p_lds_qo_tiles[bid] = num_qo_tiles; const int32_t packed_qo_tile_len = ck_tile::min(packed_qo_len, cluster_len_q); - for (int32_t tid = 0; tid < num_qo_tiles; ++tid) + for(int32_t tid = 0; tid < num_qo_tiles; ++tid) { - const int32_t kv_len_valid = - cal_packed_causal_kv_len( - qo_len, kv_len, tid, packed_qo_tile_len, num_qo_tiles, params.num_heads, params.is_causal); + const int32_t kv_len_valid = cal_packed_causal_kv_len(qo_len, + kv_len, + tid, + packed_qo_tile_len, + num_qo_tiles, + params.num_heads, + params.is_causal); workload += cal_cost(packed_qo_tile_len, kv_len_valid); } } const int32_t prefix_sum_qo_tiles = warp_prefix_sum(num_qo_tiles, ck_tile::get_warp_size()); const int32_t global_sum_qo_tiles = prefix_sum_qo_tiles + scan_base; - if (bid < params.num_batches) + if(bid < params.num_batches) { p_lds_num_qo_clusters_indptr[bid + 1] = global_sum_qo_tiles; } scan_base = ck_tile::warp_shuffle(global_sum_qo_tiles, ck_tile::get_warp_size() - 1); - workload_sum += aiter::warpReduce(workload); + workload_sum += + aiter::warpReduce( + workload); } const int32_t num_qo_tiles = scan_base; const int32_t tot_qo_tiles = warp_sum(p_lds_qo_tiles, params.num_batches); const int32_t workload_limit_global = - cal_workload_limit_global_v1( - coefs, - params.num_batches, - workload_sum, - num_clusters, - qo_state.is_unique() ? qo_state.get_seqlen(0) : cluster_len_q, - params.kv_granularity); + cal_workload_limit_global_v1(coefs, + params.num_batches, + workload_sum, + num_clusters, + qo_state.is_unique() ? qo_state.get_seqlen(0) : cluster_len_q, + params.kv_granularity); #if PRINT_DBG - if (lane_idx == 0) + if(lane_idx == 0) { printf("[metadata] workload_limit_global=%d\n", workload_limit_global); } #endif // Step.3. Sort batch idx based on cost. High cost batch first. - if constexpr (Traits::kSortBatch) + if constexpr(Traits::kSortBatch) { - int32_t *p_lds_sort_workspace = p_lds_num_qo_clusters_indptr + params.num_batches + 1; // will be reused later. - warp_sort(p_lds_batch_idx, p_lds_sort_workspace, p_lds_qo_lens, p_lds_kv_lens, params.num_batches); + int32_t* p_lds_sort_workspace = + p_lds_num_qo_clusters_indptr + params.num_batches + 1; // will be reused later. + warp_sort(p_lds_batch_idx, + p_lds_sort_workspace, + p_lds_qo_lens, + p_lds_kv_lens, + params.num_batches); } // Step.4.1. Initialize lds - int32_t* p_cost_heap = p_lds_qo_tiles; + int32_t* p_cost_heap = p_lds_qo_tiles; int32_t* p_cluster_work_counter = p_cost_heap + num_clusters + 1; - for (int32_t cid = lane_idx; cid < num_clusters; cid += ck_tile::get_warp_size()) + for(int32_t cid = lane_idx; cid < num_clusters; cid += ck_tile::get_warp_size()) { - p_cost_heap[cid] = 0; + p_cost_heap[cid] = 0; p_cluster_work_counter[cid] = 0; } // Step.5. Fill the output buffers except indptrs auto get_kv_batch_start = [&](const int32_t bid) { const int32_t bid_ori = bid / params.qk_batch_ratio; - if constexpr (Traits::kIsSparse) + if constexpr(Traits::kIsSparse) { return bid_ori * params.topk; } @@ -366,55 +397,77 @@ __global__ void kn_get_mla_metadata_v1_1( }; // Step.5.1. Get total work for each cluster - for (int32_t idx = 0; idx < params.num_batches; ++idx) + for(int32_t idx = 0; idx < params.num_batches; ++idx) { - const int32_t bid = Traits::kSortBatch ? p_lds_batch_idx[idx] : idx; - const int32_t bid_ori = bid / params.qk_batch_ratio; - const int32_t qo_len = qo_state.get_seqlen(bid); - const int32_t qo_batch_start = qo_state.get_begin(bid); - const int32_t kv_len = p_lds_kv_lens[bid]; - const int32_t kv_batch_start = Traits::kIsSparse ? bid_ori * params.topk - : params.p_seqlens_kv_indptr[bid_ori]; - const int32_t kv_batch_end = kv_batch_start + kv_len; - const int32_t packed_qo_len = qo_len * params.num_heads; - const int32_t num_qo_tiles = ck_tile::integer_divide_ceil(packed_qo_len, cluster_len_q); + const int32_t bid = Traits::kSortBatch ? p_lds_batch_idx[idx] : idx; + const int32_t bid_ori = bid / params.qk_batch_ratio; + const int32_t qo_len = qo_state.get_seqlen(bid); + const int32_t qo_batch_start = qo_state.get_begin(bid); + const int32_t kv_len = p_lds_kv_lens[bid]; + const int32_t kv_batch_start = + Traits::kIsSparse ? bid_ori * params.topk : params.p_seqlens_kv_indptr[bid_ori]; + const int32_t kv_batch_end = kv_batch_start + kv_len; + const int32_t packed_qo_len = qo_len * params.num_heads; + const int32_t num_qo_tiles = ck_tile::integer_divide_ceil(packed_qo_len, cluster_len_q); const int32_t packed_qo_tile_len = ck_tile::min(packed_qo_len, cluster_len_q); - const int32_t qo_tile_len = ck_tile::integer_divide_ceil(packed_qo_tile_len, params.num_heads); + const int32_t qo_tile_len = + ck_tile::integer_divide_ceil(packed_qo_tile_len, params.num_heads); - for (int32_t tid = 0; tid < num_qo_tiles; ++tid) + for(int32_t tid = 0; tid < num_qo_tiles; ++tid) { - const int32_t tile_kv_len = - cal_packed_causal_kv_len( - qo_len, kv_len, tid, packed_qo_tile_len, num_qo_tiles, params.num_heads, params.is_causal); - - generate_work( - bid, tid, qo_len, tile_kv_len, qo_tile_len, packed_qo_tile_len, qo_batch_start, kv_batch_start, - kv_batch_end, workload_limit_global, num_clusters, params.kv_granularity, nullptr, - p_lds_num_qo_clusters_indptr, nullptr, nullptr, nullptr, nullptr, nullptr, p_cost_heap, - p_cluster_work_counter); + const int32_t tile_kv_len = cal_packed_causal_kv_len(qo_len, + kv_len, + tid, + packed_qo_tile_len, + num_qo_tiles, + params.num_heads, + params.is_causal); + + generate_work(bid, + tid, + qo_len, + tile_kv_len, + qo_tile_len, + packed_qo_tile_len, + qo_batch_start, + kv_batch_start, + kv_batch_end, + workload_limit_global, + num_clusters, + params.kv_granularity, + nullptr, + p_lds_num_qo_clusters_indptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + p_cost_heap, + p_cluster_work_counter); } } // Step.5.2. Re-init cost heap and cumulative sum cluster_work_tot - scan_base = 0; - const int32_t num_loop_clusters = - integer_divide_ceil_power2(num_clusters, ck_tile::get_warp_size(), __builtin_ctz(ck_tile::get_warp_size())); - for (int32_t loop_idx = 0; loop_idx < num_loop_clusters; ++loop_idx) + scan_base = 0; + const int32_t num_loop_clusters = integer_divide_ceil_power2( + num_clusters, ck_tile::get_warp_size(), __builtin_ctz(ck_tile::get_warp_size())); + for(int32_t loop_idx = 0; loop_idx < num_loop_clusters; ++loop_idx) { const int32_t cid = lane_idx + loop_idx * ck_tile::get_warp_size(); const int32_t cluster_work = (cid < num_clusters) ? p_cluster_work_counter[cid] : 0; - const int32_t cum_cluster_work = warp_prefix_sum(cluster_work, ck_tile::get_warp_size()) + scan_base; + const int32_t cum_cluster_work = + warp_prefix_sum(cluster_work, ck_tile::get_warp_size()) + scan_base; scan_base = ck_tile::warp_shuffle(cum_cluster_work, ck_tile::get_warp_size() - 1); - if (cid < num_clusters) + if(cid < num_clusters) { params.p_work_indptr[cid + 1] = cum_cluster_work; - p_cost_heap[cid] = 0; - p_cluster_work_counter[cid] = 0; + p_cost_heap[cid] = 0; + p_cluster_work_counter[cid] = 0; } } - if (lane_idx == 0) + if(lane_idx == 0) { params.p_work_indptr[0] = 0; } @@ -422,57 +475,79 @@ __global__ void kn_get_mla_metadata_v1_1( MlaPartialTileInfo* p_reduce_partial_map = reinterpret_cast(p_cluster_work_counter + num_clusters); MlaPartialTileInfo* p_reduce_final_map = p_reduce_partial_map + tot_qo_tiles; - for (int32_t cluster_q_idx = threadIdx.x; cluster_q_idx < tot_qo_tiles; cluster_q_idx += ck_tile::get_warp_size()) + for(int32_t cluster_q_idx = threadIdx.x; cluster_q_idx < tot_qo_tiles; + cluster_q_idx += ck_tile::get_warp_size()) { p_reduce_partial_map[cluster_q_idx] = MlaPartialTileInfo{{-1, -2}}; - p_reduce_final_map[cluster_q_idx] = MlaPartialTileInfo{{-1, -2}}; + p_reduce_final_map[cluster_q_idx] = MlaPartialTileInfo{{-1, -2}}; } // Step.5.3. Output work info - int32_t num_partial_outputs = 0; - int32_t loc_partial_outputs = 0; + int32_t num_partial_outputs = 0; + int32_t loc_partial_outputs = 0; MlaWorkInfo* p_work_info_set = reinterpret_cast(params.p_work_info_set_raw); - for (int32_t idx = 0; idx < params.num_batches; ++idx) + for(int32_t idx = 0; idx < params.num_batches; ++idx) { - const int32_t bid = Traits::kSortBatch ? p_lds_batch_idx[idx] : idx; - const int32_t bid_ori = bid / params.qk_batch_ratio; - const int32_t qo_len = qo_state.get_seqlen(bid); - const int32_t qo_batch_start = qo_state.get_begin(bid); - const int32_t kv_len = p_lds_kv_lens[bid]; - const int32_t kv_batch_start = Traits::kIsSparse ? bid_ori * params.topk - : params.p_seqlens_kv_indptr[bid_ori]; - const int32_t kv_batch_end = kv_batch_start + kv_len; - const int32_t packed_qo_len = qo_len * params.num_heads; - const int32_t num_qo_tiles = ck_tile::integer_divide_ceil(packed_qo_len, cluster_len_q); + const int32_t bid = Traits::kSortBatch ? p_lds_batch_idx[idx] : idx; + const int32_t bid_ori = bid / params.qk_batch_ratio; + const int32_t qo_len = qo_state.get_seqlen(bid); + const int32_t qo_batch_start = qo_state.get_begin(bid); + const int32_t kv_len = p_lds_kv_lens[bid]; + const int32_t kv_batch_start = + Traits::kIsSparse ? bid_ori * params.topk : params.p_seqlens_kv_indptr[bid_ori]; + const int32_t kv_batch_end = kv_batch_start + kv_len; + const int32_t packed_qo_len = qo_len * params.num_heads; + const int32_t num_qo_tiles = ck_tile::integer_divide_ceil(packed_qo_len, cluster_len_q); const int32_t packed_qo_tile_len = ck_tile::min(packed_qo_len, cluster_len_q); - const int32_t qo_tile_len = ck_tile::integer_divide_ceil(packed_qo_tile_len, params.num_heads); + const int32_t qo_tile_len = + ck_tile::integer_divide_ceil(packed_qo_tile_len, params.num_heads); #if PRINT_DBG - if (lane_idx == 0) + if(lane_idx == 0) { printf("[metadata] Dividing batch=%d, qo_len=%d, kv_len=%d\n", bid, qo_len, kv_len); } #endif - for (int32_t tid = 0; tid < num_qo_tiles; ++tid) + for(int32_t tid = 0; tid < num_qo_tiles; ++tid) { - const int32_t tile_kv_len = - cal_packed_causal_kv_len( - qo_len, kv_len, tid, packed_qo_tile_len, num_qo_tiles, params.num_heads, params.is_causal); - - generate_work( - bid, tid, qo_len, tile_kv_len, qo_tile_len, packed_qo_tile_len, qo_batch_start, kv_batch_start, - kv_batch_end, workload_limit_global, num_clusters, params.kv_granularity, params.p_work_indptr, - p_lds_num_qo_clusters_indptr, &loc_partial_outputs, &num_partial_outputs, p_work_info_set, - p_reduce_final_map, p_reduce_partial_map, p_cost_heap, p_cluster_work_counter); + const int32_t tile_kv_len = cal_packed_causal_kv_len(qo_len, + kv_len, + tid, + packed_qo_tile_len, + num_qo_tiles, + params.num_heads, + params.is_causal); + + generate_work(bid, + tid, + qo_len, + tile_kv_len, + qo_tile_len, + packed_qo_tile_len, + qo_batch_start, + kv_batch_start, + kv_batch_end, + workload_limit_global, + num_clusters, + params.kv_granularity, + params.p_work_indptr, + p_lds_num_qo_clusters_indptr, + &loc_partial_outputs, + &num_partial_outputs, + p_work_info_set, + p_reduce_final_map, + p_reduce_partial_map, + p_cost_heap, + p_cluster_work_counter); } } // Step.6. Output metadata for reduce kernel - scan_base = 0; - const int32_t num_loop_reduce = - integer_divide_ceil_power2(tot_qo_tiles, ck_tile::get_warp_size(), __builtin_ctz(ck_tile::get_warp_size())); - for (int32_t loop_idx = 0; loop_idx < num_loop_reduce; ++loop_idx) + scan_base = 0; + const int32_t num_loop_reduce = integer_divide_ceil_power2( + tot_qo_tiles, ck_tile::get_warp_size(), __builtin_ctz(ck_tile::get_warp_size())); + for(int32_t loop_idx = 0; loop_idx < num_loop_reduce; ++loop_idx) { const int32_t global_cluster_q_idx = lane_idx + loop_idx * ck_tile::get_warp_size(); @@ -481,53 +556,61 @@ __global__ void kn_get_mla_metadata_v1_1( int32_t reduce_tile_size; int32_t num_reduce_tiles = 0; - if (global_cluster_q_idx < tot_qo_tiles) + if(global_cluster_q_idx < tot_qo_tiles) { - final_info = p_reduce_final_map[global_cluster_q_idx]; + final_info = p_reduce_final_map[global_cluster_q_idx]; partial_range = p_reduce_partial_map[global_cluster_q_idx]; - reduce_tile_size = (final_info.q_start == -1) ? 0 : (final_info.q_end - final_info.q_start); + reduce_tile_size = + (final_info.q_start == -1) ? 0 : (final_info.q_end - final_info.q_start); num_reduce_tiles = - (reduce_tile_size == 0) ? 0 : ((partial_range.q_end - partial_range.q_start) / reduce_tile_size); + (reduce_tile_size == 0) + ? 0 + : ((partial_range.q_end - partial_range.q_start) / reduce_tile_size); } - const int32_t curr_cum_reduce_tiles = warp_prefix_sum(num_reduce_tiles, ck_tile::get_warp_size()) + scan_base; + const int32_t curr_cum_reduce_tiles = + warp_prefix_sum(num_reduce_tiles, ck_tile::get_warp_size()) + scan_base; const int32_t prev_cum_reduce_tiles = curr_cum_reduce_tiles - num_reduce_tiles; scan_base = ck_tile::warp_shuffle(curr_cum_reduce_tiles, ck_tile::get_warp_size() - 1); - if (global_cluster_q_idx < tot_qo_tiles) + if(global_cluster_q_idx < tot_qo_tiles) { - for (int32_t tid = prev_cum_reduce_tiles; tid < curr_cum_reduce_tiles; ++tid) + for(int32_t tid = prev_cum_reduce_tiles; tid < curr_cum_reduce_tiles; ++tid) { const int32_t local_tid = tid - prev_cum_reduce_tiles; - params.p_reduce_partial_map[tid] = partial_range.q_start + local_tid * reduce_tile_size; + params.p_reduce_partial_map[tid] = + partial_range.q_start + local_tid * reduce_tile_size; } - params.p_reduce_indptr[global_cluster_q_idx + 1] = curr_cum_reduce_tiles; - params.p_reduce_final_map[2 * global_cluster_q_idx] = final_info.q_start; + params.p_reduce_indptr[global_cluster_q_idx + 1] = curr_cum_reduce_tiles; + params.p_reduce_final_map[2 * global_cluster_q_idx] = final_info.q_start; params.p_reduce_final_map[2 * global_cluster_q_idx + 1] = final_info.q_end; } } // reduce_indptr may be larger than #clusters. const int32_t num_reduce_tiles = scan_base; - for (int32_t idx = tot_qo_tiles + 1 + lane_idx; idx < params.reduce_indptr_size; idx += ck_tile::get_warp_size()) + for(int32_t idx = tot_qo_tiles + 1 + lane_idx; idx < params.reduce_indptr_size; + idx += ck_tile::get_warp_size()) { params.p_reduce_indptr[idx] = num_reduce_tiles; } // Step.7. Fill metadata pointers for MLA kernel and the 1st element of reduce_indptr. - if (lane_idx == 0) + if(lane_idx == 0) { params.p_reduce_indptr[0] = 0; - params.p_work_metadata_ptrs[0] = static_cast(reinterpret_cast(params.p_work_indptr)); - params.p_work_metadata_ptrs[1] = static_cast(reinterpret_cast(params.p_work_info_set_raw)); + params.p_work_metadata_ptrs[0] = + static_cast(reinterpret_cast(params.p_work_indptr)); + params.p_work_metadata_ptrs[1] = + static_cast(reinterpret_cast(params.p_work_info_set_raw)); } #if PRINT_DBG - if (lane_idx == 0) + if(lane_idx == 0) { printf("[metadata] Final Cost Heap Status:\n"); - for (int32_t cid = 0; cid < num_clusters; ++cid) + for(int32_t cid = 0; cid < num_clusters; ++cid) { printf("[metadata] - cid=%d, cost=%d\n", cid, p_cost_heap[cid]); } @@ -535,40 +618,46 @@ __global__ void kn_get_mla_metadata_v1_1( #endif } -template -void dispatch_mla_metadata_v1_1_device( - const MlaMetadataV1KernelParameter& params, - const MlaMetadataV11Coefficients& coefs, - const hipStream_t stream, - const int32_t warp_size, - const int32_t lds_size) +template +void dispatch_mla_metadata_v1_1_device(const MlaMetadataV1KernelParameter& params, + const MlaMetadataV11Coefficients& coefs, + const hipStream_t stream, + const int32_t warp_size, + const int32_t lds_size) { - using Traits = MlaMetadataV11Traits; + using Traits = MlaMetadataV11Traits; const dim3 grid = dim3(1, 1, 1); kn_get_mla_metadata_v1_1<<>>(params, coefs); } -void get_mla_metadata_v1_1_device( - const torch::Tensor& seqlens_qo_indptr, // [batch size + 1] - const torch::Tensor& seqlens_kv_indptr, // [batch size + 1] - const int32_t num_heads_per_head_k, - const int32_t num_heads_k, - const bool is_causal, - const bool no_redundant, - const int32_t kv_granularity, - const int32_t max_seqlen_qo, - const int32_t ori_uni_seqlen_qo, - const int32_t topk, - torch::Tensor& work_metadata_ptrs, - torch::Tensor& work_info_set, - torch::Tensor& work_indptr, - torch::Tensor& reduce_indptr, - torch::Tensor& reduce_final_map, - torch::Tensor& reduce_partial_map) +void get_mla_metadata_v1_1_device(const torch::Tensor& seqlens_qo_indptr, // [batch size + 1] + const torch::Tensor& seqlens_kv_indptr, // [batch size + 1] + const int32_t num_heads_per_head_k, + const int32_t num_heads_k, + const bool is_causal, + const bool no_redundant, + const int32_t kv_granularity, + const int32_t max_seqlen_qo, + const int32_t ori_uni_seqlen_qo, + const int32_t topk, + torch::Tensor& work_metadata_ptrs, + torch::Tensor& work_info_set, + torch::Tensor& work_indptr, + torch::Tensor& reduce_indptr, + torch::Tensor& reduce_final_map, + torch::Tensor& reduce_partial_map) { - // This default settings is for our ASM MLA decode kernel. This kernel supports num_heads=16 and qo size from 1 - // to 4 without support to split qo for each workgroup. This means that kPackedQoLenPerWg should be 4*16=64 to - // prevent spliting in any case supported by it. + // This default settings is for our ASM MLA decode kernel. This kernel supports num_heads=16 and + // qo size from 1 to 4 without support to split qo for each workgroup. This means that + // kPackedQoLenPerWg should be 4*16=64 to prevent spliting in any case supported by it. constexpr int32_t kPackedQoLenPerWg = 128; constexpr int32_t kMaxClusterSize = 1; @@ -587,29 +676,32 @@ void get_mla_metadata_v1_1_device( int32_t qk_batch_ratio = 1; int32_t uni_seqlen_qo = ori_uni_seqlen_qo; - // In the following cases, we use #head=16 to simulate cases which is not natively supported by mla main kernel. - if ((num_heads != 16) && (num_heads != 128) && // main kernel natively supports #head=16 or #head=128 - (num_heads % 16 == 0) && (num_heads < 128)) + // In the following cases, we use #head=16 to simulate cases which is not natively supported by + // mla main kernel. + if((num_heads != 16) && + (num_heads != 128) && // main kernel natively supports #head=16 or #head=128 + (num_heads % 16 == 0) && (num_heads < 128)) { qk_batch_ratio = num_heads / 16; num_heads = 16; - num_batches *= qk_batch_ratio; + num_batches *= qk_batch_ratio; } - if (is_sparse) + if(is_sparse) { - num_batches *= uni_seqlen_qo; + num_batches *= uni_seqlen_qo; uni_seqlen_qo = 1; } - TORCH_CHECK((num_heads == 16) || (num_heads == 128), __func__, - ": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where N is in [2, 8).") + TORCH_CHECK((num_heads == 16) || (num_heads == 128), + __func__, + ": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where " + "N is in [2, 8).") - const int32_t lds_size_in_bytes = [&]() - { - const int32_t qo_tile_per_batch = - ck_tile::integer_divide_ceil(ck_tile::max(max_seqlen_qo, 1) * num_heads, kPackedQoLenPerWg); - const int32_t tot_qo_tiles = num_batches * qo_tile_per_batch; + const int32_t lds_size_in_bytes = [&]() { + const int32_t qo_tile_per_batch = ck_tile::integer_divide_ceil( + ck_tile::max(max_seqlen_qo, 1) * num_heads, kPackedQoLenPerWg); + const int32_t tot_qo_tiles = num_batches * qo_tile_per_batch; // this is maximun #clusters const int32_t num_clusters = dev_prop.multiProcessorCount; @@ -620,10 +712,12 @@ void get_mla_metadata_v1_1_device( // Memory for indptr about #cluster for each batch in direction of qo lds_size += (num_batches + 1) * sizeof(int32_t); // LDS for sorting - const int32_t power_2_num_batches = (num_batches <= 1) ? num_batches : ck_tile::next_power_of_two(num_batches); + const int32_t power_2_num_batches = + (num_batches <= 1) ? num_batches : ck_tile::next_power_of_two(num_batches); const int32_t lds_sort_size = lds_size + - ck_tile::integer_least_multiple(power_2_num_batches, ck_tile::get_warp_size()) * 2 * sizeof(int32_t); + ck_tile::integer_least_multiple(power_2_num_batches, ck_tile::get_warp_size()) * 2 * + sizeof(int32_t); // Memory for cost. Its size should be the same as #clusters lds_size += num_clusters * sizeof(int32_t); // Memory for counter of #works for each cluster. @@ -637,7 +731,8 @@ void get_mla_metadata_v1_1_device( }(); TORCH_CHECK(lds_size_in_bytes <= dev_prop.maxSharedMemoryPerMultiProcessor, - __func__, ": There is no enough LDS."); + __func__, + ": There is no enough LDS."); // auto opts = seqlens_kv_indptr.options(); // auto work_ptrs = torch::empty({2}, opts.dtype(torch::kUInt64)); @@ -649,30 +744,30 @@ void get_mla_metadata_v1_1_device( // kernel input parameters MlaMetadataV1KernelParameter params = {}; - params.p_work_metadata_ptrs = work_metadata_ptrs.data_ptr(); - params.p_work_indptr = work_indptr.data_ptr(); - params.p_work_info_set_raw = work_info_set.data_ptr(); - params.p_reduce_indptr = reduce_indptr.data_ptr(); - params.p_reduce_final_map = reduce_final_map.data_ptr(); - params.p_reduce_partial_map = reduce_partial_map.data_ptr(); - params.p_seqlens_qo_indptr = seqlens_qo_indptr.data_ptr(); - params.p_seqlens_kv_indptr = seqlens_kv_indptr.data_ptr(); - params.num_batches = num_batches; - params.num_heads = num_heads; - params.num_cu = num_cu; - params.reduce_indptr_size = reduce_indptr.size(0); - params.kv_granularity = kv_granularity; - params.kv_granularity_log2 = __builtin_ctz(kv_granularity); - params.uni_seqlen_qo = uni_seqlen_qo; - params.ori_seqlen_qo = ori_uni_seqlen_qo; - params.topk = topk; - params.is_causal = is_causal; - params.qk_batch_ratio = qk_batch_ratio; + params.p_work_metadata_ptrs = work_metadata_ptrs.data_ptr(); + params.p_work_indptr = work_indptr.data_ptr(); + params.p_work_info_set_raw = work_info_set.data_ptr(); + params.p_reduce_indptr = reduce_indptr.data_ptr(); + params.p_reduce_final_map = reduce_final_map.data_ptr(); + params.p_reduce_partial_map = reduce_partial_map.data_ptr(); + params.p_seqlens_qo_indptr = seqlens_qo_indptr.data_ptr(); + params.p_seqlens_kv_indptr = seqlens_kv_indptr.data_ptr(); + params.num_batches = num_batches; + params.num_heads = num_heads; + params.num_cu = num_cu; + params.reduce_indptr_size = reduce_indptr.size(0); + params.kv_granularity = kv_granularity; + params.kv_granularity_log2 = __builtin_ctz(kv_granularity); + params.uni_seqlen_qo = uni_seqlen_qo; + params.ori_seqlen_qo = ori_uni_seqlen_qo; + params.topk = topk; + params.is_causal = is_causal; + params.qk_batch_ratio = qk_batch_ratio; MlaMetadataV11Coefficients coefs = {}; - coefs.workload_limit_global_0 = 0.01f; - coefs.workload_limit_global_1 = 0.01f; - coefs.workload_limit_global_2 = 10.0f; + coefs.workload_limit_global_0 = 0.01f; + coefs.workload_limit_global_1 = 0.01f; + coefs.workload_limit_global_2 = 10.0f; // launch kernel MLA_METADATA_DISPATCHER( @@ -680,7 +775,10 @@ void get_mla_metadata_v1_1_device( kPackedQoLenPerWg, params.uni_seqlen_qo, topk, - dispatch_mla_metadata_v1_1_device( - params, coefs, stream, dev_prop.warpSize, dev_prop.maxSharedMemoryPerMultiProcessor) - ); + dispatch_mla_metadata_v1_1_device( + params, coefs, stream, dev_prop.warpSize, dev_prop.maxSharedMemoryPerMultiProcessor)); } diff --git a/csrc/kernels/mla/metadata/v1_1_host.cuh b/csrc/kernels/mla/metadata/v1_1_host.cuh index 3c00b3848e..2a4e155ae4 100644 --- a/csrc/kernels/mla/metadata/v1_1_host.cuh +++ b/csrc/kernels/mla/metadata/v1_1_host.cuh @@ -1,18 +1,18 @@ #pragma once -#include #include "aiter_hip_common.h" #include "v1_comm.cuh" +#include template -std::vector get_mla_metadata_v1_1_host( - const torch::Tensor& seqlens_qo_indptr, // [batch size + 1] - const torch::Tensor& seqlens_kv_indptr, // [batch size + 1] - const int32_t num_heads_per_head_k, - const int32_t num_heads_k, - const bool is_causal, - const int32_t kv_granularity, - const bool no_redundant) +std::vector +get_mla_metadata_v1_1_host(const torch::Tensor& seqlens_qo_indptr, // [batch size + 1] + const torch::Tensor& seqlens_kv_indptr, // [batch size + 1] + const int32_t num_heads_per_head_k, + const int32_t num_heads_k, + const bool is_causal, + const int32_t kv_granularity, + const bool no_redundant) { using index_t = uint32_t; @@ -22,7 +22,7 @@ std::vector get_mla_metadata_v1_1_host( HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); const int32_t num_batches = seqlens_qo_indptr.size(0) - 1; - const int32_t num_heads = num_heads_k * num_heads_per_head_k; + const int32_t num_heads = num_heads_k * num_heads_per_head_k; auto seqlens_qo_indptr_cpu = seqlens_qo_indptr.to(at::DeviceType::CPU); auto seqlens_kv_indptr_cpu = seqlens_kv_indptr.to(at::DeviceType::CPU); @@ -34,7 +34,7 @@ std::vector get_mla_metadata_v1_1_host( std::vector batch_infos; batch_infos.reserve(num_batches); int32_t sum_packed_qo_len = 0; - for (int32_t bid = 0; bid < num_batches; ++bid) + for(int32_t bid = 0; bid < num_batches; ++bid) { const int32_t qo_len = p_seqlens_qo_indptr[bid + 1] - p_seqlens_qo_indptr[bid]; const int32_t kv_len = p_seqlens_kv_indptr[bid + 1] - p_seqlens_kv_indptr[bid]; @@ -47,67 +47,79 @@ std::vector get_mla_metadata_v1_1_host( } std::sort(batch_infos.begin(), batch_infos.end(), std::greater()); - // Step.1. Calculate the size of cluster and some related information. The size is the number of workgroups + // Step.1. Calculate the size of cluster and some related information. The size is the number of + // workgroups // composing each cluster. The size is determined by average packed qo length. - const int32_t cluster_size = - [&]() { + const int32_t cluster_size = [&]() { const int32_t avg_packed_qo_len = sum_packed_qo_len / num_batches; const int32_t cluster_size = ck_tile::integer_divide_ceil(avg_packed_qo_len, Traits::kPackedQoLenPerWg); return ck_tile::min(cluster_size, Traits::kMaxClusterSize); }(); - TORCH_CHECK((dev_prop.multiProcessorCount % cluster_size) == 0, __func__, ": Invalid cluster_size!"); + TORCH_CHECK( + (dev_prop.multiProcessorCount % cluster_size) == 0, __func__, ": Invalid cluster_size!"); const int32_t num_clusters = dev_prop.multiProcessorCount / cluster_size; const int32_t cluster_len_q = cluster_size * Traits::kPackedQoLenPerWg; // Step.2. - // a. Get total valid (after causal masking) kv lengths and the maximun workload handled by each cluster - // b. Get a indptr array about #cluster for each batch in direction of qo. + // a. Get total valid (after causal masking) kv lengths and the maximun workload handled by + // each cluster b. Get a indptr array about #cluster for each batch in direction of qo. int32_t workload_sum = 0; std::vector num_qo_clusters_indptr; num_qo_clusters_indptr.reserve(num_batches + 1); num_qo_clusters_indptr.push_back(0); - for (const auto& binfo : batch_infos) + for(const auto& binfo : batch_infos) { - const int32_t packed_qo_len = binfo.qo_len * num_heads; - const int32_t num_qo_tiles = ck_tile::integer_divide_ceil(packed_qo_len, cluster_len_q); + const int32_t packed_qo_len = binfo.qo_len * num_heads; + const int32_t num_qo_tiles = ck_tile::integer_divide_ceil(packed_qo_len, cluster_len_q); const int32_t packed_qo_tile_len = ck_tile::min(packed_qo_len, cluster_len_q); num_qo_clusters_indptr.push_back(num_qo_clusters_indptr.back() + num_qo_tiles); - for (int32_t tid = 0; tid < num_qo_tiles; ++tid) + for(int32_t tid = 0; tid < num_qo_tiles; ++tid) { - const int32_t kv_len_valid = - cal_packed_causal_kv_len( - binfo.qo_len, binfo.kv_len, tid, packed_qo_tile_len, num_qo_tiles, num_heads, is_causal); + const int32_t kv_len_valid = cal_packed_causal_kv_len(binfo.qo_len, + binfo.kv_len, + tid, + packed_qo_tile_len, + num_qo_tiles, + num_heads, + is_causal); // always assume that each batch of tile will be splited once along kv. - const int32_t kv_len_splited = - ck_tile::integer_least_multiple(ck_tile::integer_divide_ceil(kv_len_valid, 2), kv_granularity); + const int32_t kv_len_splited = ck_tile::integer_least_multiple( + ck_tile::integer_divide_ceil(kv_len_valid, 2), kv_granularity); workload_sum += 2 * cal_cost(packed_qo_tile_len, kv_len_splited) + kv_granularity; } } - const int32_t workload_limit_global = cal_workload_limit_global_v0(workload_sum, num_clusters, kv_granularity); + const int32_t workload_limit_global = + cal_workload_limit_global_v0(workload_sum, num_clusters, kv_granularity); #if PRINT_DBG printf("[metadata] workload_limit_global=%d\n", workload_limit_global); #endif // Step.3.1. Allocates output buffers except indptrs std::vector> work_info_set(num_clusters, std::vector()); - std::vector> reduce_partial_map(num_qo_clusters_indptr.back(), std::vector()); + std::vector> reduce_partial_map(num_qo_clusters_indptr.back(), + std::vector()); std::vector reduce_partial_info(num_qo_clusters_indptr.back(), {{-1, -2}}); // Step.3.2. Declare priority queue using ClusterCost = std::tuple; // cluster_id(cid), cost - auto pq_cmp = [](const ClusterCost& l, const ClusterCost& r) { return std::get<1>(l) > std::get<1>(r); }; + auto pq_cmp = [](const ClusterCost& l, const ClusterCost& r) { + return std::get<1>(l) > std::get<1>(r); + }; std::priority_queue, decltype(pq_cmp)> cost_heap(pq_cmp); - for (int32_t cid = 0; cid < num_clusters; ++cid) { cost_heap.push(std::tuple{cid, 0}); } + for(int32_t cid = 0; cid < num_clusters; ++cid) + { + cost_heap.push(std::tuple{cid, 0}); + } // Step.4. Fill the output buffers except indptrs int32_t num_reduce_row = 0; int32_t num_partial_outputs = 0; int32_t loc_partial_outputs = 0; - for (const auto& binfo : batch_infos) + for(const auto& binfo : batch_infos) { const int32_t bid = binfo.batch_idx; const int32_t qo_len = binfo.qo_len; @@ -121,42 +133,55 @@ std::vector get_mla_metadata_v1_1_host( printf("[metadata] Dividing batch=%d, qo_len=%d, kv_len=%d\n", bid, qo_len, kv_len); #endif - for (int32_t tid = 0; tid < num_qo_tiles; ++tid) + for(int32_t tid = 0; tid < num_qo_tiles; ++tid) { const int32_t global_cluster_q_idx = num_qo_clusters_indptr[bid] + tid; - int32_t remaining_kv_len = - cal_packed_causal_kv_len(qo_len, kv_len, tid, cluster_len_q, num_qo_tiles, num_heads, is_causal); + int32_t remaining_kv_len = cal_packed_causal_kv_len( + qo_len, kv_len, tid, cluster_len_q, num_qo_tiles, num_heads, is_causal); int32_t kv_start_local = 0; const auto [cid_top, accum_cost_top] = cost_heap.top(); - const int32_t remaining_capability_top = cal_kv_len(workload_limit_global - accum_cost_top, cluster_len_q); + const int32_t remaining_capability_top = + cal_kv_len(workload_limit_global - accum_cost_top, cluster_len_q); const int32_t num_splits_estimated = ck_tile::integer_divide_ceil(remaining_kv_len, remaining_capability_top); - // For the case of #splits==2, make sure that the tailing tile is smaller than Traits::kSplitTolerance. - const bool split_kv = (num_splits_estimated == 2) ? - ((remaining_kv_len - remaining_capability_top) > Traits::kSplitTolerance) : (num_splits_estimated > 1); - const int32_t kv_len_limit_floor = - ck_tile::integer_least_multiple(ck_tile::integer_divide_ceil(kv_len, num_clusters), kv_granularity); + // For the case of #splits==2, make sure that the tailing tile is smaller than + // Traits::kSplitTolerance. + const bool split_kv = + (num_splits_estimated == 2) + ? ((remaining_kv_len - remaining_capability_top) > Traits::kSplitTolerance) + : (num_splits_estimated > 1); + const int32_t kv_len_limit_floor = ck_tile::integer_least_multiple( + ck_tile::integer_divide_ceil(kv_len, num_clusters), kv_granularity); do { // Check and update cost_heap auto [cid, accum_cost] = cost_heap.top(); cost_heap.pop(); - const int32_t remaining_capability = cal_kv_len(workload_limit_global - accum_cost, cluster_len_q); - const int32_t kv_len_limit_local = - [&]() { - const int32_t limit_ori = ck_tile::max(remaining_capability, kv_len_limit_floor); - const int32_t tail_size = (remaining_kv_len > limit_ori) ? (remaining_kv_len - limit_ori) : 0x7fffffff; - const int32_t limit_fin = (tail_size <= Traits::kSplitTolerance) ? remaining_kv_len : limit_ori; + const int32_t remaining_capability = + cal_kv_len(workload_limit_global - accum_cost, cluster_len_q); + const int32_t kv_len_limit_local = [&]() { + const int32_t limit_ori = + ck_tile::max(remaining_capability, kv_len_limit_floor); + const int32_t tail_size = (remaining_kv_len > limit_ori) + ? (remaining_kv_len - limit_ori) + : 0x7fffffff; + const int32_t limit_fin = + (tail_size <= Traits::kSplitTolerance) ? remaining_kv_len : limit_ori; return limit_fin; }(); const int32_t kv_len_consuming = ck_tile::min(remaining_kv_len, kv_len_limit_local); - const int32_t cost = cal_cost(cluster_len_q, kv_len_consuming); + const int32_t cost = cal_cost(cluster_len_q, kv_len_consuming); #if PRINT_DBG - printf("[metadata] cost heap updated: cid=%d, pre_cost=%d, new_cost=%d, tot_cost=%d, kv_len_cons=%d\n", - cid, accum_cost, cost, accum_cost+cost, kv_len_consuming); + printf("[metadata] cost heap updated: cid=%d, pre_cost=%d, new_cost=%d, " + "tot_cost=%d, kv_len_cons=%d\n", + cid, + accum_cost, + cost, + accum_cost + cost, + kv_len_consuming); #endif const int32_t new_cost = accum_cost + cost; cost_heap.push(std::tuple{cid, new_cost}); @@ -165,17 +190,19 @@ std::vector get_mla_metadata_v1_1_host( MlaWorkInfo work_info{}; work_info.batch_idx = bid; work_info.qo_start = tid * cluster_len_q + qo_batch_start; - work_info.qo_end = ck_tile::min(work_info.qo_start + cluster_len_q, qo_batch_start + qo_len); + work_info.qo_end = + ck_tile::min(work_info.qo_start + cluster_len_q, qo_batch_start + qo_len); work_info.kv_start = kv_start_local + kv_batch_start; work_info.kv_end = work_info.kv_start + kv_len_consuming; work_info.kv_offset = kv_batch_end - work_info.kv_end; - if (split_kv) + if(split_kv) { work_info.partial_qo_loc = loc_partial_outputs; - if (reduce_partial_map[global_cluster_q_idx].empty()) + if(reduce_partial_map[global_cluster_q_idx].empty()) { ++num_reduce_row; - reduce_partial_info[global_cluster_q_idx] = {{ work_info.qo_start, work_info.qo_end }}; + reduce_partial_info[global_cluster_q_idx] = { + {work_info.qo_start, work_info.qo_end}}; } reduce_partial_map[global_cluster_q_idx].push_back(loc_partial_outputs); ++num_partial_outputs; @@ -190,14 +217,13 @@ std::vector get_mla_metadata_v1_1_host( // Update state remaining_kv_len -= kv_len_consuming; kv_start_local += kv_len_consuming; - } - while (remaining_kv_len > 0); + } while(remaining_kv_len > 0); } } #if PRINT_DBG printf("[metadata] Final Cost Heap Status: %zu elements\n", cost_heap.size()); - while (cost_heap.empty() == false) + while(cost_heap.empty() == false) { auto [id, cost] = cost_heap.top(); cost_heap.pop(); @@ -209,50 +235,65 @@ std::vector get_mla_metadata_v1_1_host( std::vector work_indptr; work_indptr.reserve(num_clusters + 1); work_indptr.push_back(0); - for (int32_t cid = 0; cid < num_clusters; ++cid) + for(int32_t cid = 0; cid < num_clusters; ++cid) { - if ((work_info_set[cid].empty() == false) || (no_redundant == false)) + if((work_info_set[cid].empty() == false) || (no_redundant == false)) { work_indptr.push_back(work_indptr.back() + work_info_set[cid].size()); } } const int32_t num_works = work_indptr.back(); - const int32_t reduce_final_map_size = no_redundant ? num_reduce_row : num_qo_clusters_indptr.back(); + const int32_t reduce_final_map_size = + no_redundant ? num_reduce_row : num_qo_clusters_indptr.back(); const int32_t reduce_indptr_size = reduce_final_map_size + 1; std::vector reduce_final_map; std::vector reduce_indptr; reduce_final_map.reserve(reduce_final_map_size); reduce_indptr.reserve(reduce_indptr_size); reduce_indptr.push_back(0); - for (auto [global_cluster_q_idx ,rid] = std::tuple{0, 0}; - (global_cluster_q_idx < num_qo_clusters_indptr.back()) && ((rid < num_reduce_row) || (no_redundant == false)); - ++global_cluster_q_idx) + for(auto [global_cluster_q_idx, rid] = std::tuple{0, 0}; + (global_cluster_q_idx < num_qo_clusters_indptr.back()) && + ((rid < num_reduce_row) || (no_redundant == false)); + ++global_cluster_q_idx) { - if ((reduce_partial_map[global_cluster_q_idx].empty() == false) || (no_redundant == false)) + if((reduce_partial_map[global_cluster_q_idx].empty() == false) || (no_redundant == false)) { - reduce_indptr.push_back(reduce_indptr.back() + reduce_partial_map[global_cluster_q_idx].size()); + reduce_indptr.push_back(reduce_indptr.back() + + reduce_partial_map[global_cluster_q_idx].size()); reduce_final_map.push_back(reduce_partial_info[global_cluster_q_idx]); ++rid; } } // Step.6. Flatten 2D arries - auto work_info_set_flatten = flatten(work_info_set, num_works); + auto work_info_set_flatten = flatten(work_info_set, num_works); auto reduce_partial_map_flatten = flatten(reduce_partial_map, num_partial_outputs); // Step.7. Create tensors. - auto input_opts = seqlens_qo_indptr.options(); - auto int_opts = torch::TensorOptions().dtype(torch::kInt32); + auto input_opts = seqlens_qo_indptr.options(); + auto int_opts = torch::TensorOptions().dtype(torch::kInt32); auto work_metadata_ptrs_tsr = torch::empty({2}, torch::TensorOptions().dtype(torch::kUInt64)); - auto work_info_set_tsr = torch::from_blob(work_info_set_flatten.data(), {num_works, kSizeMlaWorkInfoInDw}, int_opts).to(input_opts); - auto work_indptr_tsr = torch::from_blob(work_indptr.data(), {static_cast(work_indptr.size())}, int_opts).to(input_opts); - auto reduce_indptr_tsr = torch::from_blob(reduce_indptr.data(), {reduce_indptr_size}, int_opts).to(input_opts); - auto reduce_final_map_tsr = torch::from_blob(reduce_final_map.data(), {reduce_final_map_size, kSizeMlaPartialTileInfoInDw}, int_opts).to(input_opts); - auto reduce_partial_map_tsr = torch::from_blob(reduce_partial_map_flatten.data(), {num_partial_outputs}, int_opts).to(input_opts); + auto work_info_set_tsr = + torch::from_blob(work_info_set_flatten.data(), {num_works, kSizeMlaWorkInfoInDw}, int_opts) + .to(input_opts); + auto work_indptr_tsr = + torch::from_blob(work_indptr.data(), {static_cast(work_indptr.size())}, int_opts) + .to(input_opts); + auto reduce_indptr_tsr = + torch::from_blob(reduce_indptr.data(), {reduce_indptr_size}, int_opts).to(input_opts); + auto reduce_final_map_tsr = + torch::from_blob( + reduce_final_map.data(), {reduce_final_map_size, kSizeMlaPartialTileInfoInDw}, int_opts) + .to(input_opts); + auto reduce_partial_map_tsr = + torch::from_blob(reduce_partial_map_flatten.data(), {num_partial_outputs}, int_opts) + .to(input_opts); - work_metadata_ptrs_tsr.index_put_({0}, static_cast(reinterpret_cast(work_indptr_tsr.data_ptr()))); - work_metadata_ptrs_tsr.index_put_({1}, static_cast(reinterpret_cast(work_info_set_tsr.data_ptr()))); + work_metadata_ptrs_tsr.index_put_( + {0}, static_cast(reinterpret_cast(work_indptr_tsr.data_ptr()))); + work_metadata_ptrs_tsr.index_put_( + {1}, static_cast(reinterpret_cast(work_info_set_tsr.data_ptr()))); // Last step. Copy to the device of input and return the results. return {work_metadata_ptrs_tsr.to(input_opts), diff --git a/csrc/kernels/mla/metadata/v1_2_device.cuh b/csrc/kernels/mla/metadata/v1_2_device.cuh index ad64bce238..b96051874d 100644 --- a/csrc/kernels/mla/metadata/v1_2_device.cuh +++ b/csrc/kernels/mla/metadata/v1_2_device.cuh @@ -28,12 +28,34 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ { using QoState = QoState; + const int32_t ori_seqlen_qo = [&]() { + if constexpr (Traits::kIsSparse) + { + return params.p_seqlens_qo_indptr[1] - params.p_seqlens_qo_indptr[0]; + } + else + { + return params.ori_seqlen_qo; + } + }(); + + const int32_t num_batches = [&]() { + if constexpr (Traits::kIsSparse) + { + return params.num_batches * ori_seqlen_qo; + } + else + { + return params.num_batches; + } + }(); + extern __shared__ uint8_t p_smem[]; int32_t* p_lds_seqlens_qo = reinterpret_cast(p_smem); - int32_t* p_lds_seqlens_kv = p_lds_seqlens_qo + (QoState::is_unique() ? 0 : params.num_batches); + int32_t* p_lds_seqlens_kv = p_lds_seqlens_qo + (QoState::is_unique() ? 0 : num_batches); QoState qo_state( - params.uni_seqlen_qo, params.ori_seqlen_qo, p_lds_seqlens_qo, params.p_seqlens_qo_indptr); + params.uni_seqlen_qo, ori_seqlen_qo, p_lds_seqlens_qo, params.p_seqlens_qo_indptr); auto get_num_qo_tiles = [&](const int32_t batch_idx) { if constexpr(Traits::kQoSplits) @@ -53,10 +75,10 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ MlaWorkInfo* p_work_info_set = reinterpret_cast(params.p_work_info_set_raw); int32_t sum_blocks = 0; - for(int32_t bid = lane_idx; bid < params.num_batches; bid += ck_tile::get_warp_size()) + for(int32_t bid = lane_idx; bid < num_batches; bid += ck_tile::get_warp_size()) { const int32_t bid_ori = Traits::kIsSparse - ? (bid / params.ori_seqlen_qo / params.qk_batch_ratio) + ? (bid / ori_seqlen_qo / params.qk_batch_ratio) : (bid / params.qk_batch_ratio); const int32_t kv_end = params.p_seqlens_kv_indptr[bid_ori + 1]; const int32_t seqlen_kv = @@ -119,7 +141,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ for(int32_t cid = 0; cid < params.num_cu; ++cid) { int32_t remain_payload = payload; - while(curr_batch < params.num_batches) + while(curr_batch < num_batches) { const int32_t num_qo_tiles = get_num_qo_tiles(curr_batch); const int32_t qo_tile_size = @@ -143,9 +165,17 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ work_info.qo_end = ck_tile::min(work_info.qo_start + qo_tile_size, qo_state.get_end(curr_batch)); work_info.kv_start = curr_kv_begin + (curr_kv_block * params.kv_granularity); + int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx); + if constexpr(!Traits::kIsSparse) + { + if (params.qk_batch_ratio != 1) + { + batch_tail = num_qo_tiles - (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - 1; + } + } work_info.kv_end = ck_tile::min( work_info.kv_start + (remain_kv_blocks * params.kv_granularity), - curr_kv_end - (num_qo_tiles - 1 - curr_qo_tile_idx)); + curr_kv_end - batch_tail); work_info.kv_offset = curr_kv_end - work_info.kv_end; // split related info @@ -202,7 +232,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ curr_sub_head_idx = (curr_sub_head_idx == (params.qk_batch_ratio - 1)) ? 0 : (curr_sub_head_idx + 1); - if(curr_batch < params.num_batches) + if(curr_batch < num_batches) { if(curr_sub_head_idx == 0) { @@ -213,7 +243,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ else { const int32_t bid_ori = Traits::kIsSparse - ? (curr_batch / params.ori_seqlen_qo / + ? (curr_batch / ori_seqlen_qo / params.qk_batch_ratio) : (curr_batch / params.qk_batch_ratio); curr_kv_seqlen = params.p_seqlens_kv_indptr[bid_ori + 1] - @@ -251,9 +281,17 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ qo_state.get_end(curr_batch)); work_info.kv_start = curr_kv_begin + (curr_kv_block * params.kv_granularity); + int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx); + if constexpr(!Traits::kIsSparse) + { + if (params.qk_batch_ratio != 1) + { + batch_tail = num_qo_tiles - (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - 1; + } + } work_info.kv_end = ck_tile::min( work_info.kv_start + (consuming_blks * params.kv_granularity), - curr_kv_end - (num_qo_tiles - 1 - curr_qo_tile_idx)); + curr_kv_end - batch_tail); work_info.kv_offset = curr_kv_end - work_info.kv_end; work_info.partial_qo_loc = partial_idx; p_work_info_set[num_works] = work_info; @@ -365,12 +403,6 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba num_batches *= qk_batch_ratio; } - if(is_sparse) - { - num_batches *= uni_seqlen_qo; - uni_seqlen_qo = 1; - } - TORCH_CHECK((num_heads == 16) || (num_heads == 128), __func__, ": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where " diff --git a/csrc/kernels/rope/rope_common.h b/csrc/kernels/rope/rope_common.h index a14c0534e9..598f70c207 100644 --- a/csrc/kernels/rope/rope_common.h +++ b/csrc/kernels/rope/rope_common.h @@ -3,6 +3,7 @@ #pragma once +#include "aiter_hip_common.h" #include "dispatch_utils.h" #include @@ -1271,20 +1272,21 @@ template -__global__ void kn_entry_1c_sbhd_uncached(scalar_t* __restrict__ p_output, - const scalar_t* __restrict__ p_input, - const scalar_f_t* __restrict__ p_freqs, - const int32_t size_h, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_i_s, - const int32_t stride_i_b, - const int32_t stride_i_h, - const int32_t stride_i_d, - const int32_t stride_o_s, - const int32_t stride_o_b, - const int32_t stride_o_h, - const int32_t stride_o_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_1c_sbhd_uncached(scalar_t* __restrict__ p_output, + const scalar_t* __restrict__ p_input, + const scalar_f_t* __restrict__ p_freqs, + const int32_t size_h, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_i_s, + const int32_t stride_i_b, + const int32_t stride_i_h, + const int32_t stride_i_d, + const int32_t stride_o_s, + const int32_t stride_o_b, + const int32_t stride_o_h, + const int32_t stride_o_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -1316,16 +1318,16 @@ template -__global__ void -kn_entry_1c_sbhd_uncached_inplace(scalar_t* __restrict__ p_inout, - const scalar_f_t* __restrict__ p_freqs, - const int32_t size_h, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_s, - const int32_t stride_b, - const int32_t stride_h, - const int32_t stride_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_1c_sbhd_uncached_inplace(scalar_t* __restrict__ p_inout, + const scalar_f_t* __restrict__ p_freqs, + const int32_t size_h, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_s, + const int32_t stride_b, + const int32_t stride_h, + const int32_t stride_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -1359,31 +1361,32 @@ template -__global__ void kn_entry_2c_sbhd_uncached(scalar_t* __restrict__ p_output_x, - scalar_t* __restrict__ p_output_y, - const scalar_t* __restrict__ p_input_x, - const scalar_t* __restrict__ p_input_y, - const scalar_f_t* __restrict__ p_freqs, - const int32_t size_h_x, - const int32_t size_h_y, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_ix_s, - const int32_t stride_ix_b, - const int32_t stride_ix_h, - const int32_t stride_ix_d, - const int32_t stride_iy_s, - const int32_t stride_iy_b, - const int32_t stride_iy_h, - const int32_t stride_iy_d, - const int32_t stride_ox_s, - const int32_t stride_ox_b, - const int32_t stride_ox_h, - const int32_t stride_ox_d, - const int32_t stride_oy_s, - const int32_t stride_oy_b, - const int32_t stride_oy_h, - const int32_t stride_oy_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_2c_sbhd_uncached(scalar_t* __restrict__ p_output_x, + scalar_t* __restrict__ p_output_y, + const scalar_t* __restrict__ p_input_x, + const scalar_t* __restrict__ p_input_y, + const scalar_f_t* __restrict__ p_freqs, + const int32_t size_h_x, + const int32_t size_h_y, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_ix_s, + const int32_t stride_ix_b, + const int32_t stride_ix_h, + const int32_t stride_ix_d, + const int32_t stride_iy_s, + const int32_t stride_iy_b, + const int32_t stride_iy_h, + const int32_t stride_iy_d, + const int32_t stride_ox_s, + const int32_t stride_ox_b, + const int32_t stride_ox_h, + const int32_t stride_ox_d, + const int32_t stride_oy_s, + const int32_t stride_oy_b, + const int32_t stride_oy_h, + const int32_t stride_oy_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -1427,22 +1430,22 @@ template -__global__ void -kn_entry_2c_sbhd_uncached_inplace(scalar_t* __restrict__ p_inout_x, - scalar_t* __restrict__ p_inout_y, - const scalar_f_t* __restrict__ p_freqs, - const int32_t size_h_x, - const int32_t size_h_y, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_x_s, - const int32_t stride_x_b, - const int32_t stride_x_h, - const int32_t stride_x_d, - const int32_t stride_y_s, - const int32_t stride_y_b, - const int32_t stride_y_h, - const int32_t stride_y_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_2c_sbhd_uncached_inplace(scalar_t* __restrict__ p_inout_x, + scalar_t* __restrict__ p_inout_y, + const scalar_f_t* __restrict__ p_freqs, + const int32_t size_h_x, + const int32_t size_h_y, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_x_s, + const int32_t stride_x_b, + const int32_t stride_x_h, + const int32_t stride_x_d, + const int32_t stride_y_s, + const int32_t stride_y_b, + const int32_t stride_y_h, + const int32_t stride_y_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -1484,21 +1487,22 @@ template -__global__ void kn_entry_1c_sbhd_cached(scalar_t* __restrict__ p_output, - const scalar_t* __restrict__ p_input, - const scalar_f_t* __restrict__ p_cos, - const scalar_f_t* __restrict__ p_sin, - const int32_t size_h, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_i_s, - const int32_t stride_i_b, - const int32_t stride_i_h, - const int32_t stride_i_d, - const int32_t stride_o_s, - const int32_t stride_o_b, - const int32_t stride_o_h, - const int32_t stride_o_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_1c_sbhd_cached(scalar_t* __restrict__ p_output, + const scalar_t* __restrict__ p_input, + const scalar_f_t* __restrict__ p_cos, + const scalar_f_t* __restrict__ p_sin, + const int32_t size_h, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_i_s, + const int32_t stride_i_b, + const int32_t stride_i_h, + const int32_t stride_i_d, + const int32_t stride_o_s, + const int32_t stride_o_b, + const int32_t stride_o_h, + const int32_t stride_o_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -1531,17 +1535,17 @@ template -__global__ void -kn_entry_1c_sbhd_cached_inplace(scalar_t* __restrict__ p_inout, - const scalar_f_t* __restrict__ p_cos, - const scalar_f_t* __restrict__ p_sin, - const int32_t size_h, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_s, - const int32_t stride_b, - const int32_t stride_h, - const int32_t stride_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_1c_sbhd_cached_inplace(scalar_t* __restrict__ p_inout, + const scalar_f_t* __restrict__ p_cos, + const scalar_f_t* __restrict__ p_sin, + const int32_t size_h, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_s, + const int32_t stride_b, + const int32_t stride_h, + const int32_t stride_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -1576,32 +1580,33 @@ template -__global__ void kn_entry_2c_sbhd_cached(scalar_t* __restrict__ p_output_x, - scalar_t* __restrict__ p_output_y, - const scalar_t* __restrict__ p_input_x, - const scalar_t* __restrict__ p_input_y, - const scalar_f_t* __restrict__ p_cos, - const scalar_f_t* __restrict__ p_sin, - const int32_t size_h_x, - const int32_t size_h_y, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_ix_s, - const int32_t stride_ix_b, - const int32_t stride_ix_h, - const int32_t stride_ix_d, - const int32_t stride_iy_s, - const int32_t stride_iy_b, - const int32_t stride_iy_h, - const int32_t stride_iy_d, - const int32_t stride_ox_s, - const int32_t stride_ox_b, - const int32_t stride_ox_h, - const int32_t stride_ox_d, - const int32_t stride_oy_s, - const int32_t stride_oy_b, - const int32_t stride_oy_h, - const int32_t stride_oy_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_2c_sbhd_cached(scalar_t* __restrict__ p_output_x, + scalar_t* __restrict__ p_output_y, + const scalar_t* __restrict__ p_input_x, + const scalar_t* __restrict__ p_input_y, + const scalar_f_t* __restrict__ p_cos, + const scalar_f_t* __restrict__ p_sin, + const int32_t size_h_x, + const int32_t size_h_y, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_ix_s, + const int32_t stride_ix_b, + const int32_t stride_ix_h, + const int32_t stride_ix_d, + const int32_t stride_iy_s, + const int32_t stride_iy_b, + const int32_t stride_iy_h, + const int32_t stride_iy_d, + const int32_t stride_ox_s, + const int32_t stride_ox_b, + const int32_t stride_ox_h, + const int32_t stride_ox_d, + const int32_t stride_oy_s, + const int32_t stride_oy_b, + const int32_t stride_oy_h, + const int32_t stride_oy_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -1646,23 +1651,23 @@ template -__global__ void -kn_entry_2c_sbhd_cached_inplace(scalar_t* __restrict__ p_inout_x, - scalar_t* __restrict__ p_inout_y, - const scalar_f_t* __restrict__ p_cos, - const scalar_f_t* __restrict__ p_sin, - const int32_t size_h_x, - const int32_t size_h_y, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_x_s, - const int32_t stride_x_b, - const int32_t stride_x_h, - const int32_t stride_x_d, - const int32_t stride_y_s, - const int32_t stride_y_b, - const int32_t stride_y_h, - const int32_t stride_y_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_2c_sbhd_cached_inplace(scalar_t* __restrict__ p_inout_x, + scalar_t* __restrict__ p_inout_y, + const scalar_f_t* __restrict__ p_cos, + const scalar_f_t* __restrict__ p_sin, + const int32_t size_h_x, + const int32_t size_h_y, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_x_s, + const int32_t stride_x_b, + const int32_t stride_x_h, + const int32_t stride_x_d, + const int32_t stride_y_s, + const int32_t stride_y_b, + const int32_t stride_y_h, + const int32_t stride_y_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -1705,24 +1710,24 @@ template -__global__ void -kn_entry_1c_sbhd_cached_indirect(scalar_t* __restrict__ p_output, - const scalar_t* __restrict__ p_input, - const scalar_f_t* __restrict__ p_cos, - const scalar_f_t* __restrict__ p_sin, - const int64_t* __restrict__ p_indirect_buffer, - const int32_t max_position, - const int32_t size_h, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_i_s, - const int32_t stride_i_b, - const int32_t stride_i_h, - const int32_t stride_i_d, - const int32_t stride_o_s, - const int32_t stride_o_b, - const int32_t stride_o_h, - const int32_t stride_o_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_1c_sbhd_cached_indirect(scalar_t* __restrict__ p_output, + const scalar_t* __restrict__ p_input, + const scalar_f_t* __restrict__ p_cos, + const scalar_f_t* __restrict__ p_sin, + const int64_t* __restrict__ p_indirect_buffer, + const int32_t max_position, + const int32_t size_h, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_i_s, + const int32_t stride_i_b, + const int32_t stride_i_h, + const int32_t stride_i_d, + const int32_t stride_o_s, + const int32_t stride_o_b, + const int32_t stride_o_h, + const int32_t stride_o_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -1764,35 +1769,35 @@ template -__global__ void -kn_entry_2c_sbhd_cached_indirect(scalar_t* __restrict__ p_output_x, - scalar_t* __restrict__ p_output_y, - const scalar_t* __restrict__ p_input_x, - const scalar_t* __restrict__ p_input_y, - const scalar_f_t* __restrict__ p_cos, - const scalar_f_t* __restrict__ p_sin, - const int64_t* __restrict__ p_indirect_buffer, - const int32_t max_position, - const int32_t size_h_x, - const int32_t size_h_y, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_ix_s, - const int32_t stride_ix_b, - const int32_t stride_ix_h, - const int32_t stride_ix_d, - const int32_t stride_iy_s, - const int32_t stride_iy_b, - const int32_t stride_iy_h, - const int32_t stride_iy_d, - const int32_t stride_ox_s, - const int32_t stride_ox_b, - const int32_t stride_ox_h, - const int32_t stride_ox_d, - const int32_t stride_oy_s, - const int32_t stride_oy_b, - const int32_t stride_oy_h, - const int32_t stride_oy_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_2c_sbhd_cached_indirect(scalar_t* __restrict__ p_output_x, + scalar_t* __restrict__ p_output_y, + const scalar_t* __restrict__ p_input_x, + const scalar_t* __restrict__ p_input_y, + const scalar_f_t* __restrict__ p_cos, + const scalar_f_t* __restrict__ p_sin, + const int64_t* __restrict__ p_indirect_buffer, + const int32_t max_position, + const int32_t size_h_x, + const int32_t size_h_y, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_ix_s, + const int32_t stride_ix_b, + const int32_t stride_ix_h, + const int32_t stride_ix_d, + const int32_t stride_iy_s, + const int32_t stride_iy_b, + const int32_t stride_iy_h, + const int32_t stride_iy_d, + const int32_t stride_ox_s, + const int32_t stride_ox_b, + const int32_t stride_ox_h, + const int32_t stride_ox_d, + const int32_t stride_oy_s, + const int32_t stride_oy_b, + const int32_t stride_oy_h, + const int32_t stride_oy_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -1842,19 +1847,19 @@ template -__global__ void -kn_entry_1c_sbhd_cached_indirect_inplace(scalar_t* __restrict__ p_inout, - const scalar_f_t* __restrict__ p_cos, - const scalar_f_t* __restrict__ p_sin, - const int64_t* __restrict__ p_indirect_buffer, - const int32_t max_position, - const int32_t size_h, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_s, - const int32_t stride_b, - const int32_t stride_h, - const int32_t stride_d) +__launch_bounds__(256, 8) __global__ void kn_entry_1c_sbhd_cached_indirect_inplace( + scalar_t* __restrict__ p_inout, + const scalar_f_t* __restrict__ p_cos, + const scalar_f_t* __restrict__ p_sin, + const int64_t* __restrict__ p_indirect_buffer, + const int32_t max_position, + const int32_t size_h, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_s, + const int32_t stride_b, + const int32_t stride_h, + const int32_t stride_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -1893,25 +1898,25 @@ template -__global__ void -kn_entry_2c_sbhd_cached_indirect_inplace(scalar_t* __restrict__ p_inout_x, - scalar_t* __restrict__ p_inout_y, - const scalar_f_t* __restrict__ p_cos, - const scalar_f_t* __restrict__ p_sin, - const int64_t* __restrict__ p_indirect_buffer, - const int32_t max_position, - const int32_t size_h_x, - const int32_t size_h_y, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_x_s, - const int32_t stride_x_b, - const int32_t stride_x_h, - const int32_t stride_x_d, - const int32_t stride_y_s, - const int32_t stride_y_b, - const int32_t stride_y_h, - const int32_t stride_y_d) +__launch_bounds__(256, 8) __global__ void kn_entry_2c_sbhd_cached_indirect_inplace( + scalar_t* __restrict__ p_inout_x, + scalar_t* __restrict__ p_inout_y, + const scalar_f_t* __restrict__ p_cos, + const scalar_f_t* __restrict__ p_sin, + const int64_t* __restrict__ p_indirect_buffer, + const int32_t max_position, + const int32_t size_h_x, + const int32_t size_h_y, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_x_s, + const int32_t stride_x_b, + const int32_t stride_x_h, + const int32_t stride_x_d, + const int32_t stride_y_s, + const int32_t stride_y_b, + const int32_t stride_y_h, + const int32_t stride_y_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -1960,25 +1965,25 @@ template -__global__ void -kn_entry_1c_sbhd_cached_indirect2(scalar_t* __restrict__ p_output, - const scalar_t* __restrict__ p_input, - const scalar_f_t* __restrict__ p_cos, - const scalar_f_t* __restrict__ p_sin, - const int64_t* __restrict__ p_indirect_buffer_0, - const int64_t* __restrict__ p_indirect_buffer_1, - const int32_t max_position, - const int32_t size_h, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_i_s, - const int32_t stride_i_b, - const int32_t stride_i_h, - const int32_t stride_i_d, - const int32_t stride_o_s, - const int32_t stride_o_b, - const int32_t stride_o_h, - const int32_t stride_o_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_1c_sbhd_cached_indirect2(scalar_t* __restrict__ p_output, + const scalar_t* __restrict__ p_input, + const scalar_f_t* __restrict__ p_cos, + const scalar_f_t* __restrict__ p_sin, + const int64_t* __restrict__ p_indirect_buffer_0, + const int64_t* __restrict__ p_indirect_buffer_1, + const int32_t max_position, + const int32_t size_h, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_i_s, + const int32_t stride_i_b, + const int32_t stride_i_h, + const int32_t stride_i_d, + const int32_t stride_o_s, + const int32_t stride_o_b, + const int32_t stride_o_h, + const int32_t stride_o_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -2020,36 +2025,36 @@ template -__global__ void -kn_entry_2c_sbhd_cached_indirect2(scalar_t* __restrict__ p_output_x, - scalar_t* __restrict__ p_output_y, - const scalar_t* __restrict__ p_input_x, - const scalar_t* __restrict__ p_input_y, - const scalar_f_t* __restrict__ p_cos, - const scalar_f_t* __restrict__ p_sin, - const int64_t* __restrict__ p_indirect_buffer_0, - const int64_t* __restrict__ p_indirect_buffer_1, - const int32_t max_position, - const int32_t size_h_x, - const int32_t size_h_y, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_ix_s, - const int32_t stride_ix_b, - const int32_t stride_ix_h, - const int32_t stride_ix_d, - const int32_t stride_iy_s, - const int32_t stride_iy_b, - const int32_t stride_iy_h, - const int32_t stride_iy_d, - const int32_t stride_ox_s, - const int32_t stride_ox_b, - const int32_t stride_ox_h, - const int32_t stride_ox_d, - const int32_t stride_oy_s, - const int32_t stride_oy_b, - const int32_t stride_oy_h, - const int32_t stride_oy_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_2c_sbhd_cached_indirect2(scalar_t* __restrict__ p_output_x, + scalar_t* __restrict__ p_output_y, + const scalar_t* __restrict__ p_input_x, + const scalar_t* __restrict__ p_input_y, + const scalar_f_t* __restrict__ p_cos, + const scalar_f_t* __restrict__ p_sin, + const int64_t* __restrict__ p_indirect_buffer_0, + const int64_t* __restrict__ p_indirect_buffer_1, + const int32_t max_position, + const int32_t size_h_x, + const int32_t size_h_y, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_ix_s, + const int32_t stride_ix_b, + const int32_t stride_ix_h, + const int32_t stride_ix_d, + const int32_t stride_iy_s, + const int32_t stride_iy_b, + const int32_t stride_iy_h, + const int32_t stride_iy_d, + const int32_t stride_ox_s, + const int32_t stride_ox_b, + const int32_t stride_ox_h, + const int32_t stride_ox_d, + const int32_t stride_oy_s, + const int32_t stride_oy_b, + const int32_t stride_oy_h, + const int32_t stride_oy_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -2099,20 +2104,20 @@ template -__global__ void -kn_entry_1c_sbhd_cached_indirect2_inplace(scalar_t* __restrict__ p_inout, - const scalar_f_t* __restrict__ p_cos, - const scalar_f_t* __restrict__ p_sin, - const int64_t* __restrict__ p_indirect_buffer_0, - const int64_t* __restrict__ p_indirect_buffer_1, - const int32_t max_position, - const int32_t size_h, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_s, - const int32_t stride_b, - const int32_t stride_h, - const int32_t stride_d) +__launch_bounds__(256, 8) __global__ void kn_entry_1c_sbhd_cached_indirect2_inplace( + scalar_t* __restrict__ p_inout, + const scalar_f_t* __restrict__ p_cos, + const scalar_f_t* __restrict__ p_sin, + const int64_t* __restrict__ p_indirect_buffer_0, + const int64_t* __restrict__ p_indirect_buffer_1, + const int32_t max_position, + const int32_t size_h, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_s, + const int32_t stride_b, + const int32_t stride_h, + const int32_t stride_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -2151,26 +2156,26 @@ template -__global__ void -kn_entry_2c_sbhd_cached_indirect2_inplace(scalar_t* __restrict__ p_inout_x, - scalar_t* __restrict__ p_inout_y, - const scalar_f_t* __restrict__ p_cos, - const scalar_f_t* __restrict__ p_sin, - const int64_t* __restrict__ p_indirect_buffer_0, - const int64_t* __restrict__ p_indirect_buffer_1, - const int32_t max_position, - const int32_t size_h_x, - const int32_t size_h_y, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_x_s, - const int32_t stride_x_b, - const int32_t stride_x_h, - const int32_t stride_x_d, - const int32_t stride_y_s, - const int32_t stride_y_b, - const int32_t stride_y_h, - const int32_t stride_y_d) +__launch_bounds__(256, 8) __global__ void kn_entry_2c_sbhd_cached_indirect2_inplace( + scalar_t* __restrict__ p_inout_x, + scalar_t* __restrict__ p_inout_y, + const scalar_f_t* __restrict__ p_cos, + const scalar_f_t* __restrict__ p_sin, + const int64_t* __restrict__ p_indirect_buffer_0, + const int64_t* __restrict__ p_indirect_buffer_1, + const int32_t max_position, + const int32_t size_h_x, + const int32_t size_h_y, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_x_s, + const int32_t stride_x_b, + const int32_t stride_x_h, + const int32_t stride_x_d, + const int32_t stride_y_s, + const int32_t stride_y_b, + const int32_t stride_y_h, + const int32_t stride_y_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -2219,19 +2224,20 @@ template -__global__ void kn_entry_1c_thd_uncached(scalar_t* __restrict__ p_output, - const scalar_t* __restrict__ p_input, - const int32_t* __restrict__ p_cu_seqlens, - const scalar_f_t* __restrict__ p_freqs, - const int32_t size_h, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_i_t, - const int32_t stride_i_h, - const int32_t stride_i_d, - const int32_t stride_o_t, - const int32_t stride_o_h, - const int32_t stride_o_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_1c_thd_uncached(scalar_t* __restrict__ p_output, + const scalar_t* __restrict__ p_input, + const int32_t* __restrict__ p_cu_seqlens, + const scalar_f_t* __restrict__ p_freqs, + const int32_t size_h, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_i_t, + const int32_t stride_i_h, + const int32_t stride_i_d, + const int32_t stride_o_t, + const int32_t stride_o_h, + const int32_t stride_o_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -2268,16 +2274,16 @@ template -__global__ void -kn_entry_1c_thd_uncached_inplace(scalar_t* __restrict__ p_inout, - const int32_t* __restrict__ p_cu_seqlens, - const scalar_f_t* __restrict__ p_freqs, - const int32_t size_h, - const int32_t size_d, - const int32_t size_f, // size of last dimension of freqs. - const int32_t stride_t, - const int32_t stride_h, - const int32_t stride_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_1c_thd_uncached_inplace(scalar_t* __restrict__ p_inout, + const int32_t* __restrict__ p_cu_seqlens, + const scalar_f_t* __restrict__ p_freqs, + const int32_t size_h, + const int32_t size_d, + const int32_t size_f, // size of last dimension of freqs. + const int32_t stride_t, + const int32_t stride_h, + const int32_t stride_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; @@ -2314,23 +2320,24 @@ template -__global__ void kn_entry_1c_2d_cached(scalar_t* __restrict__ p_output, - const scalar_t* __restrict__ p_input, - const scalar_f_t* __restrict__ p_cos_h, - const scalar_f_t* __restrict__ p_sin_h, - const scalar_f_t* __restrict__ p_cos_w, - const scalar_f_t* __restrict__ p_sin_w, - const int32_t img_width, - const int32_t size_h, - const int32_t size_d, - const int32_t stride_i_b, - const int32_t stride_i_s, - const int32_t stride_i_h, - const int32_t stride_i_d, - const int32_t stride_o_b, - const int32_t stride_o_s, - const int32_t stride_o_h, - const int32_t stride_o_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_1c_2d_cached(scalar_t* __restrict__ p_output, + const scalar_t* __restrict__ p_input, + const scalar_f_t* __restrict__ p_cos_h, + const scalar_f_t* __restrict__ p_sin_h, + const scalar_f_t* __restrict__ p_cos_w, + const scalar_f_t* __restrict__ p_sin_w, + const int32_t img_width, + const int32_t size_h, + const int32_t size_d, + const int32_t stride_i_b, + const int32_t stride_i_s, + const int32_t stride_i_h, + const int32_t stride_i_d, + const int32_t stride_o_b, + const int32_t stride_o_s, + const int32_t stride_o_h, + const int32_t stride_o_d) { const uint64_t Hid = blockIdx.x; const uint64_t Wid = blockIdx.y; @@ -2386,18 +2393,19 @@ template -__global__ void kn_entry_1c_2d_cached_inplace(scalar_t* __restrict__ p_inout, - const scalar_f_t* __restrict__ p_cos_h, - const scalar_f_t* __restrict__ p_sin_h, - const scalar_f_t* __restrict__ p_cos_w, - const scalar_f_t* __restrict__ p_sin_w, - const int32_t img_width, - const int32_t size_h, - const int32_t size_d, - const int32_t stride_b, - const int32_t stride_s, - const int32_t stride_h, - const int32_t stride_d) +__launch_bounds__(256, 8) __global__ + void kn_entry_1c_2d_cached_inplace(scalar_t* __restrict__ p_inout, + const scalar_f_t* __restrict__ p_cos_h, + const scalar_f_t* __restrict__ p_sin_h, + const scalar_f_t* __restrict__ p_cos_w, + const scalar_f_t* __restrict__ p_sin_w, + const int32_t img_width, + const int32_t size_h, + const int32_t size_d, + const int32_t stride_b, + const int32_t stride_s, + const int32_t stride_h, + const int32_t stride_d) { const uint64_t Hid = blockIdx.x; const uint64_t Wid = blockIdx.y; @@ -2648,6 +2656,32 @@ __global__ void kn_entry_1c_2d_cached_inplace(scalar_t* __restrict__ p_inout, } \ } +template +std::tuple get_grid_config(const int32_t size_s_h, + const int32_t size_s_w, + const int32_t size_b, + const int32_t size_f) +{ + constexpr int32_t num_warps = 4; + constexpr int32_t num_threads = num_warps * ck_tile::get_warp_size(); + + const int32_t size_r = ReuseFreqsFrontPart ? (size_f << 1) : size_f; + const int32_t size_half_r = size_r >> 1; + const int32_t aligned_size_half_r = ck_tile::next_power_of_two(size_half_r); + + const int32_t block_dim_x = std::min(aligned_size_half_r, ck_tile::get_warp_size()); + const int32_t block_dim_y = std::max(num_threads / block_dim_x, 1); + + if constexpr(Is2D) + { + return {dim3(size_s_h, size_s_w, size_b), dim3(block_dim_x, block_dim_y)}; + } + else + { + return {dim3(size_s_h * size_s_w, size_b), dim3(block_dim_x, block_dim_y)}; + } +} + template (size_s, 1, size_b, size_f); if(p_output == p_input) { @@ -2764,8 +2797,7 @@ void dispatch_2c_sbhd_uncached(scalar_t* __restrict__ p_output_x, { const hipStream_t stream = at::hip::getCurrentHIPStream(); - const dim3 grid(size_s, size_b); - const dim3 block(C10_WARP_SIZE, size_h_x < 16 ? 4 : 8); + auto [grid, block] = get_grid_config(size_s, 1, size_b, size_f); if((p_output_x == p_input_x) && (p_output_y == p_input_y)) { @@ -2873,8 +2905,7 @@ void dispatch_1c_sbhd_cached(scalar_t* __restrict__ p_output, { const hipStream_t stream = at::hip::getCurrentHIPStream(); - const dim3 grid(size_s, size_b); - const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); + auto [grid, block] = get_grid_config(size_s, 1, size_b, size_f); if(p_output == p_input) { @@ -2967,8 +2998,7 @@ void dispatch_2c_sbhd_cached(scalar_t* __restrict__ p_output_x, { const hipStream_t stream = at::hip::getCurrentHIPStream(); - const dim3 grid(size_s, size_b); - const dim3 block(C10_WARP_SIZE, size_h_x < 16 ? 4 : 8); + auto [grid, block] = get_grid_config(size_s, 1, size_b, size_f); if((p_output_x == p_input_x) && (p_output_y == p_input_y)) { @@ -3079,8 +3109,7 @@ void dispatch_1c_sbhd_cached_indirect(scalar_t* __restrict__ p_output, { const hipStream_t stream = at::hip::getCurrentHIPStream(); - const dim3 grid(size_s, size_b); - const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); + auto [grid, block] = get_grid_config(size_s, 1, size_b, size_f); if(p_output == p_input) { @@ -3180,8 +3209,7 @@ void dispatch_2c_sbhd_cached_indirect(scalar_t* __restrict__ p_output_x, { const hipStream_t stream = at::hip::getCurrentHIPStream(); - const dim3 grid(size_s, size_b); - const dim3 block(C10_WARP_SIZE, size_h_x < 16 ? 4 : 8); + auto [grid, block] = get_grid_config(size_s, 1, size_b, size_f); if((p_output_x == p_input_x) && (p_output_y == p_input_y)) { @@ -3298,8 +3326,7 @@ void dispatch_1c_sbhd_cached_indirect2(scalar_t* __restrict__ p_output, { const hipStream_t stream = at::hip::getCurrentHIPStream(); - const dim3 grid(size_s, size_b); - const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); + auto [grid, block] = get_grid_config(size_s, 1, size_b, size_f); if(p_output == p_input) { @@ -3403,8 +3430,7 @@ void dispatch_2c_sbhd_cached_indirect2(scalar_t* __restrict__ p_output_x, { const hipStream_t stream = at::hip::getCurrentHIPStream(); - const dim3 grid(size_s, size_b); - const dim3 block(C10_WARP_SIZE, size_h_x < 16 ? 4 : 8); + auto [grid, block] = get_grid_config(size_s, 1, size_b, size_f); if((p_output_x == p_input_x) && (p_output_y == p_input_y)) { @@ -3519,8 +3545,7 @@ void dispatch_1c_thd_uncached(scalar_t* __restrict__ p_output, { const hipStream_t stream = at::hip::getCurrentHIPStream(); - const dim3 grid(size_max_s, size_b); - const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); + auto [grid, block] = get_grid_config(size_max_s, 1, size_b, size_f); if(p_output == p_input) { @@ -3600,8 +3625,8 @@ void dispatch_1c_2d_cached(scalar_t* __restrict__ p_output, { const hipStream_t stream = at::hip::getCurrentHIPStream(); - const dim3 grid(img_height, img_width, size_b); - const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); + auto [grid, block] = + get_grid_config(img_height, img_width, size_b, size_d >> 1); if(p_output == p_input) { diff --git a/csrc/kernels/topk_per_row_kernels.cu b/csrc/kernels/topk_per_row_kernels.cu index 14eae78163..89331c52df 100644 --- a/csrc/kernels/topk_per_row_kernels.cu +++ b/csrc/kernels/topk_per_row_kernels.cu @@ -420,7 +420,8 @@ __device__ void filter_and_histogram(T const* in_buf, IdxT* histogram, bool select_min, int pass, - bool early_stop) + bool early_stop, + IdxT k) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -893,9 +894,19 @@ __global__ void radix_kernel(T const* in, int const pass) { const int64_t batch_id = blockIdx.y; - const IdxT row_len = phase == Phase::Prefill - ? rowEnds[batch_id] - rowStarts[batch_id] - : rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1; + + IdxT row_len = len; + if(phase == Phase::Prefill) + { + if(rowStarts && rowEnds) + { + row_len = rowEnds[batch_id] - rowStarts[batch_id]; + } + } + else + { + row_len = rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1; + } auto counter = counters + batch_id; IdxT current_k; @@ -965,7 +976,8 @@ __global__ void radix_kernel(T const* in, histogram, select_min, pass, - early_stop); + early_stop, + k); __threadfence(); bool isLastBlock = false; @@ -1187,7 +1199,8 @@ __device__ bool filter_and_histogram_for_one_block(T const* in_buf, Counter* counter, IdxT* histogram, bool select_min, - int pass) + int pass, + IdxT k) { constexpr int num_buckets = calc_num_buckets(); for(int i = threadIdx.x; i < num_buckets * 2; i += blockDim.x) @@ -1371,11 +1384,25 @@ __global__ void radix_topk_one_block_kernel(T const* in, __shared__ IdxT histogram[num_buckets * 2]; const int64_t batch_id = blockIdx.x; - const IdxT rowStart = phase == Phase::Prefill ? rowStarts[batch_id] : 0; - const IdxT rowEnd = phase == Phase::Prefill - ? rowEnds[batch_id] - : rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1; - const IdxT row_len = rowEnd - rowStart; + + IdxT rowStart = 0; + IdxT rowEnd = len; + if(phase == Phase::Prefill) + { + if(rowStarts && rowEnds) + { + rowStart = rowStarts[batch_id]; + rowEnd = rowEnds[batch_id]; + } + } + else + { + rowEnd = rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1; + rowStart = 0; + } + + const IdxT row_len = rowEnd - rowStart; + if(threadIdx.x == 0) { counter.k = k; @@ -1448,7 +1475,8 @@ __global__ void radix_topk_one_block_kernel(T const* in, &counter, histogram, select_min, - pass); //@TODO CHECK UPDATE CODE + pass, + k); //@TODO CHECK UPDATE CODE __syncthreads(); scan(histogram + use_one_pass * num_buckets); @@ -1811,6 +1839,35 @@ void standalone_stable_radix_11bits(void* buf, } } +// Explicit template instantiation for standalone_stable_radix_11bits +template void standalone_stable_radix_11bits(void* buf, + size_t& buf_size, + float const* in, + int batch_size, + int64_t len, + int* rowStarts, + int* rowEnds, + int k, + float* out, + int* out_idx, + bool greater, + hipStream_t stream, + int next_n); + +template void standalone_stable_radix_11bits(void* buf, + size_t& buf_size, + float const* in, + int batch_size, + int64_t len, + int* rowStarts, + int* rowEnds, + int k, + float* out, + int* out_idx, + bool greater, + hipStream_t stream, + int next_n); + // AIR TopK end static inline __device__ uint32_t floatAsSortableUint(float x) @@ -2410,6 +2467,9 @@ int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0) return buf_size; } +// Explicit template instantiation to ensure the symbol is available for linking +template int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0); + void top_k_per_row_prefill(const torch::Tensor& logits, const torch::Tensor& rowStarts, const torch::Tensor& rowEnds, diff --git a/csrc/kernels/topk_plain_kernels.cu b/csrc/kernels/topk_plain_kernels.cu index 4bf732756c..7c03823ae0 100644 --- a/csrc/kernels/topk_plain_kernels.cu +++ b/csrc/kernels/topk_plain_kernels.cu @@ -49,10 +49,251 @@ utils::hip_check_((val), __FILE__, __LINE__); \ } +// Forward declaration of topk_per_row kernel from topk_per_row_kernels.cu +namespace aiter { + +// Phase enum for distinguishing prefill vs decode paths +enum class Phase +{ + Prefill, + Decode, +}; + +template +__global__ void topk_per_row(const float* logits, + const int* rowStarts, + const int* rowEnds, + int* outIndices, + int stride0, + int stride1, + int rowOffset); + +// Forward declaration of standalone_stable_radix_11bits from topk_per_row_kernels.cu +template +void standalone_stable_radix_11bits(void* buf, + size_t& buf_size, + T const* in, + int batch_size, + int64_t len, + IdxT* rowStarts, + IdxT* rowEnds, + IdxT k, + T* out, + IdxT* out_idx, + bool greater, + hipStream_t stream, + int next_n = 0); + +} // namespace aiter + +// Forward declaration of workspace size calculation function (at global scope) +template +int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0); +extern template int64_t +invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, + int32_t stride0); + +// Forward declaration of helper function to call topk_per_row kernel +template +void topk_per_row_kernel_launcher(const float* in, + const IdxT* rowStarts, + const IdxT* rowEnds, + IdxT* out_idx, + const float* out, + int batch_size, + int stride0, + int stride1, + int k, + hipStream_t stream); + +// Helper function to determine if topk_per_row kernel should be used +// Based on: n + K log²K ≥ 3 × Factor(n) × n +// where Factor(n) = 1/3 + 1.6/(log₂(n) - 9.5) +// Simplifies to: K log²K ≥ 4.8n/(log₂(n) - 9.5) +// TODO: We need to confirm whether, when n <= 2048, we might choose +// radix sort because the denominator becomes very small; does that +// still yield the best performance? +template +__forceinline__ __host__ bool should_use_topk_radix(IdxT len, IdxT k) +{ + const double n = static_cast(len); + const double K = static_cast(k); + + if(K <= 1.0) + { + return false; + } + + const double log_n = std::log2(n); + + const double denom = std::max(0.0001, log_n - 9.5); + + const double rhs = (4.8 * n) / denom; + + const double log_k = std::log2(K); + const double lhs = K * log_k * log_k; + + return lhs >= rhs; +} + +// Gather kernel to extract values based on indices (uniform length) +template +__global__ void gather_topk_values_kernel(const T* __restrict__ in, + const IdxT* __restrict__ indices, + T* __restrict__ out, + int batch_size, + int len, + int k) +{ + int batch_id = blockIdx.x; + if(batch_id >= batch_size) + return; + + const T* in_row = in + batch_id * len; + const IdxT* idx_row = indices + batch_id * k; + T* out_row = out + batch_id * k; + + for(int i = threadIdx.x; i < k; i += blockDim.x) + { + IdxT idx = idx_row[i]; + if(idx >= 0 && idx < len) + { + out_row[i] = in_row[idx]; + } + } +} + +// Gather kernel for variable length with strides +template +__global__ void gather_topk_values_strided_kernel(const T* __restrict__ in, + const IdxT* __restrict__ indices, + T* __restrict__ out, + const IdxT* __restrict__ rowStarts, + int batch_size, + int stride0, + int stride1, + int k) +{ + int batch_id = blockIdx.x; + if(batch_id >= batch_size) + return; + + IdxT start = rowStarts[batch_id]; + const T* in_row = in + batch_id * stride0; + const IdxT* idx_row = indices + batch_id * k; + T* out_row = out + batch_id * k; + + for(int i = threadIdx.x; i < k; i += blockDim.x) + { + IdxT idx = idx_row[i]; + if(idx >= 0) + { + // idx is relative to rowStart, need to add start and apply stride1 + out_row[i] = in_row[(start + idx) * stride1]; + } + } +} + namespace topk { + +// ============================================================================ +// TYPE TRAITS FOR DATA/COMPUTE TYPE SEPARATION +// ============================================================================ +// +// Design Philosophy: +// - DataType (DataT): The storage/I/O type for memory operations +// - ComputeType (ComputeT): The type used for internal computations +// +// Mapping: +// - fp16, bf16, float -> compute as float (better precision, consistent ops) +// - int -> compute as int +// +// This separation allows: +// 1. Memory-efficient storage with compact types (fp16, bf16) +// 2. High-precision computation with float +// 3. Easy extension for new types (e.g., fp8, int8) +// +// Usage: +// using ComputeT = compute_t; +// ComputeT val = type_convert::to_compute(data_val); +// DataT result = type_convert::to_data(compute_val); +// ============================================================================ + +namespace type_traits { + +// Primary template: maps DataType -> ComputeType +template +struct ComputeTypeTraits +{ + static_assert(sizeof(DataT) == 0, + "ComputeTypeTraits not specialized for this type. " + "Supported types: _Float16, __bf16, float, int"); +}; + +// Specializations for floating-point types -> float +template <> +struct ComputeTypeTraits<_Float16> +{ + using type = float; +}; + +template <> +struct ComputeTypeTraits<__bf16> +{ + using type = float; +}; + +template <> +struct ComputeTypeTraits +{ + using type = float; +}; + +// Specialization for integer types -> int +template <> +struct ComputeTypeTraits +{ + using type = int; +}; + +// Convenience alias +template +using compute_t = typename ComputeTypeTraits::type; + +} // namespace type_traits + +// Bring compute_t into topk namespace for convenience +using type_traits::compute_t; + +// ============================================================================ +// TYPE CONVERSION UTILITIES +// ============================================================================ + +namespace type_convert { + +// Convert from DataType to ComputeType +template +__device__ __host__ __forceinline__ type_traits::compute_t to_compute(DataT val) +{ + return static_cast>(val); +} + +// Convert from ComputeType to DataType +template +__device__ __host__ __forceinline__ DataT to_data(type_traits::compute_t val) +{ + return static_cast(val); +} + +} // namespace type_convert + namespace utils { -// Supported types +// Supported types (for validation) template struct is_supported_type { @@ -198,60 +439,62 @@ __inline__ __host__ __device__ constexpr int calc_capacity(int k) namespace numeric { +// ============================================================================ +// BOUNDS AND SENTINEL VALUES +// ============================================================================ +// These functions now work with ComputeType for internal operations. +// The sentinel values are defined in ComputeType space (float for floating-point +// DataTypes, int for integer DataTypes). +// ============================================================================ + /** - * @brief Gets the absolute lowest possible value for a numeric type T. + * @brief Gets the absolute lowest possible value for a compute type. + * + * Uses -infinity for floating-point compute types, and the lowest finite + * value for integer compute types. * - * Uses -infinity for signed floating-point types, and the lowest finite - * value for all other arithmetic types. + * @tparam ComputeT The compute type (float or int). */ -template -__inline__ constexpr T get_lower_bound() +template +__inline__ __device__ __host__ constexpr ComputeT get_lower_bound() { - static_assert(utils::is_supported_type_v, - "Unsupported type T: only _Float16, __bf16, float, and int are implemented"); - if constexpr(std::is_floating_point_v && std::is_signed_v) - { - return -std::numeric_limits::infinity(); - } - else if constexpr(std::is_integral_v) + if constexpr(std::is_same_v) { - return std::numeric_limits::lowest(); + return -std::numeric_limits::infinity(); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { - return -__bf16(0x7F80); + return std::numeric_limits::lowest(); } else { + static_assert(sizeof(ComputeT) == 0, "Unsupported compute type"); __builtin_unreachable(); } } /** - * @brief Gets the absolute highest possible value for a numeric type T. + * @brief Gets the absolute highest possible value for a compute type. + * + * Uses +infinity for floating-point compute types, and the maximum finite + * value for integer compute types. * - * Uses +infinity for floating-point types, and the maximum finite - * value for all other arithmetic types. + * @tparam ComputeT The compute type (float or int). */ -template -__inline__ constexpr T get_upper_bound() +template +__inline__ __device__ __host__ constexpr ComputeT get_upper_bound() { - static_assert(utils::is_supported_type_v, - "Unsupported type T: only _Float16, __bf16, float, and int are implemented"); - if constexpr(std::is_floating_point_v) - { - return std::numeric_limits::infinity(); - } - else if constexpr(std::is_integral_v) + if constexpr(std::is_same_v) { - return std::numeric_limits::max(); + return std::numeric_limits::infinity(); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { - return __bf16(0x7F80); + return std::numeric_limits::max(); } else { + static_assert(sizeof(ComputeT) == 0, "Unsupported compute type"); __builtin_unreachable(); } } @@ -259,42 +502,56 @@ __inline__ constexpr T get_upper_bound() /** * @brief Gets a sentinel value for a search algorithm (e.g., Top-K). * - * @tparam FindLargest A compile-time boolean. If true, returns the lowest possible - * value (the starting point for finding a maximum). If false, returns the - * highest possible value (the starting point for finding a minimum). - * @tparam T The numeric type. + * The sentinel is defined in ComputeType space. For finding the largest values, + * we use the lowest possible value as sentinel (so any real value will be preferred). + * For finding the smallest values, we use the highest possible value. + * + * @tparam FindLargest If true, returns lowest value. If false, returns highest value. + * @tparam ComputeT The compute type (float or int). */ -template -__inline__ constexpr T get_sentinel_value() +template +__inline__ __device__ __host__ constexpr ComputeT get_sentinel_value() { if constexpr(FindLargest) { - static_assert( - !std::is_unsigned_v, - "Cannot determine a meaningful lower bound for finding the 'largest' unsigned value. " - "The lowest value is 0, which is a poor sentinel."); - return get_lower_bound(); + return get_lower_bound(); } else { - return get_upper_bound(); + return get_upper_bound(); } } /** - * @brief A generic comparison function for search algorithms. 💡 + * @brief Gets sentinel value based on DataType (converts to appropriate ComputeType). + * + * This is a convenience overload that deduces the ComputeType from DataType. + * + * @tparam FindLargest If true, returns lowest value. If false, returns highest value. + * @tparam DataT The data type (fp16, bf16, float, int). + */ +template +__inline__ __device__ __host__ constexpr compute_t get_sentinel_value_for_data() +{ + return get_sentinel_value>(); +} + +/** + * @brief A generic comparison function for search algorithms. * * Compares `val` against `baseline` according to the search direction * specified by the `FindLargest` template parameter. + * Works with ComputeType values. * * @tparam FindLargest If true, checks if `val` is greater than `baseline`. - * If false, checks if `val` is less than `baseline`. + * If false, checks if `val` is less than `baseline`. + * @tparam ComputeT The compute type (float or int). * @param val The new value to check. * @param baseline The current best value. * @return True if `val` is "preferred" over `baseline`. */ -template -__device__ __host__ constexpr bool is_preferred(T val, T baseline) +template +__device__ __host__ __forceinline__ constexpr bool is_preferred(ComputeT val, ComputeT baseline) { if constexpr(FindLargest) { @@ -310,6 +567,19 @@ __device__ __host__ constexpr bool is_preferred(T val, T baseline) namespace sorting { +// ============================================================================ +// SORTING OPERATIONS (Work with ComputeType) +// ============================================================================ +// All sorting operations in this namespace work with ComputeType values. +// The template parameter T should be the compute type (float or int). +// The idxT parameter is the index type (typically int32_t). +// +// The sorting algorithms use: +// - DPP (Data Parallel Primitives) for small-stride shuffles (≤8) +// - Wave intrinsics (__ballot, __popcll, __shfl) for larger operations +// - Bitonic sort/merge for efficient parallel sorting +// ============================================================================ + template struct BitonicMerge { @@ -492,26 +762,30 @@ __forceinline__ __device__ T shfl_xor(T val, int stride) } } -template -__forceinline__ __device__ constexpr T get_guard(const bool x) +/** + * @brief Gets guard value for bitonic sort comparisons. + * + * This function returns boundary values used in bitonic sorting. + * Works with ComputeType (float or int). + * + * @tparam ComputeT The compute type (float or int). + * @param x If true, returns lowest value; if false, returns highest value. + */ +template +__forceinline__ __device__ constexpr ComputeT get_guard(const bool x) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { - auto inf = _Float16(0x7C00); - return x ? -inf : inf; + return x ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { - auto inf = __bf16(0x7F80); - return x ? -inf : inf; - } - else if constexpr(!std::is_floating_point_v) - { - return x ? std::numeric_limits::lowest() : std::numeric_limits::max(); + return x ? std::numeric_limits::lowest() : std::numeric_limits::max(); } else { - return x ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); + static_assert(sizeof(ComputeT) == 0, "get_guard only supports float and int compute types"); + __builtin_unreachable(); } } @@ -709,14 +983,27 @@ struct BitonicMerge<64, ascending, T, idxT> namespace buffer_load_helpers { -constexpr int MAX_CAPACITY = 512; +constexpr int MAX_CAPACITY = 2048; using int32x4_t = int __attribute__((ext_vector_type(4))); using floatx4_t = float __attribute__((ext_vector_type(4))); -using bf16x8_t = uint16_t __attribute__((ext_vector_type(8))); +using bf16x8_t = __bf16 __attribute__((ext_vector_type(8))); using halfx8_t = _Float16 __attribute__((ext_vector_type(8))); using index_t = uint32_t; +__device__ __forceinline__ static int32x4_t +asm_buffer_load_dwordx4(int32x4_t srsrc, + int32_t voffset, + int32_t soffset, + int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +template +__device__ __forceinline__ VecType +buffer_load_dwordx4(int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) +{ + return __builtin_bit_cast(VecType, asm_buffer_load_dwordx4(srsrc, voffset, soffset, aux)); +} + } // namespace buffer_load_helpers // --- Wave-Level Priority Selection Primitives (AMD/HIP Optimized) --- @@ -766,21 +1053,39 @@ struct BlockTopkSort; template struct BlockTopkMerge; -// WaveBuffer: Manages per-wave register storage for priority candidates -template +// ============================================================================ +// WAVE BUFFER (Stores priorities in ComputeType) +// ============================================================================ +// +// WaveBuffer manages per-wave register storage for priority candidates. +// Key design: +// - DataT: The I/O type for loading/storing data +// - ComputeT: The internal type for priorities (float or int) +// - Priorities are stored as ComputeType for consistent computation +// - Conversion happens at I/O boundaries +// +// Template parameters: +// - capacity: Power-of-2 buffer capacity (>= wave size) +// - DataT: Data type for I/O (fp16, bf16, float, int) +// - IdxT: Index type (typically int32_t) +// ============================================================================ + +template struct WaveBuffer { + using ComputeT = compute_t; + static constexpr int slots_per_lane = capacity / opus::get_warp_size(); static_assert(capacity >= opus::get_warp_size() && utils::is_power_of_2(capacity), "Capacity must be power-of-2 and >= wave size"); - T priorities[slots_per_lane]; + ComputeT priorities[slots_per_lane]; IdxT positions[slots_per_lane]; int lane_id; IdxT target_count; - T sentinel; + ComputeT sentinel; - __device__ WaveBuffer(IdxT k, T sentinel_value) + __device__ WaveBuffer(IdxT k, ComputeT sentinel_value) : lane_id(threadIdx.x & (opus::get_warp_size() - 1)), target_count(k), sentinel(sentinel_value) @@ -792,13 +1097,16 @@ struct WaveBuffer } } - __device__ inline void reset_slot(int slot, T val = {}, IdxT pos = {}) + __device__ inline void reset_slot(int slot, ComputeT val = {}, IdxT pos = {}) { priorities[slot] = val; positions[slot] = pos; } - __device__ inline void flush_results(T* __restrict__ out_vals, + // Flush results to output buffer + // OutT can be DataT (for final output) or ComputeT (for LDS operations) + template + __device__ inline void flush_results(OutT* __restrict__ out_vals, IdxT* __restrict__ out_indices) const { #pragma unroll @@ -807,7 +1115,7 @@ struct WaveBuffer const IdxT global_slot = i * opus::get_warp_size() + lane_id; if(global_slot < target_count) { - out_vals[global_slot] = priorities[i]; + out_vals[global_slot] = static_cast(priorities[i]); out_indices[global_slot] = positions[i]; } } @@ -815,10 +1123,14 @@ struct WaveBuffer }; // Helper for merging sorted sequences (used by multiple strategies) -template +// Works with ComputeType internally, reads from ComputeType buffers +template struct WaveMergeHelper { + using ComputeT = compute_t; + // Merges a sorted k-element chunk with the buffer's existing Top-K + // Input is in ComputeType (from LDS or previous computation) // EXAMPLE (finding Top-4 largest, capacity=64, k=4): // Wave-distributed storage (64 lanes, each lane holds slots_per_lane=1 value): // Lanes 0-3: [80, 85, 90, 95] (current top-4, in ascending order) @@ -843,8 +1155,8 @@ struct WaveMergeHelper // // Extract top-k=4 (last 4 in ascending order): // Lanes 60-63 now contain: [85, 90, 95, 100] - __device__ static void merge_sorted_range(WaveBuffer& buffer, - const T* __restrict__ in, + __device__ static void merge_sorted_range(WaveBuffer& buffer, + const ComputeT* __restrict__ in, const IdxT* __restrict__ in_idx, IdxT start) { @@ -854,56 +1166,64 @@ struct WaveMergeHelper { if(idx < start + buffer.target_count) { - T candidate = in[idx]; - if(numeric::is_preferred(candidate, buffer.priorities[i])) + ComputeT candidate = in[idx]; + if(numeric::is_preferred(candidate, buffer.priorities[i])) { buffer.priorities[i] = candidate; buffer.positions[i] = in_idx[idx]; } } } - sorting::BitonicMerge::merge(buffer.priorities, - buffer.positions); + sorting::BitonicMerge::merge(buffer.priorities, + buffer.positions); } }; // Forward declarations for kernel wrapper functions -template -__global__ void __launch_bounds__(512, 2) topk_filter_kernel(const T* __restrict__ in, +// Note: Kernels use DataT for I/O and compute_t for sentinel/internal computation +template +__global__ void __launch_bounds__(512, 2) topk_filter_kernel(const DataT* __restrict__ in, const IdxT* __restrict__ in_idx, int batch_size, IdxT len, IdxT k, - T* __restrict__ out, + DataT* __restrict__ out, IdxT* __restrict__ out_idx, - T sentinel); + compute_t sentinel); -template -__global__ void __launch_bounds__(512, 2) topk_sort_kernel(const T* __restrict__ in, +template +__global__ void __launch_bounds__(512, 2) topk_sort_kernel(const DataT* __restrict__ in, const IdxT* __restrict__ in_idx, int batch_size, IdxT len, IdxT k, - T* __restrict__ out, + DataT* __restrict__ out, IdxT* __restrict__ out_idx, - T sentinel); + compute_t sentinel); -template -__global__ void __launch_bounds__(512, 2) topk_merge_kernel(const T* __restrict__ in, +template +__global__ void __launch_bounds__(512, 2) topk_merge_kernel(const DataT* __restrict__ in, const IdxT* __restrict__ in_idx, int batch_size, IdxT len, IdxT k, - T* __restrict__ out, + DataT* __restrict__ out, IdxT* __restrict__ out_idx, - T sentinel); + compute_t sentinel); -// Kernel function pointer type alias -template -using KernelFuncPtr = void (*)(const T*, const IdxT*, int, IdxT, IdxT, T*, IdxT*, T); +template +using KernelFuncPtr = + void (*)(const DataT*, const IdxT*, int, IdxT, IdxT, DataT*, IdxT*, compute_t); // Helper: Map block-level strategy class to its corresponding kernel function template -template