Skip to content

Support nonpad kv seqlen within opset 24 Attention (CPU)#27384

Merged
titaiwangms merged 6 commits intomainfrom
titaiwang/support_nonpad_kv_seqlen
Feb 23, 2026
Merged

Support nonpad kv seqlen within opset 24 Attention (CPU)#27384
titaiwangms merged 6 commits intomainfrom
titaiwang/support_nonpad_kv_seqlen

Conversation

@titaiwangms
Copy link
Contributor

@titaiwangms titaiwangms commented Feb 18, 2026

  • Support opset 24 Attention - nonpad_kv_seqlen and add tests
  • Refactor the logic to improve performance and maintainability (GEMM)
  • Refactor fp16 fallback GEMM branch: Upcast -> GemmEx -> Downcast.

This pull request refactors and enhances the ONNX Runtime CPU attention operator with a focus on improved GEMM (matrix multiplication) handling for both float32 and MLFloat16 types, and adds support for the new nonpad_kv_seqlen input (Opset 24+) to enable more flexible masking. The changes simplify code paths, optimize performance (especially for MLFloat16), and improve maintainability.

Key changes include:

1. Unified and Optimized GEMM Handling

  • Introduces a new templated AttentionGemm function that dispatches GEMM operations for both float and MLFloat16, handling hardware capabilities and providing efficient fallbacks (including upcasting/downcasting for MLFloat16 when necessary). This replaces multiple scattered and duplicated GEMM code paths throughout the attention implementation. [1] [2] [3]

2. Support for nonpad_kv_seqlen (Opset 24+)

  • Adds handling for the optional nonpad_kv_seqlen input: validates its shape and values, ensures it is not used with past key/value, and applies per-batch masking to attention scores based on the valid key/value sequence length. [1] [2] [3] [4] [5]

3. Code Clean-up and Maintainability

  • Removes duplicated and complex branching logic for GEMM and MatMul operations, consolidating them into the new AttentionGemm helper for both QK and QKV multiplications. This reduces code complexity and the risk of subtle bugs. [1] [2]

4. Minor Includes and Utility Updates

  • Adds necessary includes for <algorithm> and <vector> to support new logic.

These changes collectively improve the performance, clarity, and extensibility of the attention implementation, particularly for models using MLFloat16 and for newer ONNX opsets that require more flexible masking.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces support for the nonpad_kv_seqlen input (Opset 24+) to enable per-batch masking of key-value positions in the attention operator, and refactors the GEMM logic to unify handling across float and MLFloat16 data types. The main goal is to allow models to specify valid KV sequence lengths per batch element, masking out positions beyond the specified length during attention computation.

Changes:

  • Added nonpad_kv_seqlen parameter handling in CPU attention implementation with masking logic
  • Refactored GEMM operations into a unified AttentionGemm template function for better maintainability
  • Added validation for nonpad_kv_seqlen including shape checks and mutual exclusivity with past_key/past_value
  • Added two test cases covering single-batch and multi-batch scenarios with 4D inputs

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
onnxruntime/test/providers/cpu/llm/attention_op_test.cc Adds two test cases validating nonpad_kv_seqlen masking behavior for single and multi-batch scenarios
onnxruntime/core/providers/cuda/llm/attention.cc Passes nonpad_kv_seqlen parameter to validation helper (but does not implement masking)
onnxruntime/core/providers/cpu/llm/attention_parameters.h Extends AttentionParameters struct with nonpad_kv_seqlen fields
onnxruntime/core/providers/cpu/llm/attention_helper.h Adds validation logic for nonpad_kv_seqlen including shape and mutual exclusivity checks
onnxruntime/core/providers/cpu/llm/attention.cc Implements AttentionGemm template function, refactors GEMM calls, and adds nonpad_kv_seqlen masking logic in ComputeAttentionProbs

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

xadupre
xadupre previously approved these changes Feb 19, 2026
@titaiwangms titaiwangms merged commit a1b634c into main Feb 23, 2026
90 checks passed
@titaiwangms titaiwangms deleted the titaiwang/support_nonpad_kv_seqlen branch February 23, 2026 16:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants