Skip to content

cuda: LRU eviction + overalloc for legacy pool#22207

Open
TheTom wants to merge 2 commits intoggml-org:masterfrom
TheTom:experiment/pool-threshold-free
Open

cuda: LRU eviction + overalloc for legacy pool#22207
TheTom wants to merge 2 commits intoggml-org:masterfrom
TheTom:experiment/pool-threshold-free

Conversation

@TheTom
Copy link
Copy Markdown

@TheTom TheTom commented Apr 21, 2026

Fixes #22107. Per #22193 (comment).

On OOM, evict LRU buffers first. FA temps use 2x overalloc.
Tested on gfx1201, q8_0 @ d40000: 369 t/s (was OOM).

Requirements

@TheTom TheTom requested a review from a team as a code owner April 21, 2026 09:26
@ggml-gh-bot
Copy link
Copy Markdown

ggml-gh-bot Bot commented Apr 21, 2026

Hi @TheTom, thanks for your contribution!

Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:

  • Multiple open PRs from a new contributor: We limit new contributors (those without a previously merged PR) to 1 open PR at a time. You currently have 3 open PRs.

Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below.

@TheTom
Copy link
Copy Markdown
Author

TheTom commented Apr 21, 2026

Hi @TheTom, thanks for your contribution!

Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:

  • Multiple open PRs from a new contributor: We limit new contributors (those without a previously merged PR) to 1 open PR at a time. You currently have 3 open PRs.

Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below.

copied from previous PR for clarity:
One is a draft (#21119), one has been waiting on review for 2 weeks (#21452). This is a bug fix for an OOM affecting all HIP users with quantized KV at long context. Happy to prioritize however maintainers prefesr.

Comment thread ggml/src/ggml-cuda/ggml-cuda.cu Outdated
Comment on lines +390 to +413
size_t evict_lru(size_t target) {
size_t freed = 0;
ggml_cuda_set_device(device);
while (freed < target) {
int oldest = -1;
uint64_t oldest_ts = UINT64_MAX;
for (int i = 0; i < MAX_BUFFERS; ++i) {
if (buffer_pool[i].ptr != nullptr && buffer_pool[i].last_used < oldest_ts) {
oldest_ts = buffer_pool[i].last_used;
oldest = i;
}
}
if (oldest < 0) {
break;
}
ggml_cuda_buffer & b = buffer_pool[oldest];
CUDA_CHECK(cudaFree(b.ptr));
freed += b.size;
pool_size -= b.size;
b.ptr = nullptr;
b.size = 0;
}
return freed;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inline this function.

Comment thread ggml/src/ggml-cuda/common.cuh Outdated
Comment on lines +1108 to +1112
virtual void * alloc(size_t size, size_t * actual_size) = 0;
virtual void * alloc_oversize(size_t size, size_t * actual_size, double factor) {
GGML_UNUSED(factor);
return alloc(size, actual_size);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding a new method, add new argument float lookahead = 1.05f to alloc.

Comment thread ggml/src/ggml-cuda/ggml-cuda.cu Outdated
Comment on lines +361 to +366
uint64_t last_used = 0;
};

ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
size_t pool_size = 0;
uint64_t timestamp = 0;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
uint64_t last_used = 0;
};
ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
size_t pool_size = 0;
uint64_t timestamp = 0;
uint64_t last_use = 0;
};
ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
size_t pool_size = 0;
uint64_t usage_counter = 0;

What you implemented is not a timestamp but would also work. You should change the variable names though to avoid confusion.

@TheTom TheTom force-pushed the experiment/pool-threshold-free branch from 6ac5e04 to 4e68d9e Compare April 21, 2026 12:22
@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Apr 21, 2026
@TheTom
Copy link
Copy Markdown
Author

TheTom commented Apr 21, 2026

comments addressed PTAL

Copy link
Copy Markdown
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In testing i have found memory access faults caused by this pr. investigating...

@TheTom
Copy link
Copy Markdown
Author

TheTom commented Apr 22, 2026

Sounds good let me know what you find. Happy to adjust if needed.

@TheTom
Copy link
Copy Markdown
Author

TheTom commented Apr 23, 2026

