Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions custom_ops/gpu_ops/append_attn/append_attention_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2451,6 +2451,7 @@ __global__ void merge_multi_chunks_v2_kernel(
if (bid == -1) {
continue;
}
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) continue;
int seq_len_kv = seq_lens_kv[bid];
Expand Down Expand Up @@ -2494,14 +2495,32 @@ __global__ void merge_multi_chunks_v2_kernel(
}
#pragma unroll 2
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
uint32_t offset;
if (ENABLE_PREFILL) {
offset = (qid * num_chunks + i) * num_heads + hid;
} else {
offset =
((bid * speculate_max_draft_token_num + local_seq_id) * num_chunks +
i) *
num_heads +
hid;
}
float m_prev = m;
float d_prev = d;
const float m_now = multi_m[offset];
const float d_now = multi_d[offset];
m = max(m_prev, m_now);
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
vid * vec_size;
if (ENABLE_PREFILL) {
offset =
(qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
vid * vec_size;
} else {
offset = ((bid * speculate_max_draft_token_num + local_seq_id) *
num_chunks * num_heads +
i * num_heads + hid) *
head_dim +
vid * vec_size;
}
Load<T, vec_size>(&multi_out[offset], &load_vec);
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
const T scale1_T = static_cast<T>(scale1),
Expand Down
212 changes: 148 additions & 64 deletions custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,17 @@ __global__ void multi_query_append_attention_kernel(
T *o_base_ptr_T = nullptr;
OutT *o_base_ptr_int8 = nullptr;
if constexpr (partition_kv) {
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
if (ENABLE_PREFILL) {
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
} else {
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
}
} else {
o_base_ptr_int8 = out + o_offset;
}
Expand Down Expand Up @@ -386,8 +394,18 @@ __global__ void multi_query_append_attention_kernel(
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
if (qo_idx - q_start_seq_id < q_len) {
uint32_t offset =
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
uint32_t offset;
if (ENABLE_PREFILL) {
offset =
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
} else {
offset = ((batch_id * speculate_max_draft_token_num +
qo_idx_now / GROUP_SIZE) *
num_chunks +
chunk_idx) *
q_num_heads +
qo_head_idx;
}
tmp_m[offset] = m_frag[fx][j];
tmp_d[offset] = d_frag[fx][j];
}
Expand Down Expand Up @@ -524,9 +542,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
} else {
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq =
Expand Down Expand Up @@ -794,8 +814,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
qo_head_idx;
} else {
offset =
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
offset = ((batch_id * speculate_max_draft_token_num +
qo_idx_now / GROUP_SIZE) *
num_chunks +
chunk_idx) *
q_num_heads +
qo_head_idx;
}
tmp_m[offset] = m_frag[fx][j];
tmp_d[offset] = d_frag[fx][j];
Expand Down Expand Up @@ -1026,51 +1050,95 @@ void MultiQueryAppendAttention(
sliding_window);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads); // 128k is too large
dim3 blocks_merge(blockx, blocky);
auto *kernelFn = merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
if (is_decoder) {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
auto *kernelFn = merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads); // 128k is too large
dim3 blocks_merge(blockx, blocky);
auto *kernelFn = merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
}
} else {
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV;
Expand Down Expand Up @@ -1189,15 +1257,31 @@ void MultiQueryAppendAttention(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(bsz * num_chunks * num_heads));
} else {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
if (ENABLE_PREFILL) {
tmp_workspace =
allocator->Allocate(phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(token_num * num_chunks *
num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
} else {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
}
}
launchWithPdlWhenEnabled(
split_kv_kernel,
Expand Down
Loading
Loading