Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,29 @@
}


def _get_additional_export_passes(
model_class: str,
) -> List[InitializedMutableBufferPass]:
patterns = []

if model_class in TORCHTUNE_DEFINED_MODELS:
patterns.append("kv_cache_pos")

# Qwen3.5 uses internal mutable buffers for both the hybrid KV path and
# DeltaNet recurrent/conv states.
if model_class.startswith("qwen3_5"):
patterns.extend(
[
"k_cache",
"v_cache",
Comment on lines +150 to +154
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initializing KV cache buffers via InitializedMutableBufferPass will cause their full tensor contents to be serialized into the .pte (the emitter treats et_init_buffer+mutable_buffer as const). For k_cache/v_cache this can be extremely large (per-layer [B, H, S, D]) and may blow up export size and load time. Consider avoiding initializing the full KV caches at export (e.g., only init the small state buffers like conv_state/recurrent_state, or add a runtime/cache-reset path that deterministically zeros these buffers without serializing them).

Suggested change
if model_class.startswith("qwen3_5"):
patterns.extend(
[
"k_cache",
"v_cache",
# Avoid initializing large KV cache buffers (k_cache/v_cache) here, since
# InitializedMutableBufferPass would serialize their full contents into
# the exported artifact, significantly increasing size and load time.
if model_class.startswith("qwen3_5"):
patterns.extend(
[

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. @Phineas1500 does qwen3.5 require initial state for the kv-cache, conv_state and recurrent_state?

The InitializedMutableBufferPass is only required for mutable buffers with initial state.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like ~5mb size increase from including initial state. Not too sure why - was expecting a bit more.

-rw-r--r-- 1 lfq users 4032780800 Mar  6 14:29 qwen3_5_0_8b_fp32_no_init.pte
-rw-r--r-- 1 lfq users 4038122240 Mar  5 11:05 qwen3_5_0_8b_fp32.pte

Output is the same with temp=0

(executorch) [lfq@devvm311.ldc0 /data/users/lfq/executorch (qwen3_5_phase2)]$ python -m executorch.examples
.models.llama.runner.native --model qwen3_5_0_8b --pte qwen3_5_0_8b_fp32_no_init.pte --tokenizer ~/.cache/h
uggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17/tokenizer.json
 --tokenizer_config ~/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9c
f38778875588b17/tokenizer_config.json --params examples/models/qwen3_5/config/0_8b_config.json --prompt "<|
im_start|>user\nHello, what's 15% of 80?<|im_end|>\n<|im_start|>assistant\n" --max_len 128 -kv --temperatur
e 0
I tokenizers:regex.cpp:27] Registering override fallback regex
Warning - given vocab_size in params is unequal to tokenizer vocab size.
[cpuinfo_utils.cpp:71] Reading file /sys/devices/soc0/image_version
[cpuinfo_utils.cpp:87] Failed to open midr file /sys/devices/soc0/image_version
[cpuinfo_utils.cpp:100] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1
[cpuinfo_utils.cpp:109] Failed to open midr file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1
[cpuinfo_utils.cpp:125] CPU info and manual query on # of cpus dont match.
<think>

</think>

To find 15% of 80, you can multiply 80 by 0.15:

$$80 \times 0.15 = 12$$

So, **15% of 80 is 12**.

Prefill time: 15.471495151519775
Generation tok/s: 2.097784149345492
Response: [248068, 271, 248069, 271, 1206, 1423, 220, 16, 20, 4, 314, 220, 23, 15, 11, 488, 628, 29283, 220, 23, 15, 539, 220, 15, 13, 16, 20, 25, 271, 13682, 23, 15, 1088, 14695, 220, 15, 13, 16, 20, 283, 220, 16, 17, 13682, 271, 4272, 11, 2972, 16, 20, 4, 314, 220, 23, 15, 369, 220, 16, 17, 159034, 248046]

Seems like the state is already zeroed here?
https://github.com/pytorch/executorch/blob/main/examples/models/llama/attention.py#L720

Comment on lines +153 to +154
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InitializedMutableBufferPass matches patterns by substring. Using "k_cache"/"v_cache" here will also match other buffer names like "k_cache_scales", "k_cache_zero_points", or "past_k_caches_*" if present in the exported graph, potentially initializing/serializing more (large) buffers than intended. If you only mean the primary caches, consider narrowing the patterns to something less collision-prone (e.g., include a delimiter or full buffer name) or splitting by known FQN fragments.

Suggested change
"k_cache",
"v_cache",
".k_cache",
".v_cache",

Copilot uses AI. Check for mistakes.
"conv_state",
"recurrent_state",
]
)

return [InitializedMutableBufferPass(patterns)] if patterns else []


def set_pkg_name(name: str) -> None:
global pkg_name
pkg_name = name
Expand Down Expand Up @@ -1285,9 +1308,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
"Each method requires separate model instantiation and export."
)

additional_passes = []
if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
additional_passes = _get_additional_export_passes(llm_config.base.model_class.value)

# Build dict of exported programs
method_to_program: Dict[str, ExportedProgram] = {}
Expand Down Expand Up @@ -1358,9 +1379,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
llm_config
)

additional_passes = []
if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
additional_passes = _get_additional_export_passes(llm_config.base.model_class.value)

# export_to_edge
builder_manager = _prepare_for_llama_export(llm_config)
Expand Down
20 changes: 20 additions & 0 deletions examples/models/llama/tests/test_export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@

from executorch.examples.models.llama.export_llama_lib import (
_export_llama,
_get_additional_export_passes,
build_args_parser,
get_quantizer_and_quant_params,
)
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize

UNWANTED_OPS = [
Expand All @@ -37,6 +39,24 @@


class ExportLlamaLibTest(unittest.TestCase):
def test_qwen3_5_mutable_buffer_passes(self):
passes = _get_additional_export_passes("qwen3_5_0_8b")
self.assertEqual(len(passes), 1)
self.assertIsInstance(passes[0], InitializedMutableBufferPass)
self.assertEqual(
passes[0].patterns,
["k_cache", "v_cache", "conv_state", "recurrent_state"],
)

def test_torchtune_mutable_buffer_passes(self):
passes = _get_additional_export_passes("llama3_2_vision")
self.assertEqual(len(passes), 1)
self.assertIsInstance(passes[0], InitializedMutableBufferPass)
self.assertEqual(passes[0].patterns, ["kv_cache_pos"])

def test_llama3_has_no_extra_mutable_buffer_passes(self):
self.assertEqual(_get_additional_export_passes("llama3"), [])

def test_has_expected_ops_and_op_counts(self):
"""
Checks the presence of unwanted expensive ops.
Expand Down
Loading