From 85d9aa6d88103079834b74da1a9b5b840e93f3cb Mon Sep 17 00:00:00 2001 From: muhtasham Date: Thu, 22 Feb 2024 01:06:46 +0100 Subject: [PATCH 01/21] Add Starcoder2 model and update utils.py --- llms/mlx_lm/models/starcoder2.py | 182 +++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 1 + 2 files changed, 183 insertions(+) create mode 100644 llms/mlx_lm/models/starcoder2.py diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py new file mode 100644 index 000000000..d68dc8001 --- /dev/null +++ b/llms/mlx_lm/models/starcoder2.py @@ -0,0 +1,182 @@ +import math +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.head_dim = head_dim = args.head_dim + + self.repeats = n_heads // n_kv_heads + + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if self.repeats > 1: + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + +class Starcoder2Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + h = h * (self.args.hidden_size**0.5) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model_type = args.model_type + self.model = Starcoder2Model(args) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.model(inputs, cache) + out = out @ self.model.embed_tokens.weight.T + return out, cache + + @property + def layers(self): + return self.model.layers \ No newline at end of file diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 579fca59f..ca44a42fd 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -32,6 +32,7 @@ def check_lora_layers(num_model): "stablelm_epoch", "qwen2", "gemma", + "starcoder2", ]: check_lora_layers(len(model.model.layers)) From 973e4a7f98e337d696f9d491a4f740b9b619309c Mon Sep 17 00:00:00 2001 From: muhtasham Date: Thu, 22 Feb 2024 01:16:11 +0100 Subject: [PATCH 02/21] Refactor model arguments and modules in starcoder2.py --- llms/mlx_lm/models/starcoder2.py | 96 ++++++++++++++------------------ 1 file changed, 43 insertions(+), 53 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index d68dc8001..07256ec62 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -10,17 +10,15 @@ @dataclass class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float vocab_size: int - num_key_value_heads: int = None rope_theta: float = 10000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): @@ -38,26 +36,20 @@ def __call__(self, x): class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.args = args - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.head_dim = head_dim = args.head_dim - - self.repeats = n_heads // n_kv_heads + self.n_heads: int = args.n_heads + self.n_kv_heads: int = args.n_kv_heads - self.scale = head_dim**-0.5 + self.repeats = self.n_heads // self.n_kv_heads - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + self.scale = self.args.head_dim**-0.5 - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - ) + self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) + self.rope = nn.RoPE(args.head_dim, traditional=True, base=args.rope_theta) def __call__( self, @@ -67,16 +59,15 @@ def __call__( ) -> mx.array: B, L, D = x.shape - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + queries, keys, values = self.wq(x), self.wk(x), self.wv(x) # Prepare the queries, keys and values for the attention computation queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache @@ -93,27 +84,28 @@ def __call__( scores += mask scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) + return self.wo(output), (keys, values) -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) def __call__(self, x) -> mx.array: - return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + return self.w2(nn.silu(self.w1(x)) * self.w3(x)) class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.feed_forward = FeedForward(args=args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) self.args = args def __call__( @@ -122,9 +114,9 @@ def __call__( mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + r, cache = self.attention(self.attention_norm(x), mask, cache) h = x + r - r = self.mlp(self.post_attention_layernorm(h)) + r = self.feed_forward(self.ffn_norm(h)) out = h + r return out, cache @@ -133,21 +125,19 @@ def __init__(self, args: ModelArgs): super().__init__() self.args = args self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers + self.n_layers = args.n_layers assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) def __call__( self, inputs: mx.array, cache=None, ): - h = self.embed_tokens(inputs) - h = h * (self.args.hidden_size**0.5) + h = self.tok_embeddings(inputs) mask = None if h.shape[1] > 1: @@ -160,7 +150,7 @@ def __call__( for e, layer in enumerate(self.layers): h, cache[e] = layer(h, mask, cache[e]) - return self.norm(h), cache + return self.output(self.norm(h)), cache class Model(nn.Module): def __init__(self, args: ModelArgs): From 6cd9870cde8410b0f0fe3a3127643de245082bdc Mon Sep 17 00:00:00 2001 From: muhtasham Date: Thu, 22 Feb 2024 13:06:25 +0100 Subject: [PATCH 03/21] Refactor FeedForward class to MLP in starcoder2.py --- llms/mlx_lm/models/starcoder2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 07256ec62..fa7be744f 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -86,7 +86,7 @@ def __call__( output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) return self.wo(output), (keys, values) -class FeedForward(nn.Module): +class MLP(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -103,7 +103,7 @@ def __init__(self, args: ModelArgs): self.n_heads = args.n_heads self.dim = args.dim self.attention = Attention(args) - self.feed_forward = FeedForward(args=args) + self.feed_forward = MLP(args=args) self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) self.args = args From 73ad35f66e25b8d963e0314e3a040e6e2c86da1d Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Wed, 28 Feb 2024 22:07:22 +0100 Subject: [PATCH 04/21] Fix typo --- llms/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/README.md b/llms/README.md index a15d00c82..27348e044 100644 --- a/llms/README.md +++ b/llms/README.md @@ -46,7 +46,7 @@ You can convert models in the Python API with: ```python from mlx_lm import convert -upload_repo = "mistralai/Mistral-7B-Instruct-v0.1" +upload_repo = "mlx-community/My-Mistral-7B-v0.1-4bit" convert("mistralai/Mistral-7B-v0.1", quantize=True, upload_repo=upload_repo) ``` From 3c0fbd5923f6365953ddddc3229f132bad551f07 Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Wed, 28 Feb 2024 22:13:17 +0100 Subject: [PATCH 05/21] pre-commit --- llms/mlx_lm/models/starcoder2.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index fa7be744f..374c69581 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -19,7 +19,8 @@ class ModelArgs(BaseModelArgs): norm_eps: float vocab_size: int rope_theta: float = 10000 - + + class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): super().__init__() @@ -33,6 +34,7 @@ def __call__(self, x): output = self._norm(x.astype(mx.float32)).astype(x.dtype) return self.weight * output + class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -86,6 +88,7 @@ def __call__( output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) return self.wo(output), (keys, values) + class MLP(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -97,6 +100,7 @@ def __init__(self, args: ModelArgs): def __call__(self, x) -> mx.array: return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -120,6 +124,7 @@ def __call__( out = h + r return out, cache + class Starcoder2Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -152,6 +157,7 @@ def __call__( return self.output(self.norm(h)), cache + class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -169,4 +175,4 @@ def __call__( @property def layers(self): - return self.model.layers \ No newline at end of file + return self.model.layers From 761b6162447323f5630c60dd8cd8b5e08c01c7da Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Thu, 29 Feb 2024 19:52:11 +0100 Subject: [PATCH 06/21] Refactor starcoder2.py: Update model arguments and modules --- llms/mlx_lm/models/starcoder2.py | 77 +++++++++++++++++--------------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 374c69581..d2ae8c85e 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -10,25 +10,27 @@ @dataclass class ModelArgs(BaseModelArgs): - dim: int - n_layers: int - head_dim: int - hidden_dim: int - n_heads: int - n_kv_heads: int - norm_eps: float - vocab_size: int - rope_theta: float = 10000 + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + num_key_value_heads: int = None + max_position_embeddings: int = 16384 + norm_eps: float = None + rms_norm_eps: float = 1e-5 + vocab_size: int = 49152 + rope_theta: float = 100000 class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): + def __init__(self, hidden_sizes: int, eps: float = 1e-5): super().__init__() - self.weight = mx.ones((dims,)) + self.weight = mx.ones((hidden_sizes,)) self.eps = eps def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + return x * mx.rsqrt(x.square().mean(-1, keephidden_sizes=True) + self.eps) def __call__(self, x): output = self._norm(x.astype(mx.float32)).astype(x.dtype) @@ -40,18 +42,20 @@ def __init__(self, args: ModelArgs): super().__init__() self.args = args - self.n_heads: int = args.n_heads - self.n_kv_heads: int = args.n_kv_heads + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.repeats = self.n_heads // self.n_kv_heads - self.scale = self.args.head_dim**-0.5 + head_dim = args.hidden_size // args.num_attention_heads + self.scale = head_dim**-0.5 - self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) - self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) - self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) - self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) - self.rope = nn.RoPE(args.head_dim, traditional=True, base=args.rope_theta) + self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) + self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) + self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) def __call__( self, @@ -90,26 +94,25 @@ def __call__( class MLP(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, dim, hidden_dim): super().__init__() - - self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) - self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) - self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: - return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + return self.w2(nn.gelu(self.w1(x)) * self.w3(x)) class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim + self.hidden_size = args.hidden_size + self.n_heads = args.num_attention_heads self.attention = Attention(args) - self.feed_forward = MLP(args=args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.attention_norm = RMSNorm(args.hidden_size, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.hidden_size, eps=args.norm_eps) self.args = args def __call__( @@ -130,12 +133,14 @@ def __init__(self, args: ModelArgs): super().__init__() self.args = args self.vocab_size = args.vocab_size - self.n_layers = args.n_layers + self.num_hidden_layers = args.num_hidden_layers assert self.vocab_size > 0 - self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) - self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] - self.norm = RMSNorm(args.dim, eps=args.norm_eps) - self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = RMSNorm(args.hidden_size, eps=args.norm_eps) + self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, From 8aee3b71622495ba2e031ae4959382b95041d7aa Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 11:04:01 +0100 Subject: [PATCH 07/21] Fix LM head and MLP layers --- llms/mlx_lm/models/starcoder2.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index d2ae8c85e..ac144813a 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -96,12 +96,12 @@ def __call__( class MLP(nn.Module): def __init__(self, dim, hidden_dim): super().__init__() - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.c_fc = nn.Linear(dim, hidden_dim, bias=False) + self.c_proj = nn.Linear(hidden_dim, dim, bias=False) + self.act = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: - return self.w2(nn.gelu(self.w1(x)) * self.w3(x)) + return self.c_proj(nn.gelu(self.c_fc(x)) * self.act(x)) class TransformerBlock(nn.Module): @@ -166,8 +166,8 @@ def __call__( class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.model_type = args.model_type self.model = Starcoder2Model(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, @@ -175,9 +175,4 @@ def __call__( cache=None, ): out, cache = self.model(inputs, cache) - out = out @ self.model.embed_tokens.weight.T - return out, cache - - @property - def layers(self): - return self.model.layers + return self.lm_head(out), cache From a9ba4b3f645c0c83a0ddd375830dfb58078d61b6 Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 11:20:38 +0100 Subject: [PATCH 08/21] Rename input layer norm --- llms/mlx_lm/models/starcoder2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index ac144813a..74750ede9 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -111,8 +111,8 @@ def __init__(self, args: ModelArgs): self.n_heads = args.num_attention_heads self.attention = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.attention_norm = RMSNorm(args.hidden_size, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.hidden_size, eps=args.norm_eps) + self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.args = args def __call__( From 1366c03f52cc0ac400fae3b58e8389b3fc3ca434 Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 11:44:19 +0100 Subject: [PATCH 09/21] Update bias in linear layers --- llms/mlx_lm/models/starcoder2.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 74750ede9..a9e83f55d 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -51,10 +51,10 @@ def __init__(self, args: ModelArgs): head_dim = args.hidden_size // args.num_attention_heads self.scale = head_dim**-0.5 - self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) - self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) + self.wq = nn.Linear(dim, n_heads * head_dim, bias=True) + self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.wo = nn.Linear(n_heads * head_dim, dim, bias=True) self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) def __call__( @@ -96,9 +96,9 @@ def __call__( class MLP(nn.Module): def __init__(self, dim, hidden_dim): super().__init__() - self.c_fc = nn.Linear(dim, hidden_dim, bias=False) - self.c_proj = nn.Linear(hidden_dim, dim, bias=False) - self.act = nn.Linear(dim, hidden_dim, bias=False) + self.c_fc = nn.Linear(dim, hidden_dim, bias=True) + self.c_proj = nn.Linear(hidden_dim, dim, bias=True) + self.act = nn.Linear(dim, hidden_dim, bias=True) def __call__(self, x) -> mx.array: return self.c_proj(nn.gelu(self.c_fc(x)) * self.act(x)) @@ -140,7 +140,7 @@ def __init__(self, args: ModelArgs): TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] self.norm = RMSNorm(args.hidden_size, eps=args.norm_eps) - self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=True) def __call__( self, @@ -167,7 +167,7 @@ class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.model = Starcoder2Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True) def __call__( self, From 446e7e9477ee55c27b1075e563656e5f9d45de4f Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 11:46:55 +0100 Subject: [PATCH 10/21] Refactor token embeddings in Starcoder2Model --- llms/mlx_lm/models/starcoder2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index a9e83f55d..02c89b4ef 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -135,7 +135,7 @@ def __init__(self, args: ModelArgs): self.vocab_size = args.vocab_size self.num_hidden_layers = args.num_hidden_layers assert self.vocab_size > 0 - self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.layers = [ TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] @@ -147,7 +147,7 @@ def __call__( inputs: mx.array, cache=None, ): - h = self.tok_embeddings(inputs) + h = self.embed_tokens(inputs) mask = None if h.shape[1] > 1: From f72792c32e9d85c083cf43d835866e7ee5e379ad Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 16:00:22 +0100 Subject: [PATCH 11/21] Rename to standard HF attention layer name --- llms/mlx_lm/models/starcoder2.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 02c89b4ef..26f3b3e10 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -51,10 +51,10 @@ def __init__(self, args: ModelArgs): head_dim = args.hidden_size // args.num_attention_heads self.scale = head_dim**-0.5 - self.wq = nn.Linear(dim, n_heads * head_dim, bias=True) - self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.wo = nn.Linear(n_heads * head_dim, dim, bias=True) + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True) self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) def __call__( @@ -65,15 +65,19 @@ def __call__( ) -> mx.array: B, L, D = x.shape - queries, keys, values = self.wq(x), self.wk(x), self.wv(x) + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) # Prepare the queries, keys and values for the attention computation queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads, L, -1]) + + if self.repeats > 1: + keys, values = map(repeat, (keys, values)) if cache is not None: key_cache, value_cache = cache @@ -90,7 +94,7 @@ def __call__( scores += mask scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.wo(output), (keys, values) + return self.o_proj(output), (keys, values) class MLP(nn.Module): From a8ce255cd2311d7c97910ce3875cc0e1e20da9c4 Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 16:29:51 +0100 Subject: [PATCH 12/21] Add LayerNorm --- llms/mlx_lm/models/starcoder2.py | 35 +++++++++++++++----------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 26f3b3e10..5c7fffd4d 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -6,6 +6,7 @@ import mlx.nn as nn from .base import BaseModelArgs +from .layers import LayerNorm @dataclass @@ -19,22 +20,16 @@ class ModelArgs(BaseModelArgs): max_position_embeddings: int = 16384 norm_eps: float = None rms_norm_eps: float = 1e-5 + norm_type: str = "layer_norm" vocab_size: int = 49152 rope_theta: float = 100000 + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads -class RMSNorm(nn.Module): - def __init__(self, hidden_sizes: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((hidden_sizes,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keephidden_sizes=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output + if self.norm_eps is None: + self.norm_eps = self.rms_norm_eps class Attention(nn.Module): @@ -113,11 +108,13 @@ def __init__(self, args: ModelArgs): super().__init__() self.hidden_size = args.hidden_size self.n_heads = args.num_attention_heads - self.attention = Attention(args) + + self.self_attn = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.args = args + self.input_layer_norm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layer_norm = LayerNorm( + args.hidden_size, eps=args.rms_norm_eps + ) def __call__( self, @@ -125,9 +122,9 @@ def __call__( mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.attention(self.attention_norm(x), mask, cache) + r, cache = self.self_attn(self.input_layer_norm(x), mask, cache) h = x + r - r = self.feed_forward(self.ffn_norm(h)) + r = self.mlp(self.post_attention_layer_norm(h)) out = h + r return out, cache @@ -143,7 +140,7 @@ def __init__(self, args: ModelArgs): self.layers = [ TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] - self.norm = RMSNorm(args.hidden_size, eps=args.norm_eps) + self.norm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps) self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=True) def __call__( From 4c9aea69d46d63608cbf306cf975478061edf1b5 Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 16:34:24 +0100 Subject: [PATCH 13/21] Add transposed token embeddings (like in Gemma) --- llms/mlx_lm/models/starcoder2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 5c7fffd4d..03b0df4f2 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -168,7 +168,6 @@ class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.model = Starcoder2Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True) def __call__( self, @@ -176,4 +175,9 @@ def __call__( cache=None, ): out, cache = self.model(inputs, cache) - return self.lm_head(out), cache + out = out @ self.model.embed_tokens.weight.T + return out, cache + + @property + def layers(self): + return self.model.layers From c95440624b6ec60a101ad8befdfcd71e19bc90e0 Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 17:09:43 +0100 Subject: [PATCH 14/21] Refactor MLP and TransformerBlock classes --- llms/mlx_lm/models/starcoder2.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 03b0df4f2..c58c180af 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -97,10 +97,9 @@ def __init__(self, dim, hidden_dim): super().__init__() self.c_fc = nn.Linear(dim, hidden_dim, bias=True) self.c_proj = nn.Linear(hidden_dim, dim, bias=True) - self.act = nn.Linear(dim, hidden_dim, bias=True) - def __call__(self, x) -> mx.array: - return self.c_proj(nn.gelu(self.c_fc(x)) * self.act(x)) + def __call__(self, x): + return self.c_proj(nn.gelu(self.c_fc(x))) class TransformerBlock(nn.Module): @@ -111,10 +110,11 @@ def __init__(self, args: ModelArgs): self.self_attn = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layer_norm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layer_norm = LayerNorm( + self.input_layernorm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.rms_norm_eps ) + self.args = args def __call__( self, @@ -122,9 +122,9 @@ def __call__( mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.self_attn(self.input_layer_norm(x), mask, cache) + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r - r = self.mlp(self.post_attention_layer_norm(h)) + r = self.mlp(self.post_attention_layernorm(h)) out = h + r return out, cache @@ -141,7 +141,6 @@ def __init__(self, args: ModelArgs): TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] self.norm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps) - self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=True) def __call__( self, @@ -161,7 +160,7 @@ def __call__( for e, layer in enumerate(self.layers): h, cache[e] = layer(h, mask, cache[e]) - return self.output(self.norm(h)), cache + return self.norm(h), cache class Model(nn.Module): From 512a542b3b1e77c1f876a52de51d0e285fc45b46 Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 18:34:51 +0100 Subject: [PATCH 15/21] Add tie_word_embeddings option to ModelArgs and update Model implementation --- llms/mlx_lm/models/starcoder2.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index c58c180af..f82139a4d 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -23,6 +23,7 @@ class ModelArgs(BaseModelArgs): norm_type: str = "layer_norm" vocab_size: int = 49152 rope_theta: float = 100000 + tie_word_embeddings: bool = True def __post_init__(self): if self.num_key_value_heads is None: @@ -167,6 +168,7 @@ class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.model = Starcoder2Model(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True) def __call__( self, @@ -174,8 +176,11 @@ def __call__( cache=None, ): out, cache = self.model(inputs, cache) - out = out @ self.model.embed_tokens.weight.T - return out, cache + if not self.model.args.tie_word_embeddings: + out = out @ self.model.embed_tokens.weight.T + return out, cache + else: + return self.lm_head(out), cache @property def layers(self): From 3a81505c08759916aa17b279ab2f9871e8a4af75 Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 18:48:19 +0100 Subject: [PATCH 16/21] Add conditional check for tying word embeddings in Starcoder2Model --- llms/mlx_lm/models/starcoder2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index f82139a4d..4ae666d4a 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -168,7 +168,9 @@ class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.model = Starcoder2Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True) + # This is for 15B starcoder2 since it doesn't tie word embeddings + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True) def __call__( self, From ab562b1d0d7d6c55ff854b7eb62aa1f6a23bf98b Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 21:42:42 +0100 Subject: [PATCH 17/21] Fix bias in lm_head linear layer --- llms/mlx_lm/models/starcoder2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 4ae666d4a..471e269b6 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -170,7 +170,7 @@ def __init__(self, args: ModelArgs): self.model = Starcoder2Model(args) # This is for 15B starcoder2 since it doesn't tie word embeddings if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, From f268ab0744ae5ac3508de223e0eadb97db9a936b Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 22:05:19 +0100 Subject: [PATCH 18/21] Remove unused LayerNorm in stablelm --- llms/mlx_lm/models/stablelm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 0f2c4f039..5fbca3aef 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -128,7 +128,6 @@ class DecoderLayer(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.self_attn = Attention(config=config) - self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = MLP(config.hidden_size, config.intermediate_size) self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = LayerNorm( From fe6c52f0f7cb23e0a8e6d7333787bb588d915d5a Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Fri, 1 Mar 2024 22:05:54 +0100 Subject: [PATCH 19/21] Update transformers dependency to use GitHub repository --- llms/mlx_lm/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 049049e76..09e73b444 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,5 +1,5 @@ mlx>=0.4 numpy -transformers>=4.38.0 +git+https://github.com/huggingface/transformers.git protobuf pyyaml From b21a6bfd7418853e122eb85dea4e3b7f6940bd3b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 1 Mar 2024 13:31:30 -0800 Subject: [PATCH 20/21] fix lm head bug, revert transformer req --- llms/mlx_lm/models/starcoder2.py | 4 ++-- llms/mlx_lm/requirements.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 471e269b6..fe664c825 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -179,10 +179,10 @@ def __call__( ): out, cache = self.model(inputs, cache) if not self.model.args.tie_word_embeddings: + return self.lm_head(out), cache + else: out = out @ self.model.embed_tokens.weight.T return out, cache - else: - return self.lm_head(out), cache @property def layers(self): diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 09e73b444..049049e76 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,5 +1,5 @@ mlx>=0.4 numpy -git+https://github.com/huggingface/transformers.git +transformers>=4.38.0 protobuf pyyaml From fc31d7a2a225134f81653a4fed12d8a008c76727 Mon Sep 17 00:00:00 2001 From: Muhtasham Oblokulov Date: Sun, 3 Mar 2024 02:45:14 +0100 Subject: [PATCH 21/21] Update RoPE initialization in Attention class --- llms/mlx_lm/models/starcoder2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index fe664c825..aeebfc96e 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -51,7 +51,7 @@ def __init__(self, args: ModelArgs): self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True) - self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) + self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta) def __call__( self,