You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed).
I reviewed the Discussions, and have a new and useful enhancement to share.
Feature Description
Summary
The current KV cache state restore implementation (e.g., for prompt cache) performs one ggml_backend_tensor_set() call per cell/token, which incurs significant PCIe transfer latency overhead when the slot allocation is fragmented. This issue proposes an optimization that reorders data in RAM and batches contiguous VRAM writes to reduce the number of PCIe transfers.
Motivation
Problem Description
When restoring a KV cache slot from the prompt cache (see llama_kv_cache::state_read_data() in src/llama-kv-cache.cpp), the current implementation handles fragmented slot allocation with a "slow path":
// Current implementation (llama-kv-cache.cpp:2045-2052)for (uint32_t i = 0; i < cell_count; ++i) {
constsize_t dst_offset = sinfo.idxs[0][i] * k_size_row;
ggml_backend_tensor_set(k,
(constchar*)src + i * k_size_row, // RAM source
dst_offset, // VRAM destination
k_size_row); // Single token size
}
Performance bottleneck:
Each ggml_backend_tensor_set() incurs PCIe transfer latency (~5-10 μs per call)
Plus GPU kernel launch overhead (~1-5 μs per call)
For 100 fragmented tokens: 100 independent transfers → ~600-1500 μs total overhead
PCIe bandwidth is nowhere near saturated due to latency-dominated transfers
This is especially problematic for:
Prompt cache restore in llama-server (multi-turn conversations)
Fragmented KV cache scenarios (common in long-running servers)
Proposed Optimization
The key insight is that RAM memcpy is orders of magnitude faster than PCIe transfers. We can:
Analyze the allocated cell indices and identify contiguous blocks
Reorder data in RAM according to contiguous blocks (fast memcpy)
Batch write each contiguous block to VRAM (fewer PCIe transfers)
Example:
Allocated cells: [10, 15, 16, 23, 24, 25, 45, 46] (8 tokens)
Current approach: 8 independent VRAM writes
Optimized approach:
- Identify contiguous blocks: [10], [15,16], [23,24,25], [45,46]
- Reorder RAM data to match block order
- Execute 4 batched VRAM writes (50% reduction)
Possible Implementation
Technical Feasibility
Implementation location:src/llama-kv-cache.cpp, function llama_kv_cache::state_read_data()
Key steps:
Sort cell indices to identify contiguous blocks (O(N log N), N < 1000 typically)
Allocate temporary RAM buffer for reordering (reused across calls)
Reorder data in RAM using memcpy (fast, ~4 μs for 100 KB)
Execute batched ggml_backend_tensor_set() per contiguous block
No changes required to:
ggml_backend API
Backend implementations (CUDA, Vulkan, CPU, etc.)
KV cache data structures
Related Work
Issue Bug: KV cache load/save is slow #8915: "Bug: KV cache load/save is slow" - identified PCIe latency as bottleneck, but focused on save/load to files
Prerequisites
Feature Description
Summary
The current KV cache state restore implementation (e.g., for prompt cache) performs one ggml_backend_tensor_set() call per cell/token, which incurs significant PCIe transfer latency overhead when the slot allocation is fragmented. This issue proposes an optimization that reorders data in RAM and batches contiguous VRAM writes to reduce the number of PCIe transfers.
Motivation
Problem Description
When restoring a KV cache slot from the prompt cache (see
llama_kv_cache::state_read_data()insrc/llama-kv-cache.cpp), the current implementation handles fragmented slot allocation with a "slow path":Performance bottleneck:
ggml_backend_tensor_set()incurs PCIe transfer latency (~5-10 μs per call)This is especially problematic for:
llama-server(multi-turn conversations)Proposed Optimization
The key insight is that RAM memcpy is orders of magnitude faster than PCIe transfers. We can:
memcpy)Example:
Possible Implementation
Technical Feasibility
Implementation location:
src/llama-kv-cache.cpp, functionllama_kv_cache::state_read_data()Key steps:
memcpy(fast, ~4 μs for 100 KB)ggml_backend_tensor_set()per contiguous blockNo changes required to:
ggml_backendAPIRelated Work
Implementation Sketch
Question for Maintainers
ggml_backend_sched?llama_kv_cachelevel or at theggml_backendlevel (as a general tensor copy optimization)?