Skip to content

Conversation

@Jeff-Huang
Copy link
Contributor

Motivation

Introduces support for a vectorized KV cache memory layout (e.g., [num_blocks, num_kv_heads, head_size/8, block_size, 8]) to improve memory access efficiency and also support different type of block table such as vLLM and SGLang.

Technical Details

Key changes:

  • KV Cache Layout Optimization and Adjustment:

    • The KV cache memory layout has been adjusted to support vectorized read patterns (Vectorized KV layout).
    • Support for various layout formats has been implemented, such as [num_blocks, num_kv_heads, head_size/8, block_size, 8] and other structures.
  • vLLM Block Table Integration:

    • Added support for vLLM block table integration ([num_batch, max_blocks_per_seq]).
    • Added support for SGLang block table integration ([num_blocks]).
    • Support PageSize 1024
  • Kernel Interface Updates:

    • New parameters for block table and kv cache layout.
  • Structure and Traits Updates:

    • Adapted to changes in the fmha_fwd_batch_prefill_traits structure.

Test Plan

Test Result

Submission Checklist

ltqin and others added 12 commits December 30, 2025 09:09
…/8, block_size, 8], [num_blocks, num_kv_heads, block_size/8, head_size, 8]
…ayout

Updated `mha_batch_prefill` API and tests to support vLLM-style block tables alongside SGLang-style page tables, while enforcing the new hardware-optimized 5D vectorized KV cache layout.

**Key Changes:**
*   **API**: Added `block_table` and `seqlen_k` arguments to python/C++ interfaces.
*   **Layout Enforcement**: Added strict checks for 5D vectorized KV layout (swizzled x=8) in host bindings and python wrappers.
*   **CodeGen**: Automatically select `VLLM_BLOCK_TABLE_2D` or `SGLANG_PAGE_TABLE_1D` trait based on input arguments.
*   **Tests**: Added `test_batch_prefill_vllm` to verify block table correctness and updated existing tests to use the vectorized layout.
@Jeff-Huang Jeff-Huang requested a review from a team December 30, 2025 05:36
if head_size_v_og % 8 != 0:
v = torch.nn.functional.pad(v, [0, 8 - head_size_v_og % 8])
head_size_q_og = q.size(-1)
k_vector_size = 16 // k.element_size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest adding a comment explaining that the magic number 16 corresponds to dwordx4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants