Skip to content

Conversation

@LJ-underdog
Copy link
Contributor

@LJ-underdog LJ-underdog commented Dec 30, 2025

Motivation

This PR adds support for a sink_ptr parameter across the multi-head attention (MHA) forward pass implementations. The sink_ptr enables the "gptoss_sink" feature, which is a mechanism for attention sink tokens in transformer models.

Technical Details

Added sink_ptr parameter to all MHA forward function signatures (regular, varlen, and batch prefill variants)
Added validation logic to ensure sink_ptr matches device and shape requirements, with automatic dtype conversion to float32
Propagated sink_ptr through the call stack from Python interfaces to C++/CUDA implementations

Test Plan

Test in ck repo

Test Result

local test passd

Submission Checklist

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
@LJ-underdog LJ-underdog requested review from a team and Copilot December 30, 2025 04:06
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for a sink_ptr parameter across the multi-head attention (MHA) forward pass implementations. The sink_ptr enables the "gptoss_sink" feature, which is a mechanism for attention sink tokens in transformer models.

Key changes:

  • Added sink_ptr parameter to all MHA forward function signatures (regular, varlen, and batch prefill variants)
  • Added validation logic to ensure sink_ptr matches device and shape requirements, with automatic dtype conversion to float32
  • Propagated sink_ptr through the call stack from Python interfaces to C++/CUDA implementations

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
csrc/py_itfs_cu/asm_mha_varlen_fwd.cu Added nullptr sink_ptr argument to mha_fwd_args constructor
csrc/py_itfs_cu/asm_mha_fwd.cu Added nullptr sink_ptr argument to mha_fwd_args constructor
csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu Added sink_ptr parameter handling and validation in varlen forward kernels
csrc/py_itfs_ck/mha_fwd_kernels.cu Added sink_ptr parameter to standard MHA forward kernels
csrc/py_itfs_ck/mha_batch_prefill_kernels.cu Added sink_ptr parameter to batch prefill kernels
csrc/include/torch/mha_varlen_fwd.h Updated header to include sink_ptr parameter in function signature
csrc/include/torch/mha_fwd.h Updated header to include sink_ptr parameter in function signature
csrc/include/torch/mha_batch_prefill.h Updated header to include sink_ptr parameter in function signature
csrc/include/rocm_ops.hpp Added sink_ptr pybind argument definitions for all MHA variants
aiter/ops/mha.py Added sink_ptr parameter to all Python MHA functions with device/shape validation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

LJ-underdog and others added 5 commits December 30, 2025 13:41
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@LJ-underdog LJ-underdog changed the title enable gptoss_sink Enable gptoss_sink Dec 30, 2025
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants