diff --git a/llms/README.md b/llms/README.md index a15d00c82..0d951857a 100644 --- a/llms/README.md +++ b/llms/README.md @@ -127,17 +127,18 @@ Most [Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending), and [Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending) +[Starcoder2](https://huggingface.co/models?library=transformers,safetensors&other=starcoder2&sort=trending) style models should work out of the box. For some models (such as `Qwen` and `plamo`) the tokenizer requires you to enable the `trust_remote_code` option. You can do this by passing `--trust-remote-code` in the command line. If you don't specify the flag explicitly, you will be prompted to trust remote code in the terminal when -running the model. +running the model. For `Qwen` models you must also specify the `eos_token`. You can do this by passing `--eos-token "<|endoftext|>"` in the command -line. +line. These options can also be set in the Python API. For example: diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py new file mode 100644 index 000000000..e16baedd6 --- /dev/null +++ b/llms/mlx_lm/models/starcoder2.py @@ -0,0 +1,195 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .layers import LayerNorm + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + sliding_window: int + num_attention_heads: int + num_key_value_heads: int = None + vocab_size: int = 49152 + intermediate_size: int = 12288 + max_position_embeddings: int = 16384 + norm_epsilon: float = 1e-05 + rope_theta: float = 10000 + rope_traditional: bool = False + attention_dropout: int = 0.0 + residual_dropout: int = 0.0 + embedding_dropout: int = 0.0 + use_bias: bool = True + tie_word_embeddings: bool = True + + +class Starcoder2Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, args: ModelArgs): + super().__init__() + self.config = args + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.head_dim = head_dim = args.hidden_size // n_heads + + self.repeats = n_heads // n_kv_heads + + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True) + + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if self.repeats > 1: + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores, axis=-1).astype(values.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class Starcoder2MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.c_fc = nn.Linear(dim, hidden_dim, bias=True) + self.c_proj = nn.Linear(hidden_dim, dim, bias=True) + + def __call__(self, x) -> mx.array: + return self.c_proj(nn.gelu(self.c_fc(x))) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Starcoder2Attention(args) + self.mlp = Starcoder2MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon) + self.post_attention_layernorm = LayerNorm( + args.hidden_size, eps=args.norm_epsilon + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class Starcoder2Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model_type = args.model_type + self.model = Starcoder2Model(args) + self.tie_word_embeddings = args.tie_word_embeddings + + # If tie_word_embeddings is False, tie (share) the embedding weights with lm_head + # Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/starcoder2/modeling_starcoder2.py#L1071 + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.model(inputs, cache) + + if not self.tie_word_embeddings: + return self.lm_head(out), cache + else: + out = out @ self.model.embed_tokens.weight.T + return out, cache + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 0fd688de9..bfa5cdf98 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -32,6 +32,7 @@ def check_lora_layers(num_model): "stablelm", "qwen2", "gemma", + "starcoder2", ]: check_lora_layers(len(model.model.layers))