Skip to content
106 changes: 93 additions & 13 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Type, TypedDict
Expand Down Expand Up @@ -52,6 +53,8 @@ def forward(


ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {}
_RECURRENT_GATED_DELTA_RULE_OP = None
_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = False


def register_attention(name: str):
Expand All @@ -64,6 +67,38 @@ def decorator(cls: Type[Attention]):
return decorator


def _get_recurrent_gated_delta_rule_op():
global _RECURRENT_GATED_DELTA_RULE_OP
global _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP

if _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP:
return _RECURRENT_GATED_DELTA_RULE_OP

_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
try:
_RECURRENT_GATED_DELTA_RULE_OP = (
torch.ops.llama.recurrent_gated_delta_rule.default
)
return _RECURRENT_GATED_DELTA_RULE_OP
except (AttributeError, RuntimeError):
pass

try:
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
except (ImportError, OSError, RuntimeError):
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_recurrent_gated_delta_rule_op() attempts to import executorch.extension.llm.custom_ops.custom_ops as a best-effort fallback, but it doesn't catch FileNotFoundError. custom_ops.py can raise FileNotFoundError when custom_ops_aot_lib isn't present, which would crash attention initialization instead of cleanly falling back to the Python implementation. Consider catching FileNotFoundError here (or making custom_ops.py raise a RuntimeError that is already handled).

Suggested change
except (ImportError, OSError, RuntimeError):
except (ImportError, FileNotFoundError, OSError, RuntimeError):

Copilot uses AI. Check for mistakes.
logging.debug("Failed to import custom ops library", exc_info=True)
return None
Comment on lines +86 to +90
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_recurrent_gated_delta_rule_op() swallows all exceptions when importing executorch.extension.llm.custom_ops.custom_ops. Catching broad Exception can hide real load/link errors and make debugging difficult; consider narrowing to ImportError/OSError (or logging the exception at debug level) so unexpected failures surface.

Copilot uses AI. Check for mistakes.

try:
_RECURRENT_GATED_DELTA_RULE_OP = (
torch.ops.llama.recurrent_gated_delta_rule.default
)
except (AttributeError, RuntimeError):
_RECURRENT_GATED_DELTA_RULE_OP = None

return _RECURRENT_GATED_DELTA_RULE_OP


class KVCache(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -725,28 +760,43 @@ def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor:
out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype)
return out.transpose(1, 2).contiguous()

def _recurrent_gated_delta_rule(
def _gated_delta_rule_op(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
) -> torch.Tensor:
# query/key/value: (batch, seq_len, num_heads, head_dim)
# g/beta: (batch, seq_len, num_heads)
initial_dtype = query.dtype
query = _l2norm(query, dim=-1, eps=1e-6)
key = _l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch_size = query.shape[0]
recurrent_gated_delta_rule_op = _get_recurrent_gated_delta_rule_op()
if recurrent_gated_delta_rule_op is not None:
return recurrent_gated_delta_rule_op(
query,
key,
value,
g,
beta,
self.recurrent_state[:batch_size],
)
return self._naive_gated_delta_rule_op(
query,
key,
value,
g,
beta,
)

batch_size, num_heads, sequence_length, k_head_dim = key.shape
def _naive_gated_delta_rule_op(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
) -> torch.Tensor:
batch_size, num_heads, sequence_length, _ = key.shape
v_head_dim = value.shape[-1]
scale = 1.0 / (query.shape[-1] ** 0.5)
query = query * scale

core_attn_out = torch.zeros(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put this logic in some function called like "naive_gated_delta_rule_op" and then just have the if statement switch between them to tidy this function up a bit.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed it so _recurrent_gated_delta_rule() switches between _gated_delta_rule_op() and _naive_gated_delta_rule_op()

batch_size,
Expand Down Expand Up @@ -780,6 +830,36 @@ def _recurrent_gated_delta_rule(
last_recurrent_state.to(self.recurrent_state.dtype)
)

return core_attn_out

def _recurrent_gated_delta_rule(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
) -> torch.Tensor:
# query/key/value: (batch, seq_len, num_heads, head_dim)
# g/beta: (batch, seq_len, num_heads)
initial_dtype = query.dtype
query = _l2norm(query, dim=-1, eps=1e-6)
key = _l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]

scale = 1.0 / (query.shape[-1] ** 0.5)
query = query * scale

core_attn_out = self._gated_delta_rule_op(
query,
key,
value,
g,
beta,
)
return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)

def forward(
Expand Down
72 changes: 72 additions & 0 deletions examples/models/llama/tests/test_export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
import tempfile
import unittest
from pathlib import Path

from executorch.devtools.backend_debug import get_delegation_info

Expand All @@ -25,6 +28,7 @@

from executorch.examples.models.llama.export_llama_lib import (
_export_llama,
_prepare_for_llama_export,
build_args_parser,
get_quantizer_and_quant_params,
)
Expand All @@ -37,6 +41,39 @@


class ExportLlamaLibTest(unittest.TestCase):
def _make_tiny_qwen35_params(self) -> dict:
return {
"dim": 64,
"hidden_dim": 128,
"n_heads": 4,
"head_dim": 16,
"n_kv_heads": 2,
"n_layers": 4,
"norm_eps": 1e-6,
"rope_theta": 10000000.0,
"use_scaled_rope": False,
"vocab_size": 256,
"use_hf_rope": True,
"partial_rotary_factor": 0.25,
"attention_qkv_bias": False,
"use_qk_norm": True,
"qk_norm_before_rope": True,
"attention_type": "mha",
"use_q_gate": True,
"rms_norm_add_unit_offset": True,
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 8,
"linear_value_head_dim": 8,
"linear_num_key_heads": 4,
"linear_num_value_heads": 4,
"layer_types": [
"linear_attention",
"full_attention",
"linear_attention",
"full_attention",
],
}

def test_has_expected_ops_and_op_counts(self):
"""
Checks the presence of unwanted expensive ops.
Expand Down Expand Up @@ -66,6 +103,41 @@ def test_has_expected_ops_and_op_counts(self):
for op, _op_info in delegation_info.delegation_by_operator.items():
self.assertTrue(op not in UNWANTED_OPS)

def test_tiny_qwen35_export_uses_recurrent_gated_delta_rule(self):
with tempfile.TemporaryDirectory() as temp_dir:
params_path = Path(temp_dir) / "tiny_qwen35.json"
params_path.write_text(json.dumps(self._make_tiny_qwen35_params()))

parser = build_args_parser()
args = parser.parse_args(
[
"--model",
"qwen3_5_0_8b",
"--params",
str(params_path),
"--use_kv_cache",
"--disable_dynamic_shape",
"--max_seq_length",
"8",
"--max_context_length",
"8",
]
)

llm_config = LlmConfig.from_args(args)
builder = _prepare_for_llama_export(llm_config).export()
assert builder.pre_autograd_graph_module is not None

recurrent_nodes = [
node
for node in builder.pre_autograd_graph_module.graph.nodes
if "auto_functionalized_v2" in str(node.target)
and node.args
and "llama.recurrent_gated_delta_rule" in str(node.args[0])
]

self.assertEqual(len(recurrent_nodes), 2)

@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self):
llm_config = LlmConfig()
Expand Down
105 changes: 105 additions & 0 deletions examples/models/llama/tests/test_qwen3_5_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import unittest

import executorch.examples.models.llama.attention as attention_module
import torch

from executorch.examples.models.llama.attention import ATTENTION_REGISTRY
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
Expand Down Expand Up @@ -123,6 +125,109 @@ def test_gated_deltanet_no_input_pos_does_not_leak_state(self):
torch.allclose(state_after_first, state_after_second, atol=1e-5)
)

def test_gated_deltanet_chunked_prefill_matches_full_sequence(self):
torch.manual_seed(0)
args = self._make_args(
use_kv_cache=True,
use_q_gate=True,
linear_conv_kernel_dim=4,
linear_key_head_dim=4,
linear_value_head_dim=4,
linear_num_key_heads=2,
linear_num_value_heads=4,
)
rope = Rope(args)
attn_full = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
attn_chunked = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
attn_chunked.load_state_dict(attn_full.state_dict())

x = torch.randn(1, 5, args.dim)
dummy_freq = torch.zeros(1, 1)

full_output, _ = attn_full(
x,
dummy_freq,
dummy_freq,
input_pos=torch.tensor([0], dtype=torch.long),
)

chunk_outputs = []
for start, end in ((0, 3), (3, 4), (4, 5)):
output, _ = attn_chunked(
x[:, start:end],
dummy_freq,
dummy_freq,
input_pos=torch.tensor([start], dtype=torch.long),
)
chunk_outputs.append(output)

chunked_output = torch.cat(chunk_outputs, dim=1)

self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5))
self.assertTrue(
torch.allclose(
attn_chunked.recurrent_state, attn_full.recurrent_state, atol=1e-5
)
)
self.assertTrue(
torch.allclose(attn_chunked.conv_state, attn_full.conv_state, atol=1e-5)
)

def test_gated_deltanet_custom_op_matches_fallback(self):
recurrent_op = attention_module._get_recurrent_gated_delta_rule_op()
if recurrent_op is None:
self.skipTest("llama::recurrent_gated_delta_rule is not available")

torch.manual_seed(0)
args = self._make_args(
use_kv_cache=True,
use_q_gate=True,
linear_conv_kernel_dim=4,
linear_key_head_dim=4,
linear_value_head_dim=4,
linear_num_key_heads=2,
linear_num_value_heads=4,
)
rope = Rope(args)
attn_custom = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
attn_fallback = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
attn_fallback.load_state_dict(attn_custom.state_dict())

query = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim)
key = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim)
value = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_v_dim)
g = torch.randn(1, 3, attn_custom.num_v_heads)
beta = torch.sigmoid(torch.randn(1, 3, attn_custom.num_v_heads))

original_op = attention_module._RECURRENT_GATED_DELTA_RULE_OP
original_tried_loading = (
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP
)
try:
attention_module._RECURRENT_GATED_DELTA_RULE_OP = recurrent_op
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
custom_output = attn_custom._recurrent_gated_delta_rule(
query, key, value, g, beta
)

attention_module._RECURRENT_GATED_DELTA_RULE_OP = None
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
fallback_output = attn_fallback._recurrent_gated_delta_rule(
query, key, value, g, beta
)
finally:
attention_module._RECURRENT_GATED_DELTA_RULE_OP = original_op
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = (
original_tried_loading
)

self.assertTrue(torch.allclose(custom_output, fallback_output, atol=1e-5))
self.assertTrue(
torch.allclose(
attn_custom.recurrent_state, attn_fallback.recurrent_state, atol=1e-5
)
)


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion extension/aten_util/make_aten_functor_from_et_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
#pragma once
#include <type_traits>
#include <vector>
#if __cplusplus < 201703L
#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \
(!defined(_MSC_VER) && __cplusplus < 201703L)
#error "This header requires C++17"
#endif
#include <ATen/native/Resize.h>
Expand Down
Loading
Loading