Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions llms/mlx_lm/models/recurrent_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class ModelArgs(BaseModelArgs):
hidden_size: int
attention_bias: bool
conv1d_width: int
embeddings_scale_by_sqrt_dim: bool
hidden_size: int
intermediate_size: int
logits_soft_cap: float
Expand All @@ -25,7 +24,14 @@ class ModelArgs(BaseModelArgs):
rope_theta: float
attention_window_size: int
vocab_size: int
_block_types: List[str]
embeddings_scale_by_sqrt_dim: bool = True
block_types: Optional[List[str]] = None
_block_types: Optional[List[str]] = None

def __post_init__(self):
# For some reason these have different names in 2B and 9B
if self.block_types is None:
self.block_types = self._block_types


def create_window_causal_mask(N: int, window_size: int):
Expand Down Expand Up @@ -202,6 +208,8 @@ def apply_block_linear(h, w, b):

# Apply gamma normalization to the input.
multiplier = mx.sqrt(1 - a_square)
if cache is None:
multiplier[:, 0, :] = 1.0
normalized_x = gated_x * multiplier.astype(x.dtype)

y, last_h = rnn_scan(
Expand Down Expand Up @@ -404,8 +412,8 @@ def __call__(
raw_x = x

inputs_normalized = self.temporal_pre_norm(raw_x)
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)

x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
residual = x + raw_x

x = self.channel_pre_norm(residual)
Expand All @@ -427,7 +435,7 @@ def __init__(self, config):
)

self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
block_types = config._block_types
block_types = config.block_types

self.layers = [
ResidualBlock(
Expand Down Expand Up @@ -461,28 +469,31 @@ def __call__(
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])

x = self.final_norm(x)
logits = self.embed_tokens.as_linear(x)

c = self.config.logits_soft_cap
if c:
logits = mx.tanh(logits / c) * c

return logits
return self.final_norm(x)


class Model(nn.Module):

def __init__(self, config):
self.args = config
self.model = Griffin(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

def __call__(self, tokens: mx.array, cache=None) -> mx.array:
"""
Args:
tokens: Sequence of input tokens.
"""
return self.model(tokens, cache=cache)
logits = self.model(tokens, cache=cache)
if "lm_head" in self:
logits = self.lm_head(logits)
else:
logits = self.model.embed_tokens.as_linear(logits)

c = self.args.logits_soft_cap
if c:
logits = mx.tanh(logits / c) * c
return logits

@property
def layers(self):
Expand All @@ -493,6 +504,8 @@ def sanitize(self, weights):
for k, v in weights.items():
if "conv_1d.weight" in k and v.ndim == 3:
weights[k] = v.squeeze(1).T
if "lm_head.weight" not in weights:
self.pop("lm_head")
return weights

def make_cache(self):
Expand Down
1 change: 1 addition & 0 deletions llms/mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def do_POST(self):
endpoints = {
"/v1/completions": self.handle_text_completions,
"/v1/chat/completions": self.handle_chat_completions,
"/chat/completions": self.handle_chat_completions,
}

if self.path not in endpoints:
Expand Down