Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ int ggml_cuda_get_device();
struct ggml_cuda_pool {
virtual ~ggml_cuda_pool() = default;

virtual void * alloc(size_t size, size_t * actual_size) = 0;
virtual void * alloc(size_t size, size_t * actual_size, bool overallocate = false) = 0;
virtual void free(void * ptr, size_t size) = 0;
};

Expand All @@ -1131,16 +1131,16 @@ struct ggml_cuda_pool_alloc {
}

// size is in number of elements
T * alloc(size_t size) {
T * alloc(size_t size, bool overallocate = false) {
GGML_ASSERT(pool != nullptr);
GGML_ASSERT(ptr == nullptr);
ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size, overallocate);
return ptr;
}

T * alloc(ggml_cuda_pool & pool, size_t size) {
T * alloc(ggml_cuda_pool & pool, size_t size, bool overallocate = false) {
this->pool = &pool;
return alloc(size);
return alloc(size, overallocate);
}

T * get() {
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ void launch_fattn(
const size_t bs = ggml_blck_size(K->type);
const size_t ts = ggml_type_size(K->type);

K_f16.alloc(ggml_nelements(K));
K_f16.alloc(ggml_nelements(K), /*overallocate=*/ true);
if (ggml_is_contiguously_allocated(K)) {
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
Expand Down Expand Up @@ -999,7 +999,7 @@ void launch_fattn(
const size_t bs = ggml_blck_size(V->type);
const size_t ts = ggml_type_size(V->type);

V_f16.alloc(ggml_nelements(V));
V_f16.alloc(ggml_nelements(V), /*overallocate=*/ true);
if (ggml_is_contiguously_allocated(V)) {
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
Expand Down
93 changes: 81 additions & 12 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
size_t pool_size = 0;

// per-slot timestamp stamped on free(); older = longer uncollected, used for LRU-style reclaim on OOM
uint64_t clock = 0;
uint64_t buffer_pool_ts[MAX_BUFFERS] = {};

explicit ggml_cuda_pool_leg(int device) :
device(device) {
}
Expand All @@ -381,11 +385,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
pool_size -= b.size;
b.ptr = nullptr;
b.size = 0;
buffer_pool_ts[i] = 0;
}
}
}

void * alloc(size_t size, size_t * actual_size) override {
void * alloc(size_t size, size_t * actual_size, bool overallocate) override {
#ifdef DEBUG_CUDA_MALLOC
int nnz = 0;
size_t max_size = 0;
Expand All @@ -409,6 +414,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
*actual_size = b.size;
b.ptr = nullptr;
b.size = 0;
buffer_pool_ts[i] = 0;
return ptr;
}
}
Expand All @@ -421,23 +427,84 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
*actual_size = b.size;
b.ptr = nullptr;
b.size = 0;
buffer_pool_ts[ibest] = 0;
return ptr;
}
void * ptr;
size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
size_t look_ahead_size;
if (overallocate) {
look_ahead_size = (size > SIZE_MAX / 2) ? SIZE_MAX : 2 * size;
} else {
look_ahead_size = (size_t) (1.05 * size);
}
// 256-byte align, overflow-safe
look_ahead_size = (look_ahead_size > SIZE_MAX - 255)
? (SIZE_MAX & ~size_t(255))
: 256 * ((look_ahead_size + 255)/256);
ggml_cuda_set_device(device);
cudaError_t err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device);
if (err == cudaErrorMemoryAllocation) {
if (err == cudaErrorMemoryAllocation && pool_size > 0) {
(void)cudaGetLastError();
const size_t cached_bytes = pool_size;
GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: alloc of %.2f MiB failed, flushing %.2f MiB of cached buffers and retrying\n",
device, look_ahead_size/1024.0/1024.0, cached_bytes/1024.0/1024.0);
CUDA_CHECK(cudaDeviceSynchronize());
clear_pool();
err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device);
if (err == cudaSuccess) {
GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: retry succeeded\n", device);

#if defined(GGML_USE_HIP)
// HIP multi-GPU: LRU timing amplifies ROCm/rocm-systems#4817; fall back to clear_pool.
if (ggml_backend_cuda_get_device_count() > 1) {
GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: HIP multi-GPU, flushing %.2f MiB of cached buffers and retrying\n",
device, pool_size/1024.0/1024.0);
clear_pool();
err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device);
if (err == cudaSuccess) {
GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: retry succeeded\n", device);
}
} else
#endif
{
const size_t cached_bytes = pool_size;
const size_t reclaim_target = (look_ahead_size > SIZE_MAX / 3)
? SIZE_MAX
: 3 * look_ahead_size;
GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: alloc of %.2f MiB failed, trying LRU-style reclaim (target %.2f MiB) from %.2f MiB cached\n",
device, look_ahead_size/1024.0/1024.0, reclaim_target/1024.0/1024.0, cached_bytes/1024.0/1024.0);

size_t freed = 0;
while (freed < reclaim_target) {
int victim = -1;
uint64_t oldest = UINT64_MAX;
for (int i = 0; i < MAX_BUFFERS; ++i) {
ggml_cuda_buffer & b = buffer_pool[i];
if (b.ptr != nullptr && buffer_pool_ts[i] < oldest) {
victim = i;
oldest = buffer_pool_ts[i];
}
}
if (victim < 0) {
break;
}
ggml_cuda_buffer & b = buffer_pool[victim];
CUDA_CHECK(cudaFree(b.ptr));
freed += b.size;
pool_size -= b.size;
b.ptr = nullptr;
b.size = 0;
buffer_pool_ts[victim] = 0;
}

err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device);

// terminal fallback if the 3x-bounded reclaim left cached buffers behind
if (err == cudaErrorMemoryAllocation && pool_size > 0) {
(void)cudaGetLastError();
GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: reclaim bounded at %.2f MiB, flushing remaining %.2f MiB cached\n",
device, freed/1024.0/1024.0, pool_size/1024.0/1024.0);
clear_pool();
err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device);
}

if (err == cudaSuccess) {
GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: retry succeeded after reclaiming %.2f MiB\n",
device, freed/1024.0/1024.0);
}
}
}
CUDA_CHECK(err);
Expand All @@ -456,6 +523,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
if (b.ptr == nullptr) {
b.ptr = ptr;
b.size = size;
buffer_pool_ts[i] = ++clock;
return;
}
}
Expand Down Expand Up @@ -499,7 +567,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
}
}

void * alloc(size_t size, size_t * actual_size) override {
void * alloc(size_t size, size_t * actual_size, [[maybe_unused]] bool overallocate) override {
// overallocate hint is ignored: VMM growth is already granularity-based
// round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
const size_t alignment = 128;
size = alignment * ((size + alignment - 1) / alignment);
Expand Down