server: Fix multimodal context checkpointing for hybrid/recurrent models#19747
server: Fix multimodal context checkpointing for hybrid/recurrent models#19747timkhronos wants to merge 13 commits intoggml-org:masterfrom
Conversation
| const llama_pos checkpoint_pos = std::max(it->pos_min, it->pos_max); | ||
| llama_memory_seq_rm(llama_get_memory(ctx), slot.id, checkpoint_pos, -1); | ||
|
|
||
| slot.prompt_clear(true); | ||
| const size_t checkpoint_size = it->data.size(); | ||
| const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); |
There was a problem hiding this comment.
Could we avoid the changes to libllama by simply calling llama_memory_seq_rm() after restoring the partial (i.e. recurrent in this case) state?
There was a problem hiding this comment.
I believe the changes to find slot in llama-memory-recurrent.cpp are unavoidable, as the issue is at checkpoint creation, not recovery. Reordering the restore/seq_rm in the server would not help because the recurrent cell's position is already wrong in the checkpoint data.
Without the fix, when we try to find the last_pos, it will always read the position from the temporal plane, which is the same for all processed image tokens (lower than the real last position).
This will later cause M-rope constraint violations when we try to restore. The checkpoint would also save the position of the recurrent cell wrong. The kv cache would be properly truncated to the last image token, the recurrent cell's state would also be saved properly, after it was modified by the processed image, but the checkpoint would save the position of the recurrent cell to be just before the image, even though it was already modified by the image.
There was a problem hiding this comment.
Can you take an example of such prompt?
In reality, the chance of having image as the last token position should never happen. This is because image is always being followed by end-of-image and/or end-of-turn tokens.
If we want to avoid the problematic case where image is at the end of prompt (without any text tokens after it), then the simple fix is to simply delete the image.
In other words, we can simply this whole logic by doing llama_memory_seq_rm() the image plus one token before it.
There was a problem hiding this comment.
Could you please clarify on what part of the implementation you have a problem with?
I don't quite see how we could achieve a valid fix without modifying the checkpointing logic. It was written for text only checkpoints, and doesn't take into account that multiple M-RoPE compressed tokens can share the same temporal value.
In my previous message, my point wasn't that there could be a case where the last token of a prompt is an image token, rather that without changing the logic in memory-recurrent.cpp, the checkpoints wouldn't properly track the amount of M-RoPE compressed image tokens, causing the kv cache and the recurrent cell to desync after a checkpoint restore. As a band aid we could truncate the image(which might still not be a good idea as right now checkpoints are created after prompt processing, so after the recurrent cell is modified by the image. If we load the checkpoint but truncate the image from the kv cache and make it reprocess the image, the recurrent cell will have been modified by the image twice.), but without the modification , if we ever, at any point create a checkpoint after an image has been processed, the checkpoint will be created at the wrong position.
There was a problem hiding this comment.
as the issue is at checkpoint creation, not recovery.
I think we have to add logic to not do checkpoints in the middle of an image. This is important because some vision models use non-causal attention and this requires all image tokens for a given image to be processed in a single ubatch (so that each image token can "see" every other image token).
IIUC your proposed solution implicitly assumes causal attention for the image tokens. Although this seems to work for Qwen3.5, it is not completely generic as described in the paragraph above.
Unless I am missing something, if we impose the restriction to not create the checkpoints in the middle of an image, then we won't need the extra changes to libllama. Do you agree?
There was a problem hiding this comment.
but without the modification , if we ever, at any point create a checkpoint after an image has been processed, the checkpoint will be created at the wrong position.
Please correct if I understand this phrase correctly: what you are saying is the case of M-RoPE where one image takes the same temporal index, example, my image has t=2:
0 1 2 2 2 2 2 4
So what you are saying is that from the perspective of the recurrent layer, it sees:
0 1 2 3 4 5 6 7 ...
Essentially a linearly increasing index, correct?
If that's the case, then can we reuse the same positional tracking system between the 2?
As a band aid we could truncate the image(which might still not be a good idea as right now checkpoints are created after prompt processing, so after the recurrent cell is modified by the image. If we load the checkpoint but truncate the image from the kv cache and make it reprocess the image, the recurrent cell will have been modified by the image twice.)
We should not track the amount of tokens. Instead, we should track the position index.
If we track the position index, the whole image can be viewed as one big blob, it is not allowed to have a checkpoint of half-image as @ggerganov explained above.
To mitigate the case where image can take multiple batches, and that user can potentially stop the processing mid-way, we can always create one checkpoint just right before we process an image.
There was a problem hiding this comment.
@ggerganov You're right that we could add a guard against mid image checkpoints, even though currently as far as I understand, checkpoints are only created after we are done with prompt processing, so there should never be a chance to create a checkpoint mid image. It would be a problem if that ever changed, so we can add that check.
However, that alone wouldn't fix the issue I previously brought up. The bug is in how the checkpoint position is recorded, and it happens even when the checkpoint is created well after the image has been fully processed, regardless of how many text tokens we have processed after.
Here's a concrete example with some imaginary values:
- We have 6000 text tokens (positions 0–5999) in a conversation
- An image gets processed, producing ~1500 vision tokens that get M-RoPE compressed to let's say 60 KV entries (positions 6000–6059)
- A few hundred text tokens follow after the image (positions 6060–6259)
- Prompt processing finishes, and we create a checkpoint
The original code in memory-recurrent.cpp records the position of the checkpoint by looking at the temporal value of the last token. For text tokens this works fine as each text token increments the temporal value by 1. But for M-RoPE image tokens, all tokens from a single image share the same temporal value. So in the example, we are at position 6259. If the code tries to find the position of the last token however, the image block at 6000–6059 will report a temporal value of ~6000 for all 60 entries, causing the checkpoint to be created for position 6200 ( image tokens temporal pos is 6000 + 200 text tokens. This would ignore that the 60 image tokens are sharing a temporal position, thus are only counted as one)
This means the checkpoint records a position that is too low. The recurrent cell's state has been updated by the full image + 200 text tokens, but the checkpoint metadata says it only covers up to ~6200. On restore, this desync between the recurrent state and the KV cache causes M-RoPE constraint violations + in theory a drift of the recurrent cache of (Number of images in context) * (Image tokens in image -1) * (Number of times we restored the checkpoint) number of tokens after we restore. The drift would get more severe the more times we restore the checkpoint.
The code needs to look at the max(width, height) values from the M-RoPE position planes too, to find the actual span of the image, otherwise the image, regardless of how many tokens it is, would always return as just 1 token, causing the drift. If video vision is ever implemented the way I have it set up here should work correctly, since we look at Temporal + Width + Height.
The above is why I believe the changes to the libllama are needed. Without it, images would always report as only a single token to the checkpointing logic, causing more and more severe drift as the number of restores + images are increased.
I agree we should add a mid-image checkpoint guard for safety, even though I believe it is currently not possible for the create checkpoint logic to run before prompt processing is fully finished, thus preventing that specific scenario, unless I am misunderstanding something. Regardless if someone were to ever add mid prompt processing checkpointing they could run into that issue, and even though I think it would be best addressed in that PR, I'd be happy to add it here if you believe it's best.
But the core problem, that requires the libllama changes is that post-image checkpoints record the wrong position because the temporal plane doesn't reflect the true extent of M-RoPE compressed tokens.
There was a problem hiding this comment.
The original code in memory-recurrent.cpp records the position of the checkpoint by looking at the temporal value of the last token. For text tokens this works fine as each text token increments the temporal value by 1. But for M-RoPE image tokens, all tokens from a single image share the same temporal value. So in the example, we are at position 6259. If the code tries to find the position of the last token however, the image block at 6000–6059 will report a temporal value of ~6000 for all 60 entries, causing the checkpoint to be created for position 6200 ( image tokens temporal pos is 6000 + 200 text tokens. This would ignore that the 60 image tokens are sharing a temporal position, thus are only counted as one)
I'm pretty sure this problem can be better understand from my question above:
If that's the case, then can we reuse the same positional tracking system between the 2?
Also note that we already somewhat has this logic in the mask construction of m-rope:
llama.cpp/src/llama-kv-cache.cpp
Lines 1380 to 1389 in d8aeb65
So I don't think the current code is acceptable as-is, especially because it assume the next position is max(x, y). This calculation must not be inside libllama.
| if (ubatch.n_pos > 1 && ubatch.embd != nullptr) { | ||
| for (uint32_t p = 0; p < ubatch.n_pos; ++p) { | ||
| for (uint32_t t = 0; t < n_seq_tokens; ++t) { | ||
| last_pos = std::max(last_pos, ubatch.pos[p * ubatch.n_tokens + i + t]); | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
This must not be inside libllama. If you look into server code, server_tokens::pos_next() already handle this in more generic way.
There was a problem hiding this comment.
The reason I believe this needs to be in libllama is that the find_slot() is where the recurrent cell's position gets written. This happens internally during ubatch processing, and I don't see where the server would get a chance to intercept or correct the value before it gets stored.
Specifically, find_slot() reads ubatch.pos[i + n_seq_tokens - 1] to determine last_pos, which then gets written to the cell. For M-RoPE ubatches, with the current implementation, that value comes from the temporal plane and is too low. server_tokens::pos_next() tracks the next position to assign, but it doesn't fix that the cell already recorded the last pos during find_slot().
I'm open to handling this in another way, but I personally can't see it fixed without modifying recurrent.cpp since at least as far as I saw, by the time the server could do anything about it, the wrong value was already stored, so we need to fix it where the value gets stored.
There was a problem hiding this comment.
So basically what you said is that llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1]; will return the incorrect position index if the last token in ubatch is an image token.
But in reality, what's the example of such input? You're talking about the case where ubatch contains part of the image? Or the case where image is at the end of prompt?
There was a problem hiding this comment.
The problem isn't specifically about the last token being an image token, or images being at the end of the prompt. It happens for any ubatch that contains M-RoPE image tokens, regardless of position in the conversation.
When find_slot() processes an ubatch containing image tokens, it sets the cell's last_pos from the temporal plane. For an image with 60 M-RoPE compressed KV entries at positions 6000–6059, all 60 entries share a temporal value of 6000. So after processing that ubatch, the cell records last_pos = 6000 instead of 6059.
Then the next ubatch with text tokens gets called. The cell thinks it's at position 6000, but the KV cache knows the image occupied slots up to 6059. This drift persists through every subsequent checkpoint save/restore.
Basically let's say:
- text tokens 0–5999 -> cell records last_pos = 5999
- image tokens, 60 entries, real positions are 6000–6059, but the temporal position for all image tokens is = 6000 thus the cell incorrectly records last_pos as 6000 when it should be 6059
- text tokens 6060–6259 -> cell records last_pos as 6200, when it should be 6259
The drift of 59 tokens is always present from 2 onward. It's not an edge case, but rather it happens every time an image is processed. I think my longer reply above to @ggerganov might help illustrate my concern better.
There was a problem hiding this comment.
ubatch.pos[i + n_seq_tokens - 1] explained in human language is: "get the (temporal) position of the last token in ubatch"
Now, if the "last token in ubatch" is a text token, there is nothing wrong with this logic, because it will always be set to the correct position (pos = 6259 in your case) before even added into the batch.
Which other cases do you think, that the logic above can be wrong?
There was a problem hiding this comment.
- text tokens 0–5999 -> cell records last_pos = 5999
- image tokens, 60 entries, real positions are 6000–6059, but the temporal position for all image tokens is = 6000 thus the cell incorrectly records last_pos as 6000 when it should be 6059
- text tokens 6060–6259 -> cell records last_pos as 6200, when it should be 6259
In other words: the moment you decode another text token, because its position is set by server_tokens, its pos will be set to 6259 + 1 = 6260
llama_pos pos = server_tokens.pos_next(); // returns 6259 + 1 = 6260
common_batch_add(batch, text_token, pos, ...);
llama_decode(batch); // now, last_pos will be updated to 6060So, suddenly, everything will be in sync again?
There was a problem hiding this comment.
The cell thinks it's at position 6000, but the KV cache knows the image occupied slots up to 6059.
Also, important note that KV slot does NOT know the that image occupied position upto 6059, it only knows that 6059 cells are used.
Here is the code where cell.pos is updated in KV cache:
llama.cpp/src/llama-kv-cache.cpp
Line 932 in d8aeb65
There was a problem hiding this comment.
Despite what you said making sense in theory, in practice it fails. I've just tested a build with the find_slot() changes in memory-recurrent.cpp reverted but all other changes intact. The issue reproduces immediately, always prompting a full reprocess after an image is present.
Here are a few relevant pieces of logs from a build without the find_slot modification, everything else from my PR intact. After an image is processed and a checkpoint is created:
created context checkpoint 1 of 8 (pos_min = 11667, pos_max = 11667)
find_slot: non-consecutive token position 11707 after 11667 for sequence 3 with 10 new tokens
The checkpoint is created after prompt processing completes but before generation begins. At that moment, cell.pos is still 11667 — the wrong value from the temporal plane. The actual position should be 11707 (the gap of 40 is the compressed image tokens being counted as 1).
The "self-correction" from token decoding comes too late, by that point the checkpoint has already been initialized with pos_min = 11667, pos_max = 11667.
On the next conversation turn, we try to restore this checkpoint:
restored context checkpoint (pos_min = 11667, pos_max = 11667)
memory_seq_rm [10986, end)
failed to recover recurrent state - clearing the memory
seq_rm fails because the position data is wrong, then we fall back to a full prompt reprocess from scratch, which is the exact same behavior from issue #19690.
With the find_slot() fix restored, the checkpoint records the correct position and restore works as expected.
There was a problem hiding this comment.
After an image is processed and a checkpoint is created
Unless I missed something: checkpoint is only created upon SLOT_STATE_DONE_PROMPT, so that means your prompt ends with an image, not a text token.
From what you confirmed, pos_max = 11667 is the cell.pos value which you assumed to be wrong in recurrent case. However, as I explained above, adding one more text token will correct cell.pos (the text token must be inside the prompt, not as a generated text) please verify this.
There was a problem hiding this comment.
I have tested this in a normal SillyTavern conversations, text before the image, text after the image, so nothing unusual. Checkpoint restore still fails without the fix to memory-recurrent.cpp. The text token after the image do not appear to correct the position tracking for the M-RoPE compressed image tokens.
Logically, find_slot was written assuming each token increments position by 1. M-RoPE image tokens break that assumption, many tokens share a temporal position but occupy distinct spatial positions. I believe this should be accounted for at the checkpoint creation level, not bridged over later. Doing that would be a workaround, that would probably end up being fairly fragile if someone later changes something about the checkpointing or the tracking logic. But in this case, it doesn't seem to work at any rate.
| llama_pos pos_next() const; | ||
| const mtmd::input_chunk_ptr & find_chunk(size_t idx) const; | ||
|
|
||
| size_t tokens_up_to_pos(llama_pos max_pos) const; |
There was a problem hiding this comment.
instead of adding this as a dedicated function, you just need to extend pos_next() to have an optional arg: pos_next(llama_pos i_pos_start = -1)
There was a problem hiding this comment.
Makes sense, I can add the code from tokens_up_to_pos into pos_next(). I'll use the schema you suggested, but the naming could end up becoming a bit confusing. pos_next(-1) would mean next position and pos_next(6259) would mean how many tokens up to position 6259. These would be pretty different operations.
There was a problem hiding this comment.
before you do so, I would suggest reflecting one more time about my point about not using absolute token count at all
most logics here suggests me that the conversion between token count <--> position index is redundant, as we can just simply use position index
the n_past_new calculation is for 2 purposes: filling out the slot.n_prompt_tokens_cache = n_past_new and calling keep_first(), which I proved to be wrong in another comment; n_prompt_tokens_cache can be changed to n_prompt_pos_cache
There was a problem hiding this comment.
I think refactoring the server to track positions natively instead of token counts is a worthwhile idea, but that would touch a lot of code beyond this PR. This PR is trying to add support for context checkpointing with recurrent models, on multimodal contexts.
Regarding keep_first(n_past_new) n_past_new is already a token count, not a position index. The tokens_up_to_pos() call converts the checkpoint position into the corresponding token count, specifically to handle the M-RoPE case where token count >= position index.
So the conversion is neccessary here to make keep_first() work correctly with its existing API. It could be modified to take a position instead of a count, but since I didn't add n_prompt_tokens_cache and keep_first(), and they are already used throughout the server codebase with their current behaviour, that would probably require reworking large swaths of the code base.
This would be a larger scale refactor, and I think it's beyond the scope of this PR.
| SLT_WRN(slot, "recovered recurrent state from checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d), n_past: %d -> %d\n", | ||
| it->pos_min, it->pos_max, it->n_tokens_cached, slot.prompt.n_tokens(), n_past_new); | ||
|
|
||
| slot.prompt.tokens.keep_first(n_past_new); |
There was a problem hiding this comment.
This seems to be wrong: keep_first is the absolute number of tokens (or KV slots count), not the position index
If you do keep_first(n_past_new) here, it will end up removing more tokens than needed, because in case of m-rope: number of tokens >= index position
Instead, the cleaner way is to convert everything to token index, we should no longer rely on absolute token count
There was a problem hiding this comment.
n_past_new is already a token count here, not a position. It comes from tokens_up_to_pos(checkpoint_pos), which converts the position index into the absolute number of tokens (correctly accounting for M-RoPE compressed image tokens). So keep_first(n_past_new) should be correct.
That said, I understand the confusion since the variable name could be better. I can rename it to something like n_tokens_keep to make the intent more obvious.
There was a problem hiding this comment.
hmm right, I was indeed confused.
to make it clear, probably better to have 2 different keep_first:
keep_first(size_t)takes the number of absolute tokenskeep_first_n_pos(llama_pos)takes the position index
By having specific type, I hope that any misuse of these 2 can be detected by the compiler
There was a problem hiding this comment.
That's fine, I'll add the new keep first pos, that will handle the position-to-token-count conversion internally, then call keep first.
|
@timkhronos Could you confirm that #19849 works correctly? I tried to simplify the approach here and need some feedback if that implementation is correct. |
|
@ggerganov The implementation in #19849 seems to work fine in the same test cases I tried for my implementation, However I have found an issue during testing: If we have a conversation (Ta = assistant text turn, Tu = user text turn, Iu = user image turn):
If we delete turn 10, 9, and swipe on reply 8, or if we delete turns 10,9,8 and send 7. to generate, the recurrent state incorrectly gets restored. The checkpoint should have become invalid since the prompt has moved to a point, (7) before it, instead the checkpoint is still restored, even though its recurrent cell state was computed after processing tokens that no longer exist. The checkpoint at pos_min=pos_max=8440 was created after processing all 11052 tokens, so the recurrent state reflects all of them. After deletion (prompt is now only 10607 tokens), restoring this checkpoint contaminates the recurrent state with influence from deleted messages. We'd need to also store the total prompt length or max position at checkpoint creation time, and invalidate any checkpoint where that exceeds the current context.
|
|
@timkhronos I've been running this branch since yesterday and it works great so far. But today I've hit a regression. During an agentic workflow, this has happened: And it got into a loop. I had to kill llamacpp and it got back just fine. |
I am not sure that is correct. The reason is because the checkpoint is created before processing the last batch. So these logs: They are a bit misleading. It means that the checkpoint was created before processing the last batch of 512 tokens (i.e. before actually calling It is intentionally done like this (see #16440). We want to store the checkpoint slightly earlier before the full prompt is processed, specifically to allow regenerations or small user message modifications when recurrent state is involved. I'll add some changes to improve the logs in this regard. Can you confirm? |
|
@ggerganov After testing I believe you are correct, and I was indeed the one getting tripped up by the logging. Your implementation seems to behave as expected, with the knowledge that we checkpoint before the last batch, and my extra check would cause undue reprocessing, for no benefit as far as I could see. Recently some smaller Qwen MoE's have released, sharing the same vision capabilities and hybrid recurrent attention as the 397B one does, in case you want to validate the PR personally. |
Yes, it's much easier now that we have these models. I am running evaluation and so far seems to work OK. |
|
Superseded by #19849 |
This PR enables context checkpointing to work with multimodal inputs on hybrid/recurrent architectures (e.g. Qwen3.5).
Previously, checkpointing was hard-disabled when an mmproj was loaded, causing full prompt reprocessing on every turn.
The changes implemented in this PR allow context checkpointing to function normally, aimed at properly handling processed images in the kvcache.
Tested extensively on Qwen3.5, and context checkpointing now works as expected with multimodal contexts.
This closes #19690.