From 555714b732f1e4f6e246befa8bdcf1e9e08903f1 Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Tue, 9 Sep 2025 13:05:22 +0800 Subject: [PATCH 1/2] Qualcomm AI Engine Direct - GLM1.5B --- backends/qualcomm/tests/test_qnn_delegate.py | 3 + examples/models/glm/__init__.py | 16 ++++ examples/models/glm/config/1_5b_config.json | 17 ++++ examples/models/glm/convert_weights.py | 79 +++++++++++++++++++ examples/models/llama/model_args.py | 1 + examples/qualcomm/oss_scripts/llama/README.md | 22 ++++-- .../qualcomm/oss_scripts/llama/__init__.py | 23 ++++++ .../oss_scripts/llama/decoder_constants.py | 1 + examples/qualcomm/oss_scripts/llama/llama.py | 12 ++- .../oss_scripts/llama/model/feed_forward.py | 51 ++++++++++++ .../oss_scripts/llama/model/static_llama.py | 1 - .../oss_scripts/llama/qnn_llama_runner.cpp | 9 +++ .../oss_scripts/llama/runner/runner.cpp | 4 + .../oss_scripts/llama/runner/runner.h | 1 + .../llama/static_llm_quant_recipe.py | 35 ++++++++ 15 files changed, 266 insertions(+), 9 deletions(-) create mode 100644 examples/models/glm/__init__.py create mode 100644 examples/models/glm/config/1_5b_config.json create mode 100644 examples/models/glm/convert_weights.py diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index a9403f98b17..c878edd53c9 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -5862,6 +5862,9 @@ def setUp(self): "gemma3-1b": TestExampleLLMScript.LlmSpecs( SM8650=70, SM8750=100, ppl=23, pte_size=1_200_000_000 ), # 1.2 GB + "glm-1_5b": TestExampleLLMScript.LlmSpecs( + SM8650=42, SM8750=52, ppl=21, pte_size=1_100_000_000 + ), # 1.1 GB "phi_4_mini": TestExampleLLMScript.LlmSpecs( SM8650=14, SM8750=19, ppl=12, pte_size=4_000_000_000 ), # 4GB diff --git a/examples/models/glm/__init__.py b/examples/models/glm/__init__.py new file mode 100644 index 00000000000..aef380e7f6b --- /dev/null +++ b/examples/models/glm/__init__.py @@ -0,0 +1,16 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.examples.models.glm.convert_weights import convert_weights +from executorch.examples.models.llama.model import Llama2Model + + +class GLMModel(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +__all__ = [ + "GLMModel", + "convert_weights", +] diff --git a/examples/models/glm/config/1_5b_config.json b/examples/models/glm/config/1_5b_config.json new file mode 100644 index 00000000000..23576622255 --- /dev/null +++ b/examples/models/glm/config/1_5b_config.json @@ -0,0 +1,17 @@ +{ + "dim": 2048, + "ffn_dim_multiplier": 1, + "hidden_dim": 6144, + "n_heads": 16, + "head_dim": 128, + "n_kv_heads": 4, + "n_layers": 28, + "norm_eps": 1e-05, + "rope_theta": 10000.0, + "use_scaled_rope": false, + "vocab_size": 59264, + "use_hf_rope": true, + "attention_qkv_bias": false, + "use_qk_norm": false, + "model_architecture" : "GlmForCausalLM" +} diff --git a/examples/models/glm/convert_weights.py b/examples/models/glm/convert_weights.py new file mode 100644 index 00000000000..0568c9dccec --- /dev/null +++ b/examples/models/glm/convert_weights.py @@ -0,0 +1,79 @@ +import argparse +import os +from typing import Dict + +import torch +from safetensors.torch import load_file +from torchtune.models.convert_weights import get_mapped_key + +# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings. +_GLM_FROM_META = { + "tok_embeddings.weight": "model.embed_tokens.weight", + "norm.weight": "model.norm.weight", + "output.weight": "lm_head.weight", + "layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight", + "layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight", + "layers.{}.attention.wv.weight": "model.layers.{}.self_attn.v_proj.weight", + "layers.{}.attention.wo.weight": "model.layers.{}.self_attn.o_proj.weight", + "layers.{}.attention_norm.weight": "model.layers.{}.input_layernorm.weight", + "layers.{}.ffn_norm.weight": "model.layers.{}.post_attention_layernorm.weight", + "layers.{}.feed_forward.gate_up_proj.weight": "model.layers.{}.mlp.gate_up_proj.weight", + "layers.{}.feed_forward.down_proj.weight": "model.layers.{}.mlp.down_proj.weight", +} + + +def glm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from torchtune's format to Meta's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _GLM_FROM_META.items()} + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + converted_state_dict[new_key] = value + + if "lm_head.weight" not in state_dict: + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + + return converted_state_dict + + +def convert_weights(input_dir: str, output_file: str) -> None: + pt_path = os.path.join(input_dir, "model.safetensors") + print("Loading checkpoint from file...") + sd = load_file(pt_path) + + print("Converting checkpoint...") + sd = glm_tune_to_meta(sd) + + print("Saving checkpoint...") + torch.save(sd, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser(description="Convert GLM weights to Meta format.") + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing checkpoint files", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 9f45960c7a9..71d4d161958 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -131,6 +131,7 @@ class ModelArgs: attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) # Hybrid models can have layer types different from attention layer_types: Optional[list] = None + model_architecture: Optional[str] = None def __post_init__(self): if self.n_kv_heads is None: diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index 9d97775265f..7a08cbfd881 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -9,6 +9,7 @@ This file provides you the instructions to run LLM Decoder model with different 1. Codegen2 1B 1. Gemma 2B 1. Gemma3 1B + 1. GLM 1.5B 1. Granite3.3 2B 1. Phi4-mini-instruct 1. QWEN2.5 0.5B / 1.5B @@ -65,7 +66,10 @@ Follow the [instructions](https://www.llama.com/) to download models. At the end of this step, users should have the following files ready: `consolidated.00.pth`, `params.json`, and `tokenizer.model`. -### Step3: Run default examples using hybrid mode for smaller models and kv mode for larger models. +### Step3: Run default examples. +#### Note: +All example scripts below use hybrid mode, which is optimized for on-device performance. However, compiling a model in hybrid mode can consume a significant amount of memory on the host machine—sometimes up to ~100 GB. If your host machine has limited memory, it is highly recommended to switch from `--model_mode hybrid` to `--model_mode kv` and remove the `--prefill_ar_len` flag. + #### LLAMA2 ```bash python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --decoder_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time" @@ -80,7 +84,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL #### LLAMA3.2 3B Instruct Default example using kv mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` #### Codegen2 @@ -102,6 +106,12 @@ Default example using hybrid mode python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma3-1b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` +#### GLM 1.5B +Default example using hybrid mode +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model glm-1_5b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +``` + #### Granite3.3 2B Default example using hybrid mode ```bash @@ -111,7 +121,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL #### Phi4-mini-instruct Default example using kv mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model phi_4_mini --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model phi_4_mini --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` #### QWEN2.5 0.5B @@ -123,7 +133,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL #### QWEN2.5 1.5B Default example using kv mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-1_5b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` #### QWEN3 0.6B @@ -135,7 +145,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL #### QWEN3 1.7B Default example using hybrid mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` #### SmolLM2 @@ -147,7 +157,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL #### SmolLM3 Default example using kv mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` ### KV Cache update mechanism diff --git a/examples/qualcomm/oss_scripts/llama/__init__.py b/examples/qualcomm/oss_scripts/llama/__init__.py index fb6a5a3a3b0..4e7c4b9be46 100644 --- a/examples/qualcomm/oss_scripts/llama/__init__.py +++ b/examples/qualcomm/oss_scripts/llama/__init__.py @@ -16,6 +16,8 @@ ) from executorch.examples.models.gemma import convert_weights as convert_gemma_weights from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights + +from executorch.examples.models.glm import convert_weights as convert_glm_weights from executorch.examples.models.granite import ( convert_weights as convert_granite_weights, ) @@ -44,6 +46,7 @@ CodegenQuantRecipe, Gemma3QuantRecipe, Gemma_2BQuantRecipe, + GLM_1_5B_InstructQuantRecipe, Granite_3_3_2B_InstructQuantRecipe, Llama3_1BQuantRecipe, Llama3_3BQuantRecipe, @@ -293,6 +296,26 @@ class Gemma3(LLMModelConfig): quant_recipe = Gemma3QuantRecipe +@register_llm_model("glm-1_5b") +@dataclass(init=False, frozen=True) +class GLM_1_5B(LLMModelConfig): + repo_id: str = "THUDM/glm-edge-1.5b-chat" + params_path: str = os.path.join( + BASE_DIR, "../../../models/glm/config/1_5b_config.json" + ) + convert_weights = convert_glm_weights + transform_weight = True + instruct_model = True + num_sharding = 1 + group_size = 32 + masked_softmax = False + seq_mse_candidates = 0 + r1 = False + r2 = False + r3 = False + quant_recipe = GLM_1_5B_InstructQuantRecipe + + @register_llm_model("granite_3_3-2b_instruct") @dataclass(init=False, frozen=True) class Granite_3_3_2b_Instruct(LLMModelConfig): diff --git a/examples/qualcomm/oss_scripts/llama/decoder_constants.py b/examples/qualcomm/oss_scripts/llama/decoder_constants.py index a228266d106..f6f0dc3067f 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_constants.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_constants.py @@ -27,4 +27,5 @@ "smollm2_135m": "smollm2_135m", "smollm3-3b": "smollm3", "codegen2_1b": "codegen", + "glm-1_5b": "glm", } diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 29212c7855b..0847f93d98f 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -1309,15 +1309,23 @@ def export_llama(args) -> None: # For Gemma, use tokenizer.model as it doesn't provide pre_tokenizer in tokenizer.json. runtime_tokenizer_path = tokenizer_artifacts[-3] else: + if args.decoder_model == "glm-1_5b": + with open(tokenizer_config, "r+") as file: + data = json.load(file) + # Verified with HF flow and it uses <|user|> as eos condition + data["bos_token"] = "<|user|>" + data["eos_token"] = "<|user|>" + file.seek(0) + json.dump(data, file, indent=4) + file.truncate() runtime_tokenizer_path = tokenizer_artifacts[-1] + tokenizer = get_tokenizer(runtime_tokenizer_path, tokenizer_config) if args.decoder_model == "codegen2_1b": # Override the default BOS and EOS token IDs for codegen2_1b tokenizer.bos_id = 1 tokenizer.eos_id = 2 - - # TODO: Remove this once error is resolved. elif args.decoder_model == "phi_4_mini": with open(runtime_tokenizer_path, "r+") as file: data = json.load(file) diff --git a/examples/qualcomm/oss_scripts/llama/model/feed_forward.py b/examples/qualcomm/oss_scripts/llama/model/feed_forward.py index 062123b52cc..2f36779cc71 100644 --- a/examples/qualcomm/oss_scripts/llama/model/feed_forward.py +++ b/examples/qualcomm/oss_scripts/llama/model/feed_forward.py @@ -88,3 +88,54 @@ def forward(self, x): hidden_states = self.act(hidden_states) hidden_states = self.fc_out(hidden_states) return hidden_states + + +@register_feed_forward("GlmForCausalLM") +class GLMFeedForward(FeedForwardBase): + """FeedForward with gate_up_proj and down_proj""" + + def __init__(self, args: ModelArgs): # in MLP: intermediate_size= 4 * embed_dim + super().__init__() + + assert args.hidden_dim is not None + self.dim = args.dim + self.hidden_dim = args.hidden_dim + + self.gate_up_proj = torch.nn.Linear(args.dim, 2 * args.hidden_dim, bias=False) + self.down_proj = torch.nn.Linear(args.hidden_dim, args.dim, bias=False) + self.activation_fn = args.act_fn.get_function() + + def prepare_feedfoward_conv(self): + self.gate_up_proj_conv = torch.nn.Conv2d( + self.dim, 2 * self.hidden_dim, 1, bias=False + ) + self.down_proj_conv = torch.nn.Conv2d(self.hidden_dim, self.dim, 1, bias=False) + + self.forward_no_conv = self.forward + self.forward = self.forward_feedfoward_conv + + self.gate_up_proj_conv.weight.data.copy_( + self.gate_up_proj.weight[:, :, None, None] + ) + self.down_proj_conv.weight.data.copy_(self.down_proj.weight[:, :, None, None]) + + del self.gate_up_proj + del self.down_proj + + def forward_feedfoward_conv(self, x): + bsz, _, _ = x.size() + x = torch.reshape(x, (bsz, -1, 1, self.dim)) + x = x.transpose(1, 3) # Transpose right before and after Conv + up_states = self.gate_up_proj_conv(x) + gate, up_states = up_states.chunk(2, dim=1) + up_states = up_states * self.activation_fn(gate) + x = self.down_proj_conv(up_states) + x = x.transpose(1, 3) + x = torch.reshape(x, (bsz, -1, self.dim)) + return x + + def forward(self, x): + up_states = self.gate_up_proj(x) + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + return self.down_proj(up_states) diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index ce5691d0c34..65cf71e0480 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -353,7 +353,6 @@ def prepare_feedfoward_conv(self): self.forward_no_conv = self.forward self.forward = self.forward_feedfoward_conv - self.w1_conv.weight.data.copy_(self.w1.weight[:, :, None, None]) self.w2_conv.weight.data.copy_(self.w2.weight[:, :, None, None]) self.w3_conv.weight.data.copy_(self.w3.weight[:, :, None, None]) diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 7d0172a70a5..af260242316 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -183,6 +183,15 @@ std::string get_formatted_prompt( formatted_prompt.append("<|im_end|>\n"); formatted_prompt.append("<|im_start|>assistant\n"); break; + case example::DecoderModelVersion::kGlm: + formatted_prompt.append("<|user|>\n"); + formatted_prompt.append(prompt); + if (!system_prompt.empty()) { + formatted_prompt.append("<|system|>\n"); + formatted_prompt.append(system_prompt); + } + formatted_prompt.append("<|assistant|>\n"); + break; default: ET_CHECK_MSG(false, "unsupported llama version"); break; diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index d2fdcf4281b..e021d5d512f 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -141,6 +141,8 @@ Runner::Runner( decoder_model_version_ = DecoderModelVersion::kSmollm3; } else if (decoder_model_version == "codegen") { decoder_model_version_ = DecoderModelVersion::kCodegen; + } else if (decoder_model_version == "glm") { + decoder_model_version_ = DecoderModelVersion::kGlm; } else { ET_CHECK_MSG(false, "Unsupported Decoder Model"); } @@ -211,6 +213,8 @@ Error Runner::load() { eos_ids->insert(tokenizer_->encode("", 0, 0).get()[0]); } else if (decoder_model_version_ == DecoderModelVersion::kCodegen) { eos_ids->insert(tokenizer_->encode("<|endoftext|>", 0, 0).get()[0]); + } else if (decoder_model_version_ == DecoderModelVersion::kGlm) { + eos_ids->insert(tokenizer_->encode("<|user|>", 0, 0).get()[0]); } // Try avoid getMetadataHelper as it is time consuming. diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 770aa3eb3a3..c436d40f20c 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -41,6 +41,7 @@ enum DecoderModelVersion { kSmollm2_135m, kSmollm3, kCodegen, + kGlm, }; enum KvBitWidth { diff --git a/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py b/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py index fc2827cd895..1736a44e642 100644 --- a/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py +++ b/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py @@ -307,6 +307,41 @@ def __init__(self, verbose: bool = False): self.recipe.custom_quant_annotations.append(annotate_kv_8bit) +class GLM_1_5B_InstructQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 32, 1, 1)}, + ) + .add_regex( + {r"output\.conv"}, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + class Granite_3_3_2B_InstructQuantRecipe(StaticLLMQuantRecipe): default_quant_dtype = QuantDtype.use_16a4w From 0ba5eb0bb18dfe08ab94f561c0ea415668940191 Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Mon, 24 Nov 2025 11:14:05 +0800 Subject: [PATCH 2/2] Code Review --- examples/models/llama/model_args.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 71d4d161958..a0e9eb70498 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -131,7 +131,9 @@ class ModelArgs: attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) # Hybrid models can have layer types different from attention layer_types: Optional[list] = None - model_architecture: Optional[str] = None + model_architecture: Optional[str] = ( + None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now. + ) def __post_init__(self): if self.n_kv_heads is None: