diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index a1b45e4a3cc..b4656681c0f 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -127,25 +127,79 @@ bool llama_memory_hybrid::get_can_shift() const { void llama_memory_hybrid::clear(bool data) { mem_attn->clear(data); mem_recr->clear(data); + recr_rebuild_needed.clear(); } bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - // Try removing from the recurrent cache first since it may fail. If it does - // fail, the cache will not have been mutated. - if (!mem_recr->seq_rm(seq_id, p0, p1)) { + // For hybrid models, we handle the case where recurrent cache cannot do + // partial removal (which is expected for Mamba/SSM layers). + // + // Strategy: + // 1. Try recurrent removal first + // 2. If it fails (partial removal not supported), clear recurrent entirely + // and mark the sequence for SSM state rebuild + // 3. Proceed with attention cache removal (which can handle partial removal) + // 4. Return success if attention succeeded + // + // The SSM state will be rebuilt during the next decode pass while the + // KV cache remains valid and can be reused. + + bool recr_ok = mem_recr->seq_rm(seq_id, p0, p1); + + if (!recr_ok) { + // Recurrent cache cannot do partial removal - this is expected for Mamba. + // Clear the recurrent state entirely for this sequence. + // It will be rebuilt during the next decode pass. + LLAMA_LOG_WARN("%s: recurrent seq_rm failed for seq %d [%d, %d), " + "clearing recurrent state and marking for rebuild\n", + __func__, seq_id, p0, p1); + + // Clear recurrent state for this sequence + // Use full range removal which should always succeed + mem_recr->seq_rm(seq_id, -1, -1); + + // Mark this sequence as needing SSM state rebuild from position 0 + mark_recurrent_rebuild(seq_id); + } + + // Now handle attention cache - this should work for transformers + bool attn_ok = mem_attn->seq_rm(seq_id, p0, p1); + + if (!attn_ok) { + LLAMA_LOG_ERROR("%s: attention seq_rm failed for seq %d [%d, %d)\n", + __func__, seq_id, p0, p1); return false; } - return mem_attn->seq_rm(seq_id, p0, p1); + + // Success: attention cache was trimmed, recurrent either trimmed or marked for rebuild + return true; } void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1); mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1); + + // If source needs rebuild, destination also needs rebuild + if (needs_recurrent_rebuild(seq_id_src)) { + mark_recurrent_rebuild(seq_id_dst); + } } void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) { mem_attn->seq_keep(seq_id); mem_recr->seq_keep(seq_id); + + // Clear rebuild flags for sequences that are being removed + // (keep only the specified sequence) + std::unordered_set to_remove; + for (const auto & id : recr_rebuild_needed) { + if (id != seq_id) { + to_remove.insert(id); + } + } + for (const auto & id : to_remove) { + recr_rebuild_needed.erase(id); + } } void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { @@ -159,12 +213,29 @@ void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p } llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const { - // the min of the total cache is the max of the two caches' min values - return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id)); + // For hybrid models, return the ATTENTION cache pos_min. + // + // Rationale: + // - The attention KV cache has valid data from position 0 (or wherever it starts) + // - The recurrent cache only tracks recent SSM state (pos_min near current pos) + // - Using max() of both causes the server to see "invalid cache" and force n_past=0 + // + // By returning attention pos_min, the server can continue from n_past based on + // KV cache validity. The recurrent state will be rebuilt during decode if needed + // (tracked via recr_rebuild_needed). + // + // This allows hybrid models to benefit from KV cache reuse even when the prompt + // changes (e.g., after compaction). + + return mem_attn->seq_pos_min(seq_id); +} + +llama_pos llama_memory_hybrid::seq_pos_min_attn(llama_seq_id seq_id) const { + return mem_attn->seq_pos_min(seq_id); } llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const { - // the max of the total cache is the min of the two caches' max values + // the max of the total cache is the min of the two caches max values return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); } @@ -198,6 +269,24 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const { return mem_recr.get(); } +bool llama_memory_hybrid::needs_recurrent_rebuild(llama_seq_id seq_id) const { + return recr_rebuild_needed.count(seq_id) > 0; +} + +void llama_memory_hybrid::mark_recurrent_rebuild(llama_seq_id seq_id) { + recr_rebuild_needed.insert(seq_id); + LLAMA_LOG_DEBUG("%s: marked seq %d for recurrent rebuild\n", __func__, seq_id); +} + +void llama_memory_hybrid::clear_recurrent_rebuild(llama_seq_id seq_id) { + recr_rebuild_needed.erase(seq_id); + LLAMA_LOG_DEBUG("%s: cleared recurrent rebuild flag for seq %d\n", __func__, seq_id); +} + +// +// llama_memory_hybrid_context +// + llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {} llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) : diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index 558cafdf984..2682f34a20b 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -7,6 +7,7 @@ #include "llama-memory-recurrent.h" #include +#include #include // @@ -82,11 +83,38 @@ class llama_memory_hybrid : public llama_memory_i { llama_kv_cache * get_mem_attn() const; llama_memory_recurrent * get_mem_recr() const; + // + // Hybrid-specific: recurrent state rebuild tracking + // + // When seq_rm fails for recurrent (partial removal not supported), + // we clear the recurrent state and mark it for rebuild. The next + // decode pass will rebuild the SSM state while reusing the KV cache. + // + + // Check if a sequence needs its recurrent state rebuilt from position 0 + bool needs_recurrent_rebuild(llama_seq_id seq_id) const; + + // Mark a sequence as needing recurrent state rebuild + void mark_recurrent_rebuild(llama_seq_id seq_id); + + // Clear the rebuild flag after rebuild is complete + void clear_recurrent_rebuild(llama_seq_id seq_id); + + // Get the attention cache pos_min (for server pos_min checks) + // This allows continuing from n_past based on KV cache validity, + // independent of recurrent state which will be rebuilt. + llama_pos seq_pos_min_attn(llama_seq_id seq_id) const; + private: const llama_hparams & hparams; const std::unique_ptr mem_attn; const std::unique_ptr mem_recr; + + // Track sequences that need recurrent state rebuilt from position 0 + // This happens when seq_rm fails for recurrent (partial removal not supported) + // but succeeds for attention (KV cache can be trimmed) + mutable std::unordered_set recr_rebuild_needed; }; class llama_memory_hybrid_context : public llama_memory_context_i {