-
Notifications
You must be signed in to change notification settings - Fork 166
Enable gptoss_sink #1753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Enable gptoss_sink #1753
Conversation
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for a sink_ptr parameter across the multi-head attention (MHA) forward pass implementations. The sink_ptr enables the "gptoss_sink" feature, which is a mechanism for attention sink tokens in transformer models.
Key changes:
- Added
sink_ptrparameter to all MHA forward function signatures (regular, varlen, and batch prefill variants) - Added validation logic to ensure
sink_ptrmatches device and shape requirements, with automatic dtype conversion to float32 - Propagated
sink_ptrthrough the call stack from Python interfaces to C++/CUDA implementations
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| csrc/py_itfs_cu/asm_mha_varlen_fwd.cu | Added nullptr sink_ptr argument to mha_fwd_args constructor |
| csrc/py_itfs_cu/asm_mha_fwd.cu | Added nullptr sink_ptr argument to mha_fwd_args constructor |
| csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu | Added sink_ptr parameter handling and validation in varlen forward kernels |
| csrc/py_itfs_ck/mha_fwd_kernels.cu | Added sink_ptr parameter to standard MHA forward kernels |
| csrc/py_itfs_ck/mha_batch_prefill_kernels.cu | Added sink_ptr parameter to batch prefill kernels |
| csrc/include/torch/mha_varlen_fwd.h | Updated header to include sink_ptr parameter in function signature |
| csrc/include/torch/mha_fwd.h | Updated header to include sink_ptr parameter in function signature |
| csrc/include/torch/mha_batch_prefill.h | Updated header to include sink_ptr parameter in function signature |
| csrc/include/rocm_ops.hpp | Added sink_ptr pybind argument definitions for all MHA variants |
| aiter/ops/mha.py | Added sink_ptr parameter to all Python MHA functions with device/shape validation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Motivation
This PR adds support for a sink_ptr parameter across the multi-head attention (MHA) forward pass implementations. The sink_ptr enables the "gptoss_sink" feature, which is a mechanism for attention sink tokens in transformer models.
Technical Details
Added sink_ptr parameter to all MHA forward function signatures (regular, varlen, and batch prefill variants)
Added validation logic to ensure sink_ptr matches device and shape requirements, with automatic dtype conversion to float32
Propagated sink_ptr through the call stack from Python interfaces to C++/CUDA implementations
Test Plan
Test in ck repo
Test Result
local test passd
Submission Checklist