Skip to content

Feature Request: [Optimization] Batch contiguous KV cache restore to reduce PCIe transfer overhead #20854

@UltraRabbit

Description

@UltraRabbit

Prerequisites

  • I am running the latest code. Mention the version if possible as well.
  • I carefully followed the README.md.
  • I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed).
  • I reviewed the Discussions, and have a new and useful enhancement to share.

Feature Description

Summary

The current KV cache state restore implementation (e.g., for prompt cache) performs one ggml_backend_tensor_set() call per cell/token, which incurs significant PCIe transfer latency overhead when the slot allocation is fragmented. This issue proposes an optimization that reorders data in RAM and batches contiguous VRAM writes to reduce the number of PCIe transfers.

Motivation

Problem Description

When restoring a KV cache slot from the prompt cache (see llama_kv_cache::state_read_data() in src/llama-kv-cache.cpp), the current implementation handles fragmented slot allocation with a "slow path":

// Current implementation (llama-kv-cache.cpp:2045-2052)
for (uint32_t i = 0; i < cell_count; ++i) {
    const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
    ggml_backend_tensor_set(k,
        (const char*)src + i * k_size_row,  // RAM source
        dst_offset,                          // VRAM destination
        k_size_row);                         // Single token size
}

Performance bottleneck:

  • Each ggml_backend_tensor_set() incurs PCIe transfer latency (~5-10 μs per call)
  • Plus GPU kernel launch overhead (~1-5 μs per call)
  • For 100 fragmented tokens: 100 independent transfers → ~600-1500 μs total overhead
  • PCIe bandwidth is nowhere near saturated due to latency-dominated transfers

This is especially problematic for:

  • Prompt cache restore in llama-server (multi-turn conversations)
  • KV cache session load (Issue Bug: KV cache load/save is slow #8915 reported similar issues for save/load)
  • Fragmented KV cache scenarios (common in long-running servers)

Proposed Optimization

The key insight is that RAM memcpy is orders of magnitude faster than PCIe transfers. We can:

  1. Analyze the allocated cell indices and identify contiguous blocks
  2. Reorder data in RAM according to contiguous blocks (fast memcpy)
  3. Batch write each contiguous block to VRAM (fewer PCIe transfers)

Example:

Allocated cells: [10, 15, 16, 23, 24, 25, 45, 46]  (8 tokens)

Current approach: 8 independent VRAM writes

Optimized approach:
  - Identify contiguous blocks: [10], [15,16], [23,24,25], [45,46]
  - Reorder RAM data to match block order
  - Execute 4 batched VRAM writes (50% reduction)

Possible Implementation

Technical Feasibility

Implementation location: src/llama-kv-cache.cpp, function llama_kv_cache::state_read_data()

Key steps:

  1. Sort cell indices to identify contiguous blocks (O(N log N), N < 1000 typically)
  2. Allocate temporary RAM buffer for reordering (reused across calls)
  3. Reorder data in RAM using memcpy (fast, ~4 μs for 100 KB)
  4. Execute batched ggml_backend_tensor_set() per contiguous block

No changes required to:

  • ggml_backend API
  • Backend implementations (CUDA, Vulkan, CPU, etc.)
  • KV cache data structures

Related Work

Implementation Sketch

} else {
    // Slow path: scatter to non-contiguous positions
    // OPTIMIZATION: sort indices and merge contiguous blocks

    const void * src = io.read(cell_count * k_size_row);

    // Step 1: Create (index, original_position) pairs and sort
    std::vector<std::pair<uint32_t, uint32_t>> indexed_idxs(cell_count);
    for (uint32_t i = 0; i < cell_count; ++i) {
        indexed_idxs[i] = {sinfo.idxs[0][i], i};
    }
    std::sort(indexed_idxs.begin(), indexed_idxs.end(),
              [](const auto& a, const auto& b) { return a.first < b.first; });

    // Step 2: Find contiguous blocks
    struct Block {
        uint32_t start_idx;    // Position in sorted array
        uint32_t cell_start;   // Starting cell index
        uint32_t cell_count;   // Number of contiguous cells
    };
    std::vector<Block> blocks;
    // ... (identify contiguous blocks)

    // Step 3: Allocate temporary buffer and reorder data in RAM
    std::vector<uint8_t> temp_buf(cell_count * k_size_row);
    for (uint32_t i = 0; i < cell_count; ++i) {
        const uint32_t orig_pos = indexed_idxs[i].second;
        memcpy(temp_buf.data() + i * k_size_row,
               (const char*)src + orig_pos * k_size_row,
               k_size_row);
    }

    // Step 4: Batch write to VRAM per contiguous block
    for (const auto& block : blocks) {
        const size_t src_offset = block.start_idx * k_size_row;
        const size_t dst_offset = block.cell_start * k_size_row;
        const size_t copy_size = block.cell_count * k_size_row;

        ggml_backend_tensor_set(k,
                                temp_buf.data() + src_offset,
                                dst_offset,
                                copy_size);
    }
}

Question for Maintainers

  1. Is this optimization worth pursuing given the upcoming changes to ggml_backend_sched?
  2. Should this be implemented at the llama_kv_cache level or at the ggml_backend level (as a general tensor copy optimization)?
  3. Any concerns about the temporary RAM allocation for reordering?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions