Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ MLX Examples was developed with contributions from the following individuals:
- Gabrijel Boduljak: Implemented `CLIP`.
- Markus Enzweiler: Added the `cvae` examples.
- Rasmus Kinnunen: Fixed a security hole in the `llms/mlx_lm` example
- Prince Canuma: Helped add support for `Starcoder2` models.
23 changes: 6 additions & 17 deletions llms/mlx_lm/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,12 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
num_key_value_heads: int = None
max_position_embeddings: int = 16384
norm_eps: float = None
rms_norm_eps: float = 1e-5
norm_epsilon: float = 1e-5
norm_type: str = "layer_norm"
vocab_size: int = 49152
rope_theta: float = 100000
tie_word_embeddings: bool = True

def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads

if self.norm_eps is None:
self.norm_eps = self.rms_norm_eps


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
Expand Down Expand Up @@ -68,12 +60,9 @@ def __call__(
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])

if self.repeats > 1:
keys, values = map(repeat, (keys, values))
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)

if cache is not None:
key_cache, value_cache = cache
Expand Down Expand Up @@ -111,9 +100,9 @@ def __init__(self, args: ModelArgs):

self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
self.post_attention_layernorm = LayerNorm(
args.hidden_size, eps=args.rms_norm_eps
args.hidden_size, eps=args.norm_epsilon
)
self.args = args

Expand Down Expand Up @@ -141,7 +130,7 @@ def __init__(self, args: ModelArgs):
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)

def __call__(
self,
Expand Down