diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0394bf7f320..d497349222c 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -137,6 +137,29 @@ } +def _get_additional_export_passes( + model_class: str, +) -> List[InitializedMutableBufferPass]: + patterns = [] + + if model_class in TORCHTUNE_DEFINED_MODELS: + patterns.append("kv_cache_pos") + + # Qwen3.5 uses internal mutable buffers for both the hybrid KV path and + # DeltaNet recurrent/conv states. + if model_class.startswith("qwen3_5"): + patterns.extend( + [ + "k_cache", + "v_cache", + "conv_state", + "recurrent_state", + ] + ) + + return [InitializedMutableBufferPass(patterns)] if patterns else [] + + def set_pkg_name(name: str) -> None: global pkg_name pkg_name = name @@ -1285,9 +1308,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager: "Each method requires separate model instantiation and export." ) - additional_passes = [] - if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: - additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] + additional_passes = _get_additional_export_passes(llm_config.base.model_class.value) # Build dict of exported programs method_to_program: Dict[str, ExportedProgram] = {} @@ -1358,9 +1379,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 llm_config ) - additional_passes = [] - if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: - additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] + additional_passes = _get_additional_export_passes(llm_config.base.model_class.value) # export_to_edge builder_manager = _prepare_for_llama_export(llm_config) diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index 130a55f658c..d581d31fda6 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -25,9 +25,11 @@ from executorch.examples.models.llama.export_llama_lib import ( _export_llama, + _get_additional_export_passes, build_args_parser, get_quantizer_and_quant_params, ) +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize UNWANTED_OPS = [ @@ -37,6 +39,24 @@ class ExportLlamaLibTest(unittest.TestCase): + def test_qwen3_5_mutable_buffer_passes(self): + passes = _get_additional_export_passes("qwen3_5_0_8b") + self.assertEqual(len(passes), 1) + self.assertIsInstance(passes[0], InitializedMutableBufferPass) + self.assertEqual( + passes[0].patterns, + ["k_cache", "v_cache", "conv_state", "recurrent_state"], + ) + + def test_torchtune_mutable_buffer_passes(self): + passes = _get_additional_export_passes("llama3_2_vision") + self.assertEqual(len(passes), 1) + self.assertIsInstance(passes[0], InitializedMutableBufferPass) + self.assertEqual(passes[0].patterns, ["kv_cache_pos"]) + + def test_llama3_has_no_extra_mutable_buffer_passes(self): + self.assertEqual(_get_additional_export_passes("llama3"), []) + def test_has_expected_ops_and_op_counts(self): """ Checks the presence of unwanted expensive ops.