Hey @IMbackK , were you able to repro or get a backtrace for me to investigate?

@IMbackK
Copy link
Copy Markdown
Collaborator

IMbackK commented Apr 23, 2026

I have already traced the problem and it looks like this makes ROCm/rocm-systems#4817 ie failures in hipMemcpyAsync with valid source and destination parameters in multigpu scenarios, very much way more common.

There is not a whole lot we can do about this since our code is correct. At the same time pretty much breaking multigpu hip for non-fp16 fattn isent really an option either

@TheTom
Copy link
Copy Markdown
Author

TheTom commented Apr 23, 2026

I have already traced the problem and it looks like this makes ROCm/rocm-systems#4817 ie failures in hipMemcpyAsync with valid source and destination parameters in multigpu scenarios, very much way more common.

There is not a whole lot we can do about this since our code is correct. At the same time pretty much breaking multigpu hip for non-fp16 fattn isent really an option either

thanks for tracking this down. looks like the RC of ROCm/rocm-systems#4817 is a pre-existing hipMemcpyAsync host mapping race on multi-gpu. the LRU eviction changes free/realloc timing which exposes it more often, but the underlying bug is in the ROCm runtime as you surmise.

thoughts,on how to unblock: i can gate the LRU path behind a single-gpu check on HIP and fall back to clear_pool() for multi-gpu. that way multi-gpu HIP gets the same behavior it had before (no regression) and single-gpu HIP + all CUDA users get the fix.

want me to push that, or do you have a different approach in mind? I also have some AMD contacts I can ask if needed

ROCm/rocm-systems#4817: LRU free/realloc cycles amplify a
hipMemcpyAsync host-mapping race on multi-GPU setups. Gate the
LRU path behind a single-GPU check on HIP and fall back to
clear_pool() for multi-GPU. Single-GPU HIP + all CUDA users
still get LRU eviction.
@TheTom TheTom force-pushed the experiment/pool-threshold-free branch from f68e40f to a2e25b0 Compare April 23, 2026 23:22
@TheTom TheTom requested a review from IMbackK April 23, 2026 23:23
@TheTom
Copy link
Copy Markdown
Author

TheTom commented Apr 23, 2026

Added an ifdef for multigpu. Open to changes. Let me know.

@IMbackK
Copy link
Copy Markdown
Collaborator

IMbackK commented Apr 26, 2026

I dont know if there is a point to doing this. As far as i know the cuda devices will all use the VMM allocator anyhow, im not sure under what circumstances CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED is false on cuda. Maybe @JohannesGaessler can comment on if this is an in any way common case.

If this not an encountered case then i would recommend to just leave this pr open until i or amd figure out where the race in rocr/clr is exactly.

I get that its also useful for the single gpu hip case - but im not sure this justifies having the hack in the code. There is also something to be said for not doing this sort of workaround and instead pushing amd to fix their shit.

@TheTom
Copy link
Copy Markdown
Author

TheTom commented Apr 29, 2026

I dont know if there is a point to doing this. As far as i know the cuda devices will all use the VMM allocator anyhow, im not sure under what circumstances CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED is false on cuda. Maybe @JohannesGaessler can comment on if this is an in any way common case.

If this not an encountered case then i would recommend to just leave this pr open until i or amd figure out where the race in rocr/clr is exactly.

I get that its also useful for the single gpu hip case - but im not sure this justifies having the hack in the code. There is also something to be said for not doing this sort of workaround and instead pushing amd to fix their shit.

@JohannesGaessler @IMbackK looking for guidance here.

single-gpu HIP users hit both OOM and prefill slowdown on q8_0 KV at long ctx, confirmed across gfx1100/1200/1201/906 by multiple community testers.

i don't have multi-gpu amd to verify or fix 4817. the current PR is gated to single-gpu HIP only - multi-gpu path is unchanged. happy to ping my amd contacts on 4817 in parallel to push their side, but the user-facing fix shouldn't have to wait on rocm's queue.

if the ifdef is the blocker, happy to swap to a runtime device-count check. what shape would you accept here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Misc. bug: [CUDA/ROCm] VRAM leak/fragmentation in ggml_cuda_pool_leg when using Flash Attention

3 participants