diff --git a/hsa/gfx950/fmha_v3_fwd/codegen.py b/hsa/gfx950/fmha_v3_fwd/codegen.py index 2563390b5f..4b326b0759 100644 --- a/hsa/gfx950/fmha_v3_fwd/codegen.py +++ b/hsa/gfx950/fmha_v3_fwd/codegen.py @@ -31,20 +31,20 @@ // template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_fp16"; }; // template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_fp16_causal"; }; // template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_fp16_causal"; }; -template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16"; }; -template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16"; }; -template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16_causal"; }; -template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16_causal"; }; +template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16"; }; +template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16"; }; +template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16_causal"; }; +template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16_causal"; }; // ######################################################| DataType | HDim | MaskType | kIsSEQPad | kIsHDPad | kStoreLSE | GPUArch | BF16Cvt | kIsGroupMode_ | template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_bf16_group"; }; template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_bf16_group"; }; template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_bf16_causal_group"; }; template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_bf16_causal_group"; }; -template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16_group"; }; -template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16_group"; }; -template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16_causal_group"; }; -template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16_causal_group"; }; +template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16_group"; }; +template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16_group"; }; +template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16_causal_group"; }; +template<> struct FmhaFwdV3Name> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16_causal_group"; }; // #####################################################| DataType | HDim | MaskType | kIsSEQPad | kIsHDPad | kStoreLSE | GPUArch template<> struct FmhaFwdV3Buf> { static constexpr const void * fwd_v3_buf = fwd_hd128_bf16; }; @@ -122,6 +122,12 @@ class fmha_fwd_v3_kernel int gdx = ((fmha_v3_traits.s + fmha_v3_traits.ts_qo - 1) / fmha_v3_traits.ts_qo + tg_div - 1) / tg_div; int gdy = fmha_v3_traits.h; int gdz = fmha_v3_traits.b; + if (fmha_v3_traits.d == 192) + { + gdx = fmha_v3_traits.h; + gdy = (fmha_v3_traits.s + fmha_v3_traits.ts_qo - 1) / fmha_v3_traits.ts_qo; //do not merge the head and tail in seqlen_q direction + gdz = fmha_v3_traits.b; + } HIP_CALL(hipModuleLaunchKernel(kernel_func, gdx, @@ -146,7 +152,7 @@ class fmha_fwd_v3_kernel &arg_size, HIP_LAUNCH_PARAM_END}; - int tg_div = (fmha_v3_traits.mask != 0) ? 2 : 1; + int tg_div = (fmha_v3_traits.mask != 0 && fmha_v3_traits.d != 192) ? 2 : 1; int bdx = (fmha_v3_traits.d == 192) ? 256 : 512; int gdx = fmha_v3_traits.h; @@ -182,6 +188,10 @@ class fmha_fwd_v3_kernel { tune_opt -= 2; } + if (a.hdim_q == 192 && a.hdim_v == 128) + { + tune_opt = 0; + } fmha_fwd_v3_args args; args.ptr_o = a.o_ptr; @@ -245,6 +255,10 @@ class fmha_fwd_v3_kernel { tune_opt -= 2; } + if (a.hdim_q == 192 && a.hdim_v == 128) + { + tune_opt = 0; + } fmha_fwd_v3_args args; args.ptr_o = a.o_ptr; diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16.co index 482d44bd21..47a00d6fcc 100755 Binary files a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16.co and b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16.co differ diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal.co index 5f994f73b1..30e4735969 100755 Binary files a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal.co and b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal.co differ diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal_group.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal_group.co index b5ea7beff0..018f398ff8 100755 Binary files a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal_group.co and b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal_group.co differ diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_group.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_group.co index 794fd00158..bebe7e9b54 100755 Binary files a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_group.co and b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_group.co differ