ggml-hexagon: flash-attn opt#19025
Conversation
…line softmax update
…tor for score computation
| } | ||
|
|
||
| p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); | ||
| S = S * ms + hvx_vec_get_f32(p_sum_vec); |
There was a problem hiding this comment.
Here, we move the max reduction and the e^(x-max) update outside of the vector loop. This saves FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 - 1 calls to hvx_vec_reduce_max_f32, hvx_scale_f32_aa, and expf (assuming expf is computationally expensive).
Alternatively, we could optimize further by moving the softmax update after the leftover block, which would reduce expensive calculations even more and improve performance.
|
Looks good. I'm seeing a small but significant bump in perf. |
Nice! Also wondering if we could process 2 or more rows simultaneously, like we did in mulmat. Thought that might save some VTCM -> register loads, though I'm not sure if it would yield a significant perf improvement. |
* optimize flash attention kernel by improving score computation and online softmax update * wip * Refactor online softmax update in flash attention kernel for improved performance * Optimize flash attention kernel by replacing float array with HVX_Vector for score computation * wip
* optimize flash attention kernel by improving score computation and online softmax update * wip * Refactor online softmax update in flash attention kernel for improved performance * Optimize flash attention kernel by replacing float array with HVX_Vector for score computation * wip
* optimize flash attention kernel by improving score computation and online softmax update * wip * Refactor online softmax update in flash attention kernel for improved performance * Optimize flash attention kernel by replacing float array with HVX_Vector for score computation * wip
This pull request refactors and optimizes the
flash_attn_ext_f16_threadfunction inflash-attn-ops.c, focusing on more efficient vectorized computation and improved numerical stability in the softmax calculation. The main changes involve restructuring the computation of scores and their accumulation, as well as ensuring proper handling of vector sizes and maximum value tracking.Vectorization and Softmax Optimization:
static_assertto ensureFLASH_ATTN_BLOCK_SIZEis compatible with the usage ofHVX_Vector_x4, and refactored the loop to usescores_x4for storing intermediate score vectors, improving vectorization and memory alignment.v_max) to be computed across all score vectors before reducing to a scalar, enhancing numerical stability for the softmax step. [1] [2]p_sum_vec) in a vectorized manner, then reduced and updated the running sum (S) more efficiently.