diff --git a/.ci/scripts/test.sh b/.ci/scripts/test.sh index af50faa1c47..e8a39f8d282 100755 --- a/.ci/scripts/test.sh +++ b/.ci/scripts/test.sh @@ -53,6 +53,12 @@ build_cmake_executor_runner() { } test_model() { + if [[ "${MODEL_NAME}" == "llama2" ]]; then + cd examples/third-party/llama + pip install -e . + cd ../../.. + fi + "${PYTHON_EXECUTABLE}" -m examples.export.export_example --model_name="${MODEL_NAME}" # Run test model diff --git a/examples/models/__init__.py b/examples/models/__init__.py index 48544bd94bf..f62bab8029e 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -16,6 +16,7 @@ "emformer_transcribe": ("emformer_rnnt", "EmformerRnntTranscriberModel"), "emformer_predict": ("emformer_rnnt", "EmformerRnntPredictorModel"), "emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"), + "llama2": ("llama2", "Llama2Model"), "mobilebert": ("mobilebert", "MobileBertModelExample"), "mv2": ("mobilenet_v2", "MV2Model"), "mv3": ("mobilenet_v3", "MV3Model"), diff --git a/examples/models/llama2/README.md b/examples/models/llama2/README.md new file mode 100644 index 00000000000..4af8ba4a336 --- /dev/null +++ b/examples/models/llama2/README.md @@ -0,0 +1,24 @@ +# Summary +This example demonstrates how to Export a Llama 2 model in ExecuTorch. +For Llama2, please refer to [the llama's github page](https://github.com/facebookresearch/llama) for details. +Pretrained parameters are not included in this repo. Users are suggested to download them through [the llama's download page](https://ai.meta.com/resources/models-and-libraries/llama-downloads/). + +# Notes +1. This example is to show the feasibility of exporting a Llama2 model in ExecuTorch. There is no guarantee for performance. +2. It's targeted to a reasonable size for edge devices. Depending on the model size, the memory usage of exporting the model can be high. TODO: improve memory usage in EXIR emitter. +3. The provided check point, demo_rand_params.pth is a dummy checkpoint with random parameters. It does not provide meaningful results. It's only for the purpose of demonstration and fast iterations. + +# Limitations +This example tries to reuse the Python code, with modifications to make it compatible with current ExecuTorch: +1. Since ExecuTorch does not support complex Tensor data type, use the customized functions to have rotary embedding with real numbers. TODO: support complex Tensor data type in ExecuTorch. +2. No KV cache. The current cache implementation in the original Llama2 repo is not supported by ExecuTorch, because ExecuTorch runtime assumes model data attributes being static. TODO: add support of mutable buffers in ExecuTorch. +3. No CUDA. ExecuTorch is focused on Edge use cases where CUDA is not available on most of the edge devices. +4. No dependencies on fairscale. The ColumnParallelLinear, ParallelEmbedding and training are not needed and supported in ExecuTorch. + + +# Instructions: +1. Follow the [tutorial](https://github.com/pytorch/executorch/blob/main/docs/website/docs/tutorials/00_setting_up_executorch.md) to set up ExecuTorch +2. `cd examples/third-party/llama` +3. `pip install -e .` +4. Go back to `executorch` root, run `python3 -m examples.export.export_example --model_name="llama2"`. The exported program, llama2.pte would be saved in current directory +5. Use the `executor_runner` (build instruction in step 1) to load and run llama2.pte, `executor_runner --model_path llama2.pte` diff --git a/examples/models/llama2/__init__.py b/examples/models/llama2/__init__.py new file mode 100644 index 00000000000..db6124ecc71 --- /dev/null +++ b/examples/models/llama2/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .model import Llama2Model + +__all__ = [ + Llama2Model, +] diff --git a/examples/models/llama2/demo_config.json b/examples/models/llama2/demo_config.json new file mode 100644 index 00000000000..13287f117e9 --- /dev/null +++ b/examples/models/llama2/demo_config.json @@ -0,0 +1 @@ +{"dim": 64, "multiple_of": 4, "n_heads": 8, "n_layers": 5, "norm_eps": 1e-05, "vocab_size": 512} \ No newline at end of file diff --git a/examples/models/llama2/demo_rand_params.pth b/examples/models/llama2/demo_rand_params.pth new file mode 100644 index 00000000000..900663169d5 Binary files /dev/null and b/examples/models/llama2/demo_rand_params.pth differ diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py new file mode 100644 index 00000000000..fc2797b2c8b --- /dev/null +++ b/examples/models/llama2/model.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Please refer to README.md in the same folder for more information. + + +import json +import math +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + +from examples.models.model_base import EagerModelBase + +from llama.model import ModelArgs, repeat_kv, RMSNorm +from torch import nn + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(shape) + + +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) + + freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) + freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) + + xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin + xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos + + xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) + + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert args.n_heads % self.n_kv_heads == 0 + model_parallel_size = 1 + self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + + mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) + mask = torch.triu(mask, diagonal=1) + self.register_buffer("mask", mask) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + bsz, seqlen, _ = x.shape + + # QKV + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + # RoPE relative positional embeddings + xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) + + # grouped multiquery attention: expand out keys and values + xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + # make heads into a batch dimension + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = xk.transpose(1, 2) + xv = xv.transpose(1, 2) + + scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) + assert hasattr(self, "mask") + scores = ( + scores + self.mask[:, :, :seqlen, :seqlen] + ) # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) + + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + + output = self.wo(output) + return output + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int, multiple_of: int): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward(self, x, freqs_cos, freqs_sin): + h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin) + out = h + self.feed_forward.forward(self.ffn_norm(h)) + return out + + +class Transformer(nn.Module): + last_loss: Optional[torch.Tensor] + + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + + freqs_cos, freqs_sin = precompute_freqs_cis( + self.params.dim // self.params.n_heads, self.params.max_seq_len + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + freqs_cos = self.freqs_cos[:seqlen] + freqs_sin = self.freqs_sin[:seqlen] + + for layer in self.layers: + h = layer(h, freqs_cos, freqs_sin) + # h = self.layers[0](h, freqs_cos, freqs_sin) # myuan: hack one layer for debug + + h = self.norm(h) + + logits = self.output(h) + return logits + + +class Llama2Model(EagerModelBase): + def __init__(self): + ckpt_dir = Path(__file__).absolute().parent + # The example is using a dummy small model with random weights for demo purpose only. + # Follow the instruction in https://github.com/facebookresearch/llama to download the model + device = "cpu" + checkpoint = torch.load( + Path(ckpt_dir) / "demo_rand_params.pth", map_location=device + ) + with open(Path(ckpt_dir) / "demo_config.json", "r") as f: + params = json.loads(f.read()) + max_seq_len = 128 + max_batch_size = 1 + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + **params, + ) + self.model_ = Transformer(model_args) + self.model_.load_state_dict( + checkpoint, strict=False + ) # self.model_ = Transformer(gptconf) + + # @staticmethod + def get_eager_model(self): + return self.model_ + + @staticmethod + def get_example_inputs(): + return (torch.tensor([[1]]),)