Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions csrc/cpp_itfs/pa/pa_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ __inline__ __device__ void _paged_attention_kernel(
}

// calculate qk_max and exp_sum per warp and write to shared memory
float qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {-FLT_MAX};
float exp_sum[GQA_RATIO_LOOP][MTP_PER_THREAD] = {0.0f};
float qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {{-FLT_MAX}};
float exp_sum[GQA_RATIO_LOOP][MTP_PER_THREAD] = {{0.0f}};

for (int mtp = 0; mtp < mtp_loop; mtp++) {
for (int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) {
Expand Down Expand Up @@ -459,9 +459,9 @@ __inline__ __device__ void _paged_attention_kernel(
__syncthreads();

// calculate partition qk_max and exp_sum
float inv_sum_scale[GQA_RATIO_LOOP][MTP_PER_THREAD] = {0.0f};
float partition_qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {-FLT_MAX};
float partition_exp_sum[GQA_RATIO_LOOP][MTP_PER_THREAD] = {0.0f};
float inv_sum_scale[GQA_RATIO_LOOP][MTP_PER_THREAD] = {{0.0f}};
float partition_qk_max[GQA_RATIO_LOOP][MTP_PER_THREAD] = {{-FLT_MAX}};
float partition_exp_sum[GQA_RATIO_LOOP][MTP_PER_THREAD] = {{0.0f}};

for(int mtp = 0; mtp < mtp_loop; mtp++) {
for(int gqa_ratio_loop = 0; gqa_ratio_loop < GQA_RATIO_LOOP; gqa_ratio_loop++) {
Expand Down
Loading