diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 3ef2f69e7c0..27fc9eea6d6 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2369,6 +2369,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_final_logit_softcapping( self.hparams["final_logit_softcapping"] ) + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unusem diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9bfa891d5dc..e87c5826615 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -66,6 +66,7 @@ class Attention: Q_LORA_RANK = "{arch}.attention.q_lora_rank" KV_LORA_RANK = "{arch}.attention.kv_lora_rank" REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" + SLIDING_WINDOW = "{arch}.attention.sliding_window" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 1aeb0d9b086..75a8b2636a6 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -552,6 +552,9 @@ def add_kv_lora_rank(self, length: int) -> None: def add_relative_attn_buckets_count(self, value: int) -> None: self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value) + def add_sliding_window(self, value: int) -> None: + self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) diff --git a/src/llama.cpp b/src/llama.cpp index 451264e79f2..abedd8a6821 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -341,6 +341,7 @@ enum llm_kv { LLM_KV_ATTENTION_Q_LORA_RANK, LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, + LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -433,6 +434,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -2107,7 +2109,7 @@ struct llama_hparams { bool use_par_res; uint32_t n_vocab; - uint32_t n_ctx_train; // context size the model was trained on + uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; uint32_t n_head; uint32_t n_head_kv; @@ -2127,6 +2129,7 @@ struct llama_hparams { uint32_t n_ff_shexp = 0; uint32_t n_expert_shared = 0; float expert_weights_scale = 0.0; + uint32_t n_sliding = 0; // sliding window attention (SWA) float f_norm_eps; float f_norm_rms_eps; @@ -2694,6 +2697,9 @@ struct llama_context { struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + // KQ mask per layer, used by sliding window attention (gemma 2) + std::vector inp_KQ_mask_l; + // control vectors struct llama_control_vector cvec; }; @@ -4762,6 +4768,8 @@ static void llm_load_hparams( } break; case LLM_ARCH_GEMMA2: { + hparams.n_sliding = 4096; // default value of gemma 2 + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_sliding, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); @@ -11108,9 +11116,16 @@ struct llm_build_context { struct ggml_tensor * inp_pos = build_inp_pos(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + // gemma 2 requires different mask for layers using sliding window (SWA) + struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask(); + struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask(); + lctx.inp_KQ_mask_l.clear(); for (int il = 0; il < n_layer; ++il) { + // (il % 2) layers use SWA + struct ggml_tensor * KQ_mask = (il % 2 == 0) ? KQ_mask_SWA : KQ_mask_full; + lctx.inp_KQ_mask_l.push_back(KQ_mask); + // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, @@ -12750,6 +12765,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); float * data = (float *) lctx.inp_KQ_mask->data; + float * data_swa = nullptr; + const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens; + + if (lctx.model.arch == LLM_ARCH_GEMMA2) { + GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer"); + GGML_ASSERT(hparams.n_sliding > 0); + data_swa = (float *) lctx.inp_KQ_mask_l[0]->data; + data = (float *) lctx.inp_KQ_mask_l[1]->data; + // because layer masks are alternate for gemma 2, we only need to take first 2 layers + } // For causal attention, use only the previous KV cells // of the correct sequence for each token of the batch. @@ -12771,6 +12796,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + + // may need to cut off old tokens for sliding window + if (data_swa) { + if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) { + f = -INFINITY; + } + data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; + } } }