Skip to content

ggml-cuda: flush legacy pool on OOM and retry#22155

Merged
IMbackK merged 2 commits intoggml-org:masterfrom
leonardHONG:cuda-pool-leg-oom-retry
Apr 20, 2026
Merged

ggml-cuda: flush legacy pool on OOM and retry#22155
IMbackK merged 2 commits intoggml-org:masterfrom
leonardHONG:cuda-pool-leg-oom-retry

Conversation

@leonardHONG
Copy link
Copy Markdown
Contributor

This adds a conservative fallback for the legacy CUDA/HIP pool allocator.

On non-VMM setups, the legacy pool can end up holding cached free buffers that are individually too small for a new request, but still occupy enough VRAM to make the next allocation fail. In that case, this patch flushes the cached legacy-pool buffers and retries the allocation once before aborting.

The normal hit path is unchanged. This is intended as a narrow mitigation for legacy-pool OOMs, not a broader allocator redesign. I validated the retry path locally with a synthetic OOM injection on a legacy-pool build.

This is intended to mitigate the legacy-pool OOM behavior reported in #22075 and #22107.

Signed-off-by: 梁厚宏 <2695316095@qq.com>
@leonardHONG leonardHONG requested review from a team and IMbackK as code owners April 20, 2026 07:08
@JohannesGaessler
Copy link
Copy Markdown
Contributor

This is not safe to do. If the memory is still in use it must not be freed.

@leonardHONG
Copy link
Copy Markdown
Contributor Author

This is not safe to do. If the memory is still in use it must not be freed.

Thanks for the catch! I completely overlooked the async lifecycle issue here. I'll drop this unsafe logic and rework it tonight.

@IMbackK
Copy link
Copy Markdown
Collaborator

IMbackK commented Apr 20, 2026

I.. actually don't see it (how clear_pool can prune in use chunks)

@JohannesGaessler
Copy link
Copy Markdown
Contributor

The legacy CUDA buffer pool is essentially just a list of previously used temporary buffers. "Allocating" and "freeing" a buffer just means retrieving a buffer for use within a kernel. It does not mean that there no longer are any kernels queued that will attempt to use the buffer.

Since the buffer pool is per ggml_backend_cuda_context I think the current approach is safe if it is synchronized on the corresponding stream. Looking at the documentation for cudaMalloc and cudaFree again however, it seems those do an implicit device-wide synchronization when they're called. So even without any modifications I think this PR is safe after all.

Did you consider replacing the legacy buffer pool with cudaMallocAsync and cudaFreeAsync? The only reason we are not using those for CUDA is that the performance was worse vs. manually managed chunks of memory but I don't think I ever checked the performance for HIP.

@IMbackK
Copy link
Copy Markdown
Collaborator

IMbackK commented Apr 20, 2026

Looking at the documentation for cudaMalloc and cudaFree again however, it seems those do an implicit device-wide synchronization when they're called. So even without any modifications I think this PR is safe after all.

Jeah exactly, i was really confused there thinking that this is perhaps not true for cuda and only for hip.

@IMbackK
Copy link
Copy Markdown
Collaborator

IMbackK commented Apr 20, 2026

Did you consider replacing the legacy buffer pool with cudaMallocAsync and cudaFreeAsync? The only reason we are not using those for CUDA is that the performance was worse vs. manually managed chunks of memory but I don't think I ever checked the performance for HIP.

I did try this, it is slow. Hopefully they will fix the damn virtual memory support soon.

@gaugarg-nv
Copy link
Copy Markdown
Contributor

gaugarg-nv commented Apr 20, 2026

I'd suggest if you go with this fix, don't rely on this behavior of cudaFree. It is better to add an explicit cudaDeviceSynchronize or consider using the async version of cudaMalloc and cudaFree as @JohannesGaessler suggested.

The CUDA documentation for cudaFree says: "For all other pointers, this API may perform implicit synchronization."
So, I won't rely on this.

