From ad82d03a2832bc8d8d61fe80a1c14980297935d1 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 9 Jun 2025 15:10:18 -0700 Subject: [PATCH 1/6] Update [ghstack-poisoned] --- backends/xnnpack/runtime/XNNCompiler.cpp | 8 ++++++++ backends/xnnpack/serialization/runtime_schema.fbs | 9 +++++++++ backends/xnnpack/serialization/schema.fbs | 9 +++++++++ backends/xnnpack/serialization/xnnpack_graph_schema.py | 4 ++++ 4 files changed, 30 insertions(+) diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 56d0508bef0..312cbc17b95 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -121,6 +121,14 @@ xnn_datatype getDataType(const DataType& data_type) { return xnn_datatype::xnn_datatype_qdint8; case DataType::xnn_datatype_qbint4: return xnn_datatype::xnn_datatype_qbint4; + case DataType::xnn_datatype_qpint8: + return xnn_datatype::xnn_datatype_qpint8; + case DataType::xnn_datatype_int32: + return xnn_datatype::xnn_datatype_int32; + case DataType::xnn_datatype_pfp32: + return xnn_datatype::xnn_datatype_pfp32; + case DataType::xnn_datatype_bf16: + return xnn_datatype::xnn_datatype_bf16; default: return xnn_datatype::xnn_datatype_invalid; } diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index d76c3c0807e..a0d44327912 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -29,6 +29,15 @@ enum XNNDatatype : short { xnn_datatype_qdint8 = 9, /// Quantized 4-bit signed integer with shared blockwise quantization parameters. xnn_datatype_qbint4 = 10, + /// Dynamically quantized 8-bit signed integers packed with their per-row + /// quantization parameters. + xnn_datatype_qpint8 = 11, + /// 32-bit signed integers. + xnn_datatype_int32 = 12, + /// IEEE754 single-precision packed floating-point. + xnn_datatype_pfp32 = 13, + /// BFloat16, i.e. the upper 16 bits of a float32. + xnn_datatype_bf16 = 14, } // type of quantization diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 356df663dfc..eeab28154cc 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -29,6 +29,15 @@ enum XNNDatatype : short { xnn_datatype_qdint8 = 9, /// Quantized 4-bit signed integer with shared blockwise quantization parameters. xnn_datatype_qbint4 = 10, + /// Dynamically quantized 8-bit signed integers packed with their per-row + /// quantization parameters. + xnn_datatype_qpint8 = 11, + /// 32-bit signed integers. + xnn_datatype_int32 = 12, + /// IEEE754 single-precision packed floating-point. + xnn_datatype_pfp32 = 13, + /// BFloat16, i.e. the upper 16 bits of a float32. + xnn_datatype_bf16 = 14, } // type of quantization diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index b8b4ea7f02f..dc50fb47da4 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -419,6 +419,10 @@ class XNNDatatype(IntEnum): xnn_datatype_qcint4 = 8 xnn_datatype_qdint8 = 9 xnn_datatype_qbint4 = 10 + xnn_datatype_qpint8 = 11 + xnn_datatype_int32 = 12 + xnn_datatype_pfp32 = 13 + xnn_datatype_bf16 = 14 @dataclass From 425ca7e6ed29ad96678c0d7398ec56fd978e4819 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 9 Jun 2025 15:10:34 -0700 Subject: [PATCH 2/6] Update [ghstack-poisoned] --- backends/xnnpack/operators/__init__.py | 1 - backends/xnnpack/operators/op_sdpa.py | 111 ------------------ backends/xnnpack/partition/config/__init__.py | 2 - .../partition/config/generic_node_configs.py | 30 ----- backends/xnnpack/runtime/XNNCompiler.cpp | 37 ------ 5 files changed, 181 deletions(-) delete mode 100644 backends/xnnpack/operators/op_sdpa.py diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index ec07502de54..a83f8706d94 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -39,7 +39,6 @@ op_quant_dequant, op_relu, op_rsqrt, - op_sdpa, op_sigmoid, op_skip_ops, op_slice_copy, diff --git a/backends/xnnpack/operators/op_sdpa.py b/backends/xnnpack/operators/op_sdpa.py deleted file mode 100644 index e0ec7b37b3b..00000000000 --- a/backends/xnnpack/operators/op_sdpa.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import cast, Dict - -import torch -from executorch.backends.transforms import get_shape -from executorch.backends.xnnpack.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( - XNNGraph, - XNNScaledDotProductAttention, - XNode, -) -from executorch.backends.xnnpack.utils.utils import get_input_node - - -@register_node_visitor -class SDPAVisitor(NodeVisitor): - target = "aten.scaled_dot_product_attention.default" - - def __init__(self, *args) -> None: - super().__init__(*args) - - @staticmethod - def get_fake_attr(name: str, value: torch.Tensor) -> torch.fx.Node: - g = torch.fx.Graph() - gm = torch.fx.GraphModule({}, g) - fake_node = torch.fx.Node(g, name, "get_attr", target=name, args=(), kwargs={}) - g._owning_module = gm - setattr(g._owning_module, name, value) - fake_node.meta["val"] = value - return fake_node - - def define_node( - self, - node: torch.fx.Node, - xnn_graph: XNNGraph, - vals_to_ids: Dict[torch.fx.Node, int], - debug_handle: int, - ) -> None: - # inputs - for i in range(0, 4): - inp = get_input_node(node, i) - self.define_tensor( - inp, - xnn_graph, - vals_to_ids, - ) - - # Make sure mask is not bool - mask_node = get_input_node(node, 3) - mask_dtype = mask_node.meta["val"].dtype - assert mask_dtype in [ - torch.float, - torch.float16, - ], "SDPA Mask must be a float (or half) tensor" - - # Make sure mask is not >2D - assert len(get_shape(mask_node)) == 2, "SDPA Mask must be 2D" - - # Hack to broadcast the scale - q_shape = get_shape(get_input_node(node, 0)) - embedding_dim = q_shape[-1] - scale = 1 / (embedding_dim**0.5) - if "scale" in node.kwargs and node.kwargs["scale"]: - scale = cast(float, node.kwargs["scale"]) - - t = torch.full((embedding_dim,), scale, dtype=mask_dtype) - scale_node = self.get_fake_attr("scale", t) - self.define_tensor( - scale_node, - xnn_graph, - vals_to_ids, - ) - - # outputs - outp = node - self.define_tensor( - outp, - xnn_graph, - vals_to_ids, - ) - - # ids - q_id = vals_to_ids[get_input_node(node, 0)] - k_id = vals_to_ids[get_input_node(node, 1)] - v_id = vals_to_ids[get_input_node(node, 2)] - mask_id = vals_to_ids[mask_node] - scale_id = vals_to_ids[scale_node] - output_id = vals_to_ids[outp] - - # Create a new node - sdpa_node = XNode( - xnode_union=XNNScaledDotProductAttention( - query_id=q_id, - key_id=k_id, - value_id=v_id, - scale_id=scale_id, - mask_id=mask_id, - output_id=output_id, - flags=0, - ), - debug_handle=debug_handle, - ) - xnn_graph.xnodes.append(sdpa_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 553b10f60d1..b304317b257 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -43,7 +43,6 @@ QuantizedPerTensorConfig, ReciprocalSquareRootConfig, ReLUConfig, - # SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails SigmoidConfig, SliceCopyConfig, SoftmaxConfig, @@ -99,7 +98,6 @@ PreluConfig, ReciprocalSquareRootConfig, ReLUConfig, - # SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails SigmoidConfig, SliceCopyConfig, SoftmaxConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 46922e47010..a8846b68d60 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -527,33 +527,3 @@ class BMMConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] - - -class SDPAConfig(GenericNodePartitionerConfig): - target_name = "scaled_dot_product_attention.default" - - def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: - """ - Requires Mask to have Rank 2 - """ - if not self.check_common_constraints(node, ep): - return False - - if len(node.all_input_nodes) < 4: - return False - mask_node = node.all_input_nodes[3] - mask_rank = mask_node.meta["val"].dim() - if mask_rank != 2: - why( - node, - reason=f"mask must have rank 2, got mask of rank {mask_rank}", - ) - return False - - return True - - def get_original_aten(self) -> Optional[torch._ops.OpOverload]: - return torch.ops.aten.scaled_dot_product_attention.default - - def supported_precision_types(self) -> List[ConfigPrecisionType]: - return [ConfigPrecisionType.FP32] diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 312cbc17b95..a364594fb1c 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1961,42 +1961,6 @@ Error defineStaticSliceNode( return Error::Ok; } -/* -Defines Scaled Dot Product Attention (SDPA) node into the subgraph, -using the remapped ids to map the serialized ids, -to the new ids generated when defining the tensor value -*/ -Error defineScaledDotProductAttentionNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNScaledDotProductAttention(); - - xnn_status status = xnn_define_scaled_dot_product_attention( - subgraph_ptr, - xnn_attention_logits_cap_type_none, // cap_type - nullptr, // cap_value - not used - remapped_ids.at(graph_node->query_id()), - remapped_ids.at(graph_node->key_id()), - remapped_ids.at(graph_node->value_id()), - remapped_ids.at(graph_node->scale_id()), - remapped_ids.at(graph_node->mask_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create SDPA node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - /* Defines batch matrix multiply node into the subgraph, using the remapped ids to map the serialized ids, @@ -2097,7 +2061,6 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Concatenate4) _DEFINE(Concatenate5) _DEFINE(StaticSlice) - _DEFINE(ScaledDotProductAttention) _DEFINE(BatchMatrixMultiply) case fb_xnnpack::XNodeUnion::NONE: default: // Adding here as a catch all, just in case From f3822ec6b23d901a1caeb5ec49f5f21295dc4858 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 23 Jun 2025 10:27:26 -0700 Subject: [PATCH 3/6] Update [ghstack-poisoned] --- backends/xnnpack/test/ops/test_sdpa.py | 130 ------------------------- 1 file changed, 130 deletions(-) delete mode 100644 backends/xnnpack/test/ops/test_sdpa.py diff --git a/backends/xnnpack/test/ops/test_sdpa.py b/backends/xnnpack/test/ops/test_sdpa.py deleted file mode 100644 index 205b6d4ab36..00000000000 --- a/backends/xnnpack/test/ops/test_sdpa.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -from typing import Optional - -import torch -from executorch.backends.xnnpack.partition.config.generic_node_configs import SDPAConfig -from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner -from executorch.backends.xnnpack.test.tester import Tester -from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower - - -class TestSDPA(unittest.TestCase): - def setUp(self): - torch._dynamo.reset() - - class SDPA(torch.nn.Module): - def __init__(self, scale: Optional[float] = None): - super().__init__() - self.dropout_p: float = 0.0 - self.is_causal: bool = False - self.scale = scale - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ): - return torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - attn_mask=mask, - dropout_p=self.dropout_p, - is_causal=self.is_causal, - scale=self.scale, - ) - - @staticmethod - def get_input_tensors(mask_rank: int, dtype: torch.dtype = torch.float32): - batch_size = 8 - heads = 16 - seq_len = 32 - dim = 64 - - q = torch.randn(batch_size, heads, seq_len, dim).to(dtype) - k = torch.randn(batch_size, heads, seq_len, dim).to(dtype) - v = torch.randn(batch_size, heads, seq_len, dim).to(dtype) - - mask = None - if mask_rank > 0: - assert mask_rank >= 2, "mask rank must be >= 2" - mask = torch.full((seq_len, seq_len), 0, dtype=dtype) - while mask.ndim < mask_rank: - mask.unsqueeze_(0) - - return (q, k, v, mask) - - def _test(self, module, inputs, atol=1e-03, rtol=1e-03): - module = module.eval() - ( - Tester(module, inputs) - .export() - .to_edge_transform_and_lower( - ToEdgeTransformAndLower([XnnpackPartitioner(configs=[SDPAConfig])]) - ) - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not( - ["executorch_exir_dialects_edge__ops_aten_bmm_default"], - ) - .to_executorch() - .serialize() - .run_method_and_compare_outputs(atol=atol, rtol=rtol) - ) - - def test_fp16_sdpa_mask2d(self): - """ - Tests that the SDPA operator is correctly lowered to XNNPACK - """ - module = self.SDPA() - inputs = module.get_input_tensors(mask_rank=2, dtype=torch.float16) - self._test(module, inputs, atol=1e-02, rtol=1e-02) - - def test_fp32_sdpa_mask2d(self): - """ - Tests that the SDPA operator is correctly lowered to XNNPACK - """ - module = self.SDPA() - inputs = module.get_input_tensors(mask_rank=2) - self._test(module, inputs) - - def test_fp16_sdpa_userscale(self): - """ - Tests that the scale parameter is passed correctly to the SDPA operator - """ - module = self.SDPA(scale=0.1234) - inputs = module.get_input_tensors(mask_rank=2, dtype=torch.float16) - self._test(module, inputs, atol=1e-02, rtol=1e-02) - - def test_fp32_sdpa_userscale(self): - """ - Tests that the scale parameter is passed correctly to the SDPA operator - """ - module = self.SDPA(scale=0.1234) - inputs = module.get_input_tensors(mask_rank=2) - self._test(module, inputs) - - @unittest.expectedFailure - def test_fp32_sdpa_nomask(self): - module = self.SDPA() - inputs = module.get_input_tensors(mask_rank=0) - # AssertionError: SubgraphMatcher cannot be initialized with an pattern with dead code - # This is from attn_mask=None arg - self._test(module, inputs) - - @unittest.expectedFailure - def test_fp32_sdpa_mask4d(self): - """ - Tests that the scale parameter is passed correctly to the SDPA operator - """ - module = self.SDPA(scale=0.1234) - # can't mask.squeeze_(0) yet with xnnpack - inputs = module.get_input_tensors(mask_rank=4) - self._test(module, inputs) From 9498fe79c2b41945619cf639fb76d9c52714dea7 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 26 Jun 2025 10:08:53 -0700 Subject: [PATCH 4/6] Update [ghstack-poisoned] --- .ci/scripts/test_model.sh | 4 +- .github/workflows/android-perf.yml | 2 +- .github/workflows/apple-perf.yml | 2 +- .github/workflows/trunk.yml | 29 - backends/arm/_passes/__init__.py | 2 - backends/arm/_passes/arm_pass_manager.py | 4 - backends/arm/_passes/decompose_atan_pass.py | 119 ----- .../_passes/decompose_batch_norm_no_stats.py | 219 -------- backends/arm/_passes/insert_table_ops.py | 1 - .../tosa_supported_operators.py | 1 - .../arm/quantizer/quantization_annotator.py | 1 - backends/arm/test/ops/test_atan.py | 84 --- backends/arm/test/ops/test_batch_norm.py | 55 +- backends/nxp/requirements-tests.txt | 2 +- backends/nxp/run_unittests.sh | 14 - backends/qualcomm/qnn_preprocess.py | 51 +- backends/qualcomm/tests/test_qnn_delegate.py | 62 --- devtools/inspector/TARGETS | 1 - devtools/inspector/_inspector.py | 41 +- devtools/inspector/_inspector_utils.py | 50 -- .../_intermediate_output_capturer.py | 25 +- devtools/inspector/tests/inspector_test.py | 26 +- .../inspector/tests/inspector_utils_test.py | 108 ---- .../deepseek-r1-distill-llama-8B/README.md | 15 +- .../config/deepseek_xnnpack_q8da4w.yaml | 16 - examples/models/llama/README.md | 96 ++-- examples/models/llama/config/llama_bf16.yaml | 7 - .../models/llama/config/llama_q8da4w.yaml | 11 - .../llama/config/llama_xnnpack_qat.yaml | 23 - .../llama/config/llama_xnnpack_spinquant.yaml | 22 - .../models/llama/config/test_llm_config.py | 8 +- examples/models/phi_4_mini/README.md | 16 +- .../phi_4_mini/{config => }/config.json | 0 .../phi_4_mini/config/phi_4_mini_xnnpack.yaml | 12 - .../qwen2_5/{config => }/1_5b_config.json | 0 examples/models/qwen2_5/README.md | 17 +- .../config/qwen2_5_xnnpack_q8da4w.yaml | 11 - .../qwen3/{config => }/0_6b_config.json | 0 .../qwen3/{config => }/1_7b_config.json | 0 .../models/qwen3/{config => }/4b_config.json | 0 examples/models/qwen3/README.md | 48 +- .../qwen3/config/qwen3_xnnpack_q8da4w.yaml | 15 - examples/nxp/setup.sh | 2 +- .../qualcomm/qaihub_scripts/utils/export.py | 2 +- examples/qualcomm/util_scripts/README.md | 79 --- examples/qualcomm/util_scripts/cli.py | 504 ------------------ examples/qualcomm/utils.py | 5 +- exir/passes/TARGETS | 12 - exir/passes/reinplace.py | 103 ---- exir/tests/TARGETS | 12 - exir/tests/test_reinplace_pass.py | 104 ---- extension/flat_tensor/serialize/serialize.py | 58 +- extension/flat_tensor/test/test_serialize.py | 40 +- extension/llm/export/README.md | 60 ++- extension/llm/export/export_llm.py | 49 +- extension/llm/export/test/test_export_llm.py | 114 ++-- kernels/portable/cpu/util/elementwise_util.h | 31 +- .../cpu/util/normalization_ops_util.cpp | 2 +- kernels/portable/cpu/util/targets.bzl | 2 +- pytest.ini | 1 - 60 files changed, 361 insertions(+), 2039 deletions(-) delete mode 100644 backends/arm/_passes/decompose_atan_pass.py delete mode 100644 backends/arm/_passes/decompose_batch_norm_no_stats.py delete mode 100644 backends/arm/test/ops/test_atan.py delete mode 100755 backends/nxp/run_unittests.sh delete mode 100644 examples/models/deepseek-r1-distill-llama-8B/config/deepseek_xnnpack_q8da4w.yaml delete mode 100644 examples/models/llama/config/llama_bf16.yaml delete mode 100644 examples/models/llama/config/llama_q8da4w.yaml delete mode 100644 examples/models/llama/config/llama_xnnpack_qat.yaml delete mode 100644 examples/models/llama/config/llama_xnnpack_spinquant.yaml rename examples/models/phi_4_mini/{config => }/config.json (100%) delete mode 100644 examples/models/phi_4_mini/config/phi_4_mini_xnnpack.yaml rename examples/models/qwen2_5/{config => }/1_5b_config.json (100%) delete mode 100644 examples/models/qwen2_5/config/qwen2_5_xnnpack_q8da4w.yaml rename examples/models/qwen3/{config => }/0_6b_config.json (100%) rename examples/models/qwen3/{config => }/1_7b_config.json (100%) rename examples/models/qwen3/{config => }/4b_config.json (100%) delete mode 100644 examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml delete mode 100644 examples/qualcomm/util_scripts/README.md delete mode 100644 examples/qualcomm/util_scripts/cli.py delete mode 100644 exir/passes/reinplace.py delete mode 100644 exir/tests/test_reinplace_pass.py diff --git a/.ci/scripts/test_model.sh b/.ci/scripts/test_model.sh index bc9bbb8bae0..bbf879295ae 100755 --- a/.ci/scripts/test_model.sh +++ b/.ci/scripts/test_model.sh @@ -102,7 +102,7 @@ test_model() { bash examples/models/llama/install_requirements.sh # Test export_llm script: python3 -m extension.llm.export.export_llm. # Use Llama random checkpoint with Qwen 2.5 1.5b model configuration. - "${PYTHON_EXECUTABLE}" -m extension.llm.export.export_llm base.model_class="${MODEL_NAME}" base.params=examples/models/qwen2_5/config/1_5b_config.json + "${PYTHON_EXECUTABLE}" -m extension.llm.export.export_llm base.model_class="${MODEL_NAME}" base.params=examples/models/qwen2_5/1_5b_config.json rm "./${MODEL_NAME}.pte" return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears. fi @@ -110,7 +110,7 @@ test_model() { # Install requirements for export_llama bash examples/models/llama/install_requirements.sh # Test export_llm script: python3 -m extension.llm.export.export_llm. - "${PYTHON_EXECUTABLE}" -m extension.llm.export.export_llm base.model_class="${MODEL_NAME}" base.params=examples/models/phi_4_mini/config/config.json + "${PYTHON_EXECUTABLE}" -m extension.llm.export.export_llm base.model_class="${MODEL_NAME}" base.params=examples/models/phi_4_mini/config.json run_portable_executor_runner rm "./${MODEL_NAME}.pte" return diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml index 2eab69eb88b..a79a900b2d8 100644 --- a/.github/workflows/android-perf.yml +++ b/.github/workflows/android-perf.yml @@ -317,7 +317,7 @@ jobs: DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "." --files "tokenizer.json") python -m extension.llm.export.export_llm \ base.model_class=qwen3_0_6b \ - base.params=examples/models/qwen3/config/0_6b_config.json \ + base.params=examples/models/qwen3/0_6b_config.json \ model.use_kv_cache=true \ model.use_sdpa_with_kv_cache=true \ model.dtype_override=fp32 \ diff --git a/.github/workflows/apple-perf.yml b/.github/workflows/apple-perf.yml index 3db5abbefbd..6b1666da642 100644 --- a/.github/workflows/apple-perf.yml +++ b/.github/workflows/apple-perf.yml @@ -322,7 +322,7 @@ jobs: DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "." --files "tokenizer.json") ${CONDA_RUN} python -m extension.llm.export.export_llm \ base.model_class=qwen3_0_6b \ - base.params=examples/models/qwen3/config/0_6b_config.json \ + base.params=examples/models/qwen3/0_6b_config.json \ model.use_kv_cache=true \ model.use_sdpa_with_kv_cache=true \ model.dtype_override=fp32 \ diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 4fe5ec979a3..a4996459f8a 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -718,32 +718,3 @@ jobs: build-mode: Release build-tool: cmake docker-image: executorch-ubuntu-22.04-clang12 - - unittest-nxp-neutron: - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - permissions: - id-token: write - contents: read - with: - runner: linux.2xlarge - docker-image: executorch-ubuntu-22.04-clang12 - submodules: 'recursive' - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - timeout: 90 - script: | - set -eux - - # The generic Linux job chooses to use base env, not the one setup by the image - CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") - conda activate "${CONDA_ENV}" - - # Build and install Executorch - PYTHON_EXECUTABLE=python \ - CMAKE_ARGS="-DEXECUTORCH_BUILD_NXP_NEUTRON=ON" \ - .ci/scripts/setup-linux.sh --build-tool "cmake" - - # Install test requirements - pip install -r backends/nxp/requirements-tests.txt - - # Run pytest - PYTHON_EXECUTABLE=python bash backends/nxp/run_unittests.sh diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 2a75606cb70..d3c0ae0a1b3 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -22,9 +22,7 @@ from .convert_split_to_slice import ConvertSplitToSlicePass # noqa from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa from .convert_to_clamp import ConvertToClampPass # noqa -from .decompose_atan_pass import DecomposeAtanPass # noqa from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa -from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 596decd65bb..2cefd3bdaca 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -25,9 +25,7 @@ ConvertSplitToSlicePass, ConvertSqueezesToViewPass, ConvertToClampPass, - DecomposeAtanPass, DecomposeAvgPool2d, - DecomposeBatchNormNoStatsPass, DecomposeCosineSimilarityPass, DecomposeDivPass, DecomposeEmbeddingPass, @@ -152,7 +150,6 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(DecomposeRoundPass()) self.add_pass(DecomposeSqrtPass()) - self.add_pass(DecomposeAtanPass()) self.add_pass(ConvertIntPowToMuls()) self.add_pass(CastBoolToInt8Pass()) self.add_pass(DecomposeSinhPass()) @@ -167,7 +164,6 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeLeakyReLUPass()) self.add_pass(DecomposeGroupNormPass()) self.add_pass(DecomposeLayerNormPass()) - self.add_pass(DecomposeBatchNormNoStatsPass()) self.add_pass(DecomposeVarPass()) self.add_pass( DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py deleted file mode 100644 index 57b9dde5216..00000000000 --- a/backends/arm/_passes/decompose_atan_pass.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from math import pi - -from executorch.backends.arm._passes import ArmPass -from executorch.exir.dialects._ops import ops as exir_ops - - -edge_atan = exir_ops.edge.aten.atan.default # MI case - - -def _get_atan_ops(op): - """Return the primitive ops required..""" - if op is not edge_atan: - raise RuntimeError(f"Can't decompose atan for op {op}") - - return ( - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.mul.Scalar, - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.add.Scalar, - exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.abs.default, - exir_ops.edge.aten.gt.Scalar, - exir_ops.edge.aten.reciprocal.default, - exir_ops.edge.aten.where.self, - exir_ops.edge.aten.neg.default, - ) - - -class DecomposeAtanPass(ArmPass): - """Decomposes the atan operator into a rational (Padé) approximation.""" - - def _rational_approximation(self, z, ops, meta): - """Creates a (2,1) Padé approximation for atan(x) on [-1, 1].""" - - op_mul, op_mul_scalar, op_add, op_add_scalar, _, _, _, op_recip, _, _ = ops - - # Coefficients calculated using minimax on the interval [-1, 1]. - a1 = 0.3529666667 - a2 = -0.0287666667 - b1 = 0.6863 - - z2 = super().call_operator(op_mul, (z, z), {}, meta, updated=True) - z4 = super().call_operator(op_mul, (z2, z2), {}, meta, updated=True) - - num1 = super().call_operator(op_mul_scalar, (z2, a1), {}, meta, updated=True) - num2 = super().call_operator(op_mul_scalar, (z4, a2), {}, meta, updated=True) - num = super().call_operator(op_add_scalar, (num1, 1.0), {}, meta, updated=True) - num = super().call_operator(op_add, (num, num2), {}, meta, updated=True) - - den1 = super().call_operator(op_mul_scalar, (z2, b1), {}, meta, updated=True) - den = super().call_operator(op_add_scalar, (den1, 1.0), {}, meta, updated=True) - - inv_den = super().call_operator(op_recip, (den,), {}, meta, updated=True) - - prod = super().call_operator(op_mul, (num, inv_den), {}, meta, updated=True) - return super().call_operator(op_mul, (z, prod), {}, meta, updated=True) - - def call_operator(self, op, args, kwargs, meta): - if op is not edge_atan: - return super().call_operator(op, args, kwargs, meta, updated=False) - - logging.info( - f"Approximating atan. This may introduce small numerical errors. For details, see {__file__}." - ) - - ops = _get_atan_ops(op) - ( - _, - op_mul_scalar, - _, - op_add_scalar, - op_sub, - op_abs, - op_gt, - op_recip, - op_where, - op_neg, - ) = ops - - x = args[0] - - # |x| > 1 is reduced to [0, 1] using atan(x) = pi/2 - atan(1/x) and atan(-x) = -atan(x). - - abs_x = super().call_operator(op_abs, (x,), {}, meta, updated=True) - mask_hi = super().call_operator(op_gt, (abs_x, 1.0), {}, meta, updated=True) - - inv_x = super().call_operator(op_recip, (abs_x,), {}, meta, updated=True) - z = super().call_operator( - op_where, (mask_hi, inv_x, abs_x), {}, meta, updated=True - ) - - atan_z = self._rational_approximation(z, ops, meta) - - zero_tensor = super().call_operator( - op_mul_scalar, (x, 0.0), {}, meta, updated=True - ) - half_pi_tensor = super().call_operator( - op_add_scalar, (zero_tensor, pi / 2), {}, meta, updated=True - ) - - diff = super().call_operator( - op_sub, (half_pi_tensor, atan_z), {}, meta, updated=True - ) - atan_abs = super().call_operator( - op_where, (mask_hi, diff, atan_z), {}, meta, updated=True - ) - - mask_pos = super().call_operator(op_gt, (x, 0.0), {}, meta, updated=True) - neg_val = super().call_operator(op_neg, (atan_abs,), {}, meta, updated=True) - - return super().call_operator( - op_where, (mask_pos, atan_abs, neg_val), {}, meta, updated=True - ) diff --git a/backends/arm/_passes/decompose_batch_norm_no_stats.py b/backends/arm/_passes/decompose_batch_norm_no_stats.py deleted file mode 100644 index 5fdb8db2d7c..00000000000 --- a/backends/arm/_passes/decompose_batch_norm_no_stats.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import operator - -import torch -from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import create_node -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult - - -class DecomposeBatchNormNoStatsPass(ArmPass): - """ - Decompose BatchNorm2d(track_running_stats=False) (aten._native_batch_norm_legit_no_training) - into a sequence of elementwise operations: - - # let input = x, rm = running_mean, rv = running_var, eps: float - rm_view = view(rm, weights_shape) - rv_view = view(rv, weights_shape) - centered = sub(x, rm_view) - eps_full = full(eps_shape, eps) - var_eps = add(rv_view, eps_full) - inv_sqrt = rsqrt(var_eps) - normed = mul(centered, inv_sqrt) - weighted = mul(normed, w_view) - biased = add(weighted, b_view) - - Source: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html - """ - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 - bn_ops = ( - exir_ops.edge.aten._native_batch_norm_legit.no_stats, - exir_ops.edge.aten._native_batch_norm_legit_no_training.default, - torch.ops.aten._native_batch_norm_legit_no_training.default, - torch.ops.aten.batch_norm.default, - torch.ops.aten.native_batch_norm.default, - ) - - for node in graph_module.graph.nodes: - if node.op != "call_function" or node.target not in bn_ops: - continue - - if node.target in ( - torch.ops.aten.batch_norm.default, - torch.ops.aten.native_batch_norm.default, - ): - # signature: (input, weight, bias, mean, var, training, momentum, eps, cudnn_enabled) - # pos‐arg 5 is training - training = node.kwargs.get("training", False) - if len(node.args) > 5: - training = node.args[5] - if training: - # skip training‐mode batchnorm - continue - - # Extract args - args = node.args - meta = node.meta - - # Default eps - eps: float = torch.finfo().eps - # weight and bias may be None - x = args[0] - weight = args[1] if len(args) > 1 else None - bias = args[2] if len(args) > 2 else None - running_mean = args[3] - running_var = args[4] - if len(args) > 6: - eps = args[6] - - # Determine shapes - val = meta.get("val") - ref_tensor = val[0] if isinstance(val, tuple) else val - shape = tuple(ref_tensor.size()) - dtype = ref_tensor.dtype - rank = len(shape) - - # channel dimension is 1 for BatchNorm2d - channel_axis = 1 - weights_shape = [1] * rank - weights_shape[channel_axis] = shape[channel_axis] - num_features = shape[channel_axis] - - # Ops to use - sub_op = exir_ops.edge.aten.sub.Tensor - view_op = exir_ops.edge.aten.view_copy.default - full_op = exir_ops.edge.aten.full.default - add_op = exir_ops.edge.aten.add.Tensor - rsqrt_op = exir_ops.edge.aten.rsqrt.default - mul_op = exir_ops.edge.aten.mul.Tensor - - # Begin decomposition - with graph_module.graph.inserting_before(node): - # reshape running stats - rm_view = create_node( - graph_module.graph, - view_op, - args=(running_mean, weights_shape), - from_node=node, - ) - rv_view = create_node( - graph_module.graph, - view_op, - args=(running_var, weights_shape), - from_node=node, - ) - # center input - centered = create_node( - graph_module.graph, - sub_op, - args=(x, rm_view), - from_node=node, - ) - # epsilon tensor - eps_shape = [1] * rank - eps_full = create_node( - graph_module.graph, - full_op, - args=(eps_shape, eps), - kwargs={"dtype": dtype}, - from_node=node, - ) - # var + eps - var_eps = create_node( - graph_module.graph, - add_op, - args=(rv_view, eps_full), - from_node=node, - ) - # inverse sqrt - inv_sqrt = create_node( - graph_module.graph, - rsqrt_op, - args=(var_eps,), - from_node=node, - ) - # normalized - normed = create_node( - graph_module.graph, - mul_op, - args=(centered, inv_sqrt), - from_node=node, - ) - - # weight - if weight is None: - one = create_node( - graph_module.graph, - full_op, - args=([num_features], 1), - kwargs={"dtype": dtype}, - from_node=node, - ) - w_view = create_node( - graph_module.graph, - view_op, - args=(one, weights_shape), - from_node=node, - ) - else: - w_view = create_node( - graph_module.graph, - view_op, - args=(weight, weights_shape), - from_node=node, - ) - weighted = create_node( - graph_module.graph, - mul_op, - args=(normed, w_view), - from_node=node, - ) - - # bias - if bias is None: - zero = create_node( - graph_module.graph, - full_op, - args=([num_features], 0), - kwargs={"dtype": dtype}, - from_node=node, - ) - b_view = create_node( - graph_module.graph, - view_op, - args=(zero, weights_shape), - from_node=node, - ) - else: - b_view = create_node( - graph_module.graph, - view_op, - args=(bias, weights_shape), - from_node=node, - ) - final_out = create_node( - graph_module.graph, - add_op, - args=(weighted, b_view), - from_node=node, - ) - - users = [u for u in node.users if u is not node] - node.replace_all_uses_with(final_out) - for u in users: - if u.target == operator.getitem: - u.replace_all_uses_with(final_out) - graph_module.graph.erase_node(node) - graph_module.graph.eliminate_dead_code() - - graph_module.recompile() - new_gm = super().call(graph_module).graph_module - return PassResult(new_gm, True) diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index b31b6c7106d..c579fcb0301 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -51,7 +51,6 @@ class TableOps: exir_ops.edge.aten.cos.default: torch.cos, exir_ops.edge.aten.sin.default: torch.sin, exir_ops.edge.aten.tanh.default: torch.tanh, - exir_ops.edge.aten.atan.default: torch.atan, exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid, exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish, exir_ops.edge.aten.sinh.default: torch.sinh, diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index cdb27b7c31e..639df536109 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -244,7 +244,6 @@ def is_node_supported( exir_ops.edge.aten.gelu.default, exir_ops.edge.aten.alias_copy.default, exir_ops.edge.aten.sinh.default, - exir_ops.edge.aten.atan.default, ] return supported diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 2c61aea60c3..c6415c63777 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -214,7 +214,6 @@ def _match_pattern( torch.ops.aten.pow.Tensor_Scalar, torch.ops.aten.gelu.default, torch.ops.aten.sinh.default, - torch.ops.aten.atan.default, ] _one_to_one_shared_input_qspec = [ diff --git a/backends/arm/test/ops/test_atan.py b/backends/arm/test/ops/test_atan.py deleted file mode 100644 index 3d6f8cd8fa8..00000000000 --- a/backends/arm/test/ops/test_atan.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Tuple - -import torch - -from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.test_pipeline import ( - EthosU55PipelineBI, - EthosU85PipelineBI, - TosaPipelineBI, - TosaPipelineMI, -) - -aten_op = "torch.ops.aten.atan.default" -exir_op = "executorch_exir_dialects_edge__ops_aten__atan_default" - -input_t1 = Tuple[torch.Tensor] - -test_data_suite = { - "zeros": torch.zeros(1, 10, 10, 10), - "zeros_alt_shape": torch.zeros(1, 10, 3, 5), - "ones": torch.ones(10, 10, 10), - "rand": torch.rand(10, 10) - 0.5, - "rand_alt_shape": torch.rand(1, 10, 3, 5) - 0.5, - "randn_pos": torch.randn(10) + 10, - "randn_neg": torch.randn(10) - 10, - "ramp": torch.arange(-16, 16, 0.2), -} - - -class Atan(torch.nn.Module): - - def forward(self, x: torch.Tensor): - return torch.atan(x) - - -@common.parametrize("test_data", test_data_suite) -def test_atan_tosa_MI(test_data: Tuple): - pipeline = TosaPipelineMI[input_t1]( - Atan(), - (test_data,), - aten_op=aten_op, - exir_op=exir_op, - ) - pipeline.run() - - -@common.parametrize("test_data", test_data_suite) -def test_atan_tosa_BI(test_data: Tuple): - pipeline = TosaPipelineBI[input_t1]( - Atan(), - (test_data,), - aten_op=aten_op, - exir_op=exir_op, - ) - pipeline.run() - - -@common.XfailIfNoCorstone300 -@common.parametrize("test_data", test_data_suite) -def test_atan_u55_BI(test_data: Tuple): - pipeline = EthosU55PipelineBI[input_t1]( - Atan(), - (test_data,), - aten_ops=aten_op, - exir_ops=exir_op, - ) - pipeline.run() - - -@common.XfailIfNoCorstone320 -@common.parametrize("test_data", test_data_suite) -def test_atan_u85_BI(test_data: Tuple): - pipeline = EthosU85PipelineBI[input_t1]( - Atan(), - (test_data,), - aten_ops=aten_op, - exir_ops=exir_op, - ) - pipeline.run() diff --git a/backends/arm/test/ops/test_batch_norm.py b/backends/arm/test/ops/test_batch_norm.py index eb0d4306e6e..7f98a48b203 100644 --- a/backends/arm/test/ops/test_batch_norm.py +++ b/backends/arm/test/ops/test_batch_norm.py @@ -224,8 +224,6 @@ class BatchNorm2dNoStats(torch.nn.Module): Decomposes into _native_batch_norm_legit.no_stats """ - aten_ops = ["torch.ops.aten.batch_norm.default"] - def __init__( self, num_features: int, @@ -252,60 +250,29 @@ def forward(self, x): return self.batch_norm_2d(x) -@common.parametrize("test_data", test_data_suite) -def test_native_batch_norm_legit_no_stats_tosa_MI(test_data: Tuple): - test_data, model_params = test_data() - pipeline = TosaPipelineMI[input_t1]( - BatchNorm2dNoStats(*model_params), - (test_data,), - aten_op=BatchNorm2dNoStats.aten_ops, - ) - pipeline.run() +@pytest.mark.skip( + reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats." +) +def test_native_batch_norm_legit_no_stats_tosa_MI(): + pass @pytest.mark.skip( reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats." ) -def test_native_batch_norm_legit_no_stats_tosa_BI(test_data: Tuple): - test_data, model_params = test_data() - pipeline = TosaPipelineBI[input_t1]( - BatchNorm2dNoStats(*model_params), - (test_data,), - aten_op=BatchNorm2dNoStats.aten_ops, - qtol=1, - ) - pipeline.run() +def test_native_batch_norm_legit_no_stats_tosa_BI(): + pass @pytest.mark.skip( reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats." ) -@common.parametrize("test_data", test_data_suite) -@common.XfailIfNoCorstone300 -def test_native_batch_norm_legit_no_stats_u55_BI(test_data: Tuple): - test_data, model_params = test_data() - pipeline = EthosU55PipelineBI[input_t1]( - BatchNorm2dNoStats(*model_params), - (test_data,), - aten_op=BatchNorm2dNoStats.aten_ops, - run_on_fvp=True, - qtol=1, - ) - pipeline.run() +def test_native_batch_norm_legit_no_stats_u55_BI(): + pass @pytest.mark.skip( reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats." ) -@common.parametrize("test_data", test_data_suite) -@common.XfailIfNoCorstone320 -def test_native_batch_norm_legit_no_stats_u85_BI(test_data: Tuple): - test_data, model_params = test_data() - pipeline = EthosU85PipelineBI[input_t1]( - BatchNorm2dNoStats(*model_params), - (test_data,), - aten_op=BatchNorm2dNoStats.aten_ops, - run_on_fvp=False, - qtol=1, - ) - pipeline.run() +def test_native_batch_norm_legit_no_stats_u85_BI(): + pass diff --git a/backends/nxp/requirements-tests.txt b/backends/nxp/requirements-tests.txt index ea6d56a43ec..513ccefe848 100644 --- a/backends/nxp/requirements-tests.txt +++ b/backends/nxp/requirements-tests.txt @@ -3,4 +3,4 @@ tensorflow==2.18.0 pytest-mock tflite GvGen -neutron_converter_SDK_25_03 +neutron-converter_SDK_25_03 diff --git a/backends/nxp/run_unittests.sh b/backends/nxp/run_unittests.sh deleted file mode 100755 index dde10065743..00000000000 --- a/backends/nxp/run_unittests.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash -# Copyright 2025 NXP -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -set -eux - -SCRIPT_DIR=$(dirname $(readlink -fm $0)) -EXECUTORCH_DIR=$(dirname $(dirname $SCRIPT_DIR)) - -cd $EXECUTORCH_DIR - -# '-c /dev/null' is used to ignore root level pytest.ini. -PYTHONPATH=`cd ..; pwd` pytest -c /dev/null backends/nxp/tests/ diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index 034b75fa6d0..21b16a29c58 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -78,7 +78,10 @@ def _build_op_wrappers( ) assert node.target == context_loader_target, err_msg # if graph has context binary loader node, return directly - return node.meta[OpContextLoader.meta_ctx_bin] + return PreprocessResult( + processed_bytes=node.meta[OpContextLoader.meta_ctx_bin], + debug_handle_map={}, + ) except: raise RuntimeError(err_msg) @@ -158,7 +161,7 @@ def preprocess_multimethod( generate_qnn_executorch_option(compile_spec) ) qnn_manager.Init() - py_op_wrapper_list, ctx_binary_list = [], [] + py_op_wrapper_list = [] for j, programs in enumerate(edge_programs.values()): logger.info(f"Processing Method({j}): ({i+1}/{num_sub_graphs})") py_op_wrappers = QnnBackend._build_op_wrappers( @@ -166,36 +169,22 @@ def preprocess_multimethod( qnn_manager.IsTensorDump(), option.op_package_options.op_package_infos, ) - if isinstance(py_op_wrappers, bytes): - ctx_binary_list.append(py_op_wrappers) - else: - py_op_wrapper_list.append( - [ - py_op_wrapper.GetOpWrapper() - for py_op_wrapper in py_op_wrappers - ] - ) + py_op_wrapper_list.append( + [py_op_wrapper.GetOpWrapper() for py_op_wrapper in py_op_wrappers] + ) - if len(py_op_wrapper_list) == len(edge_programs.values()): - qnn_context_binary = qnn_manager.Compile(graph_name, py_op_wrapper_list) - assert ( - len(qnn_context_binary) != 0 - ), "Failed to generate Qnn context binary." - qnn_manager.Destroy() - # methods should share the same context binary for current partition - for key in edge_programs.keys(): - all_processed_results[key].append( - PreprocessResult( - processed_bytes=bytes(qnn_context_binary), - debug_handle_map={}, - ) - ) - elif len(ctx_binary_list) == len(edge_programs.values()): - for i, key in enumerate(edge_programs.keys()): - all_processed_results[key].append( - PreprocessResult(processed_bytes=ctx_binary_list[i]) + qnn_context_binary = qnn_manager.Compile(graph_name, py_op_wrapper_list) + assert ( + len(qnn_context_binary) != 0 + ), "Failed to generate Qnn context binary." + qnn_manager.Destroy() + # methods should share the same context binary for current partition + for key in edge_programs.keys(): + all_processed_results[key].append( + PreprocessResult( + processed_bytes=bytes(qnn_context_binary), + debug_handle_map={}, ) - else: - raise RuntimeError("Hybrid compilation is not supported") + ) return all_processed_results diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 7163ce88c27..747a6804957 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -5622,68 +5622,6 @@ def test_debugger_generate_optrace(self): qhas_data = json.load(qhas_file) self.assertIn("data", qhas_data) - def test_cli(self): - with tempfile.TemporaryDirectory() as tmp_dir: - sample_input = torch.randn(1, 2, 3, 4) - ep = torch.export.export(Relu(), (sample_input,)) # noqa: F405 - torch.export.save(ep, f"{tmp_dir}/relu.pt2") - torch.save(sample_input, f"{tmp_dir}/input_0_0.pt") - with open(f"{tmp_dir}/input_list", "w") as f: - f.write(f"{tmp_dir}/input_0_0.pt\n") - - # quantize - cmds = [ - "python", - "-m", - "examples.qualcomm.util_scripts.cli", - "quantize", - "--artifact", - f"{tmp_dir}/relu.pt2", - "--output_folder", - f"{tmp_dir}/q_out", - "--input_list", - f"{tmp_dir}/input_list", - ] - subprocess.run(cmds, stdout=subprocess.DEVNULL) - self.assertTrue(os.path.isfile(f"{tmp_dir}/q_out/relu_quantized.pt2")) - # compile - cmds = [ - "python", - "-m", - "examples.qualcomm.util_scripts.cli", - "compile", - "--artifact", - f"{tmp_dir}/q_out/relu_quantized.pt2", - "--output_folder", - f"{tmp_dir}/c_out", - "--model", - self.model, - ] - subprocess.run(cmds, stdout=subprocess.DEVNULL) - self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/relu_quantized.pte")) - self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/relu_quantized.svg")) - # execute - cmds = [ - "python", - "-m", - "examples.qualcomm.util_scripts.cli", - "execute", - "--artifact", - f"{tmp_dir}/c_out/relu_quantized.pte", - "--output_folder", - f"{tmp_dir}/e_out", - "--model", - self.model, - "--device", - self.device, - "--build_folder", - self.build_folder, - "--input_list", - f"{tmp_dir}/input_list", - ] - subprocess.run(cmds, stdout=subprocess.DEVNULL) - self.assertTrue(os.path.isfile(f"{tmp_dir}/e_out/output_0_0.pt")) - def setup_environment(): parser = setup_common_args_and_variables() diff --git a/devtools/inspector/TARGETS b/devtools/inspector/TARGETS index d32698f784f..0712bdf1f9a 100644 --- a/devtools/inspector/TARGETS +++ b/devtools/inspector/TARGETS @@ -56,7 +56,6 @@ python_library( "_intermediate_output_capturer.py", ], deps = [ - "//executorch/devtools/inspector:inspector_utils", ], ) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index a209da8adb7..dfff3d0818e 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -52,7 +52,6 @@ FORWARD, gen_etdump_object, gen_graphs_from_etrecord, - get_aot_debug_handle_to_op_name_mapping, inflate_runtime_output, is_debug_output, is_inference_output_equal, @@ -1085,7 +1084,6 @@ def __init__( self._reference_outputs: Dict[str, List[ProgramOutput]] = {} self._enable_module_hierarchy = enable_module_hierarchy self._aot_intermediate_outputs: Optional[Dict[Tuple[int, ...], Any]] = None - self._aot_debug_handles_to_op_names: Optional[Dict[Tuple[int, ...], str]] = None self._consume_etrecord() def _consume_etrecord(self) -> None: @@ -1152,24 +1150,18 @@ def _consume_etrecord(self) -> None: return export_program = self._etrecord.edge_dialect_program graph_module = export_program.module() - self._aot_debug_handles_to_op_names = get_aot_debug_handle_to_op_name_mapping( - graph_module - ) capturer = IntermediateOutputCapturer(graph_module) self._aot_intermediate_outputs = capturer.run_and_capture( self._etrecord._representative_inputs ) # TODO: Make it more extensible to further merge overlapping debug handles - def _get_runtime_intermediate_outputs_and_op_names( - self, - ) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]: + def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]: """ - Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings) - from the event blocks, along with the corresponding debug handles and op names mapping. + Retrieve the raw runtime intermediate outputs(debug handles and value mappings) + from the event blocks. These outputs will be processed later to merge overlapping debug handles. """ debug_handle_to_output = {} - debug_handle_to_op_name = {} for event_block in self.event_blocks: for event in event_block.events: # Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax @@ -1178,23 +1170,20 @@ def _get_runtime_intermediate_outputs_and_op_names( or not event.op_types ): continue - # Normalize debug_handle to a tuple - debug_handle = event.debug_handles - if isinstance(debug_handle, int): - debug_handle = (debug_handle,) + # Normalize debug_handles to a tuple + debug_handles = event.debug_handles + if isinstance(debug_handles, int): + debug_handles = (debug_handles,) else: - debug_handle = tuple(debug_handle) - current_entry = debug_handle_to_output.get(debug_handle, (-1, None)) - # When event has same debug_handle, only keep the one with the largest instruction id + debug_handles = tuple(debug_handles) + current_entry = debug_handle_to_output.get(debug_handles, (-1, None)) + # When event has same debug handles, only keep the one with the largest instruction id if event._instruction_id > current_entry[0]: - debug_handle_to_output[debug_handle] = ( + debug_handle_to_output[debug_handles] = ( event._instruction_id, event.debug_data, ) - debug_handle_to_op_name[debug_handle] = event.name - return { - k: v[1] for k, v in debug_handle_to_output.items() - }, debug_handle_to_op_name + return {k: v[1] for k, v in debug_handle_to_output.items()} def to_dataframe( self, @@ -1370,12 +1359,8 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame: raise ValueError( "The aot intermediate outputs is required but not populated." ) - # The runtime_op_names will be used later to map runtime debug_handle to op_name - runtime_intermediate_outputs, runtime_op_names = ( - self._get_runtime_intermediate_outputs_and_op_names() - ) mapping = map_runtime_aot_intermediate_outputs( - self._aot_intermediate_outputs, runtime_intermediate_outputs + self._aot_intermediate_outputs, self._get_runtime_intermediate_outputs() ) metric = distance.strip().upper() if metric == "MSE": diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 50b3669309c..61e2ea4d031 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -93,28 +93,6 @@ class NodeData: output: Any -class NodeFilter: - """ - A class used to filter nodes based on extensible criteria. - Attributes: - metadata_key (str): The key to look for in the node's metadata. - op_type (str): The operation code to match. - exclude_ops (List[str]): A list of operations to exclude from the filter. - """ - - def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None): - self.metadata_key = metadata_key - self.op_type = op_type - self.exclude_ops = exclude_ops - - def matches(self, node: torch.fx.Node) -> bool: - return ( - node.meta.get(self.metadata_key) is not None - and node.op == self.op_type - and all(exclude_name not in node.name for exclude_name in self.exclude_ops) - ) - - def calculate_time_scale_factor( source_time_scale: TimeScale, target_time_scale: TimeScale ) -> float: @@ -756,31 +734,3 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor: if torch.isnan(input_tensor).any(): input_tensor = torch.nan_to_num(input_tensor) return input_tensor - - -def get_aot_debug_handle_to_op_name_mapping( - graph_module: torch.fx.GraphModule, -) -> Dict[Tuple[int, ...], str]: - """ - Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module. - Parameters: - graph_module (torch.fx.GraphModule): The graph module to get the mapping from. - Returns: - Dict[Tuple[int, ...], str]: A dictionary mapping debug handles to operator names. - """ - node_filters = [ - NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"]) - ] - - debug_handle_to_op_name = {} - for node in graph_module.graph.nodes: - if all(filter.matches(node) for filter in node_filters): - debug_handle = node.meta["debug_handle"] - # Convert the debug handle to a tuple to use as a dictionary key - key = ( - (debug_handle,) - if isinstance(debug_handle, int) - else tuple(debug_handle) - ) - debug_handle_to_op_name[key] = node.name - return debug_handle_to_op_name diff --git a/devtools/inspector/_intermediate_output_capturer.py b/devtools/inspector/_intermediate_output_capturer.py index 054c97dc245..c1f943bd02c 100644 --- a/devtools/inspector/_intermediate_output_capturer.py +++ b/devtools/inspector/_intermediate_output_capturer.py @@ -7,14 +7,35 @@ # pyre-unsafe -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Tuple import torch -from executorch.devtools.inspector._inspector_utils import NodeFilter from torch.fx import GraphModule from torch.fx.interpreter import Interpreter +class NodeFilter: + """ + A class used to filter nodes based on extensible criteria. + Attributes: + metadata_key (str): The key to look for in the node's metadata. + op_type (str): The operation code to match. + exclude_ops (List[str]): A list of operations to exclude from the filter. + """ + + def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None): + self.metadata_key = metadata_key + self.op_type = op_type + self.exclude_ops = exclude_ops + + def matches(self, node: torch.fx.Node) -> bool: + return ( + node.meta.get(self.metadata_key) is not None + and node.op == self.op_type + and all(exclude_name not in node.name for exclude_name in self.exclude_ops) + ) + + class IntermediateOutputCapturer(Interpreter): """ A class that captures intermediate outputs from a PyTorch graph module. diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index df434fd675d..1460dbd46a2 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -537,7 +537,7 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self): ) ) - def test_get_runtime_intermediate_outputs_and_op_names(self): + def test_get_runtime_intermediate_outputs(self): # Create a context manager to patch functions called by Inspector.__init__ with patch.object( _inspector, "parse_etrecord", return_value=None @@ -560,39 +560,25 @@ def test_get_runtime_intermediate_outputs_and_op_names(self): EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events()) ] - runtime_outputs, op_names = ( - inspector_instance._get_runtime_intermediate_outputs_and_op_names() - ) - # These outputs and op_names dictionaries should all have 5 keys + runtime_outputs = inspector_instance._get_runtime_intermediate_outputs() + # This output should be a dictionary with 5 keys self.assertEqual( len(runtime_outputs), 5, ) - self.assertEqual( - len(op_names), - 5, - ) - - # Check that keys (0,) and (1,) are not in these two dictionaries(skip OPERATOR_CALL and op_types are empty) + # Check that keys (0,) and (1,) are not in the dictionary(skip OPERATOR_CALL and op_types are empty) self.assertNotIn((0,), runtime_outputs) self.assertNotIn((1,), runtime_outputs) - self.assertNotIn((0,), op_names) - self.assertNotIn((1,), op_names) # Same debug_handle but different instruction_id, should record the last one self.assertIn((4,), runtime_outputs) - self.assertIn((4,), op_names) self.assertTrue( torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0])) ) - self.assertEqual(op_names[(4,)], "op_3") - # Check that keys (5,) to (8,) are in the dictionary and have values of the correct size for key in range(5, 9): self.assertIn((key,), runtime_outputs) - self.assertIn((key,), op_names) self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE) - self.assertEqual(op_names[(key,)], f"op_{key-1}") def test_calculate_numeric_gap(self): # Create a context manager to patch functions called by Inspector.__init__ @@ -622,8 +608,8 @@ def test_calculate_numeric_gap(self): } inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs - inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( - lambda: (runtime_intermediate_outputs, {}) + inspector_instance._get_runtime_intermediate_outputs = ( + lambda: runtime_intermediate_outputs ) df = inspector_instance.calculate_numeric_gap(distance="L1") diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 6d12cb13c5f..8148d2c36f0 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -34,11 +34,9 @@ EDGE_DIALECT_GRAPH_KEY, find_populated_event, gen_graphs_from_etrecord, - get_aot_debug_handle_to_op_name_mapping, is_inference_output_equal, map_runtime_aot_intermediate_outputs, merge_overlapping_debug_handles, - NodeFilter, TimeScale, ) @@ -366,112 +364,6 @@ class X: msg = str(cm.exception) self.assertIn("Cannot convert value of type", msg) - def test_get_aot_debug_handle_to_op_name_mapping_single_debug_handle(self): - # Create a simple graph module with one node - graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) - node = graph_module.graph.create_node( - "call_function", target=torch.mul, args=(), kwargs={}, name="op1" - ) - node.meta["debug_handle"] = 1 - debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) - expected_result = {(1,): "op1"} - self.assertEqual(debug_handle_to_op_name, expected_result) - - def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self): - # Create a simple graph module with two nodes - graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) - node1 = graph_module.graph.create_node( - "call_function", target=torch.mul, args=(), kwargs={}, name="op1" - ) - node1.meta["debug_handle"] = (1, 2) - node2 = graph_module.graph.create_node( - "call_function", target=torch.mul, args=(), kwargs={}, name="op2" - ) - node2.meta["debug_handle"] = 3 - debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) - expected_result = { - ( - 1, - 2, - ): "op1", - (3,): "op2", - } - self.assertEqual(debug_handle_to_op_name, expected_result) - - def test_get_aot_debug_handle_to_op_name_mapping_no_debug_handles(self): - # Create a simple graph module with no nodes - graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) - debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) - expected_result = {} - self.assertEqual(debug_handle_to_op_name, expected_result) - - def test_node_filter_match(self): - node_filter = NodeFilter( - "debug_handle", "call_function", exclude_ops=["getitem"] - ) - - # Create a mock node that matches the filter criteria - mock_node = torch.fx.Node( - graph=torch.fx.Graph(), - name="mock_node", - op="call_function", - target=torch.nn.functional.relu, - args=(), - kwargs={}, - ) - mock_node.meta["debug_handle"] = (1, 2) - # Test that the filter matches the mock node - self.assertTrue(node_filter.matches(mock_node)) - - def test_node_filter_key_mismatch(self): - node_filter = NodeFilter( - "debug_handle", "call_function", exclude_ops=["getitem"] - ) - mock_node_metadata_key_mismatch = torch.fx.Node( - graph=torch.fx.Graph(), - name="mock_node_metadata_key_mismatch", - op="call_function", - target=torch.nn.functional.relu, - args=(), - kwargs={}, - ) - # Test that the filter doesn't match the mock node (meta doesn't have debug_handle key) - self.assertFalse(node_filter.matches(mock_node_metadata_key_mismatch)) - - def test_node_filter_ops_mismatch(self): - node_filter = NodeFilter( - "debug_handle", "call_function", exclude_ops=["getitem"] - ) - - mock_node_exclude_ops_mismatch = torch.fx.Node( - graph=torch.fx.Graph(), - name="getitem", - op="call_function", - target=torch.nn.functional.relu, - args=(), - kwargs={}, - ) - mock_node_exclude_ops_mismatch.meta["debug_handle"] = (1, 2) - # Test that the filter doesn't match the mock node (exclude_ops mismatch) - self.assertFalse(node_filter.matches(mock_node_exclude_ops_mismatch)) - - def test_node_op_type_mismatch(self): - node_filter = NodeFilter( - "debug_handle", "call_function", exclude_ops=["getitem"] - ) - - mock_node_op_type_mismatch = torch.fx.Node( - graph=torch.fx.Graph(), - name="mock_node_op_type_mismatch", - op="get_attr", - target="torch.nn.functional.relu", - args=(), - kwargs={}, - ) - mock_node_op_type_mismatch.meta["debug_handle"] = (1, 2) - # Test that the filter doesn't match the mock node (op_type mismatch) - self.assertFalse(node_filter.matches(mock_node_op_type_mismatch)) - def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]] diff --git a/examples/models/deepseek-r1-distill-llama-8B/README.md b/examples/models/deepseek-r1-distill-llama-8B/README.md index 00397e9f60f..f05dd9990a2 100644 --- a/examples/models/deepseek-r1-distill-llama-8B/README.md +++ b/examples/models/deepseek-r1-distill-llama-8B/README.md @@ -53,10 +53,17 @@ torch.save(sd, "/tmp/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/checkpoint.pth") 5. Generate a PTE file for use with the Llama runner. ``` python -m extension.llm.export.export_llm \ - --config examples/models/deepseek-r1-distill-llama-8B/config/deepseek-r1-distill-llama-8B - +base.checkpoint=/tmp/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/checkpoint.pth \ - +base.params=params.json \ - +export.output_name="DeepSeek-R1-Distill-Llama-8B.pte" + base.checkpoint=/tmp/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/checkpoint.pth \ + base.params=params.json \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + backend.xnnpack.enabled=True \ + quantization.qmode="8da4w" \ + quantization.group_size=128 \ + model.dtype_override="fp16" \ + base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' \ + quantization.embedding_quantize=\'4,32\' \ + export.output_name="DeepSeek-R1-Distill-Llama-8B.pte" ``` 6. Run the model on your desktop for validation or integrate with iOS/Android apps. Instructions for these are available in the Llama [README](../llama/README.md) starting at Step 3. diff --git a/examples/models/deepseek-r1-distill-llama-8B/config/deepseek_xnnpack_q8da4w.yaml b/examples/models/deepseek-r1-distill-llama-8B/config/deepseek_xnnpack_q8da4w.yaml deleted file mode 100644 index 1da7c253d92..00000000000 --- a/examples/models/deepseek-r1-distill-llama-8B/config/deepseek_xnnpack_q8da4w.yaml +++ /dev/null @@ -1,16 +0,0 @@ -base: - metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' - -model: - use_kv_cache: True - use_sdpa_with_kv_cache: True - dtype_override: fp16 - -backend: - xnnpack: - enabled: True - -quantization: - qmode: 8da4w - group_size: 128 - embedding_quantize: 4,32 diff --git a/examples/models/llama/README.md b/examples/models/llama/README.md index bbd2107ad74..3e6869e5c49 100644 --- a/examples/models/llama/README.md +++ b/examples/models/llama/README.md @@ -168,10 +168,14 @@ LLAMA_CHECKPOINT=path/to/consolidated.00.pth LLAMA_PARAMS=path/to/params.json python -m extension.llm.export.export_llm \ - --config examples/models/llamaconfig/llama_bf16.yaml - +base.model_class="llama3_2" \ - +base.checkpoint="${LLAMA_CHECKPOINT:?}" \ - +base.params="${LLAMA_PARAMS:?}" \ + base.model_class="llama3_2" \ + base.checkpoint="${LLAMA_CHECKPOINT:?}" \ + base.params="${LLAMA_PARAMS:?}" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + model.dtype_override="bf16" \ + base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' \ + export.output_name="llama3_2.pte" ``` For convenience, an [exported ExecuTorch bf16 model](https://huggingface.co/executorch-community/Llama-3.2-1B-ET/blob/main/llama3_2-1B.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-ET/blob/main/ExportRecipe_1B.ipynb). @@ -186,10 +190,22 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/consolidated.00.pth.pth LLAMA_PARAMS=path/to/spinquant/params.json python -m extension.llm.export.export_llm \ - --config examples/models/llama/config/llama_xnnpack_spinquant.yaml - +base.model_class="llama3_2" \ - +base.checkpoint="${LLAMA_QUANTIZED_CHECKPOINT:?}" \ - +base.params="${LLAMA_PARAMS:?}" + base.model_class="llama3_2" \ + base.checkpoint="${LLAMA_QUANTIZED_CHECKPOINT:?}" \ + base.params="${LLAMA_PARAMS:?}" \ + model.use_sdpa_with_kv_cache=True \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + base.preq_mode="preq_8da4w_out_8da8w" \ + base.preq_group_size=32 \ + export.max_seq_length=2048 \ + export.max_context_length=2048 \ + export.output_name="llama3_2.pte" \ + model.use_kv_cache=True \ + model.dtype_override="fp32" \ + base.preq_embedding_quantize=\'8,0\' \ + quantization.use_spin_quant="native" \ + base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' ``` For convenience, an [exported ExecuTorch SpinQuant model](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8-ET/blob/main/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8-ET/blob/main/Export_Recipe_Llama_3_2_1B_Instruct_SpinQuant_INT4_EO8.ipynb). @@ -203,10 +219,23 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/consolidated.00.pth.pth LLAMA_PARAMS=path/to/qlora/params.json python -m extension.llm.export.export_llm \ - --config examples/models/llama/config/llama_xnnpack_qat.yaml - +base.model_class="llama3_2" \ - +base.checkpoint="${LLAMA_QUANTIZED_CHECKPOINT:?}" \ - +base.params="${LLAMA_PARAMS:?}" \ + base.model_class="llama3_2" \ + base.checkpoint="${LLAMA_QUANTIZED_CHECKPOINT:?}" \ + base.params="${LLAMA_PARAMS:?}" \ + quantization.use_qat=True \ + base.use_lora=16 \ + base.preq_mode="preq_8da4w_out_8da8w" \ + base.preq_group_size=32 \ + base.preq_embedding_quantize=\'8,0\' \ + model.use_sdpa_with_kv_cache=True \ + model.use_kv_cache=True \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + model.dtype_override="fp32" \ + export.max_seq_length=2048 \ + export.max_context_length=2048 \ + export.output_name="llama3_2.pte" \ + base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' ``` For convenience, an [exported ExecuTorch QAT+LoRA model](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-QLORA_INT4_EO8-ET/blob/main/Llama-3.2-1B-Instruct-QLORA_INT4_EO8.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-QLORA_INT4_EO8-ET/blob/main/Export_Recipe_Llama_3_2_1B_Instruct_QLORA_INT4_EO8.ipynb). @@ -217,13 +246,20 @@ You can export and run the original Llama 3 8B instruct model. 1. Llama 3 pretrained parameters can be downloaded from [Meta's official Llama 3 repository](https://github.com/meta-llama/llama3/). 2. Export model and generate `.pte` file -``` -python -m extension.llm.export.export_llm \ - --config examples/models/llama/config/llama_q8da4w.yaml - +base.model_clas="llama3" - +base.checkpoint= \ - +base.params= -``` + ``` + python -m extension.llm.export.export_llm \ + base.checkpoint= \ + base.params= \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + backend.xnnpack.enabled=True \ + quantization.qmode="8da4w" \ + quantization.group_size=128 \ + model.dtype_override="fp32" \ + base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' \ + quantization.embedding_quantize=\'4,32\' \ + export.output_name="llama3_kv_sdpa_xnn_qe_4_32.pte" + ``` Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `quantization.embedding_quantize=\'4,32\'` as shown above to further reduce the model size. @@ -240,20 +276,20 @@ python -m extension.llm.export.export_llm \ Note for Mac users: There's a known linking issue with Xcode 15.1. Refer to the section of Common Issues and Mitigations below for solutions. 2. Build llama runner. -``` -cmake -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DBUILD_TESTING=OFF \ - -DCMAKE_BUILD_TYPE=Release \ - -Bcmake-out/examples/models/llama \ - examples/models/llama + ``` + cmake -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DBUILD_TESTING=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -Bcmake-out/examples/models/llama \ + examples/models/llama -cmake --build cmake-out/examples/models/llama -j16 --config Release -``` + cmake --build cmake-out/examples/models/llama -j16 --config Release + ``` 3. Run model. Run options available [here](https://github.com/pytorch/executorch/blob/main/examples/models/llama/main.cpp#L18-L40). -``` -cmake-out/examples/models/llama/llama_main --model_path= --tokenizer_path= --prompt= -``` + ``` + cmake-out/examples/models/llama/llama_main --model_path= --tokenizer_path= --prompt= + ``` To build for CoreML backend and validate on Mac, replace `-DEXECUTORCH_BUILD_XNNPACK=ON` with `-DEXECUTORCH_BUILD_COREML=ON` diff --git a/examples/models/llama/config/llama_bf16.yaml b/examples/models/llama/config/llama_bf16.yaml deleted file mode 100644 index 8e89e8aa437..00000000000 --- a/examples/models/llama/config/llama_bf16.yaml +++ /dev/null @@ -1,7 +0,0 @@ -base: - metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' - -model: - use_kv_cache: True - use_sdpa_with_kv_cache: True - dtype_override: bf16 \ No newline at end of file diff --git a/examples/models/llama/config/llama_q8da4w.yaml b/examples/models/llama/config/llama_q8da4w.yaml deleted file mode 100644 index 476ae928c60..00000000000 --- a/examples/models/llama/config/llama_q8da4w.yaml +++ /dev/null @@ -1,11 +0,0 @@ -base: - metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' - -model: - dtype_override: fp32 - -quantization: - qmode: 8da4w - group_size: 128 - embedding_quantize: 4,32 - \ No newline at end of file diff --git a/examples/models/llama/config/llama_xnnpack_qat.yaml b/examples/models/llama/config/llama_xnnpack_qat.yaml deleted file mode 100644 index 2369ff1d279..00000000000 --- a/examples/models/llama/config/llama_xnnpack_qat.yaml +++ /dev/null @@ -1,23 +0,0 @@ -base: - preq_mode: preq_8da4w_out_8da8w - preq_group_size: 32 - preq_embedding_quantize: 8,0 - metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' - use_lora: 16 - -model: - use_sdpa_with_kv_cache: True - use_kv_cache: True - dtype_override: fp32 - -export: - max_seq_length: 2048 - max_context_length: 2048 - -quantization: - use_qat: True - -backend: - xnnpack: - enabled: True - extended_ops: True \ No newline at end of file diff --git a/examples/models/llama/config/llama_xnnpack_spinquant.yaml b/examples/models/llama/config/llama_xnnpack_spinquant.yaml deleted file mode 100644 index 441086d6f73..00000000000 --- a/examples/models/llama/config/llama_xnnpack_spinquant.yaml +++ /dev/null @@ -1,22 +0,0 @@ -base: - preq_mode: preq_8da4w_out_8da8w - preq_group_size: 32 - preq_embedding_quantize: 8,0 - metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' - -model: - use_sdpa_with_kv_cache: True - use_kv_cache: True - dtype_override: fp32 - -export: - max_seq_length: 2048 - max_context_length: 2048 - -quantization: - use_spin_quant: native - -backend: - xnnpack: - enabled: True - extended_ops: True \ No newline at end of file diff --git a/examples/models/llama/config/test_llm_config.py b/examples/models/llama/config/test_llm_config.py index 52b56d71a03..0853e9dbbd8 100644 --- a/examples/models/llama/config/test_llm_config.py +++ b/examples/models/llama/config/test_llm_config.py @@ -41,7 +41,7 @@ def test_local_global_attention_without_kv(self): def test_invalid_export_config_context_length(self): with self.assertRaises(ValueError): - ExportConfig(max_seq_length=256, max_context_length=128) + ExportConfig(max_seq_length=128, max_context_length=256) def test_invalid_qmode(self): with self.assertRaises(ValueError): @@ -84,8 +84,8 @@ def test_valid_llm_config(self): local_global_attention="[16, 32]", ), export=ExportConfig( - max_seq_length=128, - max_context_length=256, + max_seq_length=256, + max_context_length=128, output_dir="/tmp/export", output_name="model.pte", ), @@ -94,7 +94,7 @@ def test_valid_llm_config(self): backend=BackendConfig( xnnpack=XNNPackConfig(enabled=False), coreml=CoreMLConfig( - enabled=True, ios=17, compute_units=CoreMLComputeUnit.cpu_only + enabled=True, ios=17, compute_units=CoreMLComputeUnit.ALL ), ), ) diff --git a/examples/models/phi_4_mini/README.md b/examples/models/phi_4_mini/README.md index 8fb2f03ac4c..d168d54226e 100644 --- a/examples/models/phi_4_mini/README.md +++ b/examples/models/phi_4_mini/README.md @@ -8,7 +8,7 @@ Phi-4-mini uses the same example code as Llama, while the checkpoint, model para All commands for exporting and running Llama on various backends should also be applicable to Phi-4-mini, by swapping the following args: ``` base.model_class="phi_4_mini" -base.params="examples/models/phi-4-mini/config/config.json" +base.params="examples/models/phi-4-mini/config.json" base.checkpoint= ``` @@ -33,10 +33,16 @@ Export to XNNPack, no quantization: PHI_CHECKPOINT=path/to/checkpoint.pth python -m extension.llm.export.export_llm \ - --config config/phi_4_mini_xnnpack.yaml - +base.checkpoint="${PHI_CHECKPOINT=path/to/checkpoint.pth:?}" \ - +base.params="examples/models/phi-4-mini/config/config.json" \ - +export.output_name="phi-4-mini.pte" \ + base.model_class="phi_4_mini" \ + base.checkpoint="${PHI_CHECKPOINT=path/to/checkpoint.pth:?}" \ + base.params="examples/models/phi-4-mini/config.json" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=True \ + base.metadata='"{\"get_bos_id\":151643, \"get_eos_ids\":[151643]}"' \ + export.output_name="phi-4-mini.pte" \ + debug.verbose=True ``` Run using the executor runner: diff --git a/examples/models/phi_4_mini/config/config.json b/examples/models/phi_4_mini/config.json similarity index 100% rename from examples/models/phi_4_mini/config/config.json rename to examples/models/phi_4_mini/config.json diff --git a/examples/models/phi_4_mini/config/phi_4_mini_xnnpack.yaml b/examples/models/phi_4_mini/config/phi_4_mini_xnnpack.yaml deleted file mode 100644 index 9355bd99f64..00000000000 --- a/examples/models/phi_4_mini/config/phi_4_mini_xnnpack.yaml +++ /dev/null @@ -1,12 +0,0 @@ -base: - model_class: phi_4_mini - metadata: '{"get_bos_id":151643, "get_eos_ids":[151643]}' - -model: - use_kv_cache: True - use_sdpa_with_kv_cache: True - dtype_override: fp32 - -backend: - xnnpack: - enabled: True \ No newline at end of file diff --git a/examples/models/qwen2_5/config/1_5b_config.json b/examples/models/qwen2_5/1_5b_config.json similarity index 100% rename from examples/models/qwen2_5/config/1_5b_config.json rename to examples/models/qwen2_5/1_5b_config.json diff --git a/examples/models/qwen2_5/README.md b/examples/models/qwen2_5/README.md index 566a7a5c30b..57784169ece 100644 --- a/examples/models/qwen2_5/README.md +++ b/examples/models/qwen2_5/README.md @@ -8,7 +8,7 @@ Qwen 2.5 uses the same example code as Llama, while the checkpoint, model params All commands for exporting and running Llama on various backends should also be applicable to Qwen 2.5, by swapping the following args: ``` base.model_class="qwen2_5" -base.params="examples/models/qwen2_5/config/1_5b_config.json" +base.params="examples/models/qwen2_5/1_5b_config.json" base.checkpoint= ``` @@ -33,11 +33,16 @@ Export to XNNPack, no quantization: QWEN_CHECKPOINT=path/to/checkpoint.pth python -m extension.llm.export.export_llm \ - --config examples/models/qwen2_5/config/qwen2_5_xnnpack_q8da4w.yaml - +base.model_class="qwen2_5" \ - +base.checkpoint="${QWEN_CHECKPOINT:?}" \ - +base.params="examples/models/qwen2_5/1_5b_config.json" \ - +export.output_name="qwen2_5-1_5b.pte" \ + base.model_class="qwen2_5" \ + base.checkpoint="${QWEN_CHECKPOINT:?}" \ + base.params="examples/models/qwen2_5/1_5b_config.json" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=True \ + base.metadata='"{\"get_bos_id\":151643, \"get_eos_ids\":[151643]}"' \ + export.output_name="qwen2_5-1_5b.pte" \ + debug.verbose=True ``` Run using the executor runner: diff --git a/examples/models/qwen2_5/config/qwen2_5_xnnpack_q8da4w.yaml b/examples/models/qwen2_5/config/qwen2_5_xnnpack_q8da4w.yaml deleted file mode 100644 index 0e5c6f7624e..00000000000 --- a/examples/models/qwen2_5/config/qwen2_5_xnnpack_q8da4w.yaml +++ /dev/null @@ -1,11 +0,0 @@ -base: - metadata='{"get_bos_id":151643, "get_eos_ids":[151643]}' - -model: - use_kv_cache: True - use_sdpa_with_kv_cache: True - dtype_override: fp32 - -backend: - xnnpack: - enabled: True \ No newline at end of file diff --git a/examples/models/qwen3/config/0_6b_config.json b/examples/models/qwen3/0_6b_config.json similarity index 100% rename from examples/models/qwen3/config/0_6b_config.json rename to examples/models/qwen3/0_6b_config.json diff --git a/examples/models/qwen3/config/1_7b_config.json b/examples/models/qwen3/1_7b_config.json similarity index 100% rename from examples/models/qwen3/config/1_7b_config.json rename to examples/models/qwen3/1_7b_config.json diff --git a/examples/models/qwen3/config/4b_config.json b/examples/models/qwen3/4b_config.json similarity index 100% rename from examples/models/qwen3/config/4b_config.json rename to examples/models/qwen3/4b_config.json diff --git a/examples/models/qwen3/README.md b/examples/models/qwen3/README.md index d2d89db93c2..e24d8da2637 100644 --- a/examples/models/qwen3/README.md +++ b/examples/models/qwen3/README.md @@ -8,7 +8,7 @@ Qwen 3 uses the same example code as our optimized Llama model, while the checkp All commands for exporting and running Llama on various backends should also be applicable to Qwen 3, by swapping the following args: ``` base.model_class=[qwen3_0_6b,qwen3_1_7b,qwen3_4b] -base.params=[examples/models/qwen3/config/0_6b_config.json,examples/models/qwen3/config/1_7b_config.json,examples/models/config/qwen3/4b_config.json] +base.params=[examples/models/qwen3/0_6b_config.json,examples/models/qwen3/1_7b_config.json,examples/models/qwen3/4b_config.json] ``` ### Example export @@ -17,29 +17,49 @@ Here is a basic example for exporting Qwen 3, although please refer to the Llama Export 0.6b to XNNPack, quantized with 8da4w: ``` python -m extension.llm.export.export_llm \ - --config examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml - +base.model_class="qwen3_0_6b" \ - +base.params="examples/models/qwen3/config/0_6b_config.json" \ - +export.output_name="qwen3_0_6b.pte" \ - + base.model_class="qwen3_0_6b" \ + base.params="examples/models/qwen3/0_6b_config.json" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + quantization.qmode="8da4w" \ + base.metadata='"{\"get_bos_id\": 151644, \"get_eos_ids\":[151645]}"' \ + export.output_name="qwen3_0_6b.pte" \ + debug.verbose=True ``` Export 1.7b to XNNPack, quantized with 8da4w: ``` python -m extension.llm.export.export_llm \ - --config examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml - +base.model_class="qwen3_1_7b" \ - +base.params="examples/models/qwen3/config/1_7b_config.json" \ - +export.output_name="qwen3_1_7b.pte" \ + base.model_class="qwen3_1_7b" \ + base.params="examples/models/qwen3/1_7b_config.json" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + quantization.qmode="8da4w" \ + base.metadata='"{\"get_bos_id\": 151644, \"get_eos_ids\":[151645]}"' \ + export.output_name="qwen3_1_7b.pte" \ + debug.verbose=True ``` Export 4b to XNNPack, quantized with 8da4w: ``` python -m extension.llm.export.export_llm \ - --config examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml - +base.model_class="qwen3_4b" \ - +base.params="examples/models/qwen3/config/4b_config.json" \ - +export.output_name="qwen3_4b.pte" \ + base.model_class="qwen3_4b" \ + base.params="examples/models/qwen3/4b_config.json" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + quantization.qmode="8da4w" \ + base.metadata='"{\"get_bos_id\": 151644, \"get_eos_ids\":[151645]}"' \ + export.output_name="qwen3_4b.pte" \ + debug.verbose=True ``` ### Example run diff --git a/examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml b/examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml deleted file mode 100644 index 60292b1ecdc..00000000000 --- a/examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml +++ /dev/null @@ -1,15 +0,0 @@ -base: - metadata: '{"get_bos_id": 151644, "get_eos_ids":[151645]}' - -model: - use_kv_cache: True - use_sdpa_with_kv_cache: True - dtype_override: fp32 - -quantization: - qmode: 8da4w - -backend: - xnnpack: - enabled: True - extended_ops: True \ No newline at end of file diff --git a/examples/nxp/setup.sh b/examples/nxp/setup.sh index 1a050a79c19..1ef2cc82c2a 100644 --- a/examples/nxp/setup.sh +++ b/examples/nxp/setup.sh @@ -7,4 +7,4 @@ set -u # Install neutron-converter -pip install --extra-index-url https://eiq.nxp.com/repository neutron_converter_SDK_25_03 +pip install --extra-index-url https://eiq.nxp.com/repository neutron-converter_SDK_25_03 diff --git a/examples/qualcomm/qaihub_scripts/utils/export.py b/examples/qualcomm/qaihub_scripts/utils/export.py index 2ee1968dd82..4d252175dbb 100644 --- a/examples/qualcomm/qaihub_scripts/utils/export.py +++ b/examples/qualcomm/qaihub_scripts/utils/export.py @@ -18,6 +18,7 @@ from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.utils import ( draw_graph, + ExecutorchBackendConfig, from_context_binary, generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, @@ -25,7 +26,6 @@ ) from executorch.examples.qualcomm.qaihub_scripts.utils.utils import preprocess_binary from executorch.examples.qualcomm.utils import make_output_dir, SimpleADB -from executorch.exir import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass diff --git a/examples/qualcomm/util_scripts/README.md b/examples/qualcomm/util_scripts/README.md deleted file mode 100644 index 712bbcd4277..00000000000 --- a/examples/qualcomm/util_scripts/README.md +++ /dev/null @@ -1,79 +0,0 @@ -# CLI Tool for Quantize / Compile / Deploy PyTorch Model with QNN Backend - -An easy-to-use tool for quantizing / compiling / executing .pte program with Qualcomm AI Engine Direct. Tool is verified with [host environement](../../../docs/source/backends-qualcomm.md#host-os). - -## Description - -This tool aims for users who want to deploy models with ExecuTorch runtime. It's possible for them to produce .pte program in few steps.
- -### Quantizing Model - -* Save torch.nn.Module with .pt2 format & prepare input data - ```bash - # create workspace for following operations - cd path/to/executorch - mkdir cli_example - ``` - ```python - # take SimpleModel as an example - import torch - from executorch.backends.qualcomm.tests.models import SimpleModel - from pathlib import Path - # make example inputs - example_inputs = (torch.randn(1, 32, 28, 28), torch.randn(1, 32, 28, 28)) - # generate ExportedProgram - ep = torch.export.export(SimpleModel(), example_inputs) - # save to workspace - ws = f"{Path().cwd()}/cli_example" - torch.export.save(ep, f"{ws}/simple_model.pt2") - # prepare calibration dataset: 2 sets of data with 2 inputs each - input_list = "" - for i in range(2): - current_input = "" - for j in range(2): - file_name = f"{ws}/input_{i}_{j}.pt" - torch.save(torch.randn(1, 32, 28, 28), file_name) - current_input += f"{file_name} " - input_list += f"{current_input.strip()}\n" - - with open(f"{ws}/input_list", 'w') as f: - f.write(input_list) - ``` - -* Quantize - ```bash - # user could get more information via: PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli quantize -h - PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli quantize -a cli_example/simple_model.pt2 -o cli_example/quantize_output -c use_8a8w -i cli_example/input_list --per_channel - ``` -* Artifacts for quantized .pt2 file - - `cli_example/quantize_output/simple_model_quantized.pt2` - - -### Compiling Program - -* Compile .pt2 to .pte program - ```bash - # `pip install pydot` if package is missing - # user could get more information via: PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli compile -h - PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli compile -a cli_example/quantize_output/simple_model_quantized.pt2 -o cli_example/compile_output -m SM8750 - ``` -* (Optional) Compile pre-generated context binary to .pte program - ```bash - # `pip install pydot` if package is missing - # user could get more information via: PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli compile -h - PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli compile -a model.bin -o path/to/model/output -m SM8750 - ``` -* Artifacts for .pte file and figure of graph information - - `cli_example/compile_output/simple_model_quantized.pte` - - `cli_example/compile_output/simple_model_quantized.svg` - -### Executing Program - -* Execute .pte program - ```bash - # user could get more information via: PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli execute -h - PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli execute -a cli_example/compile_output/simple_model_quantized.pte -o cli_example/execute_output -i cli_example/input_list -s $DEVICE_SERIAL -b build-android -m SM8750 - ``` -* Artifacts for .pte file and figure of graph information - - `cli_example/execute_output/output_{data_index}_{output_index}.pt`.
- `data_index` represents the sequence of dataset, `output_index` stands for the order of graph output. diff --git a/examples/qualcomm/util_scripts/cli.py b/examples/qualcomm/util_scripts/cli.py deleted file mode 100644 index e4c4c5dcaf8..00000000000 --- a/examples/qualcomm/util_scripts/cli.py +++ /dev/null @@ -1,504 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import importlib -import logging -import os -import re -from pathlib import Path - -import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor -import numpy as np - -import torch - -from executorch.backends.qualcomm._passes.qnn_pass_manager import ( - get_capture_program_passes, -) -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset -from executorch.backends.qualcomm.utils.constants import QCOM_PASS_ACTIVATE_KEY -from executorch.backends.qualcomm.utils.utils import ( - draw_graph, - dump_context_from_pte, - from_context_binary, - generate_htp_compiler_spec, - generate_qnn_executorch_compiler_spec, - generate_qnn_executorch_option, - QNN_QUANT_TYPE_MAP, - QNN_TENSOR_TYPE_MAP, - to_edge_transform_and_lower_to_qnn, -) -from executorch.examples.qualcomm.qaihub_scripts.utils.utils import preprocess_binary -from executorch.examples.qualcomm.utils import ( - make_output_dir, - make_quantizer, - SimpleADB, -) -from executorch.exir import ExecutorchBackendConfig -from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass -from torchao.quantization import pt2e -from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e - - -def get_logger(): - logger = logging.getLogger("examples.qualcomm.util_scripts.cli") - handler = logging.StreamHandler() - handler.setFormatter( - logging.Formatter( - fmt="[%(asctime)s %(prefix)s] %(levelname)-8s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - ) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - logger.propagate = False - return logging.LoggerAdapter(logger, extra={"prefix": "QNN_BACKEND"}) - - -def get_io_info(pte_path, compiler_specs): - dtype_map = {} - for type_map in (QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP): - for k, v in type_map.items(): - dtype_map.setdefault(v, k) - - def fill_tensor_info(info, qnn_tensors, category): - for tensor in qnn_tensors: - encoding = tensor.GetEncodings() - quantization_info = { - "scale": encoding.data["scale"].tolist(), - "offset": encoding.data["offset"].tolist(), - "axis": encoding.axis, - } - info[category].append( - { - "name": tensor.GetName(), - "shape": tensor.GetDims().tolist(), - "dtype": dtype_map[tensor.GetDataType()], - "encoding": quantization_info, - } - ) - - in_key, out_key = "inputs", "outputs" - tensor_info = {in_key: [], out_key: []} - - path_of_pte = Path(pte_path) - dump_context_from_pte(path_of_pte.absolute()) - ctx_bin = [f for f in os.listdir(path_of_pte.parent) if Path(f).suffix == ".bin"][0] - # assume graph is fully delegated or it will be too hard to handle - with open(f"{path_of_pte.parent}/{ctx_bin}", "rb") as f: - ctx_bin = preprocess_binary(f.read(), compiler_specs) - # leverage QNN pybind interface to retrieve tensor encodings - qnn_mgr = PyQnnManagerAdaptor.QnnManager( - generate_qnn_executorch_option(compiler_specs), ctx_bin - ) - assert qnn_mgr.Init().value == 0, "failed to load context binary" - graph_name = qnn_mgr.GetGraphNames()[0] - qnn_mgr.AllocateTensor(graph_name) - fill_tensor_info(tensor_info, qnn_mgr.GetGraphInputs(graph_name), in_key) - fill_tensor_info(tensor_info, qnn_mgr.GetGraphOutputs(graph_name), out_key) - qnn_mgr.Destroy() - - return tensor_info - - -def quantize(args): - logger = get_logger() - - # get corresponding QnnQuantizer - try: - quant_dtype = getattr(QuantDtype, args.config) - act_observer = getattr(pt2e, args.activation_observer) - quantizer = make_quantizer( - quant_dtype=quant_dtype, - per_channel_conv=args.per_channel, - per_channel_linear=args.per_row, - act_observer=act_observer, - ) - except Exception: - logger.error( - f"Failed to retrieve expected config {args.config} / {args.activation_observer}." - ) - exit(1) - - # step 0: load saved model - ep = torch.export.load(args.artifact) - # step 1: use prepare_pt2e to annotate QDQ pairs - ep_prepared = prepare_pt2e(ep.module(), quantizer) - logger.info(f"perform calibration on {args.artifact}") - # step 2: perform calibration - with open(args.input_list, "r") as f: - for line in f.read().split("\n")[:-1]: - inputs = [torch.load(t, weights_only=True) for t in line.split(" ")] - ep_prepared(*inputs) - # step 3: use convert_pt2e to fix encodings of QDQ pairs - logger.info(f"saving calibrated model for {args.artifact}") - ep_converted = convert_pt2e(ep_prepared) - ep_quantized = torch.export.export(ep_converted, tuple(inputs)) - make_output_dir(args.output_folder) - torch.export.save( - ep_quantized, f"{args.output_folder}/{Path(args.artifact).stem}_quantized.pt2" - ) - - -def compile(args): - logger = get_logger() - - # setup memory planning - memory_planning_pass = MemoryPlanningPass( - alloc_graph_input=args.shared_buffer is None, - alloc_graph_output=args.shared_buffer is None, - ) - - file_name, extension = Path(args.artifact).stem, Path(args.artifact).suffix - make_output_dir(args.output_folder) - # setup compiler spec dedicated to QNN HTP backend - backend_options = generate_htp_compiler_spec(use_fp16=True) - # setup general compiler spec for QNN - compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.model), - backend_options=backend_options, - is_from_context_binary=extension == "bin", - ) - if extension == ".bin": - custom_op_name = f"ctx_loader_{file_name}" - # step 1: generate ExportedProgram with custom op as a binary loader & lower it w/QnnBackend - logger.info(f"exporting program for {args.artifact}") - prog_info = from_context_binary( - args.artifact, custom_op_name, getattr(QcomChipset, args.model) - ) - # step 2: write pte files and store final graph - logger.info(f"exporting {file_name}.pte") - with open(f"{args.output_folder}/{file_name}.pte", "wb") as f: - prog_info["edge_program_manager"].to_executorch( - config=ExecutorchBackendConfig( - memory_planning_pass=memory_planning_pass - ) - ).write_to_file(f) - logger.info(f"exporting network graph with {file_name}.svg") - draw_graph(file_name, args.output_folder, prog_info["exported_program"]) - elif extension == ".pt2": - # step 0: prepare exported_program - ep = torch.export.load(args.artifact) - sample_inputs = ep.example_inputs[0] - # step 1: start lowering to QnnBackend - logger.info(f"start lowering program for {args.artifact}") - passes, user_passes = get_capture_program_passes(), [] - if args.pass_job is not None: - for job in args.pass_job: - try: - user_passes.append( - importlib.import_module( - "executorch.backends.qualcomm._passes", job - ) - ) - except Exception: - logger.error(f"failed to extract designated pass '{args.artifact}'") - - for user_pass in user_passes: - passes[user_pass][QCOM_PASS_ACTIVATE_KEY] = True - - edge_prog_mgr = to_edge_transform_and_lower_to_qnn( - module=ep.module(), - inputs=sample_inputs, - compiler_specs=compiler_specs, - passes_job=passes, - ) - # step 2: write pte files and store final graph - logger.info(f"exporting {file_name}.pte") - with open(f"{args.output_folder}/{file_name}.pte", "wb") as f: - edge_prog_mgr.to_executorch( - config=ExecutorchBackendConfig( - memory_planning_pass=memory_planning_pass - ) - ).write_to_file(f) - logger.info(f"exporting network graph with {file_name}.svg") - draw_graph(file_name, args.output_folder, edge_prog_mgr.exported_program()) - else: - logger.error(f"unsupported file extension for '{args.artifact}'") - - -def execute(args): - logger = get_logger() - - pte_name = Path(args.artifact).stem - - # load input files - logger.info("loading user inputs") - user_inputs, input_list = [], "" - with open(args.input_list, "r") as f: - for line in f.read().split("\n")[:-1]: - inputs, input_names = [], "" - for data in line.split(" "): - input_names += f"{Path(data).stem}.raw " - inputs.append(torch.load(data, weights_only=True)) - user_inputs.append(inputs) - input_list += input_names.strip() + "\n" - - logger.info("retrieving graph I/O") - # setup compiler spec dedicated to QNN HTP backend - backend_options = generate_htp_compiler_spec(use_fp16=True) - # setup general compiler spec for QNN - compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.model), - backend_options=backend_options, - ) - io_info = get_io_info(args.artifact, compiler_specs) - - logger.info("preparing ADB connection") - # leverage SimpleADB for e2e inference - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=args.build_folder, - pte_path=args.artifact, - workspace=f"/data/local/tmp/executorch/{pte_name}", - device_id=args.device, - soc_model=args.model, - host_id=args.host, - shared_buffer=args.shared_buffer, - ) - - logger.info("pushing QNN libraries & other artifacts") - adb.push(inputs=user_inputs, input_list=input_list) - - logger.info("starting inference") - adb.execute() - - def post_process(): - torch_to_numpy_dtype_dict = { - torch.bool: np.dtype("bool"), - torch.uint8: np.dtype("uint8"), - torch.int8: np.dtype("int8"), - torch.int16: np.dtype("int16"), - torch.int32: np.dtype("int32"), - torch.int64: np.dtype("int64"), - torch.float16: np.dtype("float16"), - torch.float32: np.dtype("float32"), - torch.float64: np.dtype("float64"), - torch.complex64: np.dtype("complex64"), - torch.complex128: np.dtype("complex128"), - } - output_info = io_info["outputs"] - output_folder = f"{args.output_folder}/outputs" - for _, f in enumerate(os.listdir(output_folder)): - filename = os.path.join(output_folder, f) - match_res = re.match(r".*([0-9]+)_([0-9]+)\.raw$", filename) - data_index, output_index = int(match_res.group(1)), int(match_res.group(2)) - output = np.fromfile( - filename, - dtype=eval( - f"np.{torch_to_numpy_dtype_dict[output_info[output_index]['dtype']]}" - ), - ) - output = torch.from_numpy( - output.reshape(output_info[output_index]["shape"]) - ) - torch.save( - output, f"{args.output_folder}/output_{data_index}_{output_index}.pt" - ) - - logger.info("collecting output data") - make_output_dir(args.output_folder) - adb.pull(args.output_folder, post_process) - logger.info(f"execution finished, please check {args.output_folder} for results") - - -def main(): - parser = argparse.ArgumentParser( - description=( - "Utility to quantize / compile / execute models via Qualcomm backend" - ), - ) - subparsers = parser.add_subparsers( - title="subcommands", - description=( - "[quantize]: Perform PTQ with QnnQuantizer for models in .pt2 extension. " - "[compile]: Compile model in .pt2 extenstion / context binary into .pte file. " - "[execute]: Perform on-device inference with given .pte." - ), - ) - - sub_quantize = subparsers.add_parser( - name="quantize", - help=( - "e.g. python -m executorch.example.qualcomm.util_scripts.cli quantize " - "-a model.pt2 -c use_8a8w -i calibration_data" - ), - ) - sub_quantize.add_argument( - "-a", - "--artifact", - type=str, - required=True, - help="Path to saved .pt2 model in floating point precision.", - ) - sub_quantize.add_argument( - "-o", - "--output_folder", - type=str, - default="./output_quantized", - help="Path to output artifact, store in 'output_quantized' if not given.", - ) - sub_quantize.add_argument( - "-c", - "--config", - type=str, - default="use_8a8w", - help=(f"Configuration to be applied: {list(QuantDtype.__members__.keys())}."), - ) - sub_quantize.add_argument( - "-i", - "--input_list", - type=str, - required=True, - help=( - "List of input files specified for calibration. " - 'e.g. File content with: "input_0_0.pt2 input_0_1.pt2\\ninput_1_0.pt2 input_1_1.pt2" ' - "means there are 2 sets of data for calibration on a graph with 2 inputs." - ), - ) - sub_quantize.add_argument( - "--per_channel", - action="store_true", - help="Use per_channel encoding for operator convolution and its' families.", - ) - sub_quantize.add_argument( - "--per_row", - action="store_true", - help="Use per_row encoding for operator linear.", - ) - sub_quantize.add_argument( - "--activation_observer", - type=str, - default="MovingAverageMinMaxObserver", - help=( - "Activation observer for PTQ " - "(MinMaxObserver / MovingAverageMinMaxObserver / HistogramObserver)." - ), - ) - sub_quantize.set_defaults(callback=quantize) - - sub_compile = subparsers.add_parser( - name="compile", - help=( - "e.g. python -m executorch.example.qualcomm.util_scripts.cli compile " - "-a model.(pt2 / bin) -m SM8750" - ), - ) - sub_compile.add_argument( - "-a", - "--artifact", - type=str, - required=True, - help="Path to saved .pt2 model or pre-generated context binary.", - ) - sub_compile.add_argument( - "-m", - "--model", - type=str, - required=True, - help="SoC model. e.g. SM8750", - ) - sub_compile.add_argument( - "-o", - "--output_folder", - type=str, - default="./output_pte", - help="Path to output artifacts, store in 'output_pte' if not given.", - ) - sub_compile.add_argument( - "-p", - "--pass_job", - nargs="+", - type=str, - help=( - 'Add extra passes for model lowering. e.g. "ExpandBroadcastTensorShape".' - ), - ) - sub_compile.add_argument( - "--shared_buffer", - help=( - "Enable usage of shared buffer between application and backend for graph I/O." - ), - action="store_true", - ) - sub_compile.set_defaults(callback=compile) - - sub_execute = subparsers.add_parser( - name="execute", - help=( - "e.g. python -m executorch.example.qualcomm.util_scripts.cli " - "execute -p model.pte -i execution_data -s device_serial" - ), - ) - sub_execute.add_argument( - "-a", - "--artifact", - type=str, - required=True, - help="Path to .pte file generated from 'compile' subcommand.", - ) - sub_execute.add_argument( - "-i", - "--input_list", - type=str, - help=( - "List of input files specified for execution. " - 'e.g. File content with: "input_0_0.pt2 input_0_1.pt2\\ninput_1_0.pt2 input_1_1.pt2" ' - "means there are 2 sets of data for execution on a graph with 2 inputs.\n" - ), - ) - sub_execute.add_argument( - "-m", - "--model", - type=str, - required=True, - help="SoC model. e.g. SM8750", - ) - sub_execute.add_argument( - "-s", - "--device", - type=str, - required=True, - help="Serial no of device which could be obtained by 'adb devices'.", - ) - sub_execute.add_argument( - "-o", - "--output_folder", - type=str, - default="./output_data", - help="Path to output data, store in 'output_data' if not given.", - ) - sub_execute.add_argument( - "-b", - "--build_folder", - help="Path to cmake binary directory for android, e.g., /path/to/build-android", - type=str, - required=True, - ) - sub_execute.add_argument( - "-H", - "--host", - type=str, - help="Gateway hostname.", - ) - sub_execute.add_argument( - "--shared_buffer", - help=( - "Enable usage of shared buffer between application and backend for graph I/O." - " Please use with `--shared_buffer` in compile command." - ), - action="store_true", - ) - sub_execute.set_defaults(callback=execute) - - args = parser.parse_args() - args.callback(args) - - -if __name__ == "__main__": - main() diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index e70510b0b70..6d9a6653ec7 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -9,7 +9,6 @@ import argparse import os -import shutil import subprocess import sys import tempfile @@ -396,7 +395,9 @@ def build_executorch_binary( def make_output_dir(path: str): if os.path.exists(path): - shutil.rmtree(path, ignore_errors=True) + for f in os.listdir(path): + os.remove(os.path.join(path, f)) + os.removedirs(path) os.makedirs(path) diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index 8699fe2fd02..749e8f5c2f1 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -31,7 +31,6 @@ python_library( ":sym_shape_eval_pass", ":sym_to_tensor_pass", ":weights_to_outputs_pass", - ":reinplace_pass", "//caffe2:torch", "//executorch/exir:common", "//executorch/exir:control_flow", @@ -69,17 +68,6 @@ python_library( ], ) -python_library( - name = "reinplace_pass", - srcs = [ - "reinplace.py", - ], - deps = [ - "//caffe2:torch", - "//executorch/exir/dialects:lib", - ], -) - python_library( name = "insert_write_back_for_buffers_pass", srcs = [ diff --git a/exir/passes/reinplace.py b/exir/passes/reinplace.py deleted file mode 100644 index 349869a2f4b..00000000000 --- a/exir/passes/reinplace.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from typing import Set - -import torch -from executorch.exir.dialects._ops import ops -from torch.export import ExportedProgram - - -def _is_index_put(node: torch.fx.Node) -> bool: - """Check if a node is an index_put operation.""" - return node.op == "call_function" and node.target in ( - torch.ops.aten.index_put.default, - ops.edge.aten.index_put.default, - ) - - -def _is_safe_to_reinplace( - node: torch.fx.Node, - later_nodes: Set[torch.fx.Node], - inputs: Set[torch.fx.Node], - mutable_inputs: Set[torch.fx.Node], -) -> bool: - # This node is used later in the graph so we can't reinplace it - # There is probably a faster way to do this but this works for now. - if node in later_nodes: - return False - # If its not an input then we can reinplace it - if node not in inputs: - return True - # If its a mutable input then we can reinplace it - elif node in mutable_inputs: - return True - else: # input but not mutable input - return False - - -def _is_mutable_user_input( - node: torch.fx.Node, exported_program: ExportedProgram -) -> bool: - return ( - node.target in exported_program.graph_signature.user_inputs_to_mutate.values() - ) - - -def _is_mutable_buffer(node: torch.fx.Node, exported_program: ExportedProgram) -> bool: - if node.target not in exported_program.graph_signature.inputs_to_buffers: - return False - buf = exported_program.graph_signature.inputs_to_buffers[node.target] - return buf in exported_program.graph_signature.buffers_to_mutate.values() - - -def reinplace_pass(ep: ExportedProgram) -> ExportedProgram: - """ - Pass that loops over nodes in an exported program and collects the first argument - of every call_function node that is a view_copy operation. - - Args: - exported_program: The ExportedProgram to analyze - - Returns: - Set of nodes that are first arguments to view_copy operations - """ - seen_nodes: Set[torch.fx.Node] = set() - # Get all placeholders - inputs = set() - for node in ep.graph.nodes: - if node.op == "placeholder": - inputs.add(node) - # Get all inputs that we could potentially mutate - mutable_nodes = set( - [ - node - for node in inputs - if _is_mutable_user_input(node, ep) or _is_mutable_buffer(node, ep) - ] - ) - - results = set() - for node in reversed(ep.graph.nodes): - if _is_index_put(node): - # Check if this index_put node is safe to inplace - # The first argument is the base tensor being indexed into - first_arg = node.args[0] - if _is_safe_to_reinplace(first_arg, seen_nodes, inputs, mutable_nodes): - # This index_put is safe to reinplace - with ep.graph.inserting_before(node): - new_node = ep.graph.call_function( - ops.edge.aten.index_put_.default, args=node.args - ) - new_node.meta["val"] = node.meta["val"] - node.replace_all_uses_with(new_node) - ep.graph.erase_node(node) - results.add(first_arg) - elif node.op == "call_function": - seen_nodes.update(node.all_input_nodes) - return ep diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 2c2ad3e05f0..1423984c563 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -136,18 +136,6 @@ python_unittest( ], ) -python_unittest( - name = "reinplace_pass", - srcs = [ - "test_reinplace_pass.py", - ], - deps = [ - "//caffe2:torch", - "//executorch/exir:lib", - "//executorch/exir/passes:lib", - ], -) - cpp_library( name = "test_lib", srcs = [ diff --git a/exir/tests/test_reinplace_pass.py b/exir/tests/test_reinplace_pass.py deleted file mode 100644 index 2f4538770d6..00000000000 --- a/exir/tests/test_reinplace_pass.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import unittest - -import torch -from executorch.exir import to_edge -from executorch.exir.passes.reinplace import reinplace_pass -from torch.export import export - - -class TestReinplacePass(unittest.TestCase): - def test_index_put_reinplace(self) -> None: - """Test that index_put on a mutable buffer can be reinplaced.""" - - class IndexPutModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("state", torch.zeros(5)) - - def forward( - self, indices: torch.Tensor, values: torch.Tensor - ) -> torch.Tensor: - # index_put on buffer (non-user input) should be safe - self.state.index_put_((indices,), values) - return self.state - - model = IndexPutModel() - indices = torch.tensor([0]) - values = torch.tensor([1.0]) - - exported_program = export(model, (indices, values), strict=True) - print(exported_program.graph) - edge_program = to_edge(exported_program).exported_program() - - # Find the index_put node - index_put_node = None - for node in edge_program.graph.nodes: - if node.op == "call_function" and "index_put" in str(node.target): - index_put_node = node - break - - self.assertIsNotNone(index_put_node, "Should find an index_put node") - - ep = reinplace_pass(edge_program) - # Find the index_put node - index_put_node = None - for node in ep.graph.nodes: - if node.op == "call_function" and "index_put_" in str(node.target): - index_put_node = node - break - - self.assertIsNotNone(index_put_node, "Should find an index_put_ node") - - def test_cant_reinplace(self) -> None: - """Test that index_put on a mutable buffer that is viewed later is not safe.""" - - class IndexPutModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("state", torch.zeros(5)) - - def forward( - self, indices: torch.Tensor, values: torch.Tensor - ) -> torch.Tensor: - # index_put on buffer (non-user input) should be safe - x = self.state.index_put((indices,), values) - self.state.add_(1) - return x - - model = IndexPutModel() - indices = torch.tensor([0]) - values = torch.tensor([1.0]) - - exported_program = export(model, (indices, values), strict=True) - edge_program = to_edge(exported_program).exported_program() - - # Find the index_put node - index_put_node = None - for node in edge_program.graph.nodes: - if node.op == "call_function" and "index_put" in str(node.target): - index_put_node = node - break - - self.assertIsNotNone(index_put_node, "Should find an index_put node") - - ep = reinplace_pass(edge_program) - # Find the index_put node - index_put_node = None - for node in ep.graph.nodes: - if ( - node.op == "call_function" - and "index_put" in str(node.target) - and "index_put_" not in str(node.target) - ): - index_put_node = node - break - - self.assertIsNotNone(index_put_node, "Should still find an index_put node") diff --git a/extension/flat_tensor/serialize/serialize.py b/extension/flat_tensor/serialize/serialize.py index 7f3332c4303..5b29d7ccacd 100644 --- a/extension/flat_tensor/serialize/serialize.py +++ b/extension/flat_tensor/serialize/serialize.py @@ -19,11 +19,7 @@ from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile from executorch.exir._serialize._program import _insert_flatbuffer_header -from executorch.exir._serialize.data_serializer import ( - DataEntry, - DataPayload, - DataSerializer, -) +from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required @@ -38,9 +34,6 @@ # endian. _HEADER_BYTEORDER: Literal["little"] = "little" -# Current version. Keep in sync with c++ version number in serialize. -_FLAT_TENSOR_VERSION: int = 0 - def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: """Serializes a FlatTensor to a flatbuffer and returns the serialized data.""" @@ -327,7 +320,7 @@ def serialize( # Create FlatTensor, which describes of the contents of the file and # points to all the data segments. It will be serialized to flatbuffer. flat_tensor = FlatTensor( - version=_FLAT_TENSOR_VERSION, + version=0, # Keep in sync with c++ version number in serialize.h segments=data_segments, named_data=named_data, ) @@ -390,49 +383,4 @@ def deserialize(self, blob: Cord) -> DataPayload: """ Deserializes a flat_tensor blob into a list of tensor metadata and tensors. """ - - data = bytes(blob) - - # Read header. Verify that it's valid. - header = FlatTensorHeader.from_bytes(data[8:]) - if not header.is_valid(): - raise RuntimeError( - "Flat tensor header is invalid. File is likely incorrect format or corrupt." - ) - - # Deserialize the flat tensor data, which contains the data offsets and tensor metadata. - flat_tensor_bytes = data[0 : header.flatbuffer_offset + header.flatbuffer_size] - flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes) - - # Verify that this is a supported version. - if flat_tensor.version != _FLAT_TENSOR_VERSION: - raise NotImplementedError( - f"Flat tensor files reports unsupported version {flat_tensor.version}. Expected {_FLAT_TENSOR_VERSION}." - ) - - # Extract the buffers. - buffers = [ - data[ - header.segment_base_offset - + segment.offset : header.segment_base_offset - + segment.offset - + segment.size - ] - for segment in flat_tensor.segments - ] - - payload = DataPayload( - buffers=buffers, - named_data={}, - ) - - # Read the named data entries. - for named_data in flat_tensor.named_data: - entry = DataEntry( - buffer_index=named_data.segment_index, - alignment=1, - tensor_layout=named_data.tensor_layout, - ) - payload.named_data[named_data.key] = entry - - return payload + raise NotImplementedError("deserialize_data") diff --git a/extension/flat_tensor/test/test_serialize.py b/extension/flat_tensor/test/test_serialize.py index 13402e60a65..80ee59ae974 100644 --- a/extension/flat_tensor/test/test_serialize.py +++ b/extension/flat_tensor/test/test_serialize.py @@ -6,19 +6,17 @@ # pyre-unsafe -import dataclasses import math import unittest from typing import List, Optional -from executorch.exir._serialize._cord import Cord - from executorch.exir._serialize.data_serializer import ( DataEntry, DataPayload, DataSerializer, ) + from executorch.exir._serialize.padding import aligned_size from executorch.exir.schema import ScalarType @@ -225,39 +223,3 @@ def test_serialize(self) -> None: ) self.assertEqual(segments[2].offset + segments[2].size, len(segment_data)) - - def test_round_trip(self) -> None: - # Serialize and then deserialize the test payload. Make sure it's reconstructed - # properly. - config = FlatTensorConfig() - serializer: DataSerializer = FlatTensorSerializer(config) - - # Round trip the data. - serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD)) - deserialized_payload = serializer.deserialize(Cord(serialized_data)) - - # Validate the deserialized payload. Since alignment isn't serialized, we need to - # do this somewhat manually. - for i in range(len(deserialized_payload.buffers)): - self.assertEqual( - TEST_DATA_PAYLOAD.buffers[i], - deserialized_payload.buffers[i], - f"Buffer at index {i} does not match.", - ) - - self.assertEqual( - TEST_DATA_PAYLOAD.named_data.keys(), deserialized_payload.named_data.keys() - ) - - SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison. - for key in TEST_DATA_PAYLOAD.named_data.keys(): - reference = TEST_DATA_PAYLOAD.named_data[key] - actual = deserialized_payload.named_data[key] - - for field in dataclasses.fields(reference): - if field.name not in SKIP_FIELDS: - self.assertEqual( - getattr(reference, field.name), - getattr(actual, field.name), - f"Named data record {key}.{field.name} does not match.", - ) diff --git a/extension/llm/export/README.md b/extension/llm/export/README.md index e97b9e10462..96f36acc1b4 100644 --- a/extension/llm/export/README.md +++ b/extension/llm/export/README.md @@ -23,9 +23,9 @@ The LLM export process transforms a model from its original format to an optimiz ## Usage -The export API supports a Hydra-style CLI where you can you configure using yaml and also CLI args. +The export API supports two configuration approaches: -### Hydra CLI Arguments +### Option 1: Hydra CLI Arguments Use structured configuration arguments directly on the command line: @@ -41,7 +41,7 @@ python -m extension.llm.export.export_llm \ quantization.qmode=8da4w ``` -### Configuration File +### Option 2: Configuration File Create a YAML configuration file and reference it: @@ -78,21 +78,53 @@ debug: verbose: true ``` -You can you also still provide additional overrides using the CLI args as well: +**Important**: You cannot mix both approaches. Use either CLI arguments OR a config file, not both. +## Example Commands + +### Export Qwen3 0.6B with XNNPACK backend and quantization ```bash -python -m extension.llm.export.export_llm - --config my_config.yaml - base.model_class="llama2" - +export.max_context_length=1024 +python -m extension.llm.export.export_llm \ + base.model_class=qwen3_0_6b \ + base.params=examples/models/qwen3/0_6b_config.json \ + base.metadata='{"get_bos_id": 151644, "get_eos_ids":[151645]}' \ + model.use_kv_cache=true \ + model.use_sdpa_with_kv_cache=true \ + model.dtype_override=FP32 \ + export.max_seq_length=512 \ + export.output_name=qwen3_0_6b.pte \ + quantization.qmode=8da4w \ + backend.xnnpack.enabled=true \ + backend.xnnpack.extended_ops=true \ + debug.verbose=true ``` -Note that if a config file is specified and you want to specify a CLI arg that is not in the config, you need to prepend with a `+`. You can read more about this in the Hydra [docs](https://hydra.cc/docs/advanced/override_grammar/basic/). - - -## Example Commands +### Export Phi-4-Mini with custom checkpoint +```bash +python -m extension.llm.export.export_llm \ + base.model_class=phi_4_mini \ + base.checkpoint=/path/to/phi4_checkpoint.pth \ + base.params=examples/models/phi-4-mini/config.json \ + base.metadata='{"get_bos_id":151643, "get_eos_ids":[151643]}' \ + model.use_kv_cache=true \ + model.use_sdpa_with_kv_cache=true \ + export.max_seq_length=256 \ + export.output_name=phi4_mini.pte \ + backend.xnnpack.enabled=true \ + debug.verbose=true +``` -Please refer to the docs for some of our example suported models ([Llama](https://github.com/pytorch/executorch/blob/main/examples/models/llama/README.md), [Qwen3](https://github.com/pytorch/executorch/tree/main/examples/models/qwen3/README.md), [Phi-4-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi_4_mini/README.md)). +### Export with CoreML backend (iOS optimization) +```bash +python -m extension.llm.export.export_llm \ + base.model_class=llama3 \ + model.use_kv_cache=true \ + export.max_seq_length=128 \ + backend.coreml.enabled=true \ + backend.coreml.compute_units=ALL \ + quantization.pt2e_quantize=coreml_c4w \ + debug.verbose=true +``` ## Configuration Options @@ -102,4 +134,4 @@ For a complete reference of all available configuration options, see the [LlmCon - [Llama Examples](../../../examples/models/llama/README.md) - Comprehensive Llama export guide - [LLM Runner](../runner/) - Running exported models -- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview +- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview \ No newline at end of file diff --git a/extension/llm/export/export_llm.py b/extension/llm/export/export_llm.py index e0467250a28..e995b329f30 100644 --- a/extension/llm/export/export_llm.py +++ b/extension/llm/export/export_llm.py @@ -30,7 +30,6 @@ """ import argparse -import os import sys from typing import Any, List, Tuple @@ -46,6 +45,7 @@ def parse_config_arg() -> Tuple[str, List[Any]]: + """First parse out the arg for whether to use Hydra or the old CLI.""" parser = argparse.ArgumentParser(add_help=True) parser.add_argument("--config", type=str, help="Path to the LlmConfig file") args, remaining = parser.parse_known_args() @@ -56,7 +56,6 @@ def pop_config_arg() -> str: """ Removes '--config' and its value from sys.argv. Assumes --config is specified and argparse has already validated the args. - Returns the config file path. """ idx = sys.argv.index("--config") value = sys.argv[idx + 1] @@ -64,42 +63,30 @@ def pop_config_arg() -> str: return value -def add_hydra_config_args(config_file_path: str) -> None: - """ - Breaks down the config file path into directory and filename, - resolves the directory to an absolute path, and adds the - --config_path and --config_name arguments to sys.argv. - """ - config_dir = os.path.dirname(config_file_path) - config_name = os.path.basename(config_file_path) - - # Resolve to absolute path - config_dir_abs = os.path.abspath(config_dir) - - # Add the hydra config arguments to sys.argv - sys.argv.extend(["--config-path", config_dir_abs, "--config-name", config_name]) - - -@hydra.main(version_base=None, config_name="llm_config", config_path=None) +@hydra.main(version_base=None, config_name="llm_config") def hydra_main(llm_config: LlmConfig) -> None: - structured = OmegaConf.structured(LlmConfig) - merged = OmegaConf.merge(structured, llm_config) - llm_config_obj = OmegaConf.to_object(merged) - export_llama(llm_config_obj) + export_llama(OmegaConf.to_object(llm_config)) def main() -> None: - # First parse out the arg for whether to use Hydra or the old CLI. config, remaining_args = parse_config_arg() if config: - # Pop out --config and its value so that they are not parsed by - # Hydra's main. - config_file_path = pop_config_arg() + # Check if there are any remaining hydra CLI args when --config is specified + # This might change in the future to allow overriding config file values + if remaining_args: + raise ValueError( + "Cannot specify additional CLI arguments when using --config. " + f"Found: {remaining_args}. Use either --config file or hydra CLI args, not both." + ) - # Add hydra config_path and config_name arguments to sys.argv. - add_hydra_config_args(config_file_path) - - hydra_main() + config_file_path = pop_config_arg() + default_llm_config = LlmConfig() + llm_config_from_file = OmegaConf.load(config_file_path) + # Override defaults with values specified in the .yaml provided by --config. + merged_llm_config = OmegaConf.merge(default_llm_config, llm_config_from_file) + export_llama(merged_llm_config) + else: + hydra_main() if __name__ == "__main__": diff --git a/extension/llm/export/test/test_export_llm.py b/extension/llm/export/test/test_export_llm.py index ab7db1b4e3a..7d17b7819d3 100644 --- a/extension/llm/export/test/test_export_llm.py +++ b/extension/llm/export/test/test_export_llm.py @@ -21,7 +21,7 @@ class TestExportLlm(unittest.TestCase): def test_parse_config_arg_with_config(self) -> None: """Test parse_config_arg when --config is provided.""" # Mock sys.argv to include --config - test_argv = ["export_llm.py", "--config", "test_config.yaml", "extra", "args"] + test_argv = ["script.py", "--config", "test_config.yaml", "extra", "args"] with patch.object(sys, "argv", test_argv): config_path, remaining = parse_config_arg() self.assertEqual(config_path, "test_config.yaml") @@ -29,7 +29,7 @@ def test_parse_config_arg_with_config(self) -> None: def test_parse_config_arg_without_config(self) -> None: """Test parse_config_arg when --config is not provided.""" - test_argv = ["export_llm.py", "debug.verbose=True"] + test_argv = ["script.py", "debug.verbose=True"] with patch.object(sys, "argv", test_argv): config_path, remaining = parse_config_arg() self.assertIsNone(config_path) @@ -37,21 +37,11 @@ def test_parse_config_arg_without_config(self) -> None: def test_pop_config_arg(self) -> None: """Test pop_config_arg removes --config and its value from sys.argv.""" - test_argv = ["export_llm.py", "--config", "test_config.yaml", "other", "args"] + test_argv = ["script.py", "--config", "test_config.yaml", "other", "args"] with patch.object(sys, "argv", test_argv): config_path = pop_config_arg() self.assertEqual(config_path, "test_config.yaml") - self.assertEqual(sys.argv, ["export_llm.py", "other", "args"]) - - def test_with_cli_args(self) -> None: - """Test main function with only hydra CLI args.""" - test_argv = ["export_llm.py", "debug.verbose=True"] - with patch.object(sys, "argv", test_argv): - with patch( - "executorch.extension.llm.export.export_llm.hydra_main" - ) as mock_hydra: - main() - mock_hydra.assert_called_once() + self.assertEqual(sys.argv, ["script.py", "other", "args"]) @patch("executorch.extension.llm.export.export_llm.export_llama") def test_with_config(self, mock_export_llama: MagicMock) -> None: @@ -67,7 +57,7 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: model: dtype_override: fp16 export: - max_seq_length: 128 + max_seq_length: 256 quantization: pt2e_quantize: xnnpack_dynamic use_spin_quant: cuda @@ -80,7 +70,7 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: config_file = f.name try: - test_argv = ["export_llm.py", "--config", config_file] + test_argv = ["script.py", "--config", config_file] with patch.object(sys, "argv", test_argv): main() @@ -88,65 +78,75 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: mock_export_llama.assert_called_once() called_config = mock_export_llama.call_args[0][0] self.assertEqual( - called_config.base.tokenizer_path, "/path/to/tokenizer.json" + called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json" + ) + self.assertEqual(called_config["base"]["model_class"], "llama2") + self.assertEqual(called_config["base"]["preq_mode"].value, "8da4w") + self.assertEqual(called_config["model"]["dtype_override"].value, "fp16") + self.assertEqual(called_config["export"]["max_seq_length"], 256) + self.assertEqual( + called_config["quantization"]["pt2e_quantize"].value, "xnnpack_dynamic" + ) + self.assertEqual( + called_config["quantization"]["use_spin_quant"].value, "cuda" ) - self.assertEqual(called_config.base.model_class, "llama2") - self.assertEqual(called_config.base.preq_mode.value, "8da4w") - self.assertEqual(called_config.model.dtype_override.value, "fp16") - self.assertEqual(called_config.export.max_seq_length, 128) self.assertEqual( - called_config.quantization.pt2e_quantize.value, "xnnpack_dynamic" + called_config["backend"]["coreml"]["quantize"].value, "c4w" ) - self.assertEqual(called_config.quantization.use_spin_quant.value, "cuda") - self.assertEqual(called_config.backend.coreml.quantize.value, "c4w") self.assertEqual( - called_config.backend.coreml.compute_units.value, "cpu_and_gpu" + called_config["backend"]["coreml"]["compute_units"].value, "cpu_and_gpu" ) finally: os.unlink(config_file) - @patch("executorch.extension.llm.export.export_llm.export_llama") - def test_with_config_and_cli(self, mock_export_llama: MagicMock) -> None: - """Test main function with --config file and no hydra args.""" + def test_with_cli_args(self) -> None: + """Test main function with only hydra CLI args.""" + test_argv = ["script.py", "debug.verbose=True"] + with patch.object(sys, "argv", test_argv): + with patch( + "executorch.extension.llm.export.export_llm.hydra_main" + ) as mock_hydra: + main() + mock_hydra.assert_called_once() + + def test_config_with_cli_args_error(self) -> None: + """Test that --config rejects additional CLI arguments to prevent mixing approaches.""" # Create a temporary config file with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write( - """ -base: - model_class: llama2 -model: - dtype_override: fp16 -backend: - xnnpack: - enabled: False -""" - ) + f.write("base:\n checkpoint: /path/to/checkpoint.pth") + config_file = f.name + + try: + test_argv = ["script.py", "--config", config_file, "debug.verbose=True"] + with patch.object(sys, "argv", test_argv): + with self.assertRaises(ValueError) as cm: + main() + + error_msg = str(cm.exception) + self.assertIn( + "Cannot specify additional CLI arguments when using --config", + error_msg, + ) + finally: + os.unlink(config_file) + + def test_config_rejects_multiple_cli_args(self) -> None: + """Test that --config rejects multiple CLI arguments (not just single ones).""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write("export:\n max_seq_length: 128") config_file = f.name try: test_argv = [ - "export_llm.py", + "script.py", "--config", config_file, - "base.model_class=stories110m", - "backend.xnnpack.enabled=True", + "debug.verbose=True", + "export.output_dir=/tmp", ] with patch.object(sys, "argv", test_argv): - main() - - # Verify export_llama was called with config - mock_export_llama.assert_called_once() - called_config = mock_export_llama.call_args[0][0] - self.assertEqual( - called_config.base.model_class, "stories110m" - ) # Override from CLI. - self.assertEqual( - called_config.model.dtype_override.value, "fp16" - ) # From yaml. - self.assertEqual( - called_config.backend.xnnpack.enabled, - True, # Override from CLI. - ) + with self.assertRaises(ValueError): + main() finally: os.unlink(config_file) diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index a376a89747b..948da50fdd4 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,7 +9,6 @@ #pragma once #include -#include #include #include #include @@ -346,22 +345,20 @@ inline void apply_elementwise_fn( } constexpr auto compute_type = CppTypeToScalarType::value; - if constexpr (should_include_kernel_dtype(op_name, compute_type)) { - const bool all_inputs_compute_dtype = - ((inputs.first->scalar_type() == compute_type) && ...); - - constexpr ScalarType out_specialized_scalar_type = - specialized_output_scalar_type(out_dtypes); - if (all_inputs_compute_dtype && - out.scalar_type() == out_specialized_scalar_type) { - using CTYPE_OUT = - typename ScalarTypeToCppType::type; - dtype_specialized_elementwise_fn_impl< - CTYPE_COMPUTE, - CTYPE_OUT, - support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); - return; - } + const bool all_inputs_compute_dtype = + ((inputs.first->scalar_type() == compute_type) && ...); + + constexpr ScalarType out_specialized_scalar_type = + specialized_output_scalar_type(out_dtypes); + if (all_inputs_compute_dtype && + out.scalar_type() == out_specialized_scalar_type) { + using CTYPE_OUT = + typename ScalarTypeToCppType::type; + dtype_specialized_elementwise_fn_impl< + CTYPE_COMPUTE, + CTYPE_OUT, + support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); + return; } apply_elementwise_fn_generic_impl< diff --git a/kernels/portable/cpu/util/normalization_ops_util.cpp b/kernels/portable/cpu/util/normalization_ops_util.cpp index 4adcf02b303..f7118257898 100644 --- a/kernels/portable/cpu/util/normalization_ops_util.cpp +++ b/kernels/portable/cpu/util/normalization_ops_util.cpp @@ -38,7 +38,7 @@ bool check_batch_norm_args( ET_LOG_AND_RETURN_IF_FALSE( tensors_have_same_dtype(in, running_mean.value())); } - if (running_var.has_value()) { + if (running_mean.has_value()) { ET_LOG_AND_RETURN_IF_FALSE( tensors_have_same_dtype(in, running_var.value())); } diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 44b95aa55c4..1523fcfe706 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -115,10 +115,10 @@ def define_common_targets(): ":vectorized_math", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", "//executorch/runtime/kernel:kernel_runtime_context", - "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/extension/threadpool:threadpool", ], deps = [ + "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/runtime/kernel:kernel_includes", ], visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"], diff --git a/pytest.ini b/pytest.ini index e0f8eafb082..557a307bdf2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -18,7 +18,6 @@ addopts = --ignore=devtools/visualization/visualization_utils_test.py # examples examples/models/llama/tests - examples/models/llama/config examples/models/llama3_2_vision/preprocess examples/models/llama3_2_vision/vision_encoder/test examples/models/llama3_2_vision/text_decoder/test From 91c21d9c632c920f95236a779ad142b0b2590c75 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 26 Jun 2025 10:12:14 -0700 Subject: [PATCH 5/6] Update [ghstack-poisoned] --- .ci/scripts/test_model.sh | 4 +- .github/workflows/android-perf.yml | 2 +- .github/workflows/apple-perf.yml | 2 +- .github/workflows/trunk.yml | 29 + backends/arm/_passes/__init__.py | 2 + backends/arm/_passes/arm_pass_manager.py | 4 + backends/arm/_passes/decompose_atan_pass.py | 119 +++++ .../_passes/decompose_batch_norm_no_stats.py | 219 ++++++++ backends/arm/_passes/insert_table_ops.py | 1 + .../tosa_supported_operators.py | 1 + .../arm/quantizer/quantization_annotator.py | 1 + backends/arm/test/ops/test_atan.py | 84 +++ backends/arm/test/ops/test_batch_norm.py | 55 +- backends/nxp/requirements-tests.txt | 2 +- backends/nxp/run_unittests.sh | 14 + backends/qualcomm/qnn_preprocess.py | 51 +- backends/qualcomm/tests/test_qnn_delegate.py | 62 +++ devtools/inspector/TARGETS | 1 + devtools/inspector/_inspector.py | 41 +- devtools/inspector/_inspector_utils.py | 50 ++ .../_intermediate_output_capturer.py | 25 +- devtools/inspector/tests/inspector_test.py | 26 +- .../inspector/tests/inspector_utils_test.py | 108 ++++ .../deepseek-r1-distill-llama-8B/README.md | 15 +- .../config/deepseek_xnnpack_q8da4w.yaml | 16 + examples/models/llama/README.md | 96 ++-- examples/models/llama/config/llama_bf16.yaml | 7 + .../models/llama/config/llama_q8da4w.yaml | 11 + .../llama/config/llama_xnnpack_qat.yaml | 23 + .../llama/config/llama_xnnpack_spinquant.yaml | 22 + .../models/llama/config/test_llm_config.py | 8 +- examples/models/phi_4_mini/README.md | 16 +- .../phi_4_mini/{ => config}/config.json | 0 .../phi_4_mini/config/phi_4_mini_xnnpack.yaml | 12 + examples/models/qwen2_5/README.md | 17 +- .../qwen2_5/{ => config}/1_5b_config.json | 0 .../config/qwen2_5_xnnpack_q8da4w.yaml | 11 + examples/models/qwen3/README.md | 48 +- .../qwen3/{ => config}/0_6b_config.json | 0 .../qwen3/{ => config}/1_7b_config.json | 0 .../models/qwen3/{ => config}/4b_config.json | 0 .../qwen3/config/qwen3_xnnpack_q8da4w.yaml | 15 + examples/nxp/setup.sh | 2 +- .../qualcomm/qaihub_scripts/utils/export.py | 2 +- examples/qualcomm/util_scripts/README.md | 79 +++ examples/qualcomm/util_scripts/cli.py | 504 ++++++++++++++++++ examples/qualcomm/utils.py | 5 +- exir/passes/TARGETS | 12 + exir/passes/reinplace.py | 103 ++++ exir/tests/TARGETS | 12 + exir/tests/test_reinplace_pass.py | 104 ++++ extension/flat_tensor/serialize/serialize.py | 58 +- extension/flat_tensor/test/test_serialize.py | 40 +- extension/llm/export/README.md | 60 +-- extension/llm/export/export_llm.py | 49 +- extension/llm/export/test/test_export_llm.py | 114 ++-- kernels/portable/cpu/util/elementwise_util.h | 31 +- .../cpu/util/normalization_ops_util.cpp | 2 +- kernels/portable/cpu/util/targets.bzl | 2 +- pytest.ini | 1 + 60 files changed, 2039 insertions(+), 361 deletions(-) create mode 100644 backends/arm/_passes/decompose_atan_pass.py create mode 100644 backends/arm/_passes/decompose_batch_norm_no_stats.py create mode 100644 backends/arm/test/ops/test_atan.py create mode 100755 backends/nxp/run_unittests.sh create mode 100644 examples/models/deepseek-r1-distill-llama-8B/config/deepseek_xnnpack_q8da4w.yaml create mode 100644 examples/models/llama/config/llama_bf16.yaml create mode 100644 examples/models/llama/config/llama_q8da4w.yaml create mode 100644 examples/models/llama/config/llama_xnnpack_qat.yaml create mode 100644 examples/models/llama/config/llama_xnnpack_spinquant.yaml rename examples/models/phi_4_mini/{ => config}/config.json (100%) create mode 100644 examples/models/phi_4_mini/config/phi_4_mini_xnnpack.yaml rename examples/models/qwen2_5/{ => config}/1_5b_config.json (100%) create mode 100644 examples/models/qwen2_5/config/qwen2_5_xnnpack_q8da4w.yaml rename examples/models/qwen3/{ => config}/0_6b_config.json (100%) rename examples/models/qwen3/{ => config}/1_7b_config.json (100%) rename examples/models/qwen3/{ => config}/4b_config.json (100%) create mode 100644 examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml create mode 100644 examples/qualcomm/util_scripts/README.md create mode 100644 examples/qualcomm/util_scripts/cli.py create mode 100644 exir/passes/reinplace.py create mode 100644 exir/tests/test_reinplace_pass.py diff --git a/.ci/scripts/test_model.sh b/.ci/scripts/test_model.sh index bbf879295ae..bc9bbb8bae0 100755 --- a/.ci/scripts/test_model.sh +++ b/.ci/scripts/test_model.sh @@ -102,7 +102,7 @@ test_model() { bash examples/models/llama/install_requirements.sh # Test export_llm script: python3 -m extension.llm.export.export_llm. # Use Llama random checkpoint with Qwen 2.5 1.5b model configuration. - "${PYTHON_EXECUTABLE}" -m extension.llm.export.export_llm base.model_class="${MODEL_NAME}" base.params=examples/models/qwen2_5/1_5b_config.json + "${PYTHON_EXECUTABLE}" -m extension.llm.export.export_llm base.model_class="${MODEL_NAME}" base.params=examples/models/qwen2_5/config/1_5b_config.json rm "./${MODEL_NAME}.pte" return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears. fi @@ -110,7 +110,7 @@ test_model() { # Install requirements for export_llama bash examples/models/llama/install_requirements.sh # Test export_llm script: python3 -m extension.llm.export.export_llm. - "${PYTHON_EXECUTABLE}" -m extension.llm.export.export_llm base.model_class="${MODEL_NAME}" base.params=examples/models/phi_4_mini/config.json + "${PYTHON_EXECUTABLE}" -m extension.llm.export.export_llm base.model_class="${MODEL_NAME}" base.params=examples/models/phi_4_mini/config/config.json run_portable_executor_runner rm "./${MODEL_NAME}.pte" return diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml index a79a900b2d8..2eab69eb88b 100644 --- a/.github/workflows/android-perf.yml +++ b/.github/workflows/android-perf.yml @@ -317,7 +317,7 @@ jobs: DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "." --files "tokenizer.json") python -m extension.llm.export.export_llm \ base.model_class=qwen3_0_6b \ - base.params=examples/models/qwen3/0_6b_config.json \ + base.params=examples/models/qwen3/config/0_6b_config.json \ model.use_kv_cache=true \ model.use_sdpa_with_kv_cache=true \ model.dtype_override=fp32 \ diff --git a/.github/workflows/apple-perf.yml b/.github/workflows/apple-perf.yml index 6b1666da642..3db5abbefbd 100644 --- a/.github/workflows/apple-perf.yml +++ b/.github/workflows/apple-perf.yml @@ -322,7 +322,7 @@ jobs: DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "." --files "tokenizer.json") ${CONDA_RUN} python -m extension.llm.export.export_llm \ base.model_class=qwen3_0_6b \ - base.params=examples/models/qwen3/0_6b_config.json \ + base.params=examples/models/qwen3/config/0_6b_config.json \ model.use_kv_cache=true \ model.use_sdpa_with_kv_cache=true \ model.dtype_override=fp32 \ diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index a4996459f8a..4fe5ec979a3 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -718,3 +718,32 @@ jobs: build-mode: Release build-tool: cmake docker-image: executorch-ubuntu-22.04-clang12 + + unittest-nxp-neutron: + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + runner: linux.2xlarge + docker-image: executorch-ubuntu-22.04-clang12 + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + set -eux + + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + # Build and install Executorch + PYTHON_EXECUTABLE=python \ + CMAKE_ARGS="-DEXECUTORCH_BUILD_NXP_NEUTRON=ON" \ + .ci/scripts/setup-linux.sh --build-tool "cmake" + + # Install test requirements + pip install -r backends/nxp/requirements-tests.txt + + # Run pytest + PYTHON_EXECUTABLE=python bash backends/nxp/run_unittests.sh diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index d3c0ae0a1b3..2a75606cb70 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -22,7 +22,9 @@ from .convert_split_to_slice import ConvertSplitToSlicePass # noqa from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa from .convert_to_clamp import ConvertToClampPass # noqa +from .decompose_atan_pass import DecomposeAtanPass # noqa from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa +from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 2cefd3bdaca..596decd65bb 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -25,7 +25,9 @@ ConvertSplitToSlicePass, ConvertSqueezesToViewPass, ConvertToClampPass, + DecomposeAtanPass, DecomposeAvgPool2d, + DecomposeBatchNormNoStatsPass, DecomposeCosineSimilarityPass, DecomposeDivPass, DecomposeEmbeddingPass, @@ -150,6 +152,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(DecomposeRoundPass()) self.add_pass(DecomposeSqrtPass()) + self.add_pass(DecomposeAtanPass()) self.add_pass(ConvertIntPowToMuls()) self.add_pass(CastBoolToInt8Pass()) self.add_pass(DecomposeSinhPass()) @@ -164,6 +167,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeLeakyReLUPass()) self.add_pass(DecomposeGroupNormPass()) self.add_pass(DecomposeLayerNormPass()) + self.add_pass(DecomposeBatchNormNoStatsPass()) self.add_pass(DecomposeVarPass()) self.add_pass( DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py new file mode 100644 index 00000000000..57b9dde5216 --- /dev/null +++ b/backends/arm/_passes/decompose_atan_pass.py @@ -0,0 +1,119 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from math import pi + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops + + +edge_atan = exir_ops.edge.aten.atan.default # MI case + + +def _get_atan_ops(op): + """Return the primitive ops required..""" + if op is not edge_atan: + raise RuntimeError(f"Can't decompose atan for op {op}") + + return ( + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.mul.Scalar, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.add.Scalar, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.gt.Scalar, + exir_ops.edge.aten.reciprocal.default, + exir_ops.edge.aten.where.self, + exir_ops.edge.aten.neg.default, + ) + + +class DecomposeAtanPass(ArmPass): + """Decomposes the atan operator into a rational (Padé) approximation.""" + + def _rational_approximation(self, z, ops, meta): + """Creates a (2,1) Padé approximation for atan(x) on [-1, 1].""" + + op_mul, op_mul_scalar, op_add, op_add_scalar, _, _, _, op_recip, _, _ = ops + + # Coefficients calculated using minimax on the interval [-1, 1]. + a1 = 0.3529666667 + a2 = -0.0287666667 + b1 = 0.6863 + + z2 = super().call_operator(op_mul, (z, z), {}, meta, updated=True) + z4 = super().call_operator(op_mul, (z2, z2), {}, meta, updated=True) + + num1 = super().call_operator(op_mul_scalar, (z2, a1), {}, meta, updated=True) + num2 = super().call_operator(op_mul_scalar, (z4, a2), {}, meta, updated=True) + num = super().call_operator(op_add_scalar, (num1, 1.0), {}, meta, updated=True) + num = super().call_operator(op_add, (num, num2), {}, meta, updated=True) + + den1 = super().call_operator(op_mul_scalar, (z2, b1), {}, meta, updated=True) + den = super().call_operator(op_add_scalar, (den1, 1.0), {}, meta, updated=True) + + inv_den = super().call_operator(op_recip, (den,), {}, meta, updated=True) + + prod = super().call_operator(op_mul, (num, inv_den), {}, meta, updated=True) + return super().call_operator(op_mul, (z, prod), {}, meta, updated=True) + + def call_operator(self, op, args, kwargs, meta): + if op is not edge_atan: + return super().call_operator(op, args, kwargs, meta, updated=False) + + logging.info( + f"Approximating atan. This may introduce small numerical errors. For details, see {__file__}." + ) + + ops = _get_atan_ops(op) + ( + _, + op_mul_scalar, + _, + op_add_scalar, + op_sub, + op_abs, + op_gt, + op_recip, + op_where, + op_neg, + ) = ops + + x = args[0] + + # |x| > 1 is reduced to [0, 1] using atan(x) = pi/2 - atan(1/x) and atan(-x) = -atan(x). + + abs_x = super().call_operator(op_abs, (x,), {}, meta, updated=True) + mask_hi = super().call_operator(op_gt, (abs_x, 1.0), {}, meta, updated=True) + + inv_x = super().call_operator(op_recip, (abs_x,), {}, meta, updated=True) + z = super().call_operator( + op_where, (mask_hi, inv_x, abs_x), {}, meta, updated=True + ) + + atan_z = self._rational_approximation(z, ops, meta) + + zero_tensor = super().call_operator( + op_mul_scalar, (x, 0.0), {}, meta, updated=True + ) + half_pi_tensor = super().call_operator( + op_add_scalar, (zero_tensor, pi / 2), {}, meta, updated=True + ) + + diff = super().call_operator( + op_sub, (half_pi_tensor, atan_z), {}, meta, updated=True + ) + atan_abs = super().call_operator( + op_where, (mask_hi, diff, atan_z), {}, meta, updated=True + ) + + mask_pos = super().call_operator(op_gt, (x, 0.0), {}, meta, updated=True) + neg_val = super().call_operator(op_neg, (atan_abs,), {}, meta, updated=True) + + return super().call_operator( + op_where, (mask_pos, atan_abs, neg_val), {}, meta, updated=True + ) diff --git a/backends/arm/_passes/decompose_batch_norm_no_stats.py b/backends/arm/_passes/decompose_batch_norm_no_stats.py new file mode 100644 index 00000000000..5fdb8db2d7c --- /dev/null +++ b/backends/arm/_passes/decompose_batch_norm_no_stats.py @@ -0,0 +1,219 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import operator + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult + + +class DecomposeBatchNormNoStatsPass(ArmPass): + """ + Decompose BatchNorm2d(track_running_stats=False) (aten._native_batch_norm_legit_no_training) + into a sequence of elementwise operations: + + # let input = x, rm = running_mean, rv = running_var, eps: float + rm_view = view(rm, weights_shape) + rv_view = view(rv, weights_shape) + centered = sub(x, rm_view) + eps_full = full(eps_shape, eps) + var_eps = add(rv_view, eps_full) + inv_sqrt = rsqrt(var_eps) + normed = mul(centered, inv_sqrt) + weighted = mul(normed, w_view) + biased = add(weighted, b_view) + + Source: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 + bn_ops = ( + exir_ops.edge.aten._native_batch_norm_legit.no_stats, + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + torch.ops.aten._native_batch_norm_legit_no_training.default, + torch.ops.aten.batch_norm.default, + torch.ops.aten.native_batch_norm.default, + ) + + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in bn_ops: + continue + + if node.target in ( + torch.ops.aten.batch_norm.default, + torch.ops.aten.native_batch_norm.default, + ): + # signature: (input, weight, bias, mean, var, training, momentum, eps, cudnn_enabled) + # pos‐arg 5 is training + training = node.kwargs.get("training", False) + if len(node.args) > 5: + training = node.args[5] + if training: + # skip training‐mode batchnorm + continue + + # Extract args + args = node.args + meta = node.meta + + # Default eps + eps: float = torch.finfo().eps + # weight and bias may be None + x = args[0] + weight = args[1] if len(args) > 1 else None + bias = args[2] if len(args) > 2 else None + running_mean = args[3] + running_var = args[4] + if len(args) > 6: + eps = args[6] + + # Determine shapes + val = meta.get("val") + ref_tensor = val[0] if isinstance(val, tuple) else val + shape = tuple(ref_tensor.size()) + dtype = ref_tensor.dtype + rank = len(shape) + + # channel dimension is 1 for BatchNorm2d + channel_axis = 1 + weights_shape = [1] * rank + weights_shape[channel_axis] = shape[channel_axis] + num_features = shape[channel_axis] + + # Ops to use + sub_op = exir_ops.edge.aten.sub.Tensor + view_op = exir_ops.edge.aten.view_copy.default + full_op = exir_ops.edge.aten.full.default + add_op = exir_ops.edge.aten.add.Tensor + rsqrt_op = exir_ops.edge.aten.rsqrt.default + mul_op = exir_ops.edge.aten.mul.Tensor + + # Begin decomposition + with graph_module.graph.inserting_before(node): + # reshape running stats + rm_view = create_node( + graph_module.graph, + view_op, + args=(running_mean, weights_shape), + from_node=node, + ) + rv_view = create_node( + graph_module.graph, + view_op, + args=(running_var, weights_shape), + from_node=node, + ) + # center input + centered = create_node( + graph_module.graph, + sub_op, + args=(x, rm_view), + from_node=node, + ) + # epsilon tensor + eps_shape = [1] * rank + eps_full = create_node( + graph_module.graph, + full_op, + args=(eps_shape, eps), + kwargs={"dtype": dtype}, + from_node=node, + ) + # var + eps + var_eps = create_node( + graph_module.graph, + add_op, + args=(rv_view, eps_full), + from_node=node, + ) + # inverse sqrt + inv_sqrt = create_node( + graph_module.graph, + rsqrt_op, + args=(var_eps,), + from_node=node, + ) + # normalized + normed = create_node( + graph_module.graph, + mul_op, + args=(centered, inv_sqrt), + from_node=node, + ) + + # weight + if weight is None: + one = create_node( + graph_module.graph, + full_op, + args=([num_features], 1), + kwargs={"dtype": dtype}, + from_node=node, + ) + w_view = create_node( + graph_module.graph, + view_op, + args=(one, weights_shape), + from_node=node, + ) + else: + w_view = create_node( + graph_module.graph, + view_op, + args=(weight, weights_shape), + from_node=node, + ) + weighted = create_node( + graph_module.graph, + mul_op, + args=(normed, w_view), + from_node=node, + ) + + # bias + if bias is None: + zero = create_node( + graph_module.graph, + full_op, + args=([num_features], 0), + kwargs={"dtype": dtype}, + from_node=node, + ) + b_view = create_node( + graph_module.graph, + view_op, + args=(zero, weights_shape), + from_node=node, + ) + else: + b_view = create_node( + graph_module.graph, + view_op, + args=(bias, weights_shape), + from_node=node, + ) + final_out = create_node( + graph_module.graph, + add_op, + args=(weighted, b_view), + from_node=node, + ) + + users = [u for u in node.users if u is not node] + node.replace_all_uses_with(final_out) + for u in users: + if u.target == operator.getitem: + u.replace_all_uses_with(final_out) + graph_module.graph.erase_node(node) + graph_module.graph.eliminate_dead_code() + + graph_module.recompile() + new_gm = super().call(graph_module).graph_module + return PassResult(new_gm, True) diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index c579fcb0301..b31b6c7106d 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -51,6 +51,7 @@ class TableOps: exir_ops.edge.aten.cos.default: torch.cos, exir_ops.edge.aten.sin.default: torch.sin, exir_ops.edge.aten.tanh.default: torch.tanh, + exir_ops.edge.aten.atan.default: torch.atan, exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid, exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish, exir_ops.edge.aten.sinh.default: torch.sinh, diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 639df536109..cdb27b7c31e 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -244,6 +244,7 @@ def is_node_supported( exir_ops.edge.aten.gelu.default, exir_ops.edge.aten.alias_copy.default, exir_ops.edge.aten.sinh.default, + exir_ops.edge.aten.atan.default, ] return supported diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index c6415c63777..2c61aea60c3 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -214,6 +214,7 @@ def _match_pattern( torch.ops.aten.pow.Tensor_Scalar, torch.ops.aten.gelu.default, torch.ops.aten.sinh.default, + torch.ops.aten.atan.default, ] _one_to_one_shared_input_qspec = [ diff --git a/backends/arm/test/ops/test_atan.py b/backends/arm/test/ops/test_atan.py new file mode 100644 index 00000000000..3d6f8cd8fa8 --- /dev/null +++ b/backends/arm/test/ops/test_atan.py @@ -0,0 +1,84 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.atan.default" +exir_op = "executorch_exir_dialects_edge__ops_aten__atan_default" + +input_t1 = Tuple[torch.Tensor] + +test_data_suite = { + "zeros": torch.zeros(1, 10, 10, 10), + "zeros_alt_shape": torch.zeros(1, 10, 3, 5), + "ones": torch.ones(10, 10, 10), + "rand": torch.rand(10, 10) - 0.5, + "rand_alt_shape": torch.rand(1, 10, 3, 5) - 0.5, + "randn_pos": torch.randn(10) + 10, + "randn_neg": torch.randn(10) - 10, + "ramp": torch.arange(-16, 16, 0.2), +} + + +class Atan(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return torch.atan(x) + + +@common.parametrize("test_data", test_data_suite) +def test_atan_tosa_MI(test_data: Tuple): + pipeline = TosaPipelineMI[input_t1]( + Atan(), + (test_data,), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_atan_tosa_BI(test_data: Tuple): + pipeline = TosaPipelineBI[input_t1]( + Atan(), + (test_data,), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_suite) +def test_atan_u55_BI(test_data: Tuple): + pipeline = EthosU55PipelineBI[input_t1]( + Atan(), + (test_data,), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_suite) +def test_atan_u85_BI(test_data: Tuple): + pipeline = EthosU85PipelineBI[input_t1]( + Atan(), + (test_data,), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_batch_norm.py b/backends/arm/test/ops/test_batch_norm.py index 7f98a48b203..eb0d4306e6e 100644 --- a/backends/arm/test/ops/test_batch_norm.py +++ b/backends/arm/test/ops/test_batch_norm.py @@ -224,6 +224,8 @@ class BatchNorm2dNoStats(torch.nn.Module): Decomposes into _native_batch_norm_legit.no_stats """ + aten_ops = ["torch.ops.aten.batch_norm.default"] + def __init__( self, num_features: int, @@ -250,29 +252,60 @@ def forward(self, x): return self.batch_norm_2d(x) -@pytest.mark.skip( - reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats." -) -def test_native_batch_norm_legit_no_stats_tosa_MI(): - pass +@common.parametrize("test_data", test_data_suite) +def test_native_batch_norm_legit_no_stats_tosa_MI(test_data: Tuple): + test_data, model_params = test_data() + pipeline = TosaPipelineMI[input_t1]( + BatchNorm2dNoStats(*model_params), + (test_data,), + aten_op=BatchNorm2dNoStats.aten_ops, + ) + pipeline.run() @pytest.mark.skip( reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats." ) -def test_native_batch_norm_legit_no_stats_tosa_BI(): - pass +def test_native_batch_norm_legit_no_stats_tosa_BI(test_data: Tuple): + test_data, model_params = test_data() + pipeline = TosaPipelineBI[input_t1]( + BatchNorm2dNoStats(*model_params), + (test_data,), + aten_op=BatchNorm2dNoStats.aten_ops, + qtol=1, + ) + pipeline.run() @pytest.mark.skip( reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats." ) -def test_native_batch_norm_legit_no_stats_u55_BI(): - pass +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_native_batch_norm_legit_no_stats_u55_BI(test_data: Tuple): + test_data, model_params = test_data() + pipeline = EthosU55PipelineBI[input_t1]( + BatchNorm2dNoStats(*model_params), + (test_data,), + aten_op=BatchNorm2dNoStats.aten_ops, + run_on_fvp=True, + qtol=1, + ) + pipeline.run() @pytest.mark.skip( reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats." ) -def test_native_batch_norm_legit_no_stats_u85_BI(): - pass +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_native_batch_norm_legit_no_stats_u85_BI(test_data: Tuple): + test_data, model_params = test_data() + pipeline = EthosU85PipelineBI[input_t1]( + BatchNorm2dNoStats(*model_params), + (test_data,), + aten_op=BatchNorm2dNoStats.aten_ops, + run_on_fvp=False, + qtol=1, + ) + pipeline.run() diff --git a/backends/nxp/requirements-tests.txt b/backends/nxp/requirements-tests.txt index 513ccefe848..ea6d56a43ec 100644 --- a/backends/nxp/requirements-tests.txt +++ b/backends/nxp/requirements-tests.txt @@ -3,4 +3,4 @@ tensorflow==2.18.0 pytest-mock tflite GvGen -neutron-converter_SDK_25_03 +neutron_converter_SDK_25_03 diff --git a/backends/nxp/run_unittests.sh b/backends/nxp/run_unittests.sh new file mode 100755 index 00000000000..dde10065743 --- /dev/null +++ b/backends/nxp/run_unittests.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +set -eux + +SCRIPT_DIR=$(dirname $(readlink -fm $0)) +EXECUTORCH_DIR=$(dirname $(dirname $SCRIPT_DIR)) + +cd $EXECUTORCH_DIR + +# '-c /dev/null' is used to ignore root level pytest.ini. +PYTHONPATH=`cd ..; pwd` pytest -c /dev/null backends/nxp/tests/ diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index 21b16a29c58..034b75fa6d0 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -78,10 +78,7 @@ def _build_op_wrappers( ) assert node.target == context_loader_target, err_msg # if graph has context binary loader node, return directly - return PreprocessResult( - processed_bytes=node.meta[OpContextLoader.meta_ctx_bin], - debug_handle_map={}, - ) + return node.meta[OpContextLoader.meta_ctx_bin] except: raise RuntimeError(err_msg) @@ -161,7 +158,7 @@ def preprocess_multimethod( generate_qnn_executorch_option(compile_spec) ) qnn_manager.Init() - py_op_wrapper_list = [] + py_op_wrapper_list, ctx_binary_list = [], [] for j, programs in enumerate(edge_programs.values()): logger.info(f"Processing Method({j}): ({i+1}/{num_sub_graphs})") py_op_wrappers = QnnBackend._build_op_wrappers( @@ -169,22 +166,36 @@ def preprocess_multimethod( qnn_manager.IsTensorDump(), option.op_package_options.op_package_infos, ) - py_op_wrapper_list.append( - [py_op_wrapper.GetOpWrapper() for py_op_wrapper in py_op_wrappers] - ) + if isinstance(py_op_wrappers, bytes): + ctx_binary_list.append(py_op_wrappers) + else: + py_op_wrapper_list.append( + [ + py_op_wrapper.GetOpWrapper() + for py_op_wrapper in py_op_wrappers + ] + ) - qnn_context_binary = qnn_manager.Compile(graph_name, py_op_wrapper_list) - assert ( - len(qnn_context_binary) != 0 - ), "Failed to generate Qnn context binary." - qnn_manager.Destroy() - # methods should share the same context binary for current partition - for key in edge_programs.keys(): - all_processed_results[key].append( - PreprocessResult( - processed_bytes=bytes(qnn_context_binary), - debug_handle_map={}, + if len(py_op_wrapper_list) == len(edge_programs.values()): + qnn_context_binary = qnn_manager.Compile(graph_name, py_op_wrapper_list) + assert ( + len(qnn_context_binary) != 0 + ), "Failed to generate Qnn context binary." + qnn_manager.Destroy() + # methods should share the same context binary for current partition + for key in edge_programs.keys(): + all_processed_results[key].append( + PreprocessResult( + processed_bytes=bytes(qnn_context_binary), + debug_handle_map={}, + ) ) - ) + elif len(ctx_binary_list) == len(edge_programs.values()): + for i, key in enumerate(edge_programs.keys()): + all_processed_results[key].append( + PreprocessResult(processed_bytes=ctx_binary_list[i]) + ) + else: + raise RuntimeError("Hybrid compilation is not supported") return all_processed_results diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 747a6804957..7163ce88c27 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -5622,6 +5622,68 @@ def test_debugger_generate_optrace(self): qhas_data = json.load(qhas_file) self.assertIn("data", qhas_data) + def test_cli(self): + with tempfile.TemporaryDirectory() as tmp_dir: + sample_input = torch.randn(1, 2, 3, 4) + ep = torch.export.export(Relu(), (sample_input,)) # noqa: F405 + torch.export.save(ep, f"{tmp_dir}/relu.pt2") + torch.save(sample_input, f"{tmp_dir}/input_0_0.pt") + with open(f"{tmp_dir}/input_list", "w") as f: + f.write(f"{tmp_dir}/input_0_0.pt\n") + + # quantize + cmds = [ + "python", + "-m", + "examples.qualcomm.util_scripts.cli", + "quantize", + "--artifact", + f"{tmp_dir}/relu.pt2", + "--output_folder", + f"{tmp_dir}/q_out", + "--input_list", + f"{tmp_dir}/input_list", + ] + subprocess.run(cmds, stdout=subprocess.DEVNULL) + self.assertTrue(os.path.isfile(f"{tmp_dir}/q_out/relu_quantized.pt2")) + # compile + cmds = [ + "python", + "-m", + "examples.qualcomm.util_scripts.cli", + "compile", + "--artifact", + f"{tmp_dir}/q_out/relu_quantized.pt2", + "--output_folder", + f"{tmp_dir}/c_out", + "--model", + self.model, + ] + subprocess.run(cmds, stdout=subprocess.DEVNULL) + self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/relu_quantized.pte")) + self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/relu_quantized.svg")) + # execute + cmds = [ + "python", + "-m", + "examples.qualcomm.util_scripts.cli", + "execute", + "--artifact", + f"{tmp_dir}/c_out/relu_quantized.pte", + "--output_folder", + f"{tmp_dir}/e_out", + "--model", + self.model, + "--device", + self.device, + "--build_folder", + self.build_folder, + "--input_list", + f"{tmp_dir}/input_list", + ] + subprocess.run(cmds, stdout=subprocess.DEVNULL) + self.assertTrue(os.path.isfile(f"{tmp_dir}/e_out/output_0_0.pt")) + def setup_environment(): parser = setup_common_args_and_variables() diff --git a/devtools/inspector/TARGETS b/devtools/inspector/TARGETS index 0712bdf1f9a..d32698f784f 100644 --- a/devtools/inspector/TARGETS +++ b/devtools/inspector/TARGETS @@ -56,6 +56,7 @@ python_library( "_intermediate_output_capturer.py", ], deps = [ + "//executorch/devtools/inspector:inspector_utils", ], ) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index dfff3d0818e..a209da8adb7 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -52,6 +52,7 @@ FORWARD, gen_etdump_object, gen_graphs_from_etrecord, + get_aot_debug_handle_to_op_name_mapping, inflate_runtime_output, is_debug_output, is_inference_output_equal, @@ -1084,6 +1085,7 @@ def __init__( self._reference_outputs: Dict[str, List[ProgramOutput]] = {} self._enable_module_hierarchy = enable_module_hierarchy self._aot_intermediate_outputs: Optional[Dict[Tuple[int, ...], Any]] = None + self._aot_debug_handles_to_op_names: Optional[Dict[Tuple[int, ...], str]] = None self._consume_etrecord() def _consume_etrecord(self) -> None: @@ -1150,18 +1152,24 @@ def _consume_etrecord(self) -> None: return export_program = self._etrecord.edge_dialect_program graph_module = export_program.module() + self._aot_debug_handles_to_op_names = get_aot_debug_handle_to_op_name_mapping( + graph_module + ) capturer = IntermediateOutputCapturer(graph_module) self._aot_intermediate_outputs = capturer.run_and_capture( self._etrecord._representative_inputs ) # TODO: Make it more extensible to further merge overlapping debug handles - def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]: + def _get_runtime_intermediate_outputs_and_op_names( + self, + ) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]: """ - Retrieve the raw runtime intermediate outputs(debug handles and value mappings) - from the event blocks. These outputs will be processed later to merge overlapping debug handles. + Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings) + from the event blocks, along with the corresponding debug handles and op names mapping. """ debug_handle_to_output = {} + debug_handle_to_op_name = {} for event_block in self.event_blocks: for event in event_block.events: # Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax @@ -1170,20 +1178,23 @@ def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]: or not event.op_types ): continue - # Normalize debug_handles to a tuple - debug_handles = event.debug_handles - if isinstance(debug_handles, int): - debug_handles = (debug_handles,) + # Normalize debug_handle to a tuple + debug_handle = event.debug_handles + if isinstance(debug_handle, int): + debug_handle = (debug_handle,) else: - debug_handles = tuple(debug_handles) - current_entry = debug_handle_to_output.get(debug_handles, (-1, None)) - # When event has same debug handles, only keep the one with the largest instruction id + debug_handle = tuple(debug_handle) + current_entry = debug_handle_to_output.get(debug_handle, (-1, None)) + # When event has same debug_handle, only keep the one with the largest instruction id if event._instruction_id > current_entry[0]: - debug_handle_to_output[debug_handles] = ( + debug_handle_to_output[debug_handle] = ( event._instruction_id, event.debug_data, ) - return {k: v[1] for k, v in debug_handle_to_output.items()} + debug_handle_to_op_name[debug_handle] = event.name + return { + k: v[1] for k, v in debug_handle_to_output.items() + }, debug_handle_to_op_name def to_dataframe( self, @@ -1359,8 +1370,12 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame: raise ValueError( "The aot intermediate outputs is required but not populated." ) + # The runtime_op_names will be used later to map runtime debug_handle to op_name + runtime_intermediate_outputs, runtime_op_names = ( + self._get_runtime_intermediate_outputs_and_op_names() + ) mapping = map_runtime_aot_intermediate_outputs( - self._aot_intermediate_outputs, self._get_runtime_intermediate_outputs() + self._aot_intermediate_outputs, runtime_intermediate_outputs ) metric = distance.strip().upper() if metric == "MSE": diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 61e2ea4d031..50b3669309c 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -93,6 +93,28 @@ class NodeData: output: Any +class NodeFilter: + """ + A class used to filter nodes based on extensible criteria. + Attributes: + metadata_key (str): The key to look for in the node's metadata. + op_type (str): The operation code to match. + exclude_ops (List[str]): A list of operations to exclude from the filter. + """ + + def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None): + self.metadata_key = metadata_key + self.op_type = op_type + self.exclude_ops = exclude_ops + + def matches(self, node: torch.fx.Node) -> bool: + return ( + node.meta.get(self.metadata_key) is not None + and node.op == self.op_type + and all(exclude_name not in node.name for exclude_name in self.exclude_ops) + ) + + def calculate_time_scale_factor( source_time_scale: TimeScale, target_time_scale: TimeScale ) -> float: @@ -734,3 +756,31 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor: if torch.isnan(input_tensor).any(): input_tensor = torch.nan_to_num(input_tensor) return input_tensor + + +def get_aot_debug_handle_to_op_name_mapping( + graph_module: torch.fx.GraphModule, +) -> Dict[Tuple[int, ...], str]: + """ + Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module. + Parameters: + graph_module (torch.fx.GraphModule): The graph module to get the mapping from. + Returns: + Dict[Tuple[int, ...], str]: A dictionary mapping debug handles to operator names. + """ + node_filters = [ + NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"]) + ] + + debug_handle_to_op_name = {} + for node in graph_module.graph.nodes: + if all(filter.matches(node) for filter in node_filters): + debug_handle = node.meta["debug_handle"] + # Convert the debug handle to a tuple to use as a dictionary key + key = ( + (debug_handle,) + if isinstance(debug_handle, int) + else tuple(debug_handle) + ) + debug_handle_to_op_name[key] = node.name + return debug_handle_to_op_name diff --git a/devtools/inspector/_intermediate_output_capturer.py b/devtools/inspector/_intermediate_output_capturer.py index c1f943bd02c..054c97dc245 100644 --- a/devtools/inspector/_intermediate_output_capturer.py +++ b/devtools/inspector/_intermediate_output_capturer.py @@ -7,35 +7,14 @@ # pyre-unsafe -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, Tuple import torch +from executorch.devtools.inspector._inspector_utils import NodeFilter from torch.fx import GraphModule from torch.fx.interpreter import Interpreter -class NodeFilter: - """ - A class used to filter nodes based on extensible criteria. - Attributes: - metadata_key (str): The key to look for in the node's metadata. - op_type (str): The operation code to match. - exclude_ops (List[str]): A list of operations to exclude from the filter. - """ - - def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None): - self.metadata_key = metadata_key - self.op_type = op_type - self.exclude_ops = exclude_ops - - def matches(self, node: torch.fx.Node) -> bool: - return ( - node.meta.get(self.metadata_key) is not None - and node.op == self.op_type - and all(exclude_name not in node.name for exclude_name in self.exclude_ops) - ) - - class IntermediateOutputCapturer(Interpreter): """ A class that captures intermediate outputs from a PyTorch graph module. diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 1460dbd46a2..df434fd675d 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -537,7 +537,7 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self): ) ) - def test_get_runtime_intermediate_outputs(self): + def test_get_runtime_intermediate_outputs_and_op_names(self): # Create a context manager to patch functions called by Inspector.__init__ with patch.object( _inspector, "parse_etrecord", return_value=None @@ -560,25 +560,39 @@ def test_get_runtime_intermediate_outputs(self): EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events()) ] - runtime_outputs = inspector_instance._get_runtime_intermediate_outputs() - # This output should be a dictionary with 5 keys + runtime_outputs, op_names = ( + inspector_instance._get_runtime_intermediate_outputs_and_op_names() + ) + # These outputs and op_names dictionaries should all have 5 keys self.assertEqual( len(runtime_outputs), 5, ) - # Check that keys (0,) and (1,) are not in the dictionary(skip OPERATOR_CALL and op_types are empty) + self.assertEqual( + len(op_names), + 5, + ) + + # Check that keys (0,) and (1,) are not in these two dictionaries(skip OPERATOR_CALL and op_types are empty) self.assertNotIn((0,), runtime_outputs) self.assertNotIn((1,), runtime_outputs) + self.assertNotIn((0,), op_names) + self.assertNotIn((1,), op_names) # Same debug_handle but different instruction_id, should record the last one self.assertIn((4,), runtime_outputs) + self.assertIn((4,), op_names) self.assertTrue( torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0])) ) + self.assertEqual(op_names[(4,)], "op_3") + # Check that keys (5,) to (8,) are in the dictionary and have values of the correct size for key in range(5, 9): self.assertIn((key,), runtime_outputs) + self.assertIn((key,), op_names) self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE) + self.assertEqual(op_names[(key,)], f"op_{key-1}") def test_calculate_numeric_gap(self): # Create a context manager to patch functions called by Inspector.__init__ @@ -608,8 +622,8 @@ def test_calculate_numeric_gap(self): } inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs - inspector_instance._get_runtime_intermediate_outputs = ( - lambda: runtime_intermediate_outputs + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, {}) ) df = inspector_instance.calculate_numeric_gap(distance="L1") diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 8148d2c36f0..6d12cb13c5f 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -34,9 +34,11 @@ EDGE_DIALECT_GRAPH_KEY, find_populated_event, gen_graphs_from_etrecord, + get_aot_debug_handle_to_op_name_mapping, is_inference_output_equal, map_runtime_aot_intermediate_outputs, merge_overlapping_debug_handles, + NodeFilter, TimeScale, ) @@ -364,6 +366,112 @@ class X: msg = str(cm.exception) self.assertIn("Cannot convert value of type", msg) + def test_get_aot_debug_handle_to_op_name_mapping_single_debug_handle(self): + # Create a simple graph module with one node + graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) + node = graph_module.graph.create_node( + "call_function", target=torch.mul, args=(), kwargs={}, name="op1" + ) + node.meta["debug_handle"] = 1 + debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) + expected_result = {(1,): "op1"} + self.assertEqual(debug_handle_to_op_name, expected_result) + + def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self): + # Create a simple graph module with two nodes + graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) + node1 = graph_module.graph.create_node( + "call_function", target=torch.mul, args=(), kwargs={}, name="op1" + ) + node1.meta["debug_handle"] = (1, 2) + node2 = graph_module.graph.create_node( + "call_function", target=torch.mul, args=(), kwargs={}, name="op2" + ) + node2.meta["debug_handle"] = 3 + debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) + expected_result = { + ( + 1, + 2, + ): "op1", + (3,): "op2", + } + self.assertEqual(debug_handle_to_op_name, expected_result) + + def test_get_aot_debug_handle_to_op_name_mapping_no_debug_handles(self): + # Create a simple graph module with no nodes + graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) + debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) + expected_result = {} + self.assertEqual(debug_handle_to_op_name, expected_result) + + def test_node_filter_match(self): + node_filter = NodeFilter( + "debug_handle", "call_function", exclude_ops=["getitem"] + ) + + # Create a mock node that matches the filter criteria + mock_node = torch.fx.Node( + graph=torch.fx.Graph(), + name="mock_node", + op="call_function", + target=torch.nn.functional.relu, + args=(), + kwargs={}, + ) + mock_node.meta["debug_handle"] = (1, 2) + # Test that the filter matches the mock node + self.assertTrue(node_filter.matches(mock_node)) + + def test_node_filter_key_mismatch(self): + node_filter = NodeFilter( + "debug_handle", "call_function", exclude_ops=["getitem"] + ) + mock_node_metadata_key_mismatch = torch.fx.Node( + graph=torch.fx.Graph(), + name="mock_node_metadata_key_mismatch", + op="call_function", + target=torch.nn.functional.relu, + args=(), + kwargs={}, + ) + # Test that the filter doesn't match the mock node (meta doesn't have debug_handle key) + self.assertFalse(node_filter.matches(mock_node_metadata_key_mismatch)) + + def test_node_filter_ops_mismatch(self): + node_filter = NodeFilter( + "debug_handle", "call_function", exclude_ops=["getitem"] + ) + + mock_node_exclude_ops_mismatch = torch.fx.Node( + graph=torch.fx.Graph(), + name="getitem", + op="call_function", + target=torch.nn.functional.relu, + args=(), + kwargs={}, + ) + mock_node_exclude_ops_mismatch.meta["debug_handle"] = (1, 2) + # Test that the filter doesn't match the mock node (exclude_ops mismatch) + self.assertFalse(node_filter.matches(mock_node_exclude_ops_mismatch)) + + def test_node_op_type_mismatch(self): + node_filter = NodeFilter( + "debug_handle", "call_function", exclude_ops=["getitem"] + ) + + mock_node_op_type_mismatch = torch.fx.Node( + graph=torch.fx.Graph(), + name="mock_node_op_type_mismatch", + op="get_attr", + target="torch.nn.functional.relu", + args=(), + kwargs={}, + ) + mock_node_op_type_mismatch.meta["debug_handle"] = (1, 2) + # Test that the filter doesn't match the mock node (op_type mismatch) + self.assertFalse(node_filter.matches(mock_node_op_type_mismatch)) + def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]] diff --git a/examples/models/deepseek-r1-distill-llama-8B/README.md b/examples/models/deepseek-r1-distill-llama-8B/README.md index f05dd9990a2..00397e9f60f 100644 --- a/examples/models/deepseek-r1-distill-llama-8B/README.md +++ b/examples/models/deepseek-r1-distill-llama-8B/README.md @@ -53,17 +53,10 @@ torch.save(sd, "/tmp/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/checkpoint.pth") 5. Generate a PTE file for use with the Llama runner. ``` python -m extension.llm.export.export_llm \ - base.checkpoint=/tmp/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/checkpoint.pth \ - base.params=params.json \ - model.use_kv_cache=True \ - model.use_sdpa_with_kv_cache=True \ - backend.xnnpack.enabled=True \ - quantization.qmode="8da4w" \ - quantization.group_size=128 \ - model.dtype_override="fp16" \ - base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' \ - quantization.embedding_quantize=\'4,32\' \ - export.output_name="DeepSeek-R1-Distill-Llama-8B.pte" + --config examples/models/deepseek-r1-distill-llama-8B/config/deepseek-r1-distill-llama-8B + +base.checkpoint=/tmp/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/checkpoint.pth \ + +base.params=params.json \ + +export.output_name="DeepSeek-R1-Distill-Llama-8B.pte" ``` 6. Run the model on your desktop for validation or integrate with iOS/Android apps. Instructions for these are available in the Llama [README](../llama/README.md) starting at Step 3. diff --git a/examples/models/deepseek-r1-distill-llama-8B/config/deepseek_xnnpack_q8da4w.yaml b/examples/models/deepseek-r1-distill-llama-8B/config/deepseek_xnnpack_q8da4w.yaml new file mode 100644 index 00000000000..1da7c253d92 --- /dev/null +++ b/examples/models/deepseek-r1-distill-llama-8B/config/deepseek_xnnpack_q8da4w.yaml @@ -0,0 +1,16 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: True + dtype_override: fp16 + +backend: + xnnpack: + enabled: True + +quantization: + qmode: 8da4w + group_size: 128 + embedding_quantize: 4,32 diff --git a/examples/models/llama/README.md b/examples/models/llama/README.md index 3e6869e5c49..bbd2107ad74 100644 --- a/examples/models/llama/README.md +++ b/examples/models/llama/README.md @@ -168,14 +168,10 @@ LLAMA_CHECKPOINT=path/to/consolidated.00.pth LLAMA_PARAMS=path/to/params.json python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${LLAMA_CHECKPOINT:?}" \ - base.params="${LLAMA_PARAMS:?}" \ - model.use_kv_cache=True \ - model.use_sdpa_with_kv_cache=True \ - model.dtype_override="bf16" \ - base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' \ - export.output_name="llama3_2.pte" + --config examples/models/llamaconfig/llama_bf16.yaml + +base.model_class="llama3_2" \ + +base.checkpoint="${LLAMA_CHECKPOINT:?}" \ + +base.params="${LLAMA_PARAMS:?}" \ ``` For convenience, an [exported ExecuTorch bf16 model](https://huggingface.co/executorch-community/Llama-3.2-1B-ET/blob/main/llama3_2-1B.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-ET/blob/main/ExportRecipe_1B.ipynb). @@ -190,22 +186,10 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/consolidated.00.pth.pth LLAMA_PARAMS=path/to/spinquant/params.json python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${LLAMA_QUANTIZED_CHECKPOINT:?}" \ - base.params="${LLAMA_PARAMS:?}" \ - model.use_sdpa_with_kv_cache=True \ - backend.xnnpack.enabled=True \ - backend.xnnpack.extended_ops=True \ - base.preq_mode="preq_8da4w_out_8da8w" \ - base.preq_group_size=32 \ - export.max_seq_length=2048 \ - export.max_context_length=2048 \ - export.output_name="llama3_2.pte" \ - model.use_kv_cache=True \ - model.dtype_override="fp32" \ - base.preq_embedding_quantize=\'8,0\' \ - quantization.use_spin_quant="native" \ - base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' + --config examples/models/llama/config/llama_xnnpack_spinquant.yaml + +base.model_class="llama3_2" \ + +base.checkpoint="${LLAMA_QUANTIZED_CHECKPOINT:?}" \ + +base.params="${LLAMA_PARAMS:?}" ``` For convenience, an [exported ExecuTorch SpinQuant model](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8-ET/blob/main/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8-ET/blob/main/Export_Recipe_Llama_3_2_1B_Instruct_SpinQuant_INT4_EO8.ipynb). @@ -219,23 +203,10 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/consolidated.00.pth.pth LLAMA_PARAMS=path/to/qlora/params.json python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${LLAMA_QUANTIZED_CHECKPOINT:?}" \ - base.params="${LLAMA_PARAMS:?}" \ - quantization.use_qat=True \ - base.use_lora=16 \ - base.preq_mode="preq_8da4w_out_8da8w" \ - base.preq_group_size=32 \ - base.preq_embedding_quantize=\'8,0\' \ - model.use_sdpa_with_kv_cache=True \ - model.use_kv_cache=True \ - backend.xnnpack.enabled=True \ - backend.xnnpack.extended_ops=True \ - model.dtype_override="fp32" \ - export.max_seq_length=2048 \ - export.max_context_length=2048 \ - export.output_name="llama3_2.pte" \ - base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' + --config examples/models/llama/config/llama_xnnpack_qat.yaml + +base.model_class="llama3_2" \ + +base.checkpoint="${LLAMA_QUANTIZED_CHECKPOINT:?}" \ + +base.params="${LLAMA_PARAMS:?}" \ ``` For convenience, an [exported ExecuTorch QAT+LoRA model](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-QLORA_INT4_EO8-ET/blob/main/Llama-3.2-1B-Instruct-QLORA_INT4_EO8.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-QLORA_INT4_EO8-ET/blob/main/Export_Recipe_Llama_3_2_1B_Instruct_QLORA_INT4_EO8.ipynb). @@ -246,20 +217,13 @@ You can export and run the original Llama 3 8B instruct model. 1. Llama 3 pretrained parameters can be downloaded from [Meta's official Llama 3 repository](https://github.com/meta-llama/llama3/). 2. Export model and generate `.pte` file - ``` - python -m extension.llm.export.export_llm \ - base.checkpoint= \ - base.params= \ - model.use_kv_cache=True \ - model.use_sdpa_with_kv_cache=True \ - backend.xnnpack.enabled=True \ - quantization.qmode="8da4w" \ - quantization.group_size=128 \ - model.dtype_override="fp32" \ - base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' \ - quantization.embedding_quantize=\'4,32\' \ - export.output_name="llama3_kv_sdpa_xnn_qe_4_32.pte" - ``` +``` +python -m extension.llm.export.export_llm \ + --config examples/models/llama/config/llama_q8da4w.yaml + +base.model_clas="llama3" + +base.checkpoint= \ + +base.params= +``` Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `quantization.embedding_quantize=\'4,32\'` as shown above to further reduce the model size. @@ -276,20 +240,20 @@ You can export and run the original Llama 3 8B instruct model. Note for Mac users: There's a known linking issue with Xcode 15.1. Refer to the section of Common Issues and Mitigations below for solutions. 2. Build llama runner. - ``` - cmake -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DBUILD_TESTING=OFF \ - -DCMAKE_BUILD_TYPE=Release \ - -Bcmake-out/examples/models/llama \ - examples/models/llama +``` +cmake -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DBUILD_TESTING=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -Bcmake-out/examples/models/llama \ + examples/models/llama - cmake --build cmake-out/examples/models/llama -j16 --config Release - ``` +cmake --build cmake-out/examples/models/llama -j16 --config Release +``` 3. Run model. Run options available [here](https://github.com/pytorch/executorch/blob/main/examples/models/llama/main.cpp#L18-L40). - ``` - cmake-out/examples/models/llama/llama_main --model_path= --tokenizer_path= --prompt= - ``` +``` +cmake-out/examples/models/llama/llama_main --model_path= --tokenizer_path= --prompt= +``` To build for CoreML backend and validate on Mac, replace `-DEXECUTORCH_BUILD_XNNPACK=ON` with `-DEXECUTORCH_BUILD_COREML=ON` diff --git a/examples/models/llama/config/llama_bf16.yaml b/examples/models/llama/config/llama_bf16.yaml new file mode 100644 index 00000000000..8e89e8aa437 --- /dev/null +++ b/examples/models/llama/config/llama_bf16.yaml @@ -0,0 +1,7 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: True + dtype_override: bf16 \ No newline at end of file diff --git a/examples/models/llama/config/llama_q8da4w.yaml b/examples/models/llama/config/llama_q8da4w.yaml new file mode 100644 index 00000000000..476ae928c60 --- /dev/null +++ b/examples/models/llama/config/llama_q8da4w.yaml @@ -0,0 +1,11 @@ +base: + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + dtype_override: fp32 + +quantization: + qmode: 8da4w + group_size: 128 + embedding_quantize: 4,32 + \ No newline at end of file diff --git a/examples/models/llama/config/llama_xnnpack_qat.yaml b/examples/models/llama/config/llama_xnnpack_qat.yaml new file mode 100644 index 00000000000..2369ff1d279 --- /dev/null +++ b/examples/models/llama/config/llama_xnnpack_qat.yaml @@ -0,0 +1,23 @@ +base: + preq_mode: preq_8da4w_out_8da8w + preq_group_size: 32 + preq_embedding_quantize: 8,0 + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + use_lora: 16 + +model: + use_sdpa_with_kv_cache: True + use_kv_cache: True + dtype_override: fp32 + +export: + max_seq_length: 2048 + max_context_length: 2048 + +quantization: + use_qat: True + +backend: + xnnpack: + enabled: True + extended_ops: True \ No newline at end of file diff --git a/examples/models/llama/config/llama_xnnpack_spinquant.yaml b/examples/models/llama/config/llama_xnnpack_spinquant.yaml new file mode 100644 index 00000000000..441086d6f73 --- /dev/null +++ b/examples/models/llama/config/llama_xnnpack_spinquant.yaml @@ -0,0 +1,22 @@ +base: + preq_mode: preq_8da4w_out_8da8w + preq_group_size: 32 + preq_embedding_quantize: 8,0 + metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + +model: + use_sdpa_with_kv_cache: True + use_kv_cache: True + dtype_override: fp32 + +export: + max_seq_length: 2048 + max_context_length: 2048 + +quantization: + use_spin_quant: native + +backend: + xnnpack: + enabled: True + extended_ops: True \ No newline at end of file diff --git a/examples/models/llama/config/test_llm_config.py b/examples/models/llama/config/test_llm_config.py index 0853e9dbbd8..52b56d71a03 100644 --- a/examples/models/llama/config/test_llm_config.py +++ b/examples/models/llama/config/test_llm_config.py @@ -41,7 +41,7 @@ def test_local_global_attention_without_kv(self): def test_invalid_export_config_context_length(self): with self.assertRaises(ValueError): - ExportConfig(max_seq_length=128, max_context_length=256) + ExportConfig(max_seq_length=256, max_context_length=128) def test_invalid_qmode(self): with self.assertRaises(ValueError): @@ -84,8 +84,8 @@ def test_valid_llm_config(self): local_global_attention="[16, 32]", ), export=ExportConfig( - max_seq_length=256, - max_context_length=128, + max_seq_length=128, + max_context_length=256, output_dir="/tmp/export", output_name="model.pte", ), @@ -94,7 +94,7 @@ def test_valid_llm_config(self): backend=BackendConfig( xnnpack=XNNPackConfig(enabled=False), coreml=CoreMLConfig( - enabled=True, ios=17, compute_units=CoreMLComputeUnit.ALL + enabled=True, ios=17, compute_units=CoreMLComputeUnit.cpu_only ), ), ) diff --git a/examples/models/phi_4_mini/README.md b/examples/models/phi_4_mini/README.md index d168d54226e..8fb2f03ac4c 100644 --- a/examples/models/phi_4_mini/README.md +++ b/examples/models/phi_4_mini/README.md @@ -8,7 +8,7 @@ Phi-4-mini uses the same example code as Llama, while the checkpoint, model para All commands for exporting and running Llama on various backends should also be applicable to Phi-4-mini, by swapping the following args: ``` base.model_class="phi_4_mini" -base.params="examples/models/phi-4-mini/config.json" +base.params="examples/models/phi-4-mini/config/config.json" base.checkpoint= ``` @@ -33,16 +33,10 @@ Export to XNNPack, no quantization: PHI_CHECKPOINT=path/to/checkpoint.pth python -m extension.llm.export.export_llm \ - base.model_class="phi_4_mini" \ - base.checkpoint="${PHI_CHECKPOINT=path/to/checkpoint.pth:?}" \ - base.params="examples/models/phi-4-mini/config.json" \ - model.use_kv_cache=True \ - model.use_sdpa_with_kv_cache=True \ - model.dtype_override="fp32" \ - backend.xnnpack.enabled=True \ - base.metadata='"{\"get_bos_id\":151643, \"get_eos_ids\":[151643]}"' \ - export.output_name="phi-4-mini.pte" \ - debug.verbose=True + --config config/phi_4_mini_xnnpack.yaml + +base.checkpoint="${PHI_CHECKPOINT=path/to/checkpoint.pth:?}" \ + +base.params="examples/models/phi-4-mini/config/config.json" \ + +export.output_name="phi-4-mini.pte" \ ``` Run using the executor runner: diff --git a/examples/models/phi_4_mini/config.json b/examples/models/phi_4_mini/config/config.json similarity index 100% rename from examples/models/phi_4_mini/config.json rename to examples/models/phi_4_mini/config/config.json diff --git a/examples/models/phi_4_mini/config/phi_4_mini_xnnpack.yaml b/examples/models/phi_4_mini/config/phi_4_mini_xnnpack.yaml new file mode 100644 index 00000000000..9355bd99f64 --- /dev/null +++ b/examples/models/phi_4_mini/config/phi_4_mini_xnnpack.yaml @@ -0,0 +1,12 @@ +base: + model_class: phi_4_mini + metadata: '{"get_bos_id":151643, "get_eos_ids":[151643]}' + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: True + dtype_override: fp32 + +backend: + xnnpack: + enabled: True \ No newline at end of file diff --git a/examples/models/qwen2_5/README.md b/examples/models/qwen2_5/README.md index 57784169ece..566a7a5c30b 100644 --- a/examples/models/qwen2_5/README.md +++ b/examples/models/qwen2_5/README.md @@ -8,7 +8,7 @@ Qwen 2.5 uses the same example code as Llama, while the checkpoint, model params All commands for exporting and running Llama on various backends should also be applicable to Qwen 2.5, by swapping the following args: ``` base.model_class="qwen2_5" -base.params="examples/models/qwen2_5/1_5b_config.json" +base.params="examples/models/qwen2_5/config/1_5b_config.json" base.checkpoint= ``` @@ -33,16 +33,11 @@ Export to XNNPack, no quantization: QWEN_CHECKPOINT=path/to/checkpoint.pth python -m extension.llm.export.export_llm \ - base.model_class="qwen2_5" \ - base.checkpoint="${QWEN_CHECKPOINT:?}" \ - base.params="examples/models/qwen2_5/1_5b_config.json" \ - model.use_kv_cache=True \ - model.use_sdpa_with_kv_cache=True \ - model.dtype_override="fp32" \ - backend.xnnpack.enabled=True \ - base.metadata='"{\"get_bos_id\":151643, \"get_eos_ids\":[151643]}"' \ - export.output_name="qwen2_5-1_5b.pte" \ - debug.verbose=True + --config examples/models/qwen2_5/config/qwen2_5_xnnpack_q8da4w.yaml + +base.model_class="qwen2_5" \ + +base.checkpoint="${QWEN_CHECKPOINT:?}" \ + +base.params="examples/models/qwen2_5/1_5b_config.json" \ + +export.output_name="qwen2_5-1_5b.pte" \ ``` Run using the executor runner: diff --git a/examples/models/qwen2_5/1_5b_config.json b/examples/models/qwen2_5/config/1_5b_config.json similarity index 100% rename from examples/models/qwen2_5/1_5b_config.json rename to examples/models/qwen2_5/config/1_5b_config.json diff --git a/examples/models/qwen2_5/config/qwen2_5_xnnpack_q8da4w.yaml b/examples/models/qwen2_5/config/qwen2_5_xnnpack_q8da4w.yaml new file mode 100644 index 00000000000..0e5c6f7624e --- /dev/null +++ b/examples/models/qwen2_5/config/qwen2_5_xnnpack_q8da4w.yaml @@ -0,0 +1,11 @@ +base: + metadata='{"get_bos_id":151643, "get_eos_ids":[151643]}' + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: True + dtype_override: fp32 + +backend: + xnnpack: + enabled: True \ No newline at end of file diff --git a/examples/models/qwen3/README.md b/examples/models/qwen3/README.md index e24d8da2637..d2d89db93c2 100644 --- a/examples/models/qwen3/README.md +++ b/examples/models/qwen3/README.md @@ -8,7 +8,7 @@ Qwen 3 uses the same example code as our optimized Llama model, while the checkp All commands for exporting and running Llama on various backends should also be applicable to Qwen 3, by swapping the following args: ``` base.model_class=[qwen3_0_6b,qwen3_1_7b,qwen3_4b] -base.params=[examples/models/qwen3/0_6b_config.json,examples/models/qwen3/1_7b_config.json,examples/models/qwen3/4b_config.json] +base.params=[examples/models/qwen3/config/0_6b_config.json,examples/models/qwen3/config/1_7b_config.json,examples/models/config/qwen3/4b_config.json] ``` ### Example export @@ -17,49 +17,29 @@ Here is a basic example for exporting Qwen 3, although please refer to the Llama Export 0.6b to XNNPack, quantized with 8da4w: ``` python -m extension.llm.export.export_llm \ - base.model_class="qwen3_0_6b" \ - base.params="examples/models/qwen3/0_6b_config.json" \ - model.use_kv_cache=True \ - model.use_sdpa_with_kv_cache=True \ - model.dtype_override="fp32" \ - backend.xnnpack.enabled=True \ - backend.xnnpack.extended_ops=True \ - quantization.qmode="8da4w" \ - base.metadata='"{\"get_bos_id\": 151644, \"get_eos_ids\":[151645]}"' \ - export.output_name="qwen3_0_6b.pte" \ - debug.verbose=True + --config examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml + +base.model_class="qwen3_0_6b" \ + +base.params="examples/models/qwen3/config/0_6b_config.json" \ + +export.output_name="qwen3_0_6b.pte" \ + ``` Export 1.7b to XNNPack, quantized with 8da4w: ``` python -m extension.llm.export.export_llm \ - base.model_class="qwen3_1_7b" \ - base.params="examples/models/qwen3/1_7b_config.json" \ - model.use_kv_cache=True \ - model.use_sdpa_with_kv_cache=True \ - model.dtype_override="fp32" \ - backend.xnnpack.enabled=True \ - backend.xnnpack.extended_ops=True \ - quantization.qmode="8da4w" \ - base.metadata='"{\"get_bos_id\": 151644, \"get_eos_ids\":[151645]}"' \ - export.output_name="qwen3_1_7b.pte" \ - debug.verbose=True + --config examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml + +base.model_class="qwen3_1_7b" \ + +base.params="examples/models/qwen3/config/1_7b_config.json" \ + +export.output_name="qwen3_1_7b.pte" \ ``` Export 4b to XNNPack, quantized with 8da4w: ``` python -m extension.llm.export.export_llm \ - base.model_class="qwen3_4b" \ - base.params="examples/models/qwen3/4b_config.json" \ - model.use_kv_cache=True \ - model.use_sdpa_with_kv_cache=True \ - model.dtype_override="fp32" \ - backend.xnnpack.enabled=True \ - backend.xnnpack.extended_ops=True \ - quantization.qmode="8da4w" \ - base.metadata='"{\"get_bos_id\": 151644, \"get_eos_ids\":[151645]}"' \ - export.output_name="qwen3_4b.pte" \ - debug.verbose=True + --config examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml + +base.model_class="qwen3_4b" \ + +base.params="examples/models/qwen3/config/4b_config.json" \ + +export.output_name="qwen3_4b.pte" \ ``` ### Example run diff --git a/examples/models/qwen3/0_6b_config.json b/examples/models/qwen3/config/0_6b_config.json similarity index 100% rename from examples/models/qwen3/0_6b_config.json rename to examples/models/qwen3/config/0_6b_config.json diff --git a/examples/models/qwen3/1_7b_config.json b/examples/models/qwen3/config/1_7b_config.json similarity index 100% rename from examples/models/qwen3/1_7b_config.json rename to examples/models/qwen3/config/1_7b_config.json diff --git a/examples/models/qwen3/4b_config.json b/examples/models/qwen3/config/4b_config.json similarity index 100% rename from examples/models/qwen3/4b_config.json rename to examples/models/qwen3/config/4b_config.json diff --git a/examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml b/examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml new file mode 100644 index 00000000000..60292b1ecdc --- /dev/null +++ b/examples/models/qwen3/config/qwen3_xnnpack_q8da4w.yaml @@ -0,0 +1,15 @@ +base: + metadata: '{"get_bos_id": 151644, "get_eos_ids":[151645]}' + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: True + dtype_override: fp32 + +quantization: + qmode: 8da4w + +backend: + xnnpack: + enabled: True + extended_ops: True \ No newline at end of file diff --git a/examples/nxp/setup.sh b/examples/nxp/setup.sh index 1ef2cc82c2a..1a050a79c19 100644 --- a/examples/nxp/setup.sh +++ b/examples/nxp/setup.sh @@ -7,4 +7,4 @@ set -u # Install neutron-converter -pip install --extra-index-url https://eiq.nxp.com/repository neutron-converter_SDK_25_03 +pip install --extra-index-url https://eiq.nxp.com/repository neutron_converter_SDK_25_03 diff --git a/examples/qualcomm/qaihub_scripts/utils/export.py b/examples/qualcomm/qaihub_scripts/utils/export.py index 4d252175dbb..2ee1968dd82 100644 --- a/examples/qualcomm/qaihub_scripts/utils/export.py +++ b/examples/qualcomm/qaihub_scripts/utils/export.py @@ -18,7 +18,6 @@ from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.utils import ( draw_graph, - ExecutorchBackendConfig, from_context_binary, generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, @@ -26,6 +25,7 @@ ) from executorch.examples.qualcomm.qaihub_scripts.utils.utils import preprocess_binary from executorch.examples.qualcomm.utils import make_output_dir, SimpleADB +from executorch.exir import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass diff --git a/examples/qualcomm/util_scripts/README.md b/examples/qualcomm/util_scripts/README.md new file mode 100644 index 00000000000..712bbcd4277 --- /dev/null +++ b/examples/qualcomm/util_scripts/README.md @@ -0,0 +1,79 @@ +# CLI Tool for Quantize / Compile / Deploy PyTorch Model with QNN Backend + +An easy-to-use tool for quantizing / compiling / executing .pte program with Qualcomm AI Engine Direct. Tool is verified with [host environement](../../../docs/source/backends-qualcomm.md#host-os). + +## Description + +This tool aims for users who want to deploy models with ExecuTorch runtime. It's possible for them to produce .pte program in few steps.
+ +### Quantizing Model + +* Save torch.nn.Module with .pt2 format & prepare input data + ```bash + # create workspace for following operations + cd path/to/executorch + mkdir cli_example + ``` + ```python + # take SimpleModel as an example + import torch + from executorch.backends.qualcomm.tests.models import SimpleModel + from pathlib import Path + # make example inputs + example_inputs = (torch.randn(1, 32, 28, 28), torch.randn(1, 32, 28, 28)) + # generate ExportedProgram + ep = torch.export.export(SimpleModel(), example_inputs) + # save to workspace + ws = f"{Path().cwd()}/cli_example" + torch.export.save(ep, f"{ws}/simple_model.pt2") + # prepare calibration dataset: 2 sets of data with 2 inputs each + input_list = "" + for i in range(2): + current_input = "" + for j in range(2): + file_name = f"{ws}/input_{i}_{j}.pt" + torch.save(torch.randn(1, 32, 28, 28), file_name) + current_input += f"{file_name} " + input_list += f"{current_input.strip()}\n" + + with open(f"{ws}/input_list", 'w') as f: + f.write(input_list) + ``` + +* Quantize + ```bash + # user could get more information via: PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli quantize -h + PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli quantize -a cli_example/simple_model.pt2 -o cli_example/quantize_output -c use_8a8w -i cli_example/input_list --per_channel + ``` +* Artifacts for quantized .pt2 file + - `cli_example/quantize_output/simple_model_quantized.pt2` + + +### Compiling Program + +* Compile .pt2 to .pte program + ```bash + # `pip install pydot` if package is missing + # user could get more information via: PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli compile -h + PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli compile -a cli_example/quantize_output/simple_model_quantized.pt2 -o cli_example/compile_output -m SM8750 + ``` +* (Optional) Compile pre-generated context binary to .pte program + ```bash + # `pip install pydot` if package is missing + # user could get more information via: PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli compile -h + PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli compile -a model.bin -o path/to/model/output -m SM8750 + ``` +* Artifacts for .pte file and figure of graph information + - `cli_example/compile_output/simple_model_quantized.pte` + - `cli_example/compile_output/simple_model_quantized.svg` + +### Executing Program + +* Execute .pte program + ```bash + # user could get more information via: PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli execute -h + PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli execute -a cli_example/compile_output/simple_model_quantized.pte -o cli_example/execute_output -i cli_example/input_list -s $DEVICE_SERIAL -b build-android -m SM8750 + ``` +* Artifacts for .pte file and figure of graph information + - `cli_example/execute_output/output_{data_index}_{output_index}.pt`.
+ `data_index` represents the sequence of dataset, `output_index` stands for the order of graph output. diff --git a/examples/qualcomm/util_scripts/cli.py b/examples/qualcomm/util_scripts/cli.py new file mode 100644 index 00000000000..e4c4c5dcaf8 --- /dev/null +++ b/examples/qualcomm/util_scripts/cli.py @@ -0,0 +1,504 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import importlib +import logging +import os +import re +from pathlib import Path + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor +import numpy as np + +import torch + +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset +from executorch.backends.qualcomm.utils.constants import QCOM_PASS_ACTIVATE_KEY +from executorch.backends.qualcomm.utils.utils import ( + draw_graph, + dump_context_from_pte, + from_context_binary, + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + generate_qnn_executorch_option, + QNN_QUANT_TYPE_MAP, + QNN_TENSOR_TYPE_MAP, + to_edge_transform_and_lower_to_qnn, +) +from executorch.examples.qualcomm.qaihub_scripts.utils.utils import preprocess_binary +from executorch.examples.qualcomm.utils import ( + make_output_dir, + make_quantizer, + SimpleADB, +) +from executorch.exir import ExecutorchBackendConfig +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from torchao.quantization import pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +def get_logger(): + logger = logging.getLogger("examples.qualcomm.util_scripts.cli") + handler = logging.StreamHandler() + handler.setFormatter( + logging.Formatter( + fmt="[%(asctime)s %(prefix)s] %(levelname)-8s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + logger.propagate = False + return logging.LoggerAdapter(logger, extra={"prefix": "QNN_BACKEND"}) + + +def get_io_info(pte_path, compiler_specs): + dtype_map = {} + for type_map in (QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP): + for k, v in type_map.items(): + dtype_map.setdefault(v, k) + + def fill_tensor_info(info, qnn_tensors, category): + for tensor in qnn_tensors: + encoding = tensor.GetEncodings() + quantization_info = { + "scale": encoding.data["scale"].tolist(), + "offset": encoding.data["offset"].tolist(), + "axis": encoding.axis, + } + info[category].append( + { + "name": tensor.GetName(), + "shape": tensor.GetDims().tolist(), + "dtype": dtype_map[tensor.GetDataType()], + "encoding": quantization_info, + } + ) + + in_key, out_key = "inputs", "outputs" + tensor_info = {in_key: [], out_key: []} + + path_of_pte = Path(pte_path) + dump_context_from_pte(path_of_pte.absolute()) + ctx_bin = [f for f in os.listdir(path_of_pte.parent) if Path(f).suffix == ".bin"][0] + # assume graph is fully delegated or it will be too hard to handle + with open(f"{path_of_pte.parent}/{ctx_bin}", "rb") as f: + ctx_bin = preprocess_binary(f.read(), compiler_specs) + # leverage QNN pybind interface to retrieve tensor encodings + qnn_mgr = PyQnnManagerAdaptor.QnnManager( + generate_qnn_executorch_option(compiler_specs), ctx_bin + ) + assert qnn_mgr.Init().value == 0, "failed to load context binary" + graph_name = qnn_mgr.GetGraphNames()[0] + qnn_mgr.AllocateTensor(graph_name) + fill_tensor_info(tensor_info, qnn_mgr.GetGraphInputs(graph_name), in_key) + fill_tensor_info(tensor_info, qnn_mgr.GetGraphOutputs(graph_name), out_key) + qnn_mgr.Destroy() + + return tensor_info + + +def quantize(args): + logger = get_logger() + + # get corresponding QnnQuantizer + try: + quant_dtype = getattr(QuantDtype, args.config) + act_observer = getattr(pt2e, args.activation_observer) + quantizer = make_quantizer( + quant_dtype=quant_dtype, + per_channel_conv=args.per_channel, + per_channel_linear=args.per_row, + act_observer=act_observer, + ) + except Exception: + logger.error( + f"Failed to retrieve expected config {args.config} / {args.activation_observer}." + ) + exit(1) + + # step 0: load saved model + ep = torch.export.load(args.artifact) + # step 1: use prepare_pt2e to annotate QDQ pairs + ep_prepared = prepare_pt2e(ep.module(), quantizer) + logger.info(f"perform calibration on {args.artifact}") + # step 2: perform calibration + with open(args.input_list, "r") as f: + for line in f.read().split("\n")[:-1]: + inputs = [torch.load(t, weights_only=True) for t in line.split(" ")] + ep_prepared(*inputs) + # step 3: use convert_pt2e to fix encodings of QDQ pairs + logger.info(f"saving calibrated model for {args.artifact}") + ep_converted = convert_pt2e(ep_prepared) + ep_quantized = torch.export.export(ep_converted, tuple(inputs)) + make_output_dir(args.output_folder) + torch.export.save( + ep_quantized, f"{args.output_folder}/{Path(args.artifact).stem}_quantized.pt2" + ) + + +def compile(args): + logger = get_logger() + + # setup memory planning + memory_planning_pass = MemoryPlanningPass( + alloc_graph_input=args.shared_buffer is None, + alloc_graph_output=args.shared_buffer is None, + ) + + file_name, extension = Path(args.artifact).stem, Path(args.artifact).suffix + make_output_dir(args.output_folder) + # setup compiler spec dedicated to QNN HTP backend + backend_options = generate_htp_compiler_spec(use_fp16=True) + # setup general compiler spec for QNN + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=getattr(QcomChipset, args.model), + backend_options=backend_options, + is_from_context_binary=extension == "bin", + ) + if extension == ".bin": + custom_op_name = f"ctx_loader_{file_name}" + # step 1: generate ExportedProgram with custom op as a binary loader & lower it w/QnnBackend + logger.info(f"exporting program for {args.artifact}") + prog_info = from_context_binary( + args.artifact, custom_op_name, getattr(QcomChipset, args.model) + ) + # step 2: write pte files and store final graph + logger.info(f"exporting {file_name}.pte") + with open(f"{args.output_folder}/{file_name}.pte", "wb") as f: + prog_info["edge_program_manager"].to_executorch( + config=ExecutorchBackendConfig( + memory_planning_pass=memory_planning_pass + ) + ).write_to_file(f) + logger.info(f"exporting network graph with {file_name}.svg") + draw_graph(file_name, args.output_folder, prog_info["exported_program"]) + elif extension == ".pt2": + # step 0: prepare exported_program + ep = torch.export.load(args.artifact) + sample_inputs = ep.example_inputs[0] + # step 1: start lowering to QnnBackend + logger.info(f"start lowering program for {args.artifact}") + passes, user_passes = get_capture_program_passes(), [] + if args.pass_job is not None: + for job in args.pass_job: + try: + user_passes.append( + importlib.import_module( + "executorch.backends.qualcomm._passes", job + ) + ) + except Exception: + logger.error(f"failed to extract designated pass '{args.artifact}'") + + for user_pass in user_passes: + passes[user_pass][QCOM_PASS_ACTIVATE_KEY] = True + + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module=ep.module(), + inputs=sample_inputs, + compiler_specs=compiler_specs, + passes_job=passes, + ) + # step 2: write pte files and store final graph + logger.info(f"exporting {file_name}.pte") + with open(f"{args.output_folder}/{file_name}.pte", "wb") as f: + edge_prog_mgr.to_executorch( + config=ExecutorchBackendConfig( + memory_planning_pass=memory_planning_pass + ) + ).write_to_file(f) + logger.info(f"exporting network graph with {file_name}.svg") + draw_graph(file_name, args.output_folder, edge_prog_mgr.exported_program()) + else: + logger.error(f"unsupported file extension for '{args.artifact}'") + + +def execute(args): + logger = get_logger() + + pte_name = Path(args.artifact).stem + + # load input files + logger.info("loading user inputs") + user_inputs, input_list = [], "" + with open(args.input_list, "r") as f: + for line in f.read().split("\n")[:-1]: + inputs, input_names = [], "" + for data in line.split(" "): + input_names += f"{Path(data).stem}.raw " + inputs.append(torch.load(data, weights_only=True)) + user_inputs.append(inputs) + input_list += input_names.strip() + "\n" + + logger.info("retrieving graph I/O") + # setup compiler spec dedicated to QNN HTP backend + backend_options = generate_htp_compiler_spec(use_fp16=True) + # setup general compiler spec for QNN + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=getattr(QcomChipset, args.model), + backend_options=backend_options, + ) + io_info = get_io_info(args.artifact, compiler_specs) + + logger.info("preparing ADB connection") + # leverage SimpleADB for e2e inference + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=args.build_folder, + pte_path=args.artifact, + workspace=f"/data/local/tmp/executorch/{pte_name}", + device_id=args.device, + soc_model=args.model, + host_id=args.host, + shared_buffer=args.shared_buffer, + ) + + logger.info("pushing QNN libraries & other artifacts") + adb.push(inputs=user_inputs, input_list=input_list) + + logger.info("starting inference") + adb.execute() + + def post_process(): + torch_to_numpy_dtype_dict = { + torch.bool: np.dtype("bool"), + torch.uint8: np.dtype("uint8"), + torch.int8: np.dtype("int8"), + torch.int16: np.dtype("int16"), + torch.int32: np.dtype("int32"), + torch.int64: np.dtype("int64"), + torch.float16: np.dtype("float16"), + torch.float32: np.dtype("float32"), + torch.float64: np.dtype("float64"), + torch.complex64: np.dtype("complex64"), + torch.complex128: np.dtype("complex128"), + } + output_info = io_info["outputs"] + output_folder = f"{args.output_folder}/outputs" + for _, f in enumerate(os.listdir(output_folder)): + filename = os.path.join(output_folder, f) + match_res = re.match(r".*([0-9]+)_([0-9]+)\.raw$", filename) + data_index, output_index = int(match_res.group(1)), int(match_res.group(2)) + output = np.fromfile( + filename, + dtype=eval( + f"np.{torch_to_numpy_dtype_dict[output_info[output_index]['dtype']]}" + ), + ) + output = torch.from_numpy( + output.reshape(output_info[output_index]["shape"]) + ) + torch.save( + output, f"{args.output_folder}/output_{data_index}_{output_index}.pt" + ) + + logger.info("collecting output data") + make_output_dir(args.output_folder) + adb.pull(args.output_folder, post_process) + logger.info(f"execution finished, please check {args.output_folder} for results") + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Utility to quantize / compile / execute models via Qualcomm backend" + ), + ) + subparsers = parser.add_subparsers( + title="subcommands", + description=( + "[quantize]: Perform PTQ with QnnQuantizer for models in .pt2 extension. " + "[compile]: Compile model in .pt2 extenstion / context binary into .pte file. " + "[execute]: Perform on-device inference with given .pte." + ), + ) + + sub_quantize = subparsers.add_parser( + name="quantize", + help=( + "e.g. python -m executorch.example.qualcomm.util_scripts.cli quantize " + "-a model.pt2 -c use_8a8w -i calibration_data" + ), + ) + sub_quantize.add_argument( + "-a", + "--artifact", + type=str, + required=True, + help="Path to saved .pt2 model in floating point precision.", + ) + sub_quantize.add_argument( + "-o", + "--output_folder", + type=str, + default="./output_quantized", + help="Path to output artifact, store in 'output_quantized' if not given.", + ) + sub_quantize.add_argument( + "-c", + "--config", + type=str, + default="use_8a8w", + help=(f"Configuration to be applied: {list(QuantDtype.__members__.keys())}."), + ) + sub_quantize.add_argument( + "-i", + "--input_list", + type=str, + required=True, + help=( + "List of input files specified for calibration. " + 'e.g. File content with: "input_0_0.pt2 input_0_1.pt2\\ninput_1_0.pt2 input_1_1.pt2" ' + "means there are 2 sets of data for calibration on a graph with 2 inputs." + ), + ) + sub_quantize.add_argument( + "--per_channel", + action="store_true", + help="Use per_channel encoding for operator convolution and its' families.", + ) + sub_quantize.add_argument( + "--per_row", + action="store_true", + help="Use per_row encoding for operator linear.", + ) + sub_quantize.add_argument( + "--activation_observer", + type=str, + default="MovingAverageMinMaxObserver", + help=( + "Activation observer for PTQ " + "(MinMaxObserver / MovingAverageMinMaxObserver / HistogramObserver)." + ), + ) + sub_quantize.set_defaults(callback=quantize) + + sub_compile = subparsers.add_parser( + name="compile", + help=( + "e.g. python -m executorch.example.qualcomm.util_scripts.cli compile " + "-a model.(pt2 / bin) -m SM8750" + ), + ) + sub_compile.add_argument( + "-a", + "--artifact", + type=str, + required=True, + help="Path to saved .pt2 model or pre-generated context binary.", + ) + sub_compile.add_argument( + "-m", + "--model", + type=str, + required=True, + help="SoC model. e.g. SM8750", + ) + sub_compile.add_argument( + "-o", + "--output_folder", + type=str, + default="./output_pte", + help="Path to output artifacts, store in 'output_pte' if not given.", + ) + sub_compile.add_argument( + "-p", + "--pass_job", + nargs="+", + type=str, + help=( + 'Add extra passes for model lowering. e.g. "ExpandBroadcastTensorShape".' + ), + ) + sub_compile.add_argument( + "--shared_buffer", + help=( + "Enable usage of shared buffer between application and backend for graph I/O." + ), + action="store_true", + ) + sub_compile.set_defaults(callback=compile) + + sub_execute = subparsers.add_parser( + name="execute", + help=( + "e.g. python -m executorch.example.qualcomm.util_scripts.cli " + "execute -p model.pte -i execution_data -s device_serial" + ), + ) + sub_execute.add_argument( + "-a", + "--artifact", + type=str, + required=True, + help="Path to .pte file generated from 'compile' subcommand.", + ) + sub_execute.add_argument( + "-i", + "--input_list", + type=str, + help=( + "List of input files specified for execution. " + 'e.g. File content with: "input_0_0.pt2 input_0_1.pt2\\ninput_1_0.pt2 input_1_1.pt2" ' + "means there are 2 sets of data for execution on a graph with 2 inputs.\n" + ), + ) + sub_execute.add_argument( + "-m", + "--model", + type=str, + required=True, + help="SoC model. e.g. SM8750", + ) + sub_execute.add_argument( + "-s", + "--device", + type=str, + required=True, + help="Serial no of device which could be obtained by 'adb devices'.", + ) + sub_execute.add_argument( + "-o", + "--output_folder", + type=str, + default="./output_data", + help="Path to output data, store in 'output_data' if not given.", + ) + sub_execute.add_argument( + "-b", + "--build_folder", + help="Path to cmake binary directory for android, e.g., /path/to/build-android", + type=str, + required=True, + ) + sub_execute.add_argument( + "-H", + "--host", + type=str, + help="Gateway hostname.", + ) + sub_execute.add_argument( + "--shared_buffer", + help=( + "Enable usage of shared buffer between application and backend for graph I/O." + " Please use with `--shared_buffer` in compile command." + ), + action="store_true", + ) + sub_execute.set_defaults(callback=execute) + + args = parser.parse_args() + args.callback(args) + + +if __name__ == "__main__": + main() diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 6d9a6653ec7..e70510b0b70 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -9,6 +9,7 @@ import argparse import os +import shutil import subprocess import sys import tempfile @@ -395,9 +396,7 @@ def build_executorch_binary( def make_output_dir(path: str): if os.path.exists(path): - for f in os.listdir(path): - os.remove(os.path.join(path, f)) - os.removedirs(path) + shutil.rmtree(path, ignore_errors=True) os.makedirs(path) diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index 749e8f5c2f1..8699fe2fd02 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -31,6 +31,7 @@ python_library( ":sym_shape_eval_pass", ":sym_to_tensor_pass", ":weights_to_outputs_pass", + ":reinplace_pass", "//caffe2:torch", "//executorch/exir:common", "//executorch/exir:control_flow", @@ -68,6 +69,17 @@ python_library( ], ) +python_library( + name = "reinplace_pass", + srcs = [ + "reinplace.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ], +) + python_library( name = "insert_write_back_for_buffers_pass", srcs = [ diff --git a/exir/passes/reinplace.py b/exir/passes/reinplace.py new file mode 100644 index 00000000000..349869a2f4b --- /dev/null +++ b/exir/passes/reinplace.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Set + +import torch +from executorch.exir.dialects._ops import ops +from torch.export import ExportedProgram + + +def _is_index_put(node: torch.fx.Node) -> bool: + """Check if a node is an index_put operation.""" + return node.op == "call_function" and node.target in ( + torch.ops.aten.index_put.default, + ops.edge.aten.index_put.default, + ) + + +def _is_safe_to_reinplace( + node: torch.fx.Node, + later_nodes: Set[torch.fx.Node], + inputs: Set[torch.fx.Node], + mutable_inputs: Set[torch.fx.Node], +) -> bool: + # This node is used later in the graph so we can't reinplace it + # There is probably a faster way to do this but this works for now. + if node in later_nodes: + return False + # If its not an input then we can reinplace it + if node not in inputs: + return True + # If its a mutable input then we can reinplace it + elif node in mutable_inputs: + return True + else: # input but not mutable input + return False + + +def _is_mutable_user_input( + node: torch.fx.Node, exported_program: ExportedProgram +) -> bool: + return ( + node.target in exported_program.graph_signature.user_inputs_to_mutate.values() + ) + + +def _is_mutable_buffer(node: torch.fx.Node, exported_program: ExportedProgram) -> bool: + if node.target not in exported_program.graph_signature.inputs_to_buffers: + return False + buf = exported_program.graph_signature.inputs_to_buffers[node.target] + return buf in exported_program.graph_signature.buffers_to_mutate.values() + + +def reinplace_pass(ep: ExportedProgram) -> ExportedProgram: + """ + Pass that loops over nodes in an exported program and collects the first argument + of every call_function node that is a view_copy operation. + + Args: + exported_program: The ExportedProgram to analyze + + Returns: + Set of nodes that are first arguments to view_copy operations + """ + seen_nodes: Set[torch.fx.Node] = set() + # Get all placeholders + inputs = set() + for node in ep.graph.nodes: + if node.op == "placeholder": + inputs.add(node) + # Get all inputs that we could potentially mutate + mutable_nodes = set( + [ + node + for node in inputs + if _is_mutable_user_input(node, ep) or _is_mutable_buffer(node, ep) + ] + ) + + results = set() + for node in reversed(ep.graph.nodes): + if _is_index_put(node): + # Check if this index_put node is safe to inplace + # The first argument is the base tensor being indexed into + first_arg = node.args[0] + if _is_safe_to_reinplace(first_arg, seen_nodes, inputs, mutable_nodes): + # This index_put is safe to reinplace + with ep.graph.inserting_before(node): + new_node = ep.graph.call_function( + ops.edge.aten.index_put_.default, args=node.args + ) + new_node.meta["val"] = node.meta["val"] + node.replace_all_uses_with(new_node) + ep.graph.erase_node(node) + results.add(first_arg) + elif node.op == "call_function": + seen_nodes.update(node.all_input_nodes) + return ep diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 1423984c563..2c2ad3e05f0 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -136,6 +136,18 @@ python_unittest( ], ) +python_unittest( + name = "reinplace_pass", + srcs = [ + "test_reinplace_pass.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir/passes:lib", + ], +) + cpp_library( name = "test_lib", srcs = [ diff --git a/exir/tests/test_reinplace_pass.py b/exir/tests/test_reinplace_pass.py new file mode 100644 index 00000000000..2f4538770d6 --- /dev/null +++ b/exir/tests/test_reinplace_pass.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +from executorch.exir import to_edge +from executorch.exir.passes.reinplace import reinplace_pass +from torch.export import export + + +class TestReinplacePass(unittest.TestCase): + def test_index_put_reinplace(self) -> None: + """Test that index_put on a mutable buffer can be reinplaced.""" + + class IndexPutModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.zeros(5)) + + def forward( + self, indices: torch.Tensor, values: torch.Tensor + ) -> torch.Tensor: + # index_put on buffer (non-user input) should be safe + self.state.index_put_((indices,), values) + return self.state + + model = IndexPutModel() + indices = torch.tensor([0]) + values = torch.tensor([1.0]) + + exported_program = export(model, (indices, values), strict=True) + print(exported_program.graph) + edge_program = to_edge(exported_program).exported_program() + + # Find the index_put node + index_put_node = None + for node in edge_program.graph.nodes: + if node.op == "call_function" and "index_put" in str(node.target): + index_put_node = node + break + + self.assertIsNotNone(index_put_node, "Should find an index_put node") + + ep = reinplace_pass(edge_program) + # Find the index_put node + index_put_node = None + for node in ep.graph.nodes: + if node.op == "call_function" and "index_put_" in str(node.target): + index_put_node = node + break + + self.assertIsNotNone(index_put_node, "Should find an index_put_ node") + + def test_cant_reinplace(self) -> None: + """Test that index_put on a mutable buffer that is viewed later is not safe.""" + + class IndexPutModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.zeros(5)) + + def forward( + self, indices: torch.Tensor, values: torch.Tensor + ) -> torch.Tensor: + # index_put on buffer (non-user input) should be safe + x = self.state.index_put((indices,), values) + self.state.add_(1) + return x + + model = IndexPutModel() + indices = torch.tensor([0]) + values = torch.tensor([1.0]) + + exported_program = export(model, (indices, values), strict=True) + edge_program = to_edge(exported_program).exported_program() + + # Find the index_put node + index_put_node = None + for node in edge_program.graph.nodes: + if node.op == "call_function" and "index_put" in str(node.target): + index_put_node = node + break + + self.assertIsNotNone(index_put_node, "Should find an index_put node") + + ep = reinplace_pass(edge_program) + # Find the index_put node + index_put_node = None + for node in ep.graph.nodes: + if ( + node.op == "call_function" + and "index_put" in str(node.target) + and "index_put_" not in str(node.target) + ): + index_put_node = node + break + + self.assertIsNotNone(index_put_node, "Should still find an index_put node") diff --git a/extension/flat_tensor/serialize/serialize.py b/extension/flat_tensor/serialize/serialize.py index 5b29d7ccacd..7f3332c4303 100644 --- a/extension/flat_tensor/serialize/serialize.py +++ b/extension/flat_tensor/serialize/serialize.py @@ -19,7 +19,11 @@ from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile from executorch.exir._serialize._program import _insert_flatbuffer_header -from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer +from executorch.exir._serialize.data_serializer import ( + DataEntry, + DataPayload, + DataSerializer, +) from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required @@ -34,6 +38,9 @@ # endian. _HEADER_BYTEORDER: Literal["little"] = "little" +# Current version. Keep in sync with c++ version number in serialize. +_FLAT_TENSOR_VERSION: int = 0 + def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord: """Serializes a FlatTensor to a flatbuffer and returns the serialized data.""" @@ -320,7 +327,7 @@ def serialize( # Create FlatTensor, which describes of the contents of the file and # points to all the data segments. It will be serialized to flatbuffer. flat_tensor = FlatTensor( - version=0, # Keep in sync with c++ version number in serialize.h + version=_FLAT_TENSOR_VERSION, segments=data_segments, named_data=named_data, ) @@ -383,4 +390,49 @@ def deserialize(self, blob: Cord) -> DataPayload: """ Deserializes a flat_tensor blob into a list of tensor metadata and tensors. """ - raise NotImplementedError("deserialize_data") + + data = bytes(blob) + + # Read header. Verify that it's valid. + header = FlatTensorHeader.from_bytes(data[8:]) + if not header.is_valid(): + raise RuntimeError( + "Flat tensor header is invalid. File is likely incorrect format or corrupt." + ) + + # Deserialize the flat tensor data, which contains the data offsets and tensor metadata. + flat_tensor_bytes = data[0 : header.flatbuffer_offset + header.flatbuffer_size] + flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes) + + # Verify that this is a supported version. + if flat_tensor.version != _FLAT_TENSOR_VERSION: + raise NotImplementedError( + f"Flat tensor files reports unsupported version {flat_tensor.version}. Expected {_FLAT_TENSOR_VERSION}." + ) + + # Extract the buffers. + buffers = [ + data[ + header.segment_base_offset + + segment.offset : header.segment_base_offset + + segment.offset + + segment.size + ] + for segment in flat_tensor.segments + ] + + payload = DataPayload( + buffers=buffers, + named_data={}, + ) + + # Read the named data entries. + for named_data in flat_tensor.named_data: + entry = DataEntry( + buffer_index=named_data.segment_index, + alignment=1, + tensor_layout=named_data.tensor_layout, + ) + payload.named_data[named_data.key] = entry + + return payload diff --git a/extension/flat_tensor/test/test_serialize.py b/extension/flat_tensor/test/test_serialize.py index 80ee59ae974..13402e60a65 100644 --- a/extension/flat_tensor/test/test_serialize.py +++ b/extension/flat_tensor/test/test_serialize.py @@ -6,17 +6,19 @@ # pyre-unsafe +import dataclasses import math import unittest from typing import List, Optional +from executorch.exir._serialize._cord import Cord + from executorch.exir._serialize.data_serializer import ( DataEntry, DataPayload, DataSerializer, ) - from executorch.exir._serialize.padding import aligned_size from executorch.exir.schema import ScalarType @@ -223,3 +225,39 @@ def test_serialize(self) -> None: ) self.assertEqual(segments[2].offset + segments[2].size, len(segment_data)) + + def test_round_trip(self) -> None: + # Serialize and then deserialize the test payload. Make sure it's reconstructed + # properly. + config = FlatTensorConfig() + serializer: DataSerializer = FlatTensorSerializer(config) + + # Round trip the data. + serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD)) + deserialized_payload = serializer.deserialize(Cord(serialized_data)) + + # Validate the deserialized payload. Since alignment isn't serialized, we need to + # do this somewhat manually. + for i in range(len(deserialized_payload.buffers)): + self.assertEqual( + TEST_DATA_PAYLOAD.buffers[i], + deserialized_payload.buffers[i], + f"Buffer at index {i} does not match.", + ) + + self.assertEqual( + TEST_DATA_PAYLOAD.named_data.keys(), deserialized_payload.named_data.keys() + ) + + SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison. + for key in TEST_DATA_PAYLOAD.named_data.keys(): + reference = TEST_DATA_PAYLOAD.named_data[key] + actual = deserialized_payload.named_data[key] + + for field in dataclasses.fields(reference): + if field.name not in SKIP_FIELDS: + self.assertEqual( + getattr(reference, field.name), + getattr(actual, field.name), + f"Named data record {key}.{field.name} does not match.", + ) diff --git a/extension/llm/export/README.md b/extension/llm/export/README.md index 96f36acc1b4..e97b9e10462 100644 --- a/extension/llm/export/README.md +++ b/extension/llm/export/README.md @@ -23,9 +23,9 @@ The LLM export process transforms a model from its original format to an optimiz ## Usage -The export API supports two configuration approaches: +The export API supports a Hydra-style CLI where you can you configure using yaml and also CLI args. -### Option 1: Hydra CLI Arguments +### Hydra CLI Arguments Use structured configuration arguments directly on the command line: @@ -41,7 +41,7 @@ python -m extension.llm.export.export_llm \ quantization.qmode=8da4w ``` -### Option 2: Configuration File +### Configuration File Create a YAML configuration file and reference it: @@ -78,53 +78,21 @@ debug: verbose: true ``` -**Important**: You cannot mix both approaches. Use either CLI arguments OR a config file, not both. +You can you also still provide additional overrides using the CLI args as well: -## Example Commands - -### Export Qwen3 0.6B with XNNPACK backend and quantization ```bash -python -m extension.llm.export.export_llm \ - base.model_class=qwen3_0_6b \ - base.params=examples/models/qwen3/0_6b_config.json \ - base.metadata='{"get_bos_id": 151644, "get_eos_ids":[151645]}' \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override=FP32 \ - export.max_seq_length=512 \ - export.output_name=qwen3_0_6b.pte \ - quantization.qmode=8da4w \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - debug.verbose=true +python -m extension.llm.export.export_llm + --config my_config.yaml + base.model_class="llama2" + +export.max_context_length=1024 ``` -### Export Phi-4-Mini with custom checkpoint -```bash -python -m extension.llm.export.export_llm \ - base.model_class=phi_4_mini \ - base.checkpoint=/path/to/phi4_checkpoint.pth \ - base.params=examples/models/phi-4-mini/config.json \ - base.metadata='{"get_bos_id":151643, "get_eos_ids":[151643]}' \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - export.max_seq_length=256 \ - export.output_name=phi4_mini.pte \ - backend.xnnpack.enabled=true \ - debug.verbose=true -``` +Note that if a config file is specified and you want to specify a CLI arg that is not in the config, you need to prepend with a `+`. You can read more about this in the Hydra [docs](https://hydra.cc/docs/advanced/override_grammar/basic/). -### Export with CoreML backend (iOS optimization) -```bash -python -m extension.llm.export.export_llm \ - base.model_class=llama3 \ - model.use_kv_cache=true \ - export.max_seq_length=128 \ - backend.coreml.enabled=true \ - backend.coreml.compute_units=ALL \ - quantization.pt2e_quantize=coreml_c4w \ - debug.verbose=true -``` + +## Example Commands + +Please refer to the docs for some of our example suported models ([Llama](https://github.com/pytorch/executorch/blob/main/examples/models/llama/README.md), [Qwen3](https://github.com/pytorch/executorch/tree/main/examples/models/qwen3/README.md), [Phi-4-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi_4_mini/README.md)). ## Configuration Options @@ -134,4 +102,4 @@ For a complete reference of all available configuration options, see the [LlmCon - [Llama Examples](../../../examples/models/llama/README.md) - Comprehensive Llama export guide - [LLM Runner](../runner/) - Running exported models -- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview \ No newline at end of file +- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview diff --git a/extension/llm/export/export_llm.py b/extension/llm/export/export_llm.py index e995b329f30..e0467250a28 100644 --- a/extension/llm/export/export_llm.py +++ b/extension/llm/export/export_llm.py @@ -30,6 +30,7 @@ """ import argparse +import os import sys from typing import Any, List, Tuple @@ -45,7 +46,6 @@ def parse_config_arg() -> Tuple[str, List[Any]]: - """First parse out the arg for whether to use Hydra or the old CLI.""" parser = argparse.ArgumentParser(add_help=True) parser.add_argument("--config", type=str, help="Path to the LlmConfig file") args, remaining = parser.parse_known_args() @@ -56,6 +56,7 @@ def pop_config_arg() -> str: """ Removes '--config' and its value from sys.argv. Assumes --config is specified and argparse has already validated the args. + Returns the config file path. """ idx = sys.argv.index("--config") value = sys.argv[idx + 1] @@ -63,30 +64,42 @@ def pop_config_arg() -> str: return value -@hydra.main(version_base=None, config_name="llm_config") +def add_hydra_config_args(config_file_path: str) -> None: + """ + Breaks down the config file path into directory and filename, + resolves the directory to an absolute path, and adds the + --config_path and --config_name arguments to sys.argv. + """ + config_dir = os.path.dirname(config_file_path) + config_name = os.path.basename(config_file_path) + + # Resolve to absolute path + config_dir_abs = os.path.abspath(config_dir) + + # Add the hydra config arguments to sys.argv + sys.argv.extend(["--config-path", config_dir_abs, "--config-name", config_name]) + + +@hydra.main(version_base=None, config_name="llm_config", config_path=None) def hydra_main(llm_config: LlmConfig) -> None: - export_llama(OmegaConf.to_object(llm_config)) + structured = OmegaConf.structured(LlmConfig) + merged = OmegaConf.merge(structured, llm_config) + llm_config_obj = OmegaConf.to_object(merged) + export_llama(llm_config_obj) def main() -> None: + # First parse out the arg for whether to use Hydra or the old CLI. config, remaining_args = parse_config_arg() if config: - # Check if there are any remaining hydra CLI args when --config is specified - # This might change in the future to allow overriding config file values - if remaining_args: - raise ValueError( - "Cannot specify additional CLI arguments when using --config. " - f"Found: {remaining_args}. Use either --config file or hydra CLI args, not both." - ) - + # Pop out --config and its value so that they are not parsed by + # Hydra's main. config_file_path = pop_config_arg() - default_llm_config = LlmConfig() - llm_config_from_file = OmegaConf.load(config_file_path) - # Override defaults with values specified in the .yaml provided by --config. - merged_llm_config = OmegaConf.merge(default_llm_config, llm_config_from_file) - export_llama(merged_llm_config) - else: - hydra_main() + + # Add hydra config_path and config_name arguments to sys.argv. + add_hydra_config_args(config_file_path) + + hydra_main() if __name__ == "__main__": diff --git a/extension/llm/export/test/test_export_llm.py b/extension/llm/export/test/test_export_llm.py index 7d17b7819d3..ab7db1b4e3a 100644 --- a/extension/llm/export/test/test_export_llm.py +++ b/extension/llm/export/test/test_export_llm.py @@ -21,7 +21,7 @@ class TestExportLlm(unittest.TestCase): def test_parse_config_arg_with_config(self) -> None: """Test parse_config_arg when --config is provided.""" # Mock sys.argv to include --config - test_argv = ["script.py", "--config", "test_config.yaml", "extra", "args"] + test_argv = ["export_llm.py", "--config", "test_config.yaml", "extra", "args"] with patch.object(sys, "argv", test_argv): config_path, remaining = parse_config_arg() self.assertEqual(config_path, "test_config.yaml") @@ -29,7 +29,7 @@ def test_parse_config_arg_with_config(self) -> None: def test_parse_config_arg_without_config(self) -> None: """Test parse_config_arg when --config is not provided.""" - test_argv = ["script.py", "debug.verbose=True"] + test_argv = ["export_llm.py", "debug.verbose=True"] with patch.object(sys, "argv", test_argv): config_path, remaining = parse_config_arg() self.assertIsNone(config_path) @@ -37,11 +37,21 @@ def test_parse_config_arg_without_config(self) -> None: def test_pop_config_arg(self) -> None: """Test pop_config_arg removes --config and its value from sys.argv.""" - test_argv = ["script.py", "--config", "test_config.yaml", "other", "args"] + test_argv = ["export_llm.py", "--config", "test_config.yaml", "other", "args"] with patch.object(sys, "argv", test_argv): config_path = pop_config_arg() self.assertEqual(config_path, "test_config.yaml") - self.assertEqual(sys.argv, ["script.py", "other", "args"]) + self.assertEqual(sys.argv, ["export_llm.py", "other", "args"]) + + def test_with_cli_args(self) -> None: + """Test main function with only hydra CLI args.""" + test_argv = ["export_llm.py", "debug.verbose=True"] + with patch.object(sys, "argv", test_argv): + with patch( + "executorch.extension.llm.export.export_llm.hydra_main" + ) as mock_hydra: + main() + mock_hydra.assert_called_once() @patch("executorch.extension.llm.export.export_llm.export_llama") def test_with_config(self, mock_export_llama: MagicMock) -> None: @@ -57,7 +67,7 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: model: dtype_override: fp16 export: - max_seq_length: 256 + max_seq_length: 128 quantization: pt2e_quantize: xnnpack_dynamic use_spin_quant: cuda @@ -70,7 +80,7 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: config_file = f.name try: - test_argv = ["script.py", "--config", config_file] + test_argv = ["export_llm.py", "--config", config_file] with patch.object(sys, "argv", test_argv): main() @@ -78,75 +88,65 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: mock_export_llama.assert_called_once() called_config = mock_export_llama.call_args[0][0] self.assertEqual( - called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json" - ) - self.assertEqual(called_config["base"]["model_class"], "llama2") - self.assertEqual(called_config["base"]["preq_mode"].value, "8da4w") - self.assertEqual(called_config["model"]["dtype_override"].value, "fp16") - self.assertEqual(called_config["export"]["max_seq_length"], 256) - self.assertEqual( - called_config["quantization"]["pt2e_quantize"].value, "xnnpack_dynamic" - ) - self.assertEqual( - called_config["quantization"]["use_spin_quant"].value, "cuda" + called_config.base.tokenizer_path, "/path/to/tokenizer.json" ) + self.assertEqual(called_config.base.model_class, "llama2") + self.assertEqual(called_config.base.preq_mode.value, "8da4w") + self.assertEqual(called_config.model.dtype_override.value, "fp16") + self.assertEqual(called_config.export.max_seq_length, 128) self.assertEqual( - called_config["backend"]["coreml"]["quantize"].value, "c4w" + called_config.quantization.pt2e_quantize.value, "xnnpack_dynamic" ) + self.assertEqual(called_config.quantization.use_spin_quant.value, "cuda") + self.assertEqual(called_config.backend.coreml.quantize.value, "c4w") self.assertEqual( - called_config["backend"]["coreml"]["compute_units"].value, "cpu_and_gpu" + called_config.backend.coreml.compute_units.value, "cpu_and_gpu" ) finally: os.unlink(config_file) - def test_with_cli_args(self) -> None: - """Test main function with only hydra CLI args.""" - test_argv = ["script.py", "debug.verbose=True"] - with patch.object(sys, "argv", test_argv): - with patch( - "executorch.extension.llm.export.export_llm.hydra_main" - ) as mock_hydra: - main() - mock_hydra.assert_called_once() - - def test_config_with_cli_args_error(self) -> None: - """Test that --config rejects additional CLI arguments to prevent mixing approaches.""" + @patch("executorch.extension.llm.export.export_llm.export_llama") + def test_with_config_and_cli(self, mock_export_llama: MagicMock) -> None: + """Test main function with --config file and no hydra args.""" # Create a temporary config file with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write("base:\n checkpoint: /path/to/checkpoint.pth") - config_file = f.name - - try: - test_argv = ["script.py", "--config", config_file, "debug.verbose=True"] - with patch.object(sys, "argv", test_argv): - with self.assertRaises(ValueError) as cm: - main() - - error_msg = str(cm.exception) - self.assertIn( - "Cannot specify additional CLI arguments when using --config", - error_msg, - ) - finally: - os.unlink(config_file) - - def test_config_rejects_multiple_cli_args(self) -> None: - """Test that --config rejects multiple CLI arguments (not just single ones).""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write("export:\n max_seq_length: 128") + f.write( + """ +base: + model_class: llama2 +model: + dtype_override: fp16 +backend: + xnnpack: + enabled: False +""" + ) config_file = f.name try: test_argv = [ - "script.py", + "export_llm.py", "--config", config_file, - "debug.verbose=True", - "export.output_dir=/tmp", + "base.model_class=stories110m", + "backend.xnnpack.enabled=True", ] with patch.object(sys, "argv", test_argv): - with self.assertRaises(ValueError): - main() + main() + + # Verify export_llama was called with config + mock_export_llama.assert_called_once() + called_config = mock_export_llama.call_args[0][0] + self.assertEqual( + called_config.base.model_class, "stories110m" + ) # Override from CLI. + self.assertEqual( + called_config.model.dtype_override.value, "fp16" + ) # From yaml. + self.assertEqual( + called_config.backend.xnnpack.enabled, + True, # Override from CLI. + ) finally: os.unlink(config_file) diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 948da50fdd4..a376a89747b 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -345,20 +346,22 @@ inline void apply_elementwise_fn( } constexpr auto compute_type = CppTypeToScalarType::value; - const bool all_inputs_compute_dtype = - ((inputs.first->scalar_type() == compute_type) && ...); - - constexpr ScalarType out_specialized_scalar_type = - specialized_output_scalar_type(out_dtypes); - if (all_inputs_compute_dtype && - out.scalar_type() == out_specialized_scalar_type) { - using CTYPE_OUT = - typename ScalarTypeToCppType::type; - dtype_specialized_elementwise_fn_impl< - CTYPE_COMPUTE, - CTYPE_OUT, - support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); - return; + if constexpr (should_include_kernel_dtype(op_name, compute_type)) { + const bool all_inputs_compute_dtype = + ((inputs.first->scalar_type() == compute_type) && ...); + + constexpr ScalarType out_specialized_scalar_type = + specialized_output_scalar_type(out_dtypes); + if (all_inputs_compute_dtype && + out.scalar_type() == out_specialized_scalar_type) { + using CTYPE_OUT = + typename ScalarTypeToCppType::type; + dtype_specialized_elementwise_fn_impl< + CTYPE_COMPUTE, + CTYPE_OUT, + support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); + return; + } } apply_elementwise_fn_generic_impl< diff --git a/kernels/portable/cpu/util/normalization_ops_util.cpp b/kernels/portable/cpu/util/normalization_ops_util.cpp index f7118257898..4adcf02b303 100644 --- a/kernels/portable/cpu/util/normalization_ops_util.cpp +++ b/kernels/portable/cpu/util/normalization_ops_util.cpp @@ -38,7 +38,7 @@ bool check_batch_norm_args( ET_LOG_AND_RETURN_IF_FALSE( tensors_have_same_dtype(in, running_mean.value())); } - if (running_mean.has_value()) { + if (running_var.has_value()) { ET_LOG_AND_RETURN_IF_FALSE( tensors_have_same_dtype(in, running_var.value())); } diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 1523fcfe706..44b95aa55c4 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -115,10 +115,10 @@ def define_common_targets(): ":vectorized_math", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", "//executorch/runtime/kernel:kernel_runtime_context", + "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/extension/threadpool:threadpool", ], deps = [ - "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/runtime/kernel:kernel_includes", ], visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"], diff --git a/pytest.ini b/pytest.ini index 557a307bdf2..e0f8eafb082 100644 --- a/pytest.ini +++ b/pytest.ini @@ -18,6 +18,7 @@ addopts = --ignore=devtools/visualization/visualization_utils_test.py # examples examples/models/llama/tests + examples/models/llama/config examples/models/llama3_2_vision/preprocess examples/models/llama3_2_vision/vision_encoder/test examples/models/llama3_2_vision/text_decoder/test From 08f8e51e53386ccaee7a25b248c4df85c9f17fe8 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 26 Jun 2025 10:30:35 -0700 Subject: [PATCH 6/6] Update [ghstack-poisoned] --- kernels/portable/cpu/util/dtype_util.cpp | 2 -- kernels/portable/cpu/util/dtype_util.h | 30 +------------------ .../UnaryUfuncRealHBBF16ToFloatHBF16Test.h | 20 ++++++++----- kernels/test/op_mul_test.cpp | 15 ---------- 4 files changed, 13 insertions(+), 54 deletions(-) diff --git a/kernels/portable/cpu/util/dtype_util.cpp b/kernels/portable/cpu/util/dtype_util.cpp index 525199a6f78..d240b9f83bc 100644 --- a/kernels/portable/cpu/util/dtype_util.cpp +++ b/kernels/portable/cpu/util/dtype_util.cpp @@ -27,8 +27,6 @@ bool check_tensor_dtype( return executorch::runtime::tensor_is_floating_type(t); case SupportedTensorDtypes::INTB: return executorch::runtime::tensor_is_integral_type(t, true); - case SupportedTensorDtypes::BOOL: - return executorch::runtime::tensor_is_type(t, ScalarType::Bool); case SupportedTensorDtypes::BOOL_OR_BYTE: return (executorch::runtime::tensor_is_type( t, ScalarType::Bool, ScalarType::Byte)); diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index 15732219c8f..1e7901c80b2 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -72,16 +72,6 @@ load_to_compute_fn get_load_to_compute_fn_intb(const Tensor& t) { return result; } -template -load_to_compute_fn get_load_to_compute_fn_bool(const Tensor& t) { - ET_CHECK_MSG( - t.scalar_type() == ScalarType::Bool, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(t.scalar_type()), - op_name); - return internal::load_and_convert; -} - template load_to_compute_fn get_load_to_compute_fn_bool_or_byte( const Tensor& t) { @@ -175,17 +165,6 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn_intb( return result; } -template -store_compute_to_tensor_fn get_store_compute_to_tensor_fn_bool( - const Tensor& t) { - ET_CHECK_MSG( - t.scalar_type() == ScalarType::Bool, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(t.scalar_type()), - op_name); - return internal::convert_and_store; -} - template store_compute_to_tensor_fn get_store_compute_to_tensor_fn_bool_or_byte(const Tensor& t) { @@ -240,7 +219,6 @@ enum class SupportedTensorDtypes { REALHBF16, FLOATHBF16, INTB, - BOOL, BOOL_OR_BYTE, // DEPRECATED: not likely to be correct; use SAME_AS_COMMON. SAME_AS_COMPUTE, @@ -262,8 +240,6 @@ load_to_compute_fn get_load_to_compute_fn_impl( return get_load_to_compute_fn_realhbf16(t); case SupportedTensorDtypes::INTB: return get_load_to_compute_fn_intb(t); - case SupportedTensorDtypes::BOOL: - return get_load_to_compute_fn_bool(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_load_to_compute_fn_bool_or_byte(t); case SupportedTensorDtypes::SAME_AS_COMPUTE: @@ -295,8 +271,6 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn( t); case SupportedTensorDtypes::INTB: return get_store_compute_to_tensor_fn_intb(t); - case SupportedTensorDtypes::BOOL: - return get_store_compute_to_tensor_fn_bool(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_store_compute_to_tensor_fn_bool_or_byte< CTYPE_COMPUTE, @@ -344,14 +318,12 @@ bool check_tensor_dtype( const ScalarType compute_type); /// Return the one output type we are willing to emit specialized code -/// to handle, given a compute type of CTYPE_COMPUTE and supported +/// to handle, given a compute type of CTYPE_COMMON and supported /// output types of out_dtypes. template inline constexpr ScalarType specialized_output_scalar_type( SupportedTensorDtypes out_dtypes) { switch (out_dtypes) { - case SupportedTensorDtypes::BOOL: - return ScalarType::Bool; case SupportedTensorDtypes::BOOL_OR_BYTE: return ScalarType::Bool; case SupportedTensorDtypes::REALHBBF16: diff --git a/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h b/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h index d1e812ec2c2..6e49dd9e57b 100644 --- a/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h +++ b/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h @@ -72,16 +72,20 @@ class UnaryUfuncRealHBBF16ToFloatHBF16Test : public OperatorTest { auto expected = tf_out.make({1, 6}, expected_vector); if (IN_DTYPE == ScalarType::BFloat16 || OUT_DTYPE == ScalarType::BFloat16) { - // Raise tolerance because both we and ATen run these - // computations at internal float32 precision rather than - // float64. - double rtol = 3e-3; + double rtol = executorch::runtime::testing::internal::kDefaultRtol; + // It appears we need a higher tolerance for at least some ATen + // tests, like aten_op_acosh_test. + if (get_supported_features()->is_aten) { + rtol = 3e-3; + } EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultBFloat16Atol); } else if (IN_DTYPE == ScalarType::Half || OUT_DTYPE == ScalarType::Half) { - // Raise tolerance because both we and ATen run these - // computations at internal float32 precision rather than - // float64. - double rtol = 1e-3; + double rtol = executorch::runtime::testing::internal::kDefaultRtol; + // It appears we need a higher tolerance for at least some ATen + // tests, like aten_op_acosh_test. + if (get_supported_features()->is_aten) { + rtol = 1e-3; + } EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultHalfAtol); } else { EXPECT_TENSOR_CLOSE(out, expected); diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index 28baa0cbd16..2d2f2872b99 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -746,21 +746,6 @@ TEST_F(OpMulOutTest, DynamicShapeUnbound) { EXPECT_TENSOR_CLOSE(out, expected_result); } -// >>> torch.ops.aten.mul(torch.tensor([100], dtype=torch.int8), -// torch.tensor([100], dtype=torch.int8), out=torch.zeros([1], -// dtype=torch.long)) tensor([16]) -TEST_F(OpMulOutTest, MixedIntegerDtypeMatchesATen) { - TensorFactory tf_in; - TensorFactory tf_out; - - Tensor in = tf_in.make({1}, {100}); - Tensor out = tf_out.zeros({1}); - Tensor ret = op_mul_out(in, in, out); - - Tensor expected = tf_out.make({1}, {16}); - EXPECT_TENSOR_CLOSE(out, expected); -} - TEST_F(OpMulScalarOutTest, SanityCheck) { TensorFactory tf_a; TensorFactory tf_out;