Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 96 additions & 7 deletions src/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,25 +127,79 @@ bool llama_memory_hybrid::get_can_shift() const {
void llama_memory_hybrid::clear(bool data) {
mem_attn->clear(data);
mem_recr->clear(data);
recr_rebuild_needed.clear();
}

bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
// Try removing from the recurrent cache first since it may fail. If it does
// fail, the cache will not have been mutated.
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
// For hybrid models, we handle the case where recurrent cache cannot do
// partial removal (which is expected for Mamba/SSM layers).
//
// Strategy:
// 1. Try recurrent removal first
// 2. If it fails (partial removal not supported), clear recurrent entirely
// and mark the sequence for SSM state rebuild
// 3. Proceed with attention cache removal (which can handle partial removal)
// 4. Return success if attention succeeded
//
// The SSM state will be rebuilt during the next decode pass while the
// KV cache remains valid and can be reused.

bool recr_ok = mem_recr->seq_rm(seq_id, p0, p1);

if (!recr_ok) {
// Recurrent cache cannot do partial removal - this is expected for Mamba.
// Clear the recurrent state entirely for this sequence.
// It will be rebuilt during the next decode pass.
LLAMA_LOG_WARN("%s: recurrent seq_rm failed for seq %d [%d, %d), "
"clearing recurrent state and marking for rebuild\n",
__func__, seq_id, p0, p1);

// Clear recurrent state for this sequence
// Use full range removal which should always succeed
mem_recr->seq_rm(seq_id, -1, -1);

// Mark this sequence as needing SSM state rebuild from position 0
mark_recurrent_rebuild(seq_id);
}

// Now handle attention cache - this should work for transformers
bool attn_ok = mem_attn->seq_rm(seq_id, p0, p1);

if (!attn_ok) {
LLAMA_LOG_ERROR("%s: attention seq_rm failed for seq %d [%d, %d)\n",
__func__, seq_id, p0, p1);
return false;
}
return mem_attn->seq_rm(seq_id, p0, p1);

// Success: attention cache was trimmed, recurrent either trimmed or marked for rebuild
return true;
}

void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);

// If source needs rebuild, destination also needs rebuild
if (needs_recurrent_rebuild(seq_id_src)) {
mark_recurrent_rebuild(seq_id_dst);
}
}

void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
mem_attn->seq_keep(seq_id);
mem_recr->seq_keep(seq_id);

// Clear rebuild flags for sequences that are being removed
// (keep only the specified sequence)
std::unordered_set<llama_seq_id> to_remove;
for (const auto & id : recr_rebuild_needed) {
if (id != seq_id) {
to_remove.insert(id);
}
}
for (const auto & id : to_remove) {
recr_rebuild_needed.erase(id);
}
}

void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
Expand All @@ -159,12 +213,29 @@ void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p
}

llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
// the min of the total cache is the max of the two caches' min values
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
// For hybrid models, return the ATTENTION cache pos_min.
//
// Rationale:
// - The attention KV cache has valid data from position 0 (or wherever it starts)
// - The recurrent cache only tracks recent SSM state (pos_min near current pos)
// - Using max() of both causes the server to see "invalid cache" and force n_past=0
//
// By returning attention pos_min, the server can continue from n_past based on
// KV cache validity. The recurrent state will be rebuilt during decode if needed
// (tracked via recr_rebuild_needed).
//
// This allows hybrid models to benefit from KV cache reuse even when the prompt
// changes (e.g., after compaction).

return mem_attn->seq_pos_min(seq_id);
}

llama_pos llama_memory_hybrid::seq_pos_min_attn(llama_seq_id seq_id) const {
return mem_attn->seq_pos_min(seq_id);
}

llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
// the max of the total cache is the min of the two caches' max values
// the max of the total cache is the min of the two caches max values
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
}

Expand Down Expand Up @@ -198,6 +269,24 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
return mem_recr.get();
}

bool llama_memory_hybrid::needs_recurrent_rebuild(llama_seq_id seq_id) const {
return recr_rebuild_needed.count(seq_id) > 0;
}

void llama_memory_hybrid::mark_recurrent_rebuild(llama_seq_id seq_id) {
recr_rebuild_needed.insert(seq_id);
LLAMA_LOG_DEBUG("%s: marked seq %d for recurrent rebuild\n", __func__, seq_id);
}

void llama_memory_hybrid::clear_recurrent_rebuild(llama_seq_id seq_id) {
recr_rebuild_needed.erase(seq_id);
LLAMA_LOG_DEBUG("%s: cleared recurrent rebuild flag for seq %d\n", __func__, seq_id);
}

//
// llama_memory_hybrid_context
//

llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}

llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
Expand Down
28 changes: 28 additions & 0 deletions src/llama-memory-hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "llama-memory-recurrent.h"

#include <memory>
#include <unordered_set>
#include <vector>

//
Expand Down Expand Up @@ -82,11 +83,38 @@ class llama_memory_hybrid : public llama_memory_i {
llama_kv_cache * get_mem_attn() const;
llama_memory_recurrent * get_mem_recr() const;

//
// Hybrid-specific: recurrent state rebuild tracking
//
// When seq_rm fails for recurrent (partial removal not supported),
// we clear the recurrent state and mark it for rebuild. The next
// decode pass will rebuild the SSM state while reusing the KV cache.
//

// Check if a sequence needs its recurrent state rebuilt from position 0
bool needs_recurrent_rebuild(llama_seq_id seq_id) const;

// Mark a sequence as needing recurrent state rebuild
void mark_recurrent_rebuild(llama_seq_id seq_id);

// Clear the rebuild flag after rebuild is complete
void clear_recurrent_rebuild(llama_seq_id seq_id);

// Get the attention cache pos_min (for server pos_min checks)
// This allows continuing from n_past based on KV cache validity,
// independent of recurrent state which will be rebuilt.
llama_pos seq_pos_min_attn(llama_seq_id seq_id) const;

private:
const llama_hparams & hparams;

const std::unique_ptr<llama_kv_cache> mem_attn;
const std::unique_ptr<llama_memory_recurrent> mem_recr;

// Track sequences that need recurrent state rebuilt from position 0
// This happens when seq_rm fails for recurrent (partial removal not supported)
// but succeeds for attention (KV cache can be trimmed)
mutable std::unordered_set<llama_seq_id> recr_rebuild_needed;
};

class llama_memory_hybrid_context : public llama_memory_context_i {
Expand Down