Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions llms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,18 @@ Most
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending),
and
[Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending)
[Starcoder2](https://huggingface.co/models?library=transformers,safetensors&other=starcoder2&sort=trending)
style models should work out of the box.

For some models (such as `Qwen` and `plamo`) the tokenizer requires you to
enable the `trust_remote_code` option. You can do this by passing
`--trust-remote-code` in the command line. If you don't specify the flag
explicitly, you will be prompted to trust remote code in the terminal when
running the model.
running the model.

For `Qwen` models you must also specify the `eos_token`. You can do this by
passing `--eos-token "<|endoftext|>"` in the command
line.
line.

These options can also be set in the Python API. For example:

Expand Down
195 changes: 195 additions & 0 deletions llms/mlx_lm/models/starcoder2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from dataclasses import dataclass
from typing import Optional, Tuple

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs
from .layers import LayerNorm


@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
sliding_window: int
num_attention_heads: int
num_key_value_heads: int = None
vocab_size: int = 49152
intermediate_size: int = 12288
max_position_embeddings: int = 16384
norm_epsilon: float = 1e-05
rope_theta: float = 10000
rope_traditional: bool = False
attention_dropout: int = 0.0
residual_dropout: int = 0.0
embedding_dropout: int = 0.0
use_bias: bool = True
tie_word_embeddings: bool = True


class Starcoder2Attention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
"""

def __init__(self, args: ModelArgs):
super().__init__()
self.config = args
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.hidden_size // n_heads

self.repeats = n_heads // n_kv_heads

self.scale = head_dim**-0.5

self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)

self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape

queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)

if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)

scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we need to keep scores.astype(mx.float32) in case softmax overflows on float16 values.

output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)


class Starcoder2MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.c_fc = nn.Linear(dim, hidden_dim, bias=True)
self.c_proj = nn.Linear(hidden_dim, dim, bias=True)

def __call__(self, x) -> mx.array:
return self.c_proj(nn.gelu(self.c_fc(x)))


class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Starcoder2Attention(args)
self.mlp = Starcoder2MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
self.post_attention_layernorm = LayerNorm(
args.hidden_size, eps=args.norm_epsilon
)
self.args = args

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache


class Starcoder2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)

def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)

mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)

if cache is None:
cache = [None] * len(self.layers)

for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])

return self.norm(h), cache


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = Starcoder2Model(args)
self.tie_word_embeddings = args.tie_word_embeddings

# If tie_word_embeddings is False, tie (share) the embedding weights with lm_head
# Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/starcoder2/modeling_starcoder2.py#L1071
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)

if not self.tie_word_embeddings:
return self.lm_head(out), cache
else:
out = out @ self.model.embed_tokens.weight.T
return out, cache

@property
def layers(self):
return self.model.layers
1 change: 1 addition & 0 deletions llms/mlx_lm/tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def check_lora_layers(num_model):
"stablelm",
"qwen2",
"gemma",
"starcoder2",
]:
check_lora_layers(len(model.model.layers))

Expand Down