From 1ab3ca45ff1bf124fd65acd9da9eea2c231279b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E5=8E=9A=E5=AE=8F?= <2695316095@qq.com> Date: Mon, 20 Apr 2026 12:35:21 +0800 Subject: [PATCH 1/2] ggml-cuda: flush legacy pool on OOM and retry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 梁厚宏 <2695316095@qq.com> --- ggml/src/ggml-cuda/ggml-cuda.cu | 29 +++++++++++++++++++++++++++++ ggml/src/ggml-cuda/vendors/hip.h | 1 + 2 files changed, 30 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index de579d2ed50..353e7c81231 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -379,6 +379,18 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { GGML_ASSERT(pool_size == 0); } + void clear_pool() { + for (int i = 0; i < MAX_BUFFERS; ++i) { + ggml_cuda_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { + CUDA_CHECK(cudaFree(b.ptr)); + pool_size -= b.size; + b.ptr = nullptr; + b.size = 0; + } + } + } + void * alloc(size_t size, size_t * actual_size) override { #ifdef DEBUG_CUDA_MALLOC int nnz = 0; @@ -421,7 +433,24 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); ggml_cuda_set_device(device); +#if defined(GGML_USE_MUSA) CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); +#else + cudaError_t err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaErrorMemoryAllocation) { + // only invoked from alloc() after ggml_cuda_set_device. + (void)cudaGetLastError(); + const size_t cached_bytes = pool_size; + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: alloc of %.2f MiB failed, flushing %.2f MiB of cached buffers and retrying\n", + device, look_ahead_size/1024.0/1024.0, cached_bytes/1024.0/1024.0); + clear_pool(); + err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaSuccess) { + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: retry succeeded\n", device); + } + } + CUDA_CHECK(err); +#endif *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 898fec31e36..0c030e173ca 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -59,6 +59,7 @@ #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t +#define cudaErrorMemoryAllocation hipErrorOutOfMemory #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags hipEventCreateWithFlags From aa1c01e7c6fc1ade1af796a4e578e10e1ecf2e06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E5=8E=9A=E5=AE=8F?= <2695316095@qq.com> Date: Mon, 20 Apr 2026 21:18:29 +0800 Subject: [PATCH 2/2] Address review comments: add explicit sync, update destructor, clean up MUSA macros MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 梁厚宏 <2695316095@qq.com> --- ggml/src/ggml-cuda/ggml-cuda.cu | 16 +++------------- ggml/src/ggml-cuda/vendors/musa.h | 1 + 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 353e7c81231..31759b1676a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -368,18 +368,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } ~ggml_cuda_pool_leg() { - ggml_cuda_set_device(device); - for (int i = 0; i < MAX_BUFFERS; ++i) { - ggml_cuda_buffer & b = buffer_pool[i]; - if (b.ptr != nullptr) { - CUDA_CHECK(cudaFree(b.ptr)); - pool_size -= b.size; - } - } + clear_pool(); GGML_ASSERT(pool_size == 0); } void clear_pool() { + ggml_cuda_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cuda_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { @@ -433,16 +427,13 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); ggml_cuda_set_device(device); -#if defined(GGML_USE_MUSA) - CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); -#else cudaError_t err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); if (err == cudaErrorMemoryAllocation) { - // only invoked from alloc() after ggml_cuda_set_device. (void)cudaGetLastError(); const size_t cached_bytes = pool_size; GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: alloc of %.2f MiB failed, flushing %.2f MiB of cached buffers and retrying\n", device, look_ahead_size/1024.0/1024.0, cached_bytes/1024.0/1024.0); + CUDA_CHECK(cudaDeviceSynchronize()); clear_pool(); err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); if (err == cudaSuccess) { @@ -450,7 +441,6 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } } CUDA_CHECK(err); -#endif *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 1abb8acfd4b..8aa056e9174 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -42,6 +42,7 @@ #define cudaDeviceProp musaDeviceProp #define cudaDeviceSynchronize musaDeviceSynchronize #define cudaError_t musaError_t +#define cudaErrorMemoryAllocation musaErrorMemoryAllocation #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags musaEventCreateWithFlags