Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5862,6 +5862,9 @@ def setUp(self):
"gemma3-1b": TestExampleLLMScript.LlmSpecs(
SM8650=70, SM8750=100, ppl=23, pte_size=1_200_000_000
), # 1.2 GB
"glm-1_5b": TestExampleLLMScript.LlmSpecs(
SM8650=42, SM8750=52, ppl=21, pte_size=1_100_000_000
), # 1.1 GB
"phi_4_mini": TestExampleLLMScript.LlmSpecs(
SM8650=14, SM8750=19, ppl=12, pte_size=4_000_000_000
), # 4GB
Expand Down
16 changes: 16 additions & 0 deletions examples/models/glm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.examples.models.glm.convert_weights import convert_weights
from executorch.examples.models.llama.model import Llama2Model


class GLMModel(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)


__all__ = [
"GLMModel",
"convert_weights",
]
17 changes: 17 additions & 0 deletions examples/models/glm/config/1_5b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"dim": 2048,
"ffn_dim_multiplier": 1,
"hidden_dim": 6144,
"n_heads": 16,
"head_dim": 128,
"n_kv_heads": 4,
"n_layers": 28,
"norm_eps": 1e-05,
"rope_theta": 10000.0,
"use_scaled_rope": false,
"vocab_size": 59264,
"use_hf_rope": true,
"attention_qkv_bias": false,
"use_qk_norm": false,
"model_architecture" : "GlmForCausalLM"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any existing variable that can be used for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion.
I was actually thinking of reusing base_model_name_or_path. However, it seems like this variable is used in optimum for some other purpose, like referring to actual model path, so I created a new variable to prevent any conflict in future.
Another reason of creating this config is that as we are enabling more models, we noticed minor differences among models. For example, GLM FeedForward is different from other model's FeedForward. We need some variables to differentiate GLM and other LLM models.

}
79 changes: 79 additions & 0 deletions examples/models/glm/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse
import os
from typing import Dict

import torch
from safetensors.torch import load_file
from torchtune.models.convert_weights import get_mapped_key

# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
_GLM_FROM_META = {
"tok_embeddings.weight": "model.embed_tokens.weight",
"norm.weight": "model.norm.weight",
"output.weight": "lm_head.weight",
"layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight",
"layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight",
"layers.{}.attention.wv.weight": "model.layers.{}.self_attn.v_proj.weight",
"layers.{}.attention.wo.weight": "model.layers.{}.self_attn.o_proj.weight",
"layers.{}.attention_norm.weight": "model.layers.{}.input_layernorm.weight",
"layers.{}.ffn_norm.weight": "model.layers.{}.post_attention_layernorm.weight",
"layers.{}.feed_forward.gate_up_proj.weight": "model.layers.{}.mlp.gate_up_proj.weight",
"layers.{}.feed_forward.down_proj.weight": "model.layers.{}.mlp.down_proj.weight",
}


def glm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _GLM_FROM_META.items()}

for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value

if "lm_head.weight" not in state_dict:
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]

return converted_state_dict


def convert_weights(input_dir: str, output_file: str) -> None:
pt_path = os.path.join(input_dir, "model.safetensors")
print("Loading checkpoint from file...")
sd = load_file(pt_path)

print("Converting checkpoint...")
sd = glm_tune_to_meta(sd)

print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")


def main():
parser = argparse.ArgumentParser(description="Convert GLM weights to Meta format.")
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()
convert_weights(args.input_dir, args.output)


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class ModelArgs:
attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
# Hybrid models can have layer types different from attention
layer_types: Optional[list] = None
model_architecture: Optional[str] = (
None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now.
)

def __post_init__(self):
if self.n_kv_heads is None:
Expand Down
22 changes: 16 additions & 6 deletions examples/qualcomm/oss_scripts/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This file provides you the instructions to run LLM Decoder model with different
1. Codegen2 1B
1. Gemma 2B
1. Gemma3 1B
1. GLM 1.5B
1. Granite3.3 2B
1. Phi4-mini-instruct
1. QWEN2.5 0.5B / 1.5B
Expand Down Expand Up @@ -65,7 +66,10 @@ Follow the [instructions](https://www.llama.com/) to download models.
At the end of this step, users should have the following files ready: `consolidated.00.pth`, `params.json`, and `tokenizer.model`.


### Step3: Run default examples using hybrid mode for smaller models and kv mode for larger models.
### Step3: Run default examples.
#### Note:
All example scripts below use hybrid mode, which is optimized for on-device performance. However, compiling a model in hybrid mode can consume a significant amount of memory on the host machine—sometimes up to ~100 GB. If your host machine has limited memory, it is highly recommended to switch from `--model_mode hybrid` to `--model_mode kv` and remove the `--prefill_ar_len` flag.

#### LLAMA2
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --decoder_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time"
Expand All @@ -80,7 +84,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
#### LLAMA3.2 3B Instruct
Default example using kv mode.
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```

#### Codegen2
Expand All @@ -102,6 +106,12 @@ Default example using hybrid mode
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma3-1b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```

#### GLM 1.5B
Default example using hybrid mode
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model glm-1_5b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```

#### Granite3.3 2B
Default example using hybrid mode
```bash
Expand All @@ -111,7 +121,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
#### Phi4-mini-instruct
Default example using kv mode.
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model phi_4_mini --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model phi_4_mini --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```

#### QWEN2.5 0.5B
Expand All @@ -123,7 +133,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
#### QWEN2.5 1.5B
Default example using kv mode
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-1_5b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```

#### QWEN3 0.6B
Expand All @@ -135,7 +145,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
#### QWEN3 1.7B
Default example using hybrid mode
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```

#### SmolLM2
Expand All @@ -147,7 +157,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
#### SmolLM3
Default example using kv mode.
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```

### KV Cache update mechanism
Expand Down
23 changes: 23 additions & 0 deletions examples/qualcomm/oss_scripts/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from executorch.examples.models.gemma import convert_weights as convert_gemma_weights
from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights

from executorch.examples.models.glm import convert_weights as convert_glm_weights
from executorch.examples.models.granite import (
convert_weights as convert_granite_weights,
)
Expand Down Expand Up @@ -44,6 +46,7 @@
CodegenQuantRecipe,
Gemma3QuantRecipe,
Gemma_2BQuantRecipe,
GLM_1_5B_InstructQuantRecipe,
Granite_3_3_2B_InstructQuantRecipe,
Llama3_1BQuantRecipe,
Llama3_3BQuantRecipe,
Expand Down Expand Up @@ -293,6 +296,26 @@ class Gemma3(LLMModelConfig):
quant_recipe = Gemma3QuantRecipe


@register_llm_model("glm-1_5b")
@dataclass(init=False, frozen=True)
class GLM_1_5B(LLMModelConfig):
repo_id: str = "THUDM/glm-edge-1.5b-chat"
params_path: str = os.path.join(
BASE_DIR, "../../../models/glm/config/1_5b_config.json"
)
convert_weights = convert_glm_weights
transform_weight = True
instruct_model = True
num_sharding = 1
group_size = 32
masked_softmax = False
seq_mse_candidates = 0
r1 = False
r2 = False
r3 = False
quant_recipe = GLM_1_5B_InstructQuantRecipe


@register_llm_model("granite_3_3-2b_instruct")
@dataclass(init=False, frozen=True)
class Granite_3_3_2b_Instruct(LLMModelConfig):
Expand Down
1 change: 1 addition & 0 deletions examples/qualcomm/oss_scripts/llama/decoder_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@
"smollm2_135m": "smollm2_135m",
"smollm3-3b": "smollm3",
"codegen2_1b": "codegen",
"glm-1_5b": "glm",
}
12 changes: 10 additions & 2 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,15 +1309,23 @@ def export_llama(args) -> None:
# For Gemma, use tokenizer.model as it doesn't provide pre_tokenizer in tokenizer.json.
runtime_tokenizer_path = tokenizer_artifacts[-3]
else:
if args.decoder_model == "glm-1_5b":
with open(tokenizer_config, "r+") as file:
data = json.load(file)
# Verified with HF flow and it uses <|user|> as eos condition
data["bos_token"] = "<|user|>"
data["eos_token"] = "<|user|>"
file.seek(0)
json.dump(data, file, indent=4)
file.truncate()
runtime_tokenizer_path = tokenizer_artifacts[-1]

tokenizer = get_tokenizer(runtime_tokenizer_path, tokenizer_config)

if args.decoder_model == "codegen2_1b":
# Override the default BOS and EOS token IDs for codegen2_1b
tokenizer.bos_id = 1
tokenizer.eos_id = 2

# TODO: Remove this once error is resolved.
elif args.decoder_model == "phi_4_mini":
with open(runtime_tokenizer_path, "r+") as file:
data = json.load(file)
Expand Down
51 changes: 51 additions & 0 deletions examples/qualcomm/oss_scripts/llama/model/feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,54 @@ def forward(self, x):
hidden_states = self.act(hidden_states)
hidden_states = self.fc_out(hidden_states)
return hidden_states


@register_feed_forward("GlmForCausalLM")
class GLMFeedForward(FeedForwardBase):
"""FeedForward with gate_up_proj and down_proj"""

def __init__(self, args: ModelArgs): # in MLP: intermediate_size= 4 * embed_dim
super().__init__()

assert args.hidden_dim is not None
self.dim = args.dim
self.hidden_dim = args.hidden_dim

self.gate_up_proj = torch.nn.Linear(args.dim, 2 * args.hidden_dim, bias=False)
self.down_proj = torch.nn.Linear(args.hidden_dim, args.dim, bias=False)
self.activation_fn = args.act_fn.get_function()

def prepare_feedfoward_conv(self):
self.gate_up_proj_conv = torch.nn.Conv2d(
self.dim, 2 * self.hidden_dim, 1, bias=False
)
self.down_proj_conv = torch.nn.Conv2d(self.hidden_dim, self.dim, 1, bias=False)

self.forward_no_conv = self.forward
self.forward = self.forward_feedfoward_conv

self.gate_up_proj_conv.weight.data.copy_(
self.gate_up_proj.weight[:, :, None, None]
)
self.down_proj_conv.weight.data.copy_(self.down_proj.weight[:, :, None, None])

del self.gate_up_proj
del self.down_proj

def forward_feedfoward_conv(self, x):
bsz, _, _ = x.size()
x = torch.reshape(x, (bsz, -1, 1, self.dim))
x = x.transpose(1, 3) # Transpose right before and after Conv
up_states = self.gate_up_proj_conv(x)
gate, up_states = up_states.chunk(2, dim=1)
up_states = up_states * self.activation_fn(gate)
x = self.down_proj_conv(up_states)
x = x.transpose(1, 3)
x = torch.reshape(x, (bsz, -1, self.dim))
return x

def forward(self, x):
up_states = self.gate_up_proj(x)
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.activation_fn(gate)
return self.down_proj(up_states)
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,6 @@ def prepare_feedfoward_conv(self):

self.forward_no_conv = self.forward
self.forward = self.forward_feedfoward_conv

self.w1_conv.weight.data.copy_(self.w1.weight[:, :, None, None])
self.w2_conv.weight.data.copy_(self.w2.weight[:, :, None, None])
self.w3_conv.weight.data.copy_(self.w3.weight[:, :, None, None])
Expand Down
9 changes: 9 additions & 0 deletions examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,15 @@ std::string get_formatted_prompt(
formatted_prompt.append("<|im_end|>\n");
formatted_prompt.append("<|im_start|>assistant\n");
break;
case example::DecoderModelVersion::kGlm:
formatted_prompt.append("<|user|>\n");
formatted_prompt.append(prompt);
if (!system_prompt.empty()) {
formatted_prompt.append("<|system|>\n");
formatted_prompt.append(system_prompt);
}
formatted_prompt.append("<|assistant|>\n");
break;
default:
ET_CHECK_MSG(false, "unsupported llama version");
break;
Expand Down
4 changes: 4 additions & 0 deletions examples/qualcomm/oss_scripts/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ Runner<T>::Runner(
decoder_model_version_ = DecoderModelVersion::kSmollm3;
} else if (decoder_model_version == "codegen") {
decoder_model_version_ = DecoderModelVersion::kCodegen;
} else if (decoder_model_version == "glm") {
decoder_model_version_ = DecoderModelVersion::kGlm;
} else {
ET_CHECK_MSG(false, "Unsupported Decoder Model");
}
Expand Down Expand Up @@ -211,6 +213,8 @@ Error Runner<T>::load() {
eos_ids->insert(tokenizer_->encode("<end_of_turn>", 0, 0).get()[0]);
} else if (decoder_model_version_ == DecoderModelVersion::kCodegen) {
eos_ids->insert(tokenizer_->encode("<|endoftext|>", 0, 0).get()[0]);
} else if (decoder_model_version_ == DecoderModelVersion::kGlm) {
eos_ids->insert(tokenizer_->encode("<|user|>", 0, 0).get()[0]);
}

// Try avoid getMetadataHelper as it is time consuming.
Expand Down
1 change: 1 addition & 0 deletions examples/qualcomm/oss_scripts/llama/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ enum DecoderModelVersion {
kSmollm2_135m,
kSmollm3,
kCodegen,
kGlm,
};

enum KvBitWidth {
Expand Down
Loading
Loading