Comment on lines +382 to +392
void clear_pool() {
for (int i = 0; i < MAX_BUFFERS; ++i) {
ggml_cuda_buffer & b = buffer_pool[i];
if (b.ptr != nullptr) {
CUDA_CHECK(cudaFree(b.ptr));
pool_size -= b.size;
b.ptr = nullptr;
b.size = 0;
}
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To ensure consistency, please call clear_pool in the destructor.

Comment thread ggml/src/ggml-cuda/ggml-cuda.cu Outdated
size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
ggml_cuda_set_device(device);
#if defined(GGML_USE_MUSA)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If at all possible, please just add the missing defines to the MUSA header if they're missing rather than to have per-vendor logic.

@IMbackK
Copy link
Copy Markdown
Collaborator

IMbackK commented Apr 20, 2026

The CUDA documentation for cudaFree says: "For all other pointers, this API may perform implicit synchronization." So, I won't rely on this.

On hip there is no may so you would be safe there, doing an explicit synchronization would not hurt ofc.

@leonardHONG
Copy link
Copy Markdown
Contributor Author

Thanks everyone for the amazing clarifications! I really appreciate the guidance. I will apply all the suggested fixes (adding explicit sync, updating the destructor, and cleaning up the MUSA logic) and push an update tonight.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

The CUDA documentation for cudaFree says: "For all other pointers, this API may perform implicit synchronization."

Thank you for correcting me, it seems I suffered from confirmation bias and should have been more careful. The way I had remembered it the synchronization was guaranteed, that is what Gemini said when I asked "Does cudaFree imply a device synchronization?", and that is what I read when I quickly checked the CUDA documentation it linked. I agree that there should be an explicit synchronization to be safe.

@gaugarg-nv
Copy link
Copy Markdown
Contributor

AFAIK, there is always a sync with cudaFree, but documentation doesn't guarantee it. So, it is better to make it explicit in the code. Here is the document I was referring to.

@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Apr 20, 2026
…up MUSA macros

Signed-off-by: 梁厚宏 <2695316095@qq.com>
@leonardHONG
Copy link
Copy Markdown
Contributor Author

Thanks for the feedback — I pushed an update.

This revision:

  • reuses clear_pool() in the destructor
  • makes clear_pool() self-contained with device selection
  • adds an explicit synchronization before clearing cached legacy-pool buffers on the OOM retry path
  • removes the MUSA-specific branch by adding the corresponding error alias in the vendor header

I also reran test-backend-ops locally and it passes on my side.

TheTom added a commit to TheTom/llama-cpp-turboquant that referenced this pull request Apr 20, 2026
On HIP without VMM, the legacy pool retains these at peak size
causing quantized KV to OOM before f16. ggml_cuda_direct_alloc<T>
uses raw hipMalloc/hipFree instead. HIP-only, complements ggml-org#22155.

Fixes ggml-org#22107 without performance degradation.
Tested: gfx1100, gfx1200, gfx1201.
@IMbackK IMbackK merged commit 9789512 into ggml-org:master Apr 20, 2026
50 of 52 checks passed
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request Apr 21, 2026
* ggml-cuda: flush legacy pool on OOM and retry

Signed-off-by: 梁厚宏 <2695316095@qq.com>

* Address review comments: add explicit sync, update destructor, clean up MUSA macros

Signed-off-by: 梁厚宏 <2695316095@qq.com>

---------

Signed-off-by: 梁厚宏 <2695316095@qq.com>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Apr 23, 2026
* ggml-cuda: flush legacy pool on OOM and retry

Signed-off-by: 梁厚宏 <2695316095@qq.com>

* Address review comments: add explicit sync, update destructor, clean up MUSA macros

Signed-off-by: 梁厚宏 <2695316095@qq.com>

---------

Signed-off-by: 梁厚宏 <2695316095@qq.com>
rsenthilkumar6 pushed a commit to rsenthilkumar6/llama.cpp that referenced this pull request May 1, 2026
* ggml-cuda: flush legacy pool on OOM and retry

Signed-off-by: 梁厚宏 <2695316095@qq.com>

* Address review comments: add explicit sync, update destructor, clean up MUSA macros

Signed-off-by: 梁厚宏 <2695316095@qq.com>

---------

Signed-off-by: 梁厚宏 <2695316095@qq.com>
jimbothigpen pushed a commit to jimbothigpen/frankenturbo2 that referenced this pull request May 2, 2026
* ggml-cuda: flush legacy pool on OOM and retry

Signed-off-by: 梁厚宏 <2695316095@qq.com>

* Address review comments: add explicit sync, update destructor, clean up MUSA macros

Signed-off-by: 梁厚宏 <2695316095@qq.com>

---------

Signed-off-by: 梁厚宏 <2695316095@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants