Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions hsa/gfx950/fmha_v3_fwd/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@
// template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdFp16, 128, 0, false, false, 1, GPUArch::gfx950>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_fp16"; };
// template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdFp16, 128, 1, false, false, 0, GPUArch::gfx950>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_fp16_causal"; };
// template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdFp16, 128, 1, false, false, 1, GPUArch::gfx950>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_fp16_causal"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 0, false, false, 0, GPUArch::gfx950>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 0, false, false, 1, GPUArch::gfx950>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 1, false, false, 0, GPUArch::gfx950>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16_causal"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 1, false, false, 1, GPUArch::gfx950>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16_causal"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 0, false, false, 0, GPUArch::gfx950>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 0, false, false, 1, GPUArch::gfx950>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 1, false, false, 0, GPUArch::gfx950>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16_causal"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 1, false, false, 1, GPUArch::gfx950>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16_causal"; };

// ######################################################| DataType | HDim | MaskType | kIsSEQPad | kIsHDPad | kStoreLSE | GPUArch | BF16Cvt | kIsGroupMode_ |
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 0, false, false, 0, GPUArch::gfx950, 1, true>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_bf16_group"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 0, false, false, 1, GPUArch::gfx950, 1, true>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_bf16_group"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 1, false, false, 0, GPUArch::gfx950, 1, true>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_bf16_causal_group"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 1, false, false, 1, GPUArch::gfx950, 1, true>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd128_bf16_causal_group"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 0, false, false, 0, GPUArch::gfx950, 1, true>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16_group"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 0, false, false, 1, GPUArch::gfx950, 1, true>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16_group"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 1, false, false, 0, GPUArch::gfx950, 1, true>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16_causal_group"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 1, false, false, 1, GPUArch::gfx950, 1, true>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192_hd128_bf16_causal_group"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 0, false, false, 0, GPUArch::gfx950, 1, true>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16_group"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 0, false, false, 1, GPUArch::gfx950, 1, true>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16_group"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 1, false, false, 0, GPUArch::gfx950, 1, true>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16_causal_group"; };
template<> struct FmhaFwdV3Name<fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 1, false, false, 1, GPUArch::gfx950, 1, true>> { static constexpr const char * fwd_v3_name = "fmha_fwd_hd192x128_bf16_causal_group"; };

// #####################################################| DataType | HDim | MaskType | kIsSEQPad | kIsHDPad | kStoreLSE | GPUArch
template<> struct FmhaFwdV3Buf<fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 0, false, false, 0, GPUArch::gfx950>> { static constexpr const void * fwd_v3_buf = fwd_hd128_bf16; };
Expand Down Expand Up @@ -122,6 +122,12 @@ class fmha_fwd_v3_kernel
int gdx = ((fmha_v3_traits.s + fmha_v3_traits.ts_qo - 1) / fmha_v3_traits.ts_qo + tg_div - 1) / tg_div;
int gdy = fmha_v3_traits.h;
int gdz = fmha_v3_traits.b;
if (fmha_v3_traits.d == 192)
{
gdx = fmha_v3_traits.h;
gdy = (fmha_v3_traits.s + fmha_v3_traits.ts_qo - 1) / fmha_v3_traits.ts_qo; //do not merge the head and tail in seqlen_q direction
gdz = fmha_v3_traits.b;
}

HIP_CALL(hipModuleLaunchKernel(kernel_func,
gdx,
Expand All @@ -146,7 +152,7 @@ class fmha_fwd_v3_kernel
&arg_size,
HIP_LAUNCH_PARAM_END};

int tg_div = (fmha_v3_traits.mask != 0) ? 2 : 1;
int tg_div = (fmha_v3_traits.mask != 0 && fmha_v3_traits.d != 192) ? 2 : 1;

int bdx = (fmha_v3_traits.d == 192) ? 256 : 512;
int gdx = fmha_v3_traits.h;
Expand Down Expand Up @@ -182,6 +188,10 @@ class fmha_fwd_v3_kernel
{
tune_opt -= 2;
}
if (a.hdim_q == 192 && a.hdim_v == 128)
{
tune_opt = 0;
}
Comment on lines +191 to +194
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is duplicated at lines 191-194 and 258-261. Consider extracting this into a helper function or consolidating the logic to avoid duplication.

Copilot uses AI. Check for mistakes.

fmha_fwd_v3_args args;
args.ptr_o = a.o_ptr;
Expand Down Expand Up @@ -245,6 +255,10 @@ class fmha_fwd_v3_kernel
{
tune_opt -= 2;
}
if (a.hdim_q == 192 && a.hdim_v == 128)
{
tune_opt = 0;
}

fmha_fwd_v3_args args;
args.ptr_o = a.o_ptr;
Expand Down
Binary file modified hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16.co
Binary file not shown.
Binary file modified hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal.co
Binary file not shown.
Binary file modified hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal_group.co
Binary file not shown.
Binary file modified hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_group.co
Binary file not shown.