prefill gdr kernel enablement#656
Conversation
Signed-off-by: ganyi <ygan@amd.com>
There was a problem hiding this comment.
Pull request overview
This PR updates the vLLM Gated Delta Net attention backend to use an optimized GDR (gated delta rule) kernel during the prefill path, likely to improve performance and/or compatibility with newer aiter kernels.
Changes:
- Switch the prefill recurrent-attention implementation from the existing
self.chunk_gated_delta_rule(...)wrapper toaiter’schunk_gated_delta_rule_opt_vk(...). - Add an inline import for the new optimized kernel in the prefill path.
Comments suppressed due to low confidence (1)
atom/plugin/vllm/attention_backend/attention_gdn.py:372
- After switching the prefill path to call
chunk_gated_delta_rule_opt_vkdirectly,self.chunk_gated_delta_rule(and theChunkGatedDeltaRulewrapper +fla_chunk_gated_delta_ruleimport) appear to be unused in this module. Either remove the now-dead wrapper/attribute, or use it as the fallback when the optimized kernel isn’t available to avoid carrying unused code.
from aiter.ops.triton.gated_delta_net.gated_delta_rule import chunk_gated_delta_rule_opt_vk
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0
(
core_attn_out_non_spec,
last_recurrent_state,
) = chunk_gated_delta_rule_opt_vk(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from aiter.ops.triton.gated_delta_net.gated_delta_rule import chunk_gated_delta_rule_opt_vk | ||
| initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() | ||
| initial_state[~has_initial_state, ...] = 0 | ||
| ( | ||
| core_attn_out_non_spec, | ||
| last_recurrent_state, | ||
| ) = self.chunk_gated_delta_rule( | ||
| ) = chunk_gated_delta_rule_opt_vk( |
There was a problem hiding this comment.
The new inline import of chunk_gated_delta_rule_opt_vk will raise ImportError at runtime on prefill if the installed aiter version doesn’t provide this symbol (note this file already treats aiter as optional via the guarded flydsl_gdr_decode import). Consider doing a module-level try/except import with a clear fallback to the existing fla_chunk_gated_delta_rule implementation (or a feature flag) so prefill doesn’t hard-crash when the optimized kernel is unavailable.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist