From d1d045bcda1b27a1468eaf85495b25562e7bac87 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 14 Feb 2024 15:34:29 -0800 Subject: [PATCH] Change gqa to use repeat instead of concatenate --- llms/gguf_llm/models.py | 7 ++----- llms/mistral/mistral.py | 7 ++----- llms/mixtral/mixtral.py | 7 ++----- llms/mlx_lm/models/llama.py | 7 ++----- llms/mlx_lm/models/mixtral.py | 7 ++----- llms/mlx_lm/models/phi.py | 7 ++----- llms/mlx_lm/models/qwen2.py | 7 ++----- llms/mlx_lm/models/stablelm_epoch.py | 7 ++----- 8 files changed, 16 insertions(+), 40 deletions(-) diff --git a/llms/gguf_llm/models.py b/llms/gguf_llm/models.py index 45976f33a..e60b60d5b 100644 --- a/llms/gguf_llm/models.py +++ b/llms/gguf_llm/models.py @@ -107,12 +107,9 @@ def __call__( keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache diff --git a/llms/mistral/mistral.py b/llms/mistral/mistral.py index 9b9a602ac..39456d970 100644 --- a/llms/mistral/mistral.py +++ b/llms/mistral/mistral.py @@ -73,11 +73,8 @@ def __call__( keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index b1f147065..8a884817c 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -93,11 +93,8 @@ def __call__( keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index f44a94e70..f9f965257 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -93,12 +93,9 @@ def __call__( keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index fbd4c7a3b..5b4875ebc 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -95,12 +95,9 @@ def __call__( 0, 2, 1, 3 ) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.num_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 93bba876c..ce8c226d8 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -86,12 +86,9 @@ def __call__(self, x, mask=None, cache=None): B, L, self.num_key_value_heads, self.head_dim ).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.num_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) # Add RoPE to the queries and keys and combine them with the cache if cache is not None: diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 42abe6884..f3f868adf 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -93,12 +93,9 @@ def __call__( keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache diff --git a/llms/mlx_lm/models/stablelm_epoch.py b/llms/mlx_lm/models/stablelm_epoch.py index a0fe0d30f..2d4922953 100644 --- a/llms/mlx_lm/models/stablelm_epoch.py +++ b/llms/mlx_lm/models/stablelm_epoch.py @@ -87,12 +87,9 @@ def __call__(self, x, mask=None, cache=None): B, L, self.num_key_value_heads, self.head_dim ).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.num_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) # Add RoPE to the queries and keys and combine them with the cache if cache is not None: