From 04170fd5c4ace8d5b0eb67b35fc2a4c3ff7ce61f Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 27 Apr 2026 19:24:57 -0700 Subject: [PATCH] Revert "Add recurrent gated delta rule custom op for Qwen3.5 attention (#18088)" This reverts commit 476a7ef427cc4b78c8767b7ed6f3b7db82642867. --- examples/models/llama/attention.py | 106 +---- .../llama/tests/test_export_llama_lib.py | 72 ---- .../llama/tests/test_qwen3_5_attention.py | 105 ----- .../make_aten_functor_from_et_functor.h | 3 +- extension/llm/custom_ops/custom_ops.py | 176 +------- .../op_fast_hadamard_transform_aten.cpp | 33 +- extension/llm/custom_ops/op_sdpa.cpp | 226 ---------- extension/llm/custom_ops/op_sdpa.h | 10 - extension/llm/custom_ops/op_sdpa_aot.cpp | 399 +++--------------- extension/llm/custom_ops/op_tile_crop_aot.cpp | 39 +- extension/llm/custom_ops/test_update_cache.py | 152 ------- 11 files changed, 94 insertions(+), 1227 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 7556ef60e19..d6dff173072 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -1,4 +1,3 @@ -import logging from abc import ABC, abstractmethod from enum import Enum from typing import Any, Dict, Optional, Tuple, Type, TypedDict @@ -53,8 +52,6 @@ def forward( ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {} -_RECURRENT_GATED_DELTA_RULE_OP = None -_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = False def register_attention(name: str): @@ -67,38 +64,6 @@ def decorator(cls: Type[Attention]): return decorator -def _get_recurrent_gated_delta_rule_op(): - global _RECURRENT_GATED_DELTA_RULE_OP - global _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP - - if _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP: - return _RECURRENT_GATED_DELTA_RULE_OP - - _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True - try: - _RECURRENT_GATED_DELTA_RULE_OP = ( - torch.ops.llama.recurrent_gated_delta_rule.default - ) - return _RECURRENT_GATED_DELTA_RULE_OP - except (AttributeError, RuntimeError): - pass - - try: - from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 - except (ImportError, OSError, RuntimeError): - logging.debug("Failed to import custom ops library", exc_info=True) - return None - - try: - _RECURRENT_GATED_DELTA_RULE_OP = ( - torch.ops.llama.recurrent_gated_delta_rule.default - ) - except (AttributeError, RuntimeError): - _RECURRENT_GATED_DELTA_RULE_OP = None - - return _RECURRENT_GATED_DELTA_RULE_OP - - class KVCache(nn.Module): def __init__( self, @@ -760,7 +725,7 @@ def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor: out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype) return out.transpose(1, 2).contiguous() - def _gated_delta_rule_op( + def _recurrent_gated_delta_rule( self, query: torch.Tensor, key: torch.Tensor, @@ -768,35 +733,20 @@ def _gated_delta_rule_op( g: torch.Tensor, beta: torch.Tensor, ) -> torch.Tensor: - batch_size = query.shape[0] - recurrent_gated_delta_rule_op = _get_recurrent_gated_delta_rule_op() - if recurrent_gated_delta_rule_op is not None: - return recurrent_gated_delta_rule_op( - query, - key, - value, - g, - beta, - self.recurrent_state[:batch_size], - ) - return self._naive_gated_delta_rule_op( - query, - key, - value, - g, - beta, - ) + # query/key/value: (batch, seq_len, num_heads, head_dim) + # g/beta: (batch, seq_len, num_heads) + initial_dtype = query.dtype + query = _l2norm(query, dim=-1, eps=1e-6) + key = _l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] - def _naive_gated_delta_rule_op( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - ) -> torch.Tensor: - batch_size, num_heads, sequence_length, _ = key.shape + batch_size, num_heads, sequence_length, k_head_dim = key.shape v_head_dim = value.shape[-1] + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale core_attn_out = torch.zeros( batch_size, @@ -830,36 +780,6 @@ def _naive_gated_delta_rule_op( last_recurrent_state.to(self.recurrent_state.dtype) ) - return core_attn_out - - def _recurrent_gated_delta_rule( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - ) -> torch.Tensor: - # query/key/value: (batch, seq_len, num_heads, head_dim) - # g/beta: (batch, seq_len, num_heads) - initial_dtype = query.dtype - query = _l2norm(query, dim=-1, eps=1e-6) - key = _l2norm(key, dim=-1, eps=1e-6) - query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) - for x in (query, key, value, beta, g) - ] - - scale = 1.0 / (query.shape[-1] ** 0.5) - query = query * scale - - core_attn_out = self._gated_delta_rule_op( - query, - key, - value, - g, - beta, - ) return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) def forward( diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index c96fea8c215..130a55f658c 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -5,10 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import json -import tempfile import unittest -from pathlib import Path from executorch.devtools.backend_debug import get_delegation_info @@ -28,7 +25,6 @@ from executorch.examples.models.llama.export_llama_lib import ( _export_llama, - _prepare_for_llama_export, build_args_parser, get_quantizer_and_quant_params, ) @@ -41,39 +37,6 @@ class ExportLlamaLibTest(unittest.TestCase): - def _make_tiny_qwen35_params(self) -> dict: - return { - "dim": 64, - "hidden_dim": 128, - "n_heads": 4, - "head_dim": 16, - "n_kv_heads": 2, - "n_layers": 4, - "norm_eps": 1e-6, - "rope_theta": 10000000.0, - "use_scaled_rope": False, - "vocab_size": 256, - "use_hf_rope": True, - "partial_rotary_factor": 0.25, - "attention_qkv_bias": False, - "use_qk_norm": True, - "qk_norm_before_rope": True, - "attention_type": "mha", - "use_q_gate": True, - "rms_norm_add_unit_offset": True, - "linear_conv_kernel_dim": 4, - "linear_key_head_dim": 8, - "linear_value_head_dim": 8, - "linear_num_key_heads": 4, - "linear_num_value_heads": 4, - "layer_types": [ - "linear_attention", - "full_attention", - "linear_attention", - "full_attention", - ], - } - def test_has_expected_ops_and_op_counts(self): """ Checks the presence of unwanted expensive ops. @@ -103,41 +66,6 @@ def test_has_expected_ops_and_op_counts(self): for op, _op_info in delegation_info.delegation_by_operator.items(): self.assertTrue(op not in UNWANTED_OPS) - def test_tiny_qwen35_export_uses_recurrent_gated_delta_rule(self): - with tempfile.TemporaryDirectory() as temp_dir: - params_path = Path(temp_dir) / "tiny_qwen35.json" - params_path.write_text(json.dumps(self._make_tiny_qwen35_params())) - - parser = build_args_parser() - args = parser.parse_args( - [ - "--model", - "qwen3_5_0_8b", - "--params", - str(params_path), - "--use_kv_cache", - "--disable_dynamic_shape", - "--max_seq_length", - "8", - "--max_context_length", - "8", - ] - ) - - llm_config = LlmConfig.from_args(args) - builder = _prepare_for_llama_export(llm_config).export() - assert builder.pre_autograd_graph_module is not None - - recurrent_nodes = [ - node - for node in builder.pre_autograd_graph_module.graph.nodes - if "auto_functionalized_v2" in str(node.target) - and node.args - and "llama.recurrent_gated_delta_rule" in str(node.args[0]) - ] - - self.assertEqual(len(recurrent_nodes), 2) - @unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available") def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self): llm_config = LlmConfig() diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index ba96a96aa43..5a9f67d57cf 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -6,9 +6,7 @@ import unittest -import executorch.examples.models.llama.attention as attention_module import torch - from executorch.examples.models.llama.attention import ATTENTION_REGISTRY from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import RMSNorm @@ -125,109 +123,6 @@ def test_gated_deltanet_no_input_pos_does_not_leak_state(self): torch.allclose(state_after_first, state_after_second, atol=1e-5) ) - def test_gated_deltanet_chunked_prefill_matches_full_sequence(self): - torch.manual_seed(0) - args = self._make_args( - use_kv_cache=True, - use_q_gate=True, - linear_conv_kernel_dim=4, - linear_key_head_dim=4, - linear_value_head_dim=4, - linear_num_key_heads=2, - linear_num_value_heads=4, - ) - rope = Rope(args) - attn_full = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) - attn_chunked = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) - attn_chunked.load_state_dict(attn_full.state_dict()) - - x = torch.randn(1, 5, args.dim) - dummy_freq = torch.zeros(1, 1) - - full_output, _ = attn_full( - x, - dummy_freq, - dummy_freq, - input_pos=torch.tensor([0], dtype=torch.long), - ) - - chunk_outputs = [] - for start, end in ((0, 3), (3, 4), (4, 5)): - output, _ = attn_chunked( - x[:, start:end], - dummy_freq, - dummy_freq, - input_pos=torch.tensor([start], dtype=torch.long), - ) - chunk_outputs.append(output) - - chunked_output = torch.cat(chunk_outputs, dim=1) - - self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5)) - self.assertTrue( - torch.allclose( - attn_chunked.recurrent_state, attn_full.recurrent_state, atol=1e-5 - ) - ) - self.assertTrue( - torch.allclose(attn_chunked.conv_state, attn_full.conv_state, atol=1e-5) - ) - - def test_gated_deltanet_custom_op_matches_fallback(self): - recurrent_op = attention_module._get_recurrent_gated_delta_rule_op() - if recurrent_op is None: - self.skipTest("llama::recurrent_gated_delta_rule is not available") - - torch.manual_seed(0) - args = self._make_args( - use_kv_cache=True, - use_q_gate=True, - linear_conv_kernel_dim=4, - linear_key_head_dim=4, - linear_value_head_dim=4, - linear_num_key_heads=2, - linear_num_value_heads=4, - ) - rope = Rope(args) - attn_custom = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) - attn_fallback = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) - attn_fallback.load_state_dict(attn_custom.state_dict()) - - query = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim) - key = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim) - value = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_v_dim) - g = torch.randn(1, 3, attn_custom.num_v_heads) - beta = torch.sigmoid(torch.randn(1, 3, attn_custom.num_v_heads)) - - original_op = attention_module._RECURRENT_GATED_DELTA_RULE_OP - original_tried_loading = ( - attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP - ) - try: - attention_module._RECURRENT_GATED_DELTA_RULE_OP = recurrent_op - attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True - custom_output = attn_custom._recurrent_gated_delta_rule( - query, key, value, g, beta - ) - - attention_module._RECURRENT_GATED_DELTA_RULE_OP = None - attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True - fallback_output = attn_fallback._recurrent_gated_delta_rule( - query, key, value, g, beta - ) - finally: - attention_module._RECURRENT_GATED_DELTA_RULE_OP = original_op - attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = ( - original_tried_loading - ) - - self.assertTrue(torch.allclose(custom_output, fallback_output, atol=1e-5)) - self.assertTrue( - torch.allclose( - attn_custom.recurrent_state, attn_fallback.recurrent_state, atol=1e-5 - ) - ) - if __name__ == "__main__": unittest.main() diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h index 67e7344330e..8e1c2bf0143 100644 --- a/extension/aten_util/make_aten_functor_from_et_functor.h +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -15,8 +15,7 @@ #pragma once #include #include -#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \ - (!defined(_MSC_VER) && __cplusplus < 201703L) +#if __cplusplus < 201703L #error "This header requires C++17" #endif #include diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index e0b009d7a13..9aacded4b4c 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -11,9 +11,7 @@ # pyre-unsafe import logging -import os -from pathlib import Path from typing import Tuple import torch @@ -23,84 +21,33 @@ from torch.library import impl aten = torch.ops.aten -_CUSTOM_OPS_DLL_DIR_HANDLES = [] +try: + op = torch.ops.llama.sdpa_with_kv_cache.default + assert op is not None + op2 = torch.ops.llama.fast_hadamard_transform.default + assert op2 is not None +except: + # This is needed to ensure that custom ops are registered + from executorch.extension.pybindings import portable_lib # noqa # usort: skip -def _is_custom_ops_registered() -> bool: - try: - torch.ops.llama.sdpa_with_kv_cache.default - torch.ops.llama.fast_hadamard_transform.default - return True - except (AttributeError, RuntimeError): - return False - - -def _get_custom_ops_library_override() -> Path | None: - override = os.environ.get("EXECUTORCH_CUSTOM_OPS_AOT_LIB") - if override is None: - return None - - lib_path = Path(override).expanduser().resolve() - if not lib_path.is_file(): - raise FileNotFoundError( - "EXECUTORCH_CUSTOM_OPS_AOT_LIB must point to an existing " - f"custom_ops_aot_lib, but got {lib_path}" - ) - return lib_path - - -def _find_custom_ops_library() -> Path: - override = _get_custom_ops_library_override() - if override is not None: - return override + # Ideally package is installed in only one location but usage of + # PYATHONPATH can result in multiple locations. + # ATM this is mainly used in CI for qnn runner. Will need to revisit this + from pathlib import Path package_path = Path(__file__).parent.resolve() - candidates = [] - patterns = ( - "**/custom_ops_aot_lib.dll", - "**/libcustom_ops_aot_lib.so", - "**/libcustom_ops_aot_lib.dylib", - ) - - for pattern in patterns: - candidates.extend(package_path.glob(pattern)) - - libs = sorted({path.resolve() for path in candidates if path.is_file()}) - if not libs: - raise FileNotFoundError( - f"Could not find custom_ops_aot_lib under {package_path}" - ) - return max(libs, key=lambda path: path.stat().st_mtime) - - -def _load_custom_ops_library() -> None: - try: - # This is needed to ensure that custom ops are registered when - # portable_lib is available in the current environment. - from executorch.extension.pybindings import portable_lib # noqa # usort: skip - except ImportError: - portable_lib = None - - lib_path = _find_custom_ops_library() - logging.info(f"Loading custom ops library: {lib_path}") - - if os.name == "nt": - _CUSTOM_OPS_DLL_DIR_HANDLES.append(os.add_dll_directory(str(lib_path.parent))) - torch_lib_dir = Path(torch.__file__).resolve().parent / "lib" - if torch_lib_dir.is_dir(): - _CUSTOM_OPS_DLL_DIR_HANDLES.append(os.add_dll_directory(str(torch_lib_dir))) + logging.info(f"Looking for libcustom_ops_aot_lib.so in {package_path}") - torch.ops.load_library(lib_path) + libs = list(package_path.glob("**/*custom_ops_aot_lib.*")) - # Keep the import alive to avoid lint complaints in environments where - # portable_lib is needed for symbol resolution. - _ = portable_lib - - -if not _is_custom_ops_registered(): - _load_custom_ops_library() - if not _is_custom_ops_registered(): - raise RuntimeError("Failed to register ExecuTorch custom ops library") + assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" + logging.info(f"Loading custom ops library: {libs[0]}") + torch.ops.load_library(libs[0]) + op = torch.ops.llama.sdpa_with_kv_cache.default + assert op is not None + op2 = torch.ops.llama.fast_hadamard_transform.default + assert op2 is not None custom_ops_lib = torch.library.Library("llama", "IMPL") @@ -324,87 +271,6 @@ def update_cache_with_indices_meta( return torch.empty((1,), dtype=value.dtype, device="meta") -def _validate_recurrent_gated_delta_rule_params( - query, - key, - value, - g, - beta, - recurrent_state, -): - assert ( - query.dim() == 4 - ), f"Expected query to be 4 dimensional but got {query.dim()} dimensions." - assert ( - key.dim() == 4 - ), f"Expected key to be 4 dimensional but got {key.dim()} dimensions." - assert ( - value.dim() == 4 - ), f"Expected value to be 4 dimensional but got {value.dim()} dimensions." - assert g.dim() == 3, f"Expected g to be 3 dimensional but got {g.dim()} dimensions." - assert ( - beta.dim() == 3 - ), f"Expected beta to be 3 dimensional but got {beta.dim()} dimensions." - assert ( - recurrent_state.dim() == 4 - ), f"Expected recurrent_state to be 4 dimensional but got {recurrent_state.dim()} dimensions." - - for name, tensor in { - "query": query, - "key": key, - "value": value, - "g": g, - "beta": beta, - "recurrent_state": recurrent_state, - }.items(): - assert ( - tensor.dtype == torch.float32 - ), f"Expected {name} to be float32 but got {tensor.dtype}" - - assert ( - query.shape == key.shape - ), f"Expected query and key to have matching shapes but got {query.shape} and {key.shape}" - assert ( - query.shape[:3] == value.shape[:3] - ), f"Expected query and value to match in batch/head/sequence dims but got {query.shape} and {value.shape}" - assert ( - g.shape == query.shape[:3] - ), f"Expected g to match query batch/head/sequence dims but got {g.shape} and {query.shape}" - assert ( - beta.shape == query.shape[:3] - ), f"Expected beta to match query batch/head/sequence dims but got {beta.shape} and {query.shape}" - assert recurrent_state.shape == ( - query.size(0), - query.size(1), - query.size(3), - value.size(3), - ), ( - "Expected recurrent_state to have shape " - f"{(query.size(0), query.size(1), query.size(3), value.size(3))} " - f"but got {recurrent_state.shape}" - ) - - -@impl(custom_ops_lib, "recurrent_gated_delta_rule", "Meta") -def recurrent_gated_delta_rule_meta( - query, - key, - value, - g, - beta, - recurrent_state, -): - _validate_recurrent_gated_delta_rule_params( - query, - key, - value, - g, - beta, - recurrent_state, - ) - return torch.empty_like(value) - - def _validate_quantized_sdpa_params( query, key, diff --git a/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp b/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp index d48f593868c..146ac3cc298 100644 --- a/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp +++ b/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp @@ -13,40 +13,14 @@ namespace torch::executor::native { namespace { -template -auto to_et_arg(AType&& value) { - return executorch::extension::internal::type_convert( - std::forward(value)); -} - -at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { - auto converted_result = - executorch::extension::internal::type_convert( - et_result) - .call(); - at::native::resize_output(out, converted_result.sizes()); - out.copy_(converted_result); - return out; -} - Tensor& fast_hadamard_transform_out_no_context(const Tensor& vec, Tensor& out) { executorch::aten::RuntimeContext context; return fast_hadamard_transform_out(context, vec, out); } - -at::Tensor& fast_hadamard_transform_out_aten( - const at::Tensor& vec, - at::Tensor& out) { - auto vec_et = to_et_arg(vec); - auto out_et = to_et_arg(out); - auto& et_result = - fast_hadamard_transform_out_no_context(vec_et.call(), out_et.call()); - return copy_et_result_to_out(et_result, out); -} - at::Tensor fast_hadamard_transform_aten(const at::Tensor& vec) { auto out = at::empty_like(vec); - fast_hadamard_transform_out_aten(vec, out); + WRAP_TO_ATEN(fast_hadamard_transform_out_no_context, 1) + (vec, out); return out; } } // namespace @@ -64,5 +38,6 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { torch::executor::native::fast_hadamard_transform_aten); m.impl( "fast_hadamard_transform.out", - torch::executor::native::fast_hadamard_transform_out_aten); + WRAP_TO_ATEN( + torch::executor::native::fast_hadamard_transform_out_no_context, 1)); } diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 76ee9cb915f..72bddce7b5b 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -15,10 +15,6 @@ #include // @lint-ignore CLANGTIDY facebook-unused-include-check #include -#include -#include -#include -#include #ifdef ET_USE_THREADPOOL #include @@ -182,68 +178,6 @@ bool validate_cache_params( return true; } -bool validate_recurrent_gated_delta_rule_args( - const Tensor& query, - const Tensor& key, - const Tensor& value, - const Tensor& g, - const Tensor& beta, - const Tensor& recurrent_state) { - ET_CHECK_OR_RETURN_FALSE(query.dim() == 4, "query must be a 4D tensor"); - ET_CHECK_OR_RETURN_FALSE(key.dim() == 4, "key must be a 4D tensor"); - ET_CHECK_OR_RETURN_FALSE(value.dim() == 4, "value must be a 4D tensor"); - ET_CHECK_OR_RETURN_FALSE(g.dim() == 3, "g must be a 3D tensor"); - ET_CHECK_OR_RETURN_FALSE(beta.dim() == 3, "beta must be a 3D tensor"); - ET_CHECK_OR_RETURN_FALSE( - recurrent_state.dim() == 4, "recurrent_state must be a 4D tensor"); - - ET_CHECK_OR_RETURN_FALSE( - query.scalar_type() == ScalarType::Float, "query must be float32"); - ET_CHECK_OR_RETURN_FALSE( - key.scalar_type() == ScalarType::Float, "key must be float32"); - ET_CHECK_OR_RETURN_FALSE( - value.scalar_type() == ScalarType::Float, "value must be float32"); - ET_CHECK_OR_RETURN_FALSE( - g.scalar_type() == ScalarType::Float, "g must be float32"); - ET_CHECK_OR_RETURN_FALSE( - beta.scalar_type() == ScalarType::Float, "beta must be float32"); - ET_CHECK_OR_RETURN_FALSE( - recurrent_state.scalar_type() == ScalarType::Float, - "recurrent_state must be float32"); - - ET_CHECK_OR_RETURN_FALSE( - query.size(0) == key.size(0) && query.size(1) == key.size(1) && - query.size(2) == key.size(2) && query.size(3) == key.size(3), - "query and key must have matching shapes"); - ET_CHECK_OR_RETURN_FALSE( - query.size(0) == value.size(0) && query.size(1) == value.size(1) && - query.size(2) == value.size(2), - "query and value must match in batch/head/sequence dims"); - ET_CHECK_OR_RETURN_FALSE( - g.size(0) == query.size(0) && g.size(1) == query.size(1) && - g.size(2) == query.size(2), - "g must match query batch/head/sequence dims"); - ET_CHECK_OR_RETURN_FALSE( - beta.size(0) == query.size(0) && beta.size(1) == query.size(1) && - beta.size(2) == query.size(2), - "beta must match query batch/head/sequence dims"); - ET_CHECK_OR_RETURN_FALSE( - recurrent_state.size(0) == query.size(0) && - recurrent_state.size(1) == query.size(1) && - recurrent_state.size(2) == query.size(3) && - recurrent_state.size(3) == value.size(3), - "recurrent_state shape must match [B, H, K, V]"); - - for (const Tensor* tensor : - {&query, &key, &value, &g, &beta, &recurrent_state}) { - ET_CHECK_OR_RETURN_FALSE( - is_contiguous_dim_order((*tensor).dim_order().data(), (*tensor).dim()), - "recurrent gated delta rule expects contiguous inputs"); - } - - return true; -} - // TODO: seq_length is not yet used for copy void update_cache( const Tensor& projected_value, @@ -676,133 +610,6 @@ Tensor& sdpa_with_kv_cache_out( return output; } - -Tensor& recurrent_gated_delta_rule_out( - RuntimeContext& ctx, - const Tensor& query, - const Tensor& key, - const Tensor& value, - const Tensor& g, - const Tensor& beta, - Tensor& recurrent_state, - Tensor& output) { - ET_KERNEL_CHECK_MSG( - ctx, - resize_tensor(output, value.sizes()) == Error::Ok, - InvalidArgument, - output, - "Failed to resize recurrent_gated_delta_rule output tensor."); - ET_KERNEL_CHECK( - ctx, - validate_recurrent_gated_delta_rule_args( - query, key, value, g, beta, recurrent_state), - InvalidArgument, - output); - ET_KERNEL_CHECK( - ctx, output.scalar_type() == ScalarType::Float, InvalidArgument, output); - ET_KERNEL_CHECK( - ctx, - is_contiguous_dim_order(output.dim_order().data(), output.dim()), - InvalidArgument, - output); - - const auto batch_size = query.size(0); - const auto num_heads = query.size(1); - const auto sequence_length = query.size(2); - const auto k_head_dim = query.size(3); - const auto v_head_dim = value.size(3); - - const auto q_batch_stride = num_heads * sequence_length * k_head_dim; - const auto q_head_stride = sequence_length * k_head_dim; - const auto q_seq_stride = k_head_dim; - - const auto value_batch_stride = num_heads * sequence_length * v_head_dim; - const auto value_head_stride = sequence_length * v_head_dim; - const auto value_seq_stride = v_head_dim; - - const auto gv_batch_stride = num_heads * sequence_length; - const auto gv_head_stride = sequence_length; - - const auto state_batch_stride = num_heads * k_head_dim * v_head_dim; - const auto state_head_stride = k_head_dim * v_head_dim; - - const auto* query_data = query.const_data_ptr(); - const auto* key_data = key.const_data_ptr(); - const auto* value_data = value.const_data_ptr(); - const auto* g_data = g.const_data_ptr(); - const auto* beta_data = beta.const_data_ptr(); - auto* recurrent_state_data = recurrent_state.mutable_data_ptr(); - auto* output_data = output.mutable_data_ptr(); - std::vector kv_mem(v_head_dim); - std::vector delta(v_head_dim); - - for (int64_t batch = 0; batch < batch_size; ++batch) { - for (int64_t head = 0; head < num_heads; ++head) { - const auto q_offset = batch * q_batch_stride + head * q_head_stride; - const auto value_offset = - batch * value_batch_stride + head * value_head_stride; - const auto gv_offset = batch * gv_batch_stride + head * gv_head_stride; - const auto state_offset = - batch * state_batch_stride + head * state_head_stride; - - const auto* q_head = query_data + q_offset; - const auto* k_head = key_data + q_offset; - const auto* value_head = value_data + value_offset; - const auto* g_head = g_data + gv_offset; - const auto* beta_head = beta_data + gv_offset; - auto* state_head = recurrent_state_data + state_offset; - auto* output_head = output_data + value_offset; - - for (int64_t token = 0; token < sequence_length; ++token) { - const auto* q_t = q_head + token * q_seq_stride; - const auto* k_t = k_head + token * q_seq_stride; - const auto* v_t = value_head + token * value_seq_stride; - auto* output_t = output_head + token * value_seq_stride; - - const float g_t = std::exp(g_head[token]); - const float beta_t = beta_head[token]; - - if (g_t != 1.0f) { - for (int64_t idx = 0; idx < state_head_stride; ++idx) { - state_head[idx] *= g_t; - } - } - - std::fill(kv_mem.begin(), kv_mem.end(), 0.0f); - for (int64_t k_idx = 0; k_idx < k_head_dim; ++k_idx) { - const float key_value = k_t[k_idx]; - const auto* state_row = state_head + k_idx * v_head_dim; - for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { - kv_mem[v_idx] += state_row[v_idx] * key_value; - } - } - - for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { - delta[v_idx] = (v_t[v_idx] - kv_mem[v_idx]) * beta_t; - } - - for (int64_t k_idx = 0; k_idx < k_head_dim; ++k_idx) { - const float key_value = k_t[k_idx]; - auto* state_row = state_head + k_idx * v_head_dim; - for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { - state_row[v_idx] += key_value * delta[v_idx]; - } - } - - std::fill(output_t, output_t + v_head_dim, 0.0f); - for (int64_t k_idx = 0; k_idx < k_head_dim; ++k_idx) { - const float query_value = q_t[k_idx]; - const auto* state_row = state_head + k_idx * v_head_dim; - for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { - output_t[v_idx] += state_row[v_idx] * query_value; - } - } - } - } - } - - return output; -} } // namespace native } // namespace executor } // namespace torch @@ -821,36 +628,3 @@ EXECUTORCH_LIBRARY( llama, "custom_quantized_sdpa.out", torch::executor::native::custom_quantized_sdpa_out); - -namespace { - -void recurrent_gated_delta_rule_out_boxed( - executorch::runtime::KernelRuntimeContext& ctx, - executorch::runtime::Span stack) { - ET_KERNEL_CHECK_MSG( - ctx, - stack.size() == 7, - InvalidProgram, - /* void */, - "Expected %zu args, got %zu", - static_cast(7), - stack.size()); - - auto& query = stack[0]->toTensor(); - auto& key = stack[1]->toTensor(); - auto& value = stack[2]->toTensor(); - auto& g = stack[3]->toTensor(); - auto& beta = stack[4]->toTensor(); - auto& recurrent_state = stack[5]->toTensor(); - auto& output = stack[6]->toTensor(); - - (void)torch::executor::native::recurrent_gated_delta_rule_out( - ctx, query, key, value, g, beta, recurrent_state, output); -} - -const auto recurrent_gated_delta_rule_out_registration = - executorch::runtime::register_kernel(executorch::runtime::Kernel( - "llama::recurrent_gated_delta_rule.out", - recurrent_gated_delta_rule_out_boxed)); - -} // namespace diff --git a/extension/llm/custom_ops/op_sdpa.h b/extension/llm/custom_ops/op_sdpa.h index 9f029f52f31..9d357eb6ea1 100644 --- a/extension/llm/custom_ops/op_sdpa.h +++ b/extension/llm/custom_ops/op_sdpa.h @@ -75,16 +75,6 @@ Tensor& custom_quantized_sdpa_out( const optional& v_scales, const bool is_seq_at_dim_1, Tensor& output); - -Tensor& recurrent_gated_delta_rule_out( - RuntimeContext& ctx, - const Tensor& query, - const Tensor& key, - const Tensor& value, - const Tensor& g, - const Tensor& beta, - Tensor& recurrent_state, - Tensor& output); } // namespace native } // namespace executor } // namespace torch diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index d4d1122f614..5bbf22d336e 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -17,24 +17,6 @@ namespace torch { namespace executor { namespace native { -namespace { -template -auto to_et_arg(AType&& value) { - return executorch::extension::internal::type_convert( - std::forward(value)); -} - -at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { - auto converted_result = - executorch::extension::internal::type_convert( - et_result) - .call(); - at::native::resize_output(out, converted_result.sizes()); - out.copy_(converted_result); - return out; -} -} // namespace - Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, @@ -68,20 +50,6 @@ at::Tensor sdpa_with_kv_cache_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale); -at::Tensor& sdpa_with_kv_cache_out_aten( - const at::Tensor& q_projected, - const at::Tensor& k_projected, - const at::Tensor& v_projected, - at::Tensor& key_cache, - at::Tensor& value_cache, - const int64_t start_pos, - const int64_t seq_len, - const std::optional attn_mask, - const double dropout_p, - const bool is_causal, - const std::optional scale, - at::Tensor& output); - Tensor& custom_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -109,17 +77,6 @@ at::Tensor custom_sdpa_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale); -at::Tensor& custom_sdpa_out_aten( - const at::Tensor& q, - const at::Tensor& k, - const at::Tensor& v, - const int64_t start_pos, - const std::optional attn_mask, - const double dropout_p, - const bool is_causal, - const std::optional scale, - at::Tensor& output); - Tensor& custom_quantized_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -161,24 +118,6 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& v_scales, const bool is_seq_at_dim_2); -at::Tensor& custom_quantized_sdpa_out_aten( - const at::Tensor& q, - const at::Tensor& k, - const at::Tensor& v, - const int64_t start_pos, - const std::optional attn_mask, - const double dropout_p, - const bool is_causal, - const std::optional scale, - const std::optional& q_zero_points, - const std::optional& q_scales, - const std::optional& k_zero_points, - const std::optional& k_scales, - const std::optional& v_zero_points, - const std::optional& v_scales, - const bool is_seq_at_dim_2, - at::Tensor& output); - Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, @@ -190,12 +129,6 @@ at::Tensor update_cache_aten( at::Tensor& cache, const int64_t start_pos); -at::Tensor& update_cache_out_aten( - const at::Tensor& value, - at::Tensor& cache, - const int64_t start_pos, - at::Tensor& output); - // New functions for update_cache_with_indices Tensor& update_cache_with_indices_out_no_context( const Tensor& value, @@ -210,39 +143,6 @@ at::Tensor update_cache_with_indices_aten( const int64_t start_pos, const at::Tensor& indices); -at::Tensor& update_cache_with_indices_out_aten( - const at::Tensor& value, - at::Tensor& cache, - const int64_t start_pos, - const at::Tensor& indices, - at::Tensor& output); - -Tensor& recurrent_gated_delta_rule_out_no_context( - const Tensor& query, - const Tensor& key, - const Tensor& value, - const Tensor& g, - const Tensor& beta, - Tensor& recurrent_state, - Tensor& output); - -at::Tensor recurrent_gated_delta_rule_aten( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& g, - const at::Tensor& beta, - at::Tensor& recurrent_state); - -at::Tensor& recurrent_gated_delta_rule_out_aten( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& g, - const at::Tensor& beta, - at::Tensor& recurrent_state, - at::Tensor& output); - Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, @@ -292,59 +192,22 @@ at::Tensor sdpa_with_kv_cache_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { auto output = at::empty_like(q_projected); - sdpa_with_kv_cache_out_aten( - q_projected, - k_projected, - v_projected, - key_cache, - value_cache, - start_pos, - seq_len, - attn_mask, - dropout_p, - is_causal, - scale, - output); + WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11) + (q_projected, + k_projected, + v_projected, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask, + dropout_p, + is_causal, + scale, + output); return output; } -at::Tensor& sdpa_with_kv_cache_out_aten( - const at::Tensor& q_projected, - const at::Tensor& k_projected, - const at::Tensor& v_projected, - at::Tensor& key_cache, - at::Tensor& value_cache, - const int64_t start_pos, - const int64_t seq_len, - const std::optional attn_mask, - const double dropout_p, - const bool is_causal, - const std::optional scale, - at::Tensor& output) { - auto q_et = to_et_arg(q_projected); - auto k_et = to_et_arg(k_projected); - auto v_et = to_et_arg(v_projected); - auto key_cache_et = to_et_arg(key_cache); - auto value_cache_et = to_et_arg(value_cache); - auto attn_mask_et = to_et_arg>(attn_mask); - auto scale_et = to_et_arg>(scale); - auto output_et = to_et_arg(output); - auto& et_result = sdpa_with_kv_cache_out_no_context( - q_et.call(), - k_et.call(), - v_et.call(), - key_cache_et.call(), - value_cache_et.call(), - start_pos, - seq_len, - attn_mask_et.call(), - dropout_p, - is_causal, - scale_et.call(), - output_et.call()); - return copy_et_result_to_out(et_result, output); -} - Tensor& custom_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -385,40 +248,11 @@ at::Tensor custom_sdpa_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { auto output = at::empty(q.sizes()); - custom_sdpa_out_aten( - q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); + WRAP_TO_ATEN(custom_sdpa_out_no_context, 8) + (q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); return output; } -at::Tensor& custom_sdpa_out_aten( - const at::Tensor& q, - const at::Tensor& k, - const at::Tensor& v, - const int64_t start_pos, - const std::optional attn_mask, - const double dropout_p, - const bool is_causal, - const std::optional scale, - at::Tensor& output) { - auto q_et = to_et_arg(q); - auto k_et = to_et_arg(k); - auto v_et = to_et_arg(v); - auto attn_mask_et = to_et_arg>(attn_mask); - auto scale_et = to_et_arg>(scale); - auto output_et = to_et_arg(output); - auto& et_result = custom_sdpa_out_no_context( - q_et.call(), - k_et.call(), - v_et.call(), - start_pos, - attn_mask_et.call(), - dropout_p, - is_causal, - scale_et.call(), - output_et.call()); - return copy_et_result_to_out(et_result, output); -} - Tensor& custom_quantized_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -480,75 +314,26 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& v_scales, const bool is_seq_at_dim_2) { auto output = at::empty(q.sizes()); - custom_quantized_sdpa_out_aten( - q, - k, - v, - start_pos, - attn_mask, - dropout_p, - is_causal, - scale, - q_zero_points, - q_scales, - k_zero_points, - k_scales, - v_zero_points, - v_scales, - is_seq_at_dim_2, - output); + WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 15) + (q, + k, + v, + start_pos, + attn_mask, + dropout_p, + is_causal, + scale, + q_zero_points, + q_scales, + k_zero_points, + k_scales, + v_zero_points, + v_scales, + is_seq_at_dim_2, + output); return output; } -at::Tensor& custom_quantized_sdpa_out_aten( - const at::Tensor& q, - const at::Tensor& k, - const at::Tensor& v, - const int64_t start_pos, - const std::optional attn_mask, - const double dropout_p, - const bool is_causal, - const std::optional scale, - const std::optional& q_zero_points, - const std::optional& q_scales, - const std::optional& k_zero_points, - const std::optional& k_scales, - const std::optional& v_zero_points, - const std::optional& v_scales, - const bool is_seq_at_dim_2, - at::Tensor& output) { - auto q_et = to_et_arg(q); - auto k_et = to_et_arg(k); - auto v_et = to_et_arg(v); - auto attn_mask_et = to_et_arg>(attn_mask); - auto scale_et = to_et_arg>(scale); - auto q_zero_points_et = to_et_arg>(q_zero_points); - auto q_scales_et = to_et_arg>(q_scales); - auto k_zero_points_et = to_et_arg>(k_zero_points); - auto k_scales_et = to_et_arg>(k_scales); - auto v_zero_points_et = to_et_arg>(v_zero_points); - auto v_scales_et = to_et_arg>(v_scales); - auto output_et = to_et_arg(output); - auto& et_result = custom_quantized_sdpa_out_no_context( - q_et.call(), - k_et.call(), - v_et.call(), - start_pos, - attn_mask_et.call(), - dropout_p, - is_causal, - scale_et.call(), - q_zero_points_et.call(), - q_scales_et.call(), - k_zero_points_et.call(), - k_scales_et.call(), - v_zero_points_et.call(), - v_scales_et.call(), - is_seq_at_dim_2, - output_et.call()); - return copy_et_result_to_out(et_result, output); -} - Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, @@ -564,23 +349,11 @@ at::Tensor update_cache_aten( at::Tensor& cache, const int64_t start_pos) { auto output = at::empty({1}); - update_cache_out_aten(value, cache, start_pos, output); + WRAP_TO_ATEN(update_cache_out_no_context, 3) + (value, cache, start_pos, output); return output; } -at::Tensor& update_cache_out_aten( - const at::Tensor& value, - at::Tensor& cache, - const int64_t start_pos, - at::Tensor& output) { - auto value_et = to_et_arg(value); - auto cache_et = to_et_arg(cache); - auto output_et = to_et_arg(output); - auto& et_result = update_cache_out_no_context( - value_et.call(), cache_et.call(), start_pos, output_et.call()); - return copy_et_result_to_out(et_result, output); -} - // Implementations for update_cache_with_indices Tensor& update_cache_with_indices_out_no_context( const Tensor& value, @@ -599,81 +372,11 @@ at::Tensor update_cache_with_indices_aten( const int64_t start_pos, const at::Tensor& indices) { auto output = at::empty({1}); - update_cache_with_indices_out_aten(value, cache, start_pos, indices, output); + WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4) + (value, cache, start_pos, indices, output); return output; } -at::Tensor& update_cache_with_indices_out_aten( - const at::Tensor& value, - at::Tensor& cache, - const int64_t start_pos, - const at::Tensor& indices, - at::Tensor& output) { - auto value_et = to_et_arg(value); - auto cache_et = to_et_arg(cache); - auto indices_et = to_et_arg(indices); - auto output_et = to_et_arg(output); - auto& et_result = update_cache_with_indices_out_no_context( - value_et.call(), - cache_et.call(), - start_pos, - indices_et.call(), - output_et.call()); - return copy_et_result_to_out(et_result, output); -} - -Tensor& recurrent_gated_delta_rule_out_no_context( - const Tensor& query, - const Tensor& key, - const Tensor& value, - const Tensor& g, - const Tensor& beta, - Tensor& recurrent_state, - Tensor& output) { - executorch::aten::RuntimeContext context{}; - return torch::executor::native::recurrent_gated_delta_rule_out( - context, query, key, value, g, beta, recurrent_state, output); -} - -at::Tensor recurrent_gated_delta_rule_aten( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& g, - const at::Tensor& beta, - at::Tensor& recurrent_state) { - auto output = at::empty_like(value); - recurrent_gated_delta_rule_out_aten( - query, key, value, g, beta, recurrent_state, output); - return output; -} - -at::Tensor& recurrent_gated_delta_rule_out_aten( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& g, - const at::Tensor& beta, - at::Tensor& recurrent_state, - at::Tensor& output) { - auto query_et = to_et_arg(query); - auto key_et = to_et_arg(key); - auto value_et = to_et_arg(value); - auto g_et = to_et_arg(g); - auto beta_et = to_et_arg(beta); - auto recurrent_state_et = to_et_arg(recurrent_state); - auto output_et = to_et_arg(output); - auto& et_result = recurrent_gated_delta_rule_out_no_context( - query_et.call(), - key_et.call(), - value_et.call(), - g_et.call(), - beta_et.call(), - recurrent_state_et.call(), - output_et.call()); - return copy_et_result_to_out(et_result, output); -} - } // namespace native } // namespace executor } // namespace torch @@ -707,12 +410,6 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { m.def( "update_cache_with_indices.out(Tensor value, Tensor(a!) cache, " "SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)"); - m.def( - "recurrent_gated_delta_rule(Tensor query, Tensor key, Tensor value, Tensor g, " - "Tensor beta, Tensor(a!) recurrent_state) -> Tensor"); - m.def( - "recurrent_gated_delta_rule.out(Tensor query, Tensor key, Tensor value, Tensor g, " - "Tensor beta, Tensor(a!) recurrent_state, *, Tensor(b!) out) -> Tensor(b!)"); m.def( "custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " @@ -733,27 +430,29 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { "sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten); m.impl( "sdpa_with_kv_cache.out", - torch::executor::native::sdpa_with_kv_cache_out_aten); + WRAP_TO_ATEN( + torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten); - m.impl("custom_sdpa.out", torch::executor::native::custom_sdpa_out_aten); + m.impl( + "custom_sdpa.out", + WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8)); m.impl("update_cache", torch::executor::native::update_cache_aten); - m.impl("update_cache.out", torch::executor::native::update_cache_out_aten); + m.impl( + "update_cache.out", + WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3)); m.impl( "update_cache_with_indices", torch::executor::native::update_cache_with_indices_aten); m.impl( "update_cache_with_indices.out", - torch::executor::native::update_cache_with_indices_out_aten); - m.impl( - "recurrent_gated_delta_rule", - torch::executor::native::recurrent_gated_delta_rule_aten); - m.impl( - "recurrent_gated_delta_rule.out", - torch::executor::native::recurrent_gated_delta_rule_out_aten); + WRAP_TO_ATEN( + torch::executor::native::update_cache_with_indices_out_no_context, + 4)); m.impl( "custom_quantized_sdpa", torch::executor::native::custom_quantized_sdpa_aten); m.impl( "custom_quantized_sdpa.out", - torch::executor::native::custom_quantized_sdpa_out_aten); + WRAP_TO_ATEN( + torch::executor::native::custom_quantized_sdpa_out_no_context, 15)); } diff --git a/extension/llm/custom_ops/op_tile_crop_aot.cpp b/extension/llm/custom_ops/op_tile_crop_aot.cpp index 7d89c462e1d..5aa98ee8d4a 100644 --- a/extension/llm/custom_ops/op_tile_crop_aot.cpp +++ b/extension/llm/custom_ops/op_tile_crop_aot.cpp @@ -16,30 +16,10 @@ namespace torch { namespace executor { namespace native { -namespace { -template -auto to_et_arg(AType&& value) { - return executorch::extension::internal::type_convert( - std::forward(value)); -} - -at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { - auto converted_result = - executorch::extension::internal::type_convert( - et_result) - .call(); - at::native::resize_output(out, converted_result.sizes()); - out.copy_(converted_result); - return out; -} -} // namespace Tensor& tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out); -at::Tensor& -tile_crop_out_aten(const at::Tensor& input, int64_t tile_size, at::Tensor& out); - Tensor& tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out) { executorch::aten::RuntimeContext context{}; @@ -48,21 +28,12 @@ tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out) { at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size); -at::Tensor& tile_crop_out_aten( - const at::Tensor& input, - int64_t tile_size, - at::Tensor& out) { - auto input_et = to_et_arg(input); - auto out_et = to_et_arg(out); - auto& et_result = - tile_crop_out_no_context(input_et.call(), tile_size, out_et.call()); - return copy_et_result_to_out(et_result, out); -} - at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size) { // max_num_tiles = 4, num_channels = 3. auto output = at::empty({4, 3, tile_size, tile_size}); - tile_crop_out_aten(input, tile_size, output); + + WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2) + (input, tile_size, output); return output; } @@ -78,5 +49,7 @@ TORCH_LIBRARY(preprocess, m) { TORCH_LIBRARY_IMPL(preprocess, CompositeExplicitAutograd, m) { m.impl("tile_crop", torch::executor::native::tile_crop_aten); - m.impl("tile_crop.out", torch::executor::native::tile_crop_out_aten); + m.impl( + "tile_crop.out", + WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2)); } diff --git a/extension/llm/custom_ops/test_update_cache.py b/extension/llm/custom_ops/test_update_cache.py index 7edd273d8b9..84a349c97f0 100644 --- a/extension/llm/custom_ops/test_update_cache.py +++ b/extension/llm/custom_ops/test_update_cache.py @@ -431,155 +431,3 @@ def test_batched_update_kv_cache_more_updates(self): self._update_and_validate( k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) - - -class RecurrentGatedDeltaRuleTest(unittest.TestCase): - def _make_inputs( - self, - batch_size: int = 2, - num_heads: int = 3, - seq_len: int = 4, - k_head_dim: int = 5, - v_head_dim: int = 6, - ): - query = torch.randn(batch_size, num_heads, seq_len, k_head_dim) - key = torch.randn(batch_size, num_heads, seq_len, k_head_dim) - value = torch.randn(batch_size, num_heads, seq_len, v_head_dim) - g = torch.randn(batch_size, num_heads, seq_len) - beta = torch.sigmoid(torch.randn(batch_size, num_heads, seq_len)) - recurrent_state = torch.randn(batch_size, num_heads, k_head_dim, v_head_dim) - return query, key, value, g, beta, recurrent_state - - def _reference_recurrent_gated_delta_rule( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - recurrent_state: torch.Tensor, - ): - state = recurrent_state.clone() - output = torch.zeros_like(value) - - for token in range(query.size(2)): - g_t = g[:, :, token].exp().unsqueeze(-1).unsqueeze(-1) - beta_t = beta[:, :, token].unsqueeze(-1) - k_t = key[:, :, token] - v_t = value[:, :, token] - q_t = query[:, :, token] - - state = state * g_t - kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) - delta = (v_t - kv_mem) * beta_t - state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) - output[:, :, token] = (state * q_t.unsqueeze(-1)).sum(dim=-2) - - return output, state - - def test_recurrent_gated_delta_rule_matches_reference(self): - torch.manual_seed(0) - - test_cases = ( - (2, 3, 4, 5, 6), - (1, 4, 7, 8, 3), - ) - - for case in test_cases: - with self.subTest(case=case): - ( - query, - key, - value, - g, - beta, - recurrent_state, - ) = self._make_inputs(*case) - - expected_output, expected_state = ( - self._reference_recurrent_gated_delta_rule( - query, - key, - value, - g, - beta, - recurrent_state, - ) - ) - - actual_state = recurrent_state.clone() - actual_output = torch.ops.llama.recurrent_gated_delta_rule( - query, - key, - value, - g, - beta, - actual_state, - ) - - self.assertTrue( - torch.allclose(actual_output, expected_output, atol=1e-5) - ) - self.assertTrue(torch.allclose(actual_state, expected_state, atol=1e-5)) - - def test_recurrent_gated_delta_rule_out_matches_reference(self): - torch.manual_seed(0) - - query, key, value, g, beta, recurrent_state = self._make_inputs() - expected_output, expected_state = self._reference_recurrent_gated_delta_rule( - query, - key, - value, - g, - beta, - recurrent_state, - ) - - actual_state = recurrent_state.clone() - actual_output = torch.empty_like(value) - returned_output = torch.ops.llama.recurrent_gated_delta_rule.out( - query, - key, - value, - g, - beta, - actual_state, - out=actual_output, - ) - - self.assertEqual(returned_output.data_ptr(), actual_output.data_ptr()) - self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-5)) - self.assertTrue(torch.allclose(actual_state, expected_state, atol=1e-5)) - - def test_recurrent_gated_delta_rule_chunked_matches_full_sequence(self): - torch.manual_seed(0) - - query, key, value, g, beta, recurrent_state = self._make_inputs(seq_len=6) - - full_state = recurrent_state.clone() - full_output = torch.ops.llama.recurrent_gated_delta_rule( - query, - key, - value, - g, - beta, - full_state, - ) - - chunk_state = recurrent_state.clone() - chunk_outputs = [] - for start, end in ((0, 2), (2, 5), (5, 6)): - chunk_outputs.append( - torch.ops.llama.recurrent_gated_delta_rule( - query[:, :, start:end, :], - key[:, :, start:end, :], - value[:, :, start:end, :], - g[:, :, start:end], - beta[:, :, start:end], - chunk_state, - ) - ) - - chunked_output = torch.cat(chunk_outputs, dim=2) - self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5)) - self.assertTrue(torch.allclose(chunk_state, full_state, atol=1e-5))