diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh index aa7cd67366..705460baed 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh @@ -76,7 +76,7 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, ck::is_same_v ? 1 : NXDLPerWave; // Note: some fp8 instances didn't compile with AK1/BK1=16 static constexpr ck::index_t K1 = - (NPerBlock == 64 && sizeof(A0DataType) == 1 && sizeof(B0DataType) == 1) ? 8 : 16; + (PipelineVer == ck::BlockGemmPipelineVersion::v3 && NPerBlock == 64 && sizeof(A0DataType) == 1 && sizeof(B0DataType) == 1) ? 8 : 16; static constexpr ck::index_t AK1 = K1 / sizeof(A0DataType); static constexpr ck::index_t BK1 = ck::is_same_v ? 32 : K1 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType);