Skip to content

ggml-hexagon: flash-attn opt#19025

Merged
max-krasnyansky merged 5 commits intoggml-org:masterfrom
chraac:dev-fa-opt
Jan 24, 2026
Merged

ggml-hexagon: flash-attn opt#19025
max-krasnyansky merged 5 commits intoggml-org:masterfrom
chraac:dev-fa-opt

Conversation

@chraac
Copy link
Copy Markdown
Contributor

@chraac chraac commented Jan 22, 2026

This pull request refactors and optimizes the flash_attn_ext_f16_thread function in flash-attn-ops.c, focusing on more efficient vectorized computation and improved numerical stability in the softmax calculation. The main changes involve restructuring the computation of scores and their accumulation, as well as ensuring proper handling of vector sizes and maximum value tracking.

Vectorization and Softmax Optimization:

  • Introduced a static_assert to ensure FLASH_ATTN_BLOCK_SIZE is compatible with the usage of HVX_Vector_x4, and refactored the loop to use scores_x4 for storing intermediate score vectors, improving vectorization and memory alignment.
  • Changed the accumulation of the maximum value (v_max) to be computed across all score vectors before reducing to a scalar, enhancing numerical stability for the softmax step. [1] [2]
  • Refactored the softmax calculation to accumulate probabilities (p_sum_vec) in a vectorized manner, then reduced and updated the running sum (S) more efficiently.

@chraac chraac changed the title ggml-hexagon: flash-attn opt [WIP] ggml-hexagon: flash-attn opt Jan 22, 2026
@chraac chraac marked this pull request as draft January 22, 2026 14:34
}

p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
S = S * ms + hvx_vec_get_f32(p_sum_vec);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we move the max reduction and the e^(x-max) update outside of the vector loop. This saves FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 - 1 calls to hvx_vec_reduce_max_f32, hvx_scale_f32_aa, and expf (assuming expf is computationally expensive).

Alternatively, we could optimize further by moving the softmax update after the leftover block, which would reduce expensive calculations even more and improve performance.

@github-actions github-actions Bot added the ggml changes relating to the ggml tensor library for machine learning label Jan 22, 2026
@max-krasnyansky max-krasnyansky marked this pull request as ready for review January 24, 2026 05:59
@max-krasnyansky
Copy link
Copy Markdown
Member

Looks good. I'm seeing a small but significant bump in perf.
Testing are passing as well.

Gen4
before
  common_perf_print: prompt eval time =  1714.21 ms / 205 tokens (  8.36 ms per token, 119.59 tokens per second)
  common_perf_print:        eval time =  1763.07 ms /  63 runs   ( 27.99 ms per token,  35.73 tokens per second)
after
  common_perf_print: prompt eval time =  1677.05 ms / 205 tokens (  8.18 ms per token, 122.24 tokens per second)
  common_perf_print:        eval time =  1574.56 ms /  63 runs   ( 24.99 ms per token,  40.01 tokens per second)

Gen5
before
  common_perf_print: prompt eval time =  1194.83 ms / 205 tokens (  5.83 ms per token, 171.57 tokens per second)
  common_perf_print:        eval time =  1554.43 ms /  63 runs   ( 24.67 ms per token,  40.53 tokens per second)
after
  common_perf_print: prompt eval time =  1169.90 ms / 205 tokens (  5.71 ms per token, 175.23 tokens per second)
  common_perf_print:        eval time =  1542.48 ms /  63 runs   ( 24.48 ms per token,  40.84 tokens per second)

@max-krasnyansky max-krasnyansky changed the title [WIP] ggml-hexagon: flash-attn opt ggml-hexagon: flash-attn opt Jan 24, 2026
@max-krasnyansky max-krasnyansky merged commit 8af1f5f into ggml-org:master Jan 24, 2026
76 of 78 checks passed
@chraac chraac deleted the dev-fa-opt branch January 24, 2026 16:32
@chraac
Copy link
Copy Markdown
Contributor Author

chraac commented Jan 24, 2026

Looks good. I'm seeing a small but significant bump in perf.

Nice! Also wondering if we could process 2 or more rows simultaneously, like we did in mulmat. Thought that might save some VTCM -> register loads, though I'm not sure if it would yield a significant perf improvement.

ronaldmannak pushed a commit to PicoMLX/llama.cpp that referenced this pull request Jan 24, 2026
* optimize flash attention kernel by improving score computation and online softmax update

* wip

* Refactor online softmax update in flash attention kernel for improved performance

* Optimize flash attention kernel by replacing float array with HVX_Vector for score computation

* wip
shaofeiqi pushed a commit to qualcomm/llama.cpp that referenced this pull request Feb 6, 2026
* optimize flash attention kernel by improving score computation and online softmax update

* wip

* Refactor online softmax update in flash attention kernel for improved performance

* Optimize flash attention kernel by replacing float array with HVX_Vector for score computation

* wip
Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
* optimize flash attention kernel by improving score computation and online softmax update

* wip

* Refactor online softmax update in flash attention kernel for improved performance

* Optimize flash attention kernel by replacing float array with HVX_Vector for score computation

* wip
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants