From d396df9fc2db8d5c13df6a992151cc7a407ef6c6 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 1 Mar 2024 22:00:26 +0100 Subject: [PATCH 01/11] add starcoder 2 --- llms/mlx_lm/models/starcoder2.py | 209 +++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 llms/mlx_lm/models/starcoder2.py diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py new file mode 100644 index 000000000..81ca434f2 --- /dev/null +++ b/llms/mlx_lm/models/starcoder2.py @@ -0,0 +1,209 @@ +from dataclasses import dataclass +from functools import partial +from typing import Dict, Optional, Tuple, Union + +import logger + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int = 12288 + num_attention_heads: int + head_dim: int + rms_norm_eps: float + vocab_size: int = 49152 + num_key_value_heads: int = None + max_position_embeddings: int = 16384 + sliding_window: int = 4096 + norm_epsilon: float = 1e-5 + rope_theta: float = 10000 + rope_traditional: bool = False + attention_dropout: int = 0.0 + residual_dropout: int = 0.0 + embedding_dropout: int = 0.0 + use_bias: bool = True + +@partial(mx.compile, shapeless=True) +def rms_norm(x, weight, eps): + x = x.astype(mx.float32) + x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) + return (1.0 + weight) * x.astype(weight.dtype) + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims, )) + self.eps = eps + + def __call__(self, x): + return rms_norm(x, self.weight, self.epx) + +class Starcoder2Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, args: ModelArgs, layer_idx: Optional[int] = None): + super().__init__() + self.config = args + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.head_dim = head_dim = args.hidden_size // self.n_heads + + self.repeats = self.n_heads // n_kv_heads + + self.scale = self.head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True) + + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transponse(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transponse(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transponse(0, 2, 1, 3) + + if self.repeats > 1: + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class Starcode2MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.c_fc = nn.Linear(dim, hidden_dim, bias=True) + self.c_proj = nn.Linear(hidden_dim, dim, bias=True) + + def __call_(self, x) -> mx.array: + return self.c_proj(nn.gelu(self.c_fc(x))) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Starcoder2Attention(args) + self.mlp = Starcode2MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class Starcoder2Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + h = h * (self.args.hidden_size**0.5) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.creative_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model_type = args.model_type + self.model = Starcoder2Model(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.model(inputs, cache) + return self.lm_head(out), cache + + @property + def layers(self): + return self.model.layers From 1979c4538193fc5e86dd59a54405e5b7c20ff96d Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 2 Mar 2024 17:52:19 +0100 Subject: [PATCH 02/11] add tie_word_embeddings --- llms/mlx_lm/models/starcoder2.py | 58 +++++++++++--------------------- 1 file changed, 19 insertions(+), 39 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 81ca434f2..980850434 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -2,8 +2,6 @@ from functools import partial from typing import Dict, Optional, Tuple, Union -import logger - import mlx.core as mx import mlx.nn as nn @@ -14,14 +12,12 @@ class ModelArgs(BaseModelArgs): model_type: str hidden_size: int num_hidden_layers: int - intermediate_size: int = 12288 + sliding_window: int num_attention_heads: int - head_dim: int - rms_norm_eps: float - vocab_size: int = 49152 num_key_value_heads: int = None + vocab_size: int = 49152 + intermediate_size: int = 12288 max_position_embeddings: int = 16384 - sliding_window: int = 4096 norm_epsilon: float = 1e-5 rope_theta: float = 10000 rope_traditional: bool = False @@ -29,22 +25,7 @@ class ModelArgs(BaseModelArgs): residual_dropout: int = 0.0 embedding_dropout: int = 0.0 use_bias: bool = True - -@partial(mx.compile, shapeless=True) -def rms_norm(x, weight, eps): - x = x.astype(mx.float32) - x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) - return (1.0 + weight) * x.astype(weight.dtype) - - -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims, )) - self.eps = eps - - def __call__(self, x): - return rms_norm(x, self.weight, self.epx) + tie_word_embeddings: bool = True class Starcoder2Attention(nn.Module): """ @@ -52,17 +33,9 @@ class Starcoder2Attention(nn.Module): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, args: ModelArgs, layer_idx: Optional[int] = None): + def __init__(self, args: ModelArgs): super().__init__() self.config = args - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads @@ -120,7 +93,7 @@ def __call__( return self.o_proj(output), (keys, values) -class Starcode2MLP(nn.Module): +class Starcoder2MLP(nn.Module): def __init__(self, dim, hidden_dim): super().__init__() self.c_fc = nn.Linear(dim, hidden_dim, bias=True) @@ -136,9 +109,9 @@ def __init__(self, args: ModelArgs): self.num_attention_heads = args.num_attention_heads self.hidden_size = args.hidden_size self.self_attn = Starcoder2Attention(args) - self.mlp = Starcode2MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.mlp = Starcoder2MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) self.args = args def __call__( @@ -165,7 +138,7 @@ def __init__(self, args: ModelArgs): self.layers = [ TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) def __call__( self, @@ -194,7 +167,9 @@ def __init__(self, args: ModelArgs): super().__init__() self.model_type = args.model_type self.model = Starcoder2Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.tie_word_embeddings = args.tie_word_embeddings + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, @@ -202,7 +177,12 @@ def __call__( cache=None, ): out, cache = self.model(inputs, cache) - return self.lm_head(out), cache + + if not self.tie_word_embeddings: + return self.lm_head(out), cache + else: + out = out @ self.model.embed_tokens.weight.T + return out, cache @property def layers(self): From b099f95c9b912c3189680f94f6f0697afdceaac2 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 2 Mar 2024 18:14:09 +0100 Subject: [PATCH 03/11] add weight sharing comments --- llms/mlx_lm/models/starcoder2.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 980850434..df2e7c7f2 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -168,6 +168,9 @@ def __init__(self, args: ModelArgs): self.model_type = args.model_type self.model = Starcoder2Model(args) self.tie_word_embeddings = args.tie_word_embeddings + + # If tie_word_embeddings is False, tie (share) the embedding weights with lm_head + # Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/starcoder2/modeling_starcoder2.py#L1071 if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) @@ -179,8 +182,10 @@ def __call__( out, cache = self.model(inputs, cache) if not self.tie_word_embeddings: + # If tie_word_embeddings is False, apply linear transformation to obtain predictions return self.lm_head(out), cache else: + # If tie_word_embeddings is True, perform matrix multiplication for predictions out = out @ self.model.embed_tokens.weight.T return out, cache From 079ac4e7f51ae809d3c1daf1ce083e24a98ef74d Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 2 Mar 2024 18:15:12 +0100 Subject: [PATCH 04/11] format with black --- llms/mlx_lm/models/starcoder2.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index df2e7c7f2..2c50ee626 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -7,6 +7,7 @@ from .base import BaseModelArgs + @dataclass class ModelArgs(BaseModelArgs): model_type: str @@ -27,6 +28,7 @@ class ModelArgs(BaseModelArgs): use_bias: bool = True tie_word_embeddings: bool = True + class Starcoder2Attention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -111,17 +113,19 @@ def __init__(self, args: ModelArgs): self.self_attn = Starcoder2Attention(args) self.mlp = Starcoder2MLP(args.hidden_size, args.intermediate_size) self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) - self.post_attention_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm( + args.hidden_size, eps=args.norm_epsilon + ) self.args = args def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None + cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: r, cache = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r + h = x + r r = self.mlp(self.post_attention_layernorm(h)) out = h + r return out, cache From 3b31f096ac73e7f2ecfac650dd4e4eff4de16236 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 2 Mar 2024 18:15:40 +0100 Subject: [PATCH 05/11] add starcoder2 to mlx-lm tuner --- llms/mlx_lm/tuner/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 579fca59f..ca44a42fd 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -32,6 +32,7 @@ def check_lora_layers(num_model): "stablelm_epoch", "qwen2", "gemma", + "starcoder2", ]: check_lora_layers(len(model.model.layers)) From 5f5bb850fc3868b203c6b76279569c7c9d093532 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 2 Mar 2024 18:16:04 +0100 Subject: [PATCH 06/11] Add to readme list of supported models --- llms/README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llms/README.md b/llms/README.md index a15d00c82..0d951857a 100644 --- a/llms/README.md +++ b/llms/README.md @@ -127,17 +127,18 @@ Most [Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending), and [Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending) +[Starcoder2](https://huggingface.co/models?library=transformers,safetensors&other=starcoder2&sort=trending) style models should work out of the box. For some models (such as `Qwen` and `plamo`) the tokenizer requires you to enable the `trust_remote_code` option. You can do this by passing `--trust-remote-code` in the command line. If you don't specify the flag explicitly, you will be prompted to trust remote code in the terminal when -running the model. +running the model. For `Qwen` models you must also specify the `eos_token`. You can do this by passing `--eos-token "<|endoftext|>"` in the command -line. +line. These options can also be set in the Python API. For example: From 34a92453868d0922c6efeb2ad3eef4cc5c07c69c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 2 Mar 2024 18:19:18 +0100 Subject: [PATCH 07/11] fix typo --- llms/mlx_lm/models/starcoder2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 2c50ee626..7773a0f25 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -154,7 +154,7 @@ def __call__( mask = None if h.shape[1] > 1: - mask = nn.MultiHeadAttention.creative_additive_causal_mask(h.shape[1]) + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) mask = mask.astype(h.dtype) if cache is None: From 2fe2d337c2328e0e12b0e3c7b8a3f00f589e529c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 2 Mar 2024 18:20:13 +0100 Subject: [PATCH 08/11] fix transpose --- llms/mlx_lm/models/starcoder2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 7773a0f25..5d142592a 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -69,9 +69,9 @@ def __call__( queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transponse(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transponse(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transponse(0, 2, 1, 3) + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if self.repeats > 1: keys = mx.repeat(keys, self.repeats, axis=1) From 0a630c88d3952cd0481c1217ce89033308935eb7 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 2 Mar 2024 18:21:15 +0100 Subject: [PATCH 09/11] fix call method --- llms/mlx_lm/models/starcoder2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 5d142592a..45932c2ae 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -101,7 +101,7 @@ def __init__(self, dim, hidden_dim): self.c_fc = nn.Linear(dim, hidden_dim, bias=True) self.c_proj = nn.Linear(hidden_dim, dim, bias=True) - def __call_(self, x) -> mx.array: + def __call__(self, x) -> mx.array: return self.c_proj(nn.gelu(self.c_fc(x))) From 86a7cf763f3985cd94004cd099e7e3c41ad13a6a Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 3 Mar 2024 00:25:25 +0100 Subject: [PATCH 10/11] fix gibberish output and formatting --- llms/mlx_lm/models/starcoder2.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 45932c2ae..f9b1a6864 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -1,12 +1,11 @@ from dataclasses import dataclass -from functools import partial -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs - +from .layers import LayerNorm @dataclass class ModelArgs(BaseModelArgs): @@ -19,7 +18,7 @@ class ModelArgs(BaseModelArgs): vocab_size: int = 49152 intermediate_size: int = 12288 max_position_embeddings: int = 16384 - norm_epsilon: float = 1e-5 + norm_epsilon: float = 1e-05 rope_theta: float = 10000 rope_traditional: bool = False attention_dropout: int = 0.0 @@ -41,11 +40,11 @@ def __init__(self, args: ModelArgs): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.head_dim = head_dim = args.hidden_size // self.n_heads + self.head_dim = head_dim = args.hidden_size // n_heads - self.repeats = self.n_heads // n_kv_heads + self.repeats = n_heads // n_kv_heads - self.scale = self.head_dim**-0.5 + self.scale = head_dim**-0.5 self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) @@ -90,7 +89,7 @@ def __call__( scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + scores = mx.softmax(scores, axis=-1).astype(values.dtype) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) @@ -112,8 +111,8 @@ def __init__(self, args: ModelArgs): self.hidden_size = args.hidden_size self.self_attn = Starcoder2Attention(args) self.mlp = Starcoder2MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) - self.post_attention_layernorm = nn.LayerNorm( + self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon) + self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.norm_epsilon ) self.args = args @@ -142,7 +141,7 @@ def __init__(self, args: ModelArgs): self.layers = [ TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] - self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) + self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon) def __call__( self, @@ -150,7 +149,6 @@ def __call__( cache=None, ): h = self.embed_tokens(inputs) - h = h * (self.args.hidden_size**0.5) mask = None if h.shape[1] > 1: @@ -186,10 +184,8 @@ def __call__( out, cache = self.model(inputs, cache) if not self.tie_word_embeddings: - # If tie_word_embeddings is False, apply linear transformation to obtain predictions return self.lm_head(out), cache else: - # If tie_word_embeddings is True, perform matrix multiplication for predictions out = out @ self.model.embed_tokens.weight.T return out, cache From d4383fb42bb60b2ba23a10102e6bc268f37977ad Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 3 Mar 2024 00:44:00 +0100 Subject: [PATCH 11/11] black formatting --- llms/mlx_lm/models/starcoder2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index f9b1a6864..e16baedd6 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -7,6 +7,7 @@ from .base import BaseModelArgs from .layers import LayerNorm + @dataclass class ModelArgs(BaseModelArgs): model_type: str