Skip to content

cuda: add partial eviction on pool OOM#22193

Open
leonardHONG wants to merge 3 commits intoggml-org:masterfrom
leonardHONG:cuda-pool-partial-eviction
Open

cuda: add partial eviction on pool OOM#22193
leonardHONG wants to merge 3 commits intoggml-org:masterfrom
leonardHONG:cuda-pool-partial-eviction

Conversation

@leonardHONG
Copy link
Copy Markdown
Contributor

@leonardHONG leonardHONG commented Apr 21, 2026

This is a small follow-up to #22155.

In #22155, legacy-pool OOM recovery was handled by flushing the whole cache and retrying once. That fixed the immediate issue, but it is also a fairly blunt fallback.

This PR changes two things.

First, in the legacy pool OOM path, it adds an LRU-style bounded reclaim step before the existing full flush. On cudaErrorMemoryAllocation, it evicts the oldest cached buffers first, up to a bound of 3x the current request size, and then retries the allocation. If that still fails, it falls back to the existing clear_pool() + retry path, so the safety net from #22155 is preserved.

Second, it adds a bool overallocate parameter to ggml_cuda_pool::alloc(). When false, the existing 1.05x look-ahead behavior stays unchanged. When true, the legacy pool allocates 2x instead. This is enabled for the FA K_f16 / V_f16 conversion buffers, since those buffers tend to grow as context length grows and this helps avoid repeated reallocations.

I read @JohannesGaessler's design comment and followed the requested behavior, while making two small implementation choices differently. I used a bool overallocate flag instead of a float lookahead parameter because the current requirements only need two modes: the default 1.05x behavior and an opt-in 2x overallocate path. If more tuning points are needed later, expanding this to a float parameter should still be straightforward. I also kept the LRU bookkeeping in a parallel buffer_pool_ts[] array instead of adding a last_use field into ggml_cuda_buffer, to keep the LRU bookkeeping separate from the buffer struct. Happy to refactor either of these if a float lookahead API or in-struct last_use is preferred.

I tested two cases: (a) legacy-pool LRU reclaim, triggered via one-shot fault injection, and (b) FA K_f16 / V_f16 overallocate, triggered under quantized KV cache (-ctk q8_0 -ctv q8_0). Detailed logs are in the comment below.

@leonardHONG leonardHONG requested a review from a team as a code owner April 21, 2026 03:44
@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Apr 21, 2026
@JohannesGaessler
Copy link
Copy Markdown
Contributor

I'm not convinced that this will actually be beneficial on average.

@IMbackK
Copy link
Copy Markdown
Collaborator

IMbackK commented Apr 21, 2026

might help the slowdown reported in #22094 (comment) should be tested. But yeah this pr needs evidence.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

I mean, even without evidence I would expect clearing only the largest buffers to be prone to memory fragmentation.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

One variant of this PR that I would be willing to approve and merge would be something like this:

  • Every time that a buffer is used, add a timestamp.
  • If the memory allocation fails, clear the least recently used buffers first.
  • Stop if the freed memory is 3x the requested size.
  • Add an optional parameter that results in 2x overallocation for the legacy pool, enable it for the FA conversion where the size can be expected to continually increase as the context fills up. This should reduce thrashing.

TheTom added a commit to TheTom/llama-cpp-turboquant that referenced this pull request Apr 21, 2026
Per JohannesGaessler's feedback on ggml-org#22193:
- Add LRU timestamps to pool buffers
- On OOM, evict least recently used first (up to 3x requested)
- Fall back to full flush if LRU eviction isn't enough
- Add alloc_oversize() with configurable factor
- FA temp buffers use 2x overalloc on HIP to reduce thrashing

3 files, pool-internal fix, no separate allocator.
@leonardHONG
Copy link
Copy Markdown
Contributor Author

leonardHONG commented Apr 21, 2026

Thank you for the crystal clear roadmap, @JohannesGaessler!

Signed-off-by: 梁厚宏 <2695316095@qq.com>
@leonardHONG leonardHONG force-pushed the cuda-pool-partial-eviction branch from ef40865 to d3e3a69 Compare April 21, 2026 12:48
@leonardHONG
Copy link
Copy Markdown
Contributor Author

1. Legacy pool / LRU reclaim

  • Environment: GGML_CUDA_NO_VMM=1 (force legacy pool)
  • Used one-shot fault injection to reliably trigger reclaim
  • Key log snippet:
CUDA pool[0]: injected one-shot reclaim test at 2.36 MiB
CUDA pool[0]: alloc of 2.36 MiB failed, trying LRU-style reclaim (target 7.09 MiB) from 0.94 MiB cached
CUDA pool[0]: retry succeeded after reclaiming 0.63 MiB

2. FA conversion / overallocate

  • Triggered by using quantized KV cache: -ctk q8_0 -ctv q8_0
  • Key log snippet (K_f16):
CUDA fattn K_f16 request 0.06 MiB actual 0.30 MiB
CUDA fattn K_f16 request 0.12 MiB actual 0.12 MiB
  • Key log snippet (V_f16):
CUDA fattn V_f16 request 0.06 MiB actual 0.12 MiB
CUDA fattn V_f16 request 0.12 MiB actual 0.28 MiB
CUDA fattn V_f16 request 0.25 MiB actual 0.30 MiB

Signed-off-by: 梁厚宏 <2695316095@qq.com>
@leonardHONG leonardHONG force-pushed the cuda-pool-partial-eviction branch from 6fc78c6 to 00afd1b Compare April 23, 2026 15:13
@leonardHONG
Copy link
Copy Markdown
Contributor Author

leonardHONG commented Apr 23, 2026

Thanks @IMbackK for the root-cause work in ROCm/rocm-systems#4817 and cc @JohannesGaessler since this follows the roadmap .

For context, #22193 follows #22155 and was opened before #22207, implementing the same direction: LRU reclaim, 3x bound, and FA overallocate. I have now pushed a HIP multi-GPU guard as a follow-up, to preserve master-equivalent clear_pool() + retry behavior on multi-GPU AMD while keeping the LRU path for single-GPU. My local validation so far is on CUDA single-GPU (RTX 4070 Laptop); I have not been able to validate the AMD multi-GPU path directly.

Happy to coordinate as needed and leave the merge decision to the maintainers.

@sanmai
Copy link
Copy Markdown

sanmai commented May 3, 2026

It sounds like the reason we are at this spot because flash attention buffers grow exponentially and unless we over-allocate them in double + 1MB (what I measured) the cache is useless. So it sounds like a better fix would be to get FA their own pool with a smaller depth.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants