Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def fmha_v3_fwd(
bias: Optional[Tensor] = None,
alibi_slopes: Optional[Tensor] = None,
gen: Optional[Generator] = None,
l_tpf: int = 0,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...


Expand Down Expand Up @@ -1234,6 +1235,7 @@ def _flash_attn_forward(
how_v3_bf16_cvt: Optional[int] = 1,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
l_tpf: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

(_, seqlen_q, nhead_q, hdim_q) = q.shape
Expand Down Expand Up @@ -1290,6 +1292,7 @@ def _validate_cu(name: str, x: Optional[torch.Tensor]):
bias,
alibi_slopes,
None,
l_tpf,
)
else:
out, softmax_lse, S_dmask, rng_state = mha_fwd(
Expand Down Expand Up @@ -1700,6 +1703,7 @@ def forward(
how_v3_bf16_cvt: Optional[int] = 1,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
l_tpf: int = 0,
):
is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v])
if softmax_scale is None:
Expand Down Expand Up @@ -1731,6 +1735,7 @@ def forward(
how_v3_bf16_cvt=how_v3_bf16_cvt,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
l_tpf=l_tpf,
)
if is_grad:
assert return_lse
Expand Down Expand Up @@ -1829,6 +1834,7 @@ def backward(ctx, dout, *args):
None, # how_v3_bf16_cvt
None, # cu_seqlens_q
None, # cu_seqlens_kv
None, # l_tpf
)


Expand All @@ -1848,6 +1854,7 @@ def flash_attn_func(
how_v3_bf16_cvt=1,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
l_tpf: int = 0,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -1918,6 +1925,7 @@ def flash_attn_func(
how_v3_bf16_cvt,
cu_seqlens_q,
cu_seqlens_kv,
l_tpf
)


Expand Down
6 changes: 4 additions & 2 deletions csrc/cpp_itfs/mha_fwd_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@
int how_v3_bf16_cvt,
const void* seqstart_q_padding_ptr,
const void* seqstart_k_padding_ptr,
bool is_v3_api_check)
bool is_v3_api_check,
int magic_const,
int tokens_per_frame)
{{
int head_size_q = args.hdim_q;
int head_size_v = args.hdim_v;
Expand Down Expand Up @@ -178,7 +180,7 @@


def get_v3_api():
v3_call = "fmha_fwd_v3(traits, args, stream_config, is_v3_api_check)"
v3_call = "fmha_fwd_v3(traits, args, stream_config, is_v3_api_check, magic_const, tokens_per_frame)"
gfx_list = get_gfx_list()
v3_arch_list = [arch for arch in ["gfx942", "gfx950"] if arch in gfx_list]

Expand Down
13 changes: 11 additions & 2 deletions csrc/include/mha_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "aiter_hip_common.h"
#include "fmha_fwd.hpp"
#include "mask.hpp"
#include <iostream>

namespace aiter {
struct mha_fwd_traits : public fmha_fwd_traits
Expand Down Expand Up @@ -89,7 +90,9 @@ __attribute__((visibility("default"))) float mha_fwd(mha_fwd_args args,
int how_v3_bf16_cvt = 1,
const void* seqstart_q_padding_ptr = nullptr,
const void* seqstart_k_padding_ptr = nullptr,
bool is_v3_api_check = false);
bool is_v3_api_check = false,
int magic_const = 0,
int tokens_per_frame = 0);

__attribute__((visibility("default"))) float
mha_fwd_splitkv(mha_fwd_splitkv_args args,
Expand Down Expand Up @@ -177,6 +180,10 @@ struct __attribute__((packed)) fmha_fwd_v3_args
p2 _p30;
const void* ptr_kseq_padding;
p2 _p31;
unsigned int tokens_per_frame_magic_const;
p3 _p32;
unsigned int tokens_per_frame;
p3 _p33;
};

struct fmha_fwd_v3_traits
Expand Down Expand Up @@ -231,6 +238,8 @@ namespace gfx950 {
float fmha_fwd_v3(mha_fwd_traits t,
mha_fwd_args a,
const ck_tile::stream_config& s,
bool is_v3_api_check = false);
bool is_v3_api_check = false,
int magic_const = 0,
int tokens_per_frame = 0);
}
} // namespace aiter
4 changes: 3 additions & 1 deletion csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,9 @@ namespace py = pybind11;
py::arg("out") = std::nullopt, \
py::arg("bias") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("gen") = std::nullopt);
py::arg("gen") = std::nullopt, \
py::arg("tokens_per_frame") = 0);


#define MHA_FWD_PYBIND \
m.def("mha_fwd", \
Expand Down
3 changes: 2 additions & 1 deletion csrc/include/torch/mha_v3_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ std::vector<at::Tensor> fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d]
std::optional<at::Tensor> out_, // [b, sq, hq, d_v]
std::optional<const at::Tensor> bias_, // [sq, sk]
std::optional<const at::Tensor> alibi_slopes_, // [hq] or [b, hq]
std::optional<at::Generator> gen_);
std::optional<at::Generator> gen_,
int tokens_per_frame);
} // namespace torch_itfs
} // namespace aiter
12 changes: 10 additions & 2 deletions csrc/py_itfs_cu/asm_mha_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ std::vector<at::Tensor> fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d]
std::optional<at::Tensor> out_, // [b, sq, hq, d_v]
std::optional<const at::Tensor> bias_, // [sq, sk]
std::optional<const at::Tensor> alibi_slopes_, // [hq] or [b, hq]
std::optional<at::Generator> gen_)
std::optional<at::Generator> gen_,
int tokens_per_frame)
{
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
Expand Down Expand Up @@ -311,6 +312,8 @@ std::vector<at::Tensor> fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d]
softmax_scale,
p_dropout,
drop_seed_offset);

int magic_const = (uint32_t)(((1ULL << 32) + tokens_per_frame - 1) / tokens_per_frame);

float t = aiter::mha_fwd(args,
stream_config,
Expand All @@ -322,7 +325,12 @@ std::vector<at::Tensor> fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d]
quant_scale_enum::no_scale,
true,
false,
how_v3_bf16_cvt);
how_v3_bf16_cvt,
nullptr,
nullptr,
false,
magic_const,
tokens_per_frame);
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
}
else {
Expand Down
25 changes: 15 additions & 10 deletions hsa/gfx950/fmha_v3_fwd/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class fmha_fwd_v3_kernel
{
int length = strlen(name);
std::string kernel_func_name = "_ZN5aiter" + std::to_string(length) + name + "E";
std::cout << "Loading kernel: " << kernel_func_name << std::endl;
std::cout << "HSACO ptr: " << hsaco << std::endl;
HIP_CALL(hipModuleLoadData(&module, hsaco));
HIP_CALL(hipModuleGetFunction(&kernel_func, module, kernel_func_name.c_str()));
}
Expand Down Expand Up @@ -172,7 +174,7 @@ class fmha_fwd_v3_kernel
};

template <typename fmha_fwd_kernel_selector>
float fmha_fwd_v3_dispatcher(const ck_tile::stream_config& s, mha_fwd_args a)
float fmha_fwd_v3_dispatcher(const ck_tile::stream_config& s, mha_fwd_args a, int magic_const, int tokens_per_frame)
{
if(s.log_level_ > 0)
std::cout << ", " << FmhaFwdV3Name<fmha_fwd_kernel_selector>::fwd_v3_name << std::flush;
Expand All @@ -182,6 +184,7 @@ class fmha_fwd_v3_kernel
{
tune_opt -= 2;
}
tune_opt = 0; // disable tune for fmha v3 for now

fmha_fwd_v3_args args;
args.ptr_o = a.o_ptr;
Expand Down Expand Up @@ -219,6 +222,8 @@ class fmha_fwd_v3_kernel
args.ptr_kseq = nullptr;
args.ptr_qseq_padding = nullptr;
args.ptr_kseq_padding = nullptr;
args.tokens_per_frame_magic_const = magic_const;
args.tokens_per_frame = tokens_per_frame;

auto traits = fmha_fwd_v3_traits{a.batch,
a.nhead_q,
Expand Down Expand Up @@ -303,7 +308,7 @@ class fmha_fwd_v3_kernel
);
}

float fmha_fwd_v3(mha_fwd_traits t, mha_fwd_args a, const ck_tile::stream_config& s, bool is_v3_api_check) {
float fmha_fwd_v3(mha_fwd_traits t, mha_fwd_args a, const ck_tile::stream_config& s, bool is_v3_api_check, int magic_const, int tokens_per_frame) {
float r = -1;
if (t.use_ext_asm == true) {
if (t.data_type.compare("bf16") == 0) {
Expand All @@ -317,15 +322,15 @@ class fmha_fwd_v3_kernel
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a);
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, magic_const, tokens_per_frame);
}
else {
if (a.batch_stride_lse >= a.nhead_stride_lse) {
using fmha_fwd_kernel = fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 1, false, false, 1, GPUArch::gfx950>;
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a);
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, magic_const, tokens_per_frame);
}
}
}
Expand All @@ -335,15 +340,15 @@ class fmha_fwd_v3_kernel
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a);
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, magic_const, tokens_per_frame);
}
else {
if (a.batch_stride_lse >= a.nhead_stride_lse) {
using fmha_fwd_kernel = fmha_fwd_kernel_selector<FmhaFwdBf16, 128, 0, false, false, 1, GPUArch::gfx950>;
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a);
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, magic_const, tokens_per_frame);
}
}
}
Expand Down Expand Up @@ -392,15 +397,15 @@ class fmha_fwd_v3_kernel
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a);
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, magic_const, tokens_per_frame);
}
else {
if (a.batch_stride_lse >= a.nhead_stride_lse) {
using fmha_fwd_kernel = fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 1, false, false, 1, GPUArch::gfx950>;
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a);
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, magic_const, tokens_per_frame);
}
}
}
Expand All @@ -410,15 +415,15 @@ class fmha_fwd_v3_kernel
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a);
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, magic_const, tokens_per_frame);
}
else {
if (a.batch_stride_lse >= a.nhead_stride_lse) {
using fmha_fwd_kernel = fmha_fwd_kernel_selector<FmhaFwdBf16, 192, 0, false, false, 1, GPUArch::gfx950>;
if (is_v3_api_check) {
return 1;
}
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a);
r = fmha_fwd_v3_dispatcher<fmha_fwd_kernel>(s, a, magic_const, tokens_per_frame);
}
}
}
Expand Down
Binary file modified hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_causal.co
Binary file not shown.
Binary file not shown.
Loading
Loading