-
Notifications
You must be signed in to change notification settings - Fork 975
Add recurrent gated delta rule custom op for Qwen3.5 attention #18088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ae788ce
36a19d4
e5540ad
fd279a6
e4cb36f
92206d9
62d9674
6c00b3d
7898600
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import logging | ||
| from abc import ABC, abstractmethod | ||
| from enum import Enum | ||
| from typing import Any, Dict, Optional, Tuple, Type, TypedDict | ||
|
|
@@ -52,6 +53,8 @@ def forward( | |
|
|
||
|
|
||
| ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {} | ||
| _RECURRENT_GATED_DELTA_RULE_OP = None | ||
| _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = False | ||
|
|
||
|
|
||
| def register_attention(name: str): | ||
|
|
@@ -64,6 +67,38 @@ def decorator(cls: Type[Attention]): | |
| return decorator | ||
|
|
||
|
|
||
| def _get_recurrent_gated_delta_rule_op(): | ||
| global _RECURRENT_GATED_DELTA_RULE_OP | ||
| global _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP | ||
|
|
||
| if _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP: | ||
| return _RECURRENT_GATED_DELTA_RULE_OP | ||
|
|
||
| _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True | ||
| try: | ||
| _RECURRENT_GATED_DELTA_RULE_OP = ( | ||
| torch.ops.llama.recurrent_gated_delta_rule.default | ||
| ) | ||
| return _RECURRENT_GATED_DELTA_RULE_OP | ||
| except (AttributeError, RuntimeError): | ||
| pass | ||
|
|
||
| try: | ||
| from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 | ||
| except (ImportError, OSError, RuntimeError): | ||
| logging.debug("Failed to import custom ops library", exc_info=True) | ||
| return None | ||
|
Comment on lines
+86
to
+90
|
||
|
|
||
| try: | ||
| _RECURRENT_GATED_DELTA_RULE_OP = ( | ||
| torch.ops.llama.recurrent_gated_delta_rule.default | ||
| ) | ||
| except (AttributeError, RuntimeError): | ||
| _RECURRENT_GATED_DELTA_RULE_OP = None | ||
|
|
||
| return _RECURRENT_GATED_DELTA_RULE_OP | ||
|
|
||
|
|
||
| class KVCache(nn.Module): | ||
| def __init__( | ||
| self, | ||
|
|
@@ -725,28 +760,43 @@ def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor: | |
| out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype) | ||
| return out.transpose(1, 2).contiguous() | ||
|
|
||
| def _recurrent_gated_delta_rule( | ||
| def _gated_delta_rule_op( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| g: torch.Tensor, | ||
| beta: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| # query/key/value: (batch, seq_len, num_heads, head_dim) | ||
| # g/beta: (batch, seq_len, num_heads) | ||
| initial_dtype = query.dtype | ||
| query = _l2norm(query, dim=-1, eps=1e-6) | ||
| key = _l2norm(key, dim=-1, eps=1e-6) | ||
| query, key, value, beta, g = [ | ||
| x.transpose(1, 2).contiguous().to(torch.float32) | ||
| for x in (query, key, value, beta, g) | ||
| ] | ||
| batch_size = query.shape[0] | ||
| recurrent_gated_delta_rule_op = _get_recurrent_gated_delta_rule_op() | ||
| if recurrent_gated_delta_rule_op is not None: | ||
| return recurrent_gated_delta_rule_op( | ||
| query, | ||
| key, | ||
| value, | ||
| g, | ||
| beta, | ||
| self.recurrent_state[:batch_size], | ||
| ) | ||
| return self._naive_gated_delta_rule_op( | ||
| query, | ||
| key, | ||
| value, | ||
| g, | ||
| beta, | ||
| ) | ||
|
|
||
| batch_size, num_heads, sequence_length, k_head_dim = key.shape | ||
| def _naive_gated_delta_rule_op( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| g: torch.Tensor, | ||
| beta: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| batch_size, num_heads, sequence_length, _ = key.shape | ||
| v_head_dim = value.shape[-1] | ||
| scale = 1.0 / (query.shape[-1] ** 0.5) | ||
| query = query * scale | ||
|
|
||
| core_attn_out = torch.zeros( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you put this logic in some function called like "naive_gated_delta_rule_op" and then just have the if statement switch between them to tidy this function up a bit.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed it so _recurrent_gated_delta_rule() switches between _gated_delta_rule_op() and _naive_gated_delta_rule_op() |
||
| batch_size, | ||
|
|
@@ -780,6 +830,36 @@ def _recurrent_gated_delta_rule( | |
| last_recurrent_state.to(self.recurrent_state.dtype) | ||
| ) | ||
|
|
||
| return core_attn_out | ||
|
|
||
| def _recurrent_gated_delta_rule( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| g: torch.Tensor, | ||
| beta: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| # query/key/value: (batch, seq_len, num_heads, head_dim) | ||
| # g/beta: (batch, seq_len, num_heads) | ||
| initial_dtype = query.dtype | ||
| query = _l2norm(query, dim=-1, eps=1e-6) | ||
| key = _l2norm(key, dim=-1, eps=1e-6) | ||
| query, key, value, beta, g = [ | ||
| x.transpose(1, 2).contiguous().to(torch.float32) | ||
| for x in (query, key, value, beta, g) | ||
| ] | ||
|
|
||
| scale = 1.0 / (query.shape[-1] ** 0.5) | ||
| query = query * scale | ||
|
|
||
| core_attn_out = self._gated_delta_rule_op( | ||
| query, | ||
| key, | ||
| value, | ||
| g, | ||
| beta, | ||
| ) | ||
| return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) | ||
|
|
||
| def forward( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_get_recurrent_gated_delta_rule_op()attempts to importexecutorch.extension.llm.custom_ops.custom_opsas a best-effort fallback, but it doesn't catchFileNotFoundError.custom_ops.pycan raiseFileNotFoundErrorwhencustom_ops_aot_libisn't present, which would crash attention initialization instead of cleanly falling back to the Python implementation. Consider catchingFileNotFoundErrorhere (or makingcustom_ops.pyraise aRuntimeErrorthat is already handled).