-
Notifications
You must be signed in to change notification settings - Fork 499
Context Overflow Fix #1373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Context Overflow Fix #1373
Changes from all commits
556a401
7f54d5c
6b66a79
5d13abd
ec97967
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| using System.Collections.Generic; | ||||||||||||||||||||||||||||||||||||||||||||||||||
| using LLama.Common; | ||||||||||||||||||||||||||||||||||||||||||||||||||
| using LLama.Sampling; | ||||||||||||||||||||||||||||||||||||||||||||||||||
| using System.Collections.Generic; | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace LLama.Abstractions | ||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -36,5 +37,19 @@ public interface IInferenceParams | |||||||||||||||||||||||||||||||||||||||||||||||||
| /// Controls the behavior of decoders like <see cref="StreamingTokenDecoder" /> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// </remark> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| public bool DecodeSpecialTokens { get; set; } | ||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// Defines the strategy the executor should use when the context window is full | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// and the model architecture (e.g., models with 2D RoPE embeddings) does not | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// support native memory shifting. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// </summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ContextOverflowStrategy OverflowStrategy { get; set; } | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// The percentage of past tokens to discard when <see cref="OverflowStrategy"/> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// is set to <see cref="ContextOverflowStrategy.TruncateAndReprefill"/>. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// For example, 0.1f represents dropping the oldest 10% of the conversational context. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+42
to
+51
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| /// Defines the strategy the executor should use when the context window is full | |
| /// and the model architecture (e.g., models with 2D RoPE embeddings) does not | |
| /// support native memory shifting. | |
| /// </summary> | |
| ContextOverflowStrategy OverflowStrategy { get; set; } | |
| /// <summary> | |
| /// The percentage of past tokens to discard when <see cref="OverflowStrategy"/> | |
| /// is set to <see cref="ContextOverflowStrategy.TruncateAndReprefill"/>. | |
| /// For example, 0.1f represents dropping the oldest 10% of the conversational context. | |
| /// Defines the strategy the executor should use when the context window is full. | |
| /// </summary> | |
| /// <remarks> | |
| /// This setting applies even for models that support native memory shifting. | |
| /// Setting <see cref="ContextOverflowStrategy.ThrowException"/> disables automatic | |
| /// shifting or truncation and causes the executor to fail immediately on overflow. | |
| /// </remarks> | |
| ContextOverflowStrategy OverflowStrategy { get; set; } | |
| /// <summary> | |
| /// The percentage of past tokens to discard when <see cref="OverflowStrategy"/> | |
| /// is set to <see cref="ContextOverflowStrategy.TruncateAndReprefill"/> to recover | |
| /// from a full context window. For example, 0.1f represents dropping the oldest | |
| /// 10% of the conversational context. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not doing this now.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| using System; | ||
| using System.Collections.Generic; | ||
| using System.Text; | ||
|
|
||
| namespace LLama.Common | ||
| { | ||
| /// <summary> | ||
| /// Defines how the executor should behave when the context window fills up | ||
| /// on a model that does not support native memory shifting (e.g., 2D RoPE models). | ||
| /// </summary> | ||
| public enum ContextOverflowStrategy | ||
| { | ||
| /// <summary> | ||
| /// The engine will throw a ContextOverflowException. | ||
| /// Use this to manually manage context pruning in your application layer. | ||
| /// (Equivalent to llama-cli's --no-context-shift). | ||
| /// </summary> | ||
| ThrowException, | ||
|
|
||
| /// <summary> | ||
| /// The engine will silently drop a percentage of the oldest tokens | ||
| /// (preserving the system prompt) and completely re-prefill the context. | ||
| /// </summary> | ||
| TruncateAndReprefill | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| using System; | ||
|
|
||
| namespace LLama.Exceptions | ||
| { | ||
| /// <summary> | ||
| /// Thrown when the KV cache context is full and the model architecture | ||
| /// cannot mathematically support native memory shifting, or when the | ||
| /// ContextOverflowStrategy.ThrowException is used. | ||
| /// </summary> | ||
| public class ContextOverflowException : Exception | ||
| { | ||
| private const string DefaultMessage = "The context window is full and the current strategy is set to ThrowException. To automatically truncate and manage context, set InferenceParams.OverflowStrategy to ContextOverflowStrategy.TruncateAndReprefill."; | ||
|
|
||
| /// <summary> | ||
| /// Initializes a new instance of the ContextOverflowException class with a default error message. | ||
| /// </summary> | ||
| public ContextOverflowException() : base(DefaultMessage) | ||
| { | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Initializes a new instance of the ContextOverflowException class with a specified error message. | ||
| /// </summary> | ||
| /// <param name="message">The message that describes the error.</param> | ||
| public ContextOverflowException(string message) : base(message) | ||
| { | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Initializes a new instance of the ContextOverflowException class with a specified error message | ||
| /// and a reference to the inner exception that is the cause of this exception. | ||
| /// </summary> | ||
| /// <param name="message">The message that describes the error.</param> | ||
| /// <param name="innerException">The exception that is the cause of the current exception.</param> | ||
| public ContextOverflowException(string message, Exception innerException) : base(message, innerException) | ||
| { | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -198,21 +198,55 @@ public void SaveSessionFile(string filename) | |||||||||||||||||||||||||||||||||||||||||||||||||
| /// <summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// After running out of the context, take some tokens from the original prompt and recompute the logits in batches. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// </summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <param name="tokensToKeep"></param> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| protected virtual void HandleRunOutOfContext(int tokensToKeep) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <param name="tokensToKeep">The number of tokens from the initial prompt to preserve (e.g., system prompt).</param> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <param name="inferenceParams">The parameters controlling the inference and overflow strategy.</param> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <exception cref="ContextOverflowException">Thrown when the overflow strategy is set to ThrowException, or if the model does not support native shifting.</exception> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <exception cref="ArgumentOutOfRangeException">Thrown when tokensToKeep is invalid.</exception> | ||||||||||||||||||||||||||||||||||||||||||||||||||
| protected virtual Task HandleRunOutOfContext(int tokensToKeep, IInferenceParams inferenceParams) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
martindevans marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| // if we run out of context: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| // - take the tokensToKeep first tokens from the original prompt (via n_past) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| // - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches | ||||||||||||||||||||||||||||||||||||||||||||||||||
| // 1. Fast Fail if not configured to auto-truncate | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if (inferenceParams.OverflowStrategy == ContextOverflowStrategy.ThrowException) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| throw new ContextOverflowException(); | ||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| // 2. Guard: Stateful executors currently require native shifting to truncate. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| // TODO (Future Improvement): To support truncation on models where MemoryCanShift == false, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| // StatefulExecutorBase needs an unconditional `List<LLamaToken> _history_tokens` to track | ||||||||||||||||||||||||||||||||||||||||||||||||||
| // all ingested/generated tokens so we can clear the KV cache and perform a full re-prefill. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!Context.NativeHandle.MemoryCanShift) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| _logger?.LogError("Model does not support native memory shifting. Stateful truncation requires MemoryCanShift = true."); | ||||||||||||||||||||||||||||||||||||||||||||||||||
| throw new ContextOverflowException("Model does not support native memory shifting. Context overflowed."); | ||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| // 3. Calculate tokens safely | ||||||||||||||||||||||||||||||||||||||||||||||||||
| var n_left = _pastTokensCount - tokensToKeep; | ||||||||||||||||||||||||||||||||||||||||||||||||||
| var n_discard = n_left / 2; | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if (n_left <= 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| throw new ArgumentOutOfRangeException(nameof(tokensToKeep), "Cannot truncate context: tokensToKeep exceeds or equals the current context size."); | ||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| // Clamp the percentage between 1% and 99% to prevent math errors or total wipeouts | ||||||||||||||||||||||||||||||||||||||||||||||||||
| var percentage = Math.Max(0.01f, Math.Min(0.99f, inferenceParams.ContextTruncationPercentage)); | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
martindevans marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| var n_discard = (int)(n_left * percentage); | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| // Sanity check: always discard at least 1 token, but never more than we have available. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| n_discard = Math.Max(1, Math.Min(n_discard, n_left)); | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| // 4. Fast path: attempt the fast native memory shift | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensToKeep, tokensToKeep + n_discard); | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard); | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| _pastTokensCount -= n_discard; | ||||||||||||||||||||||||||||||||||||||||||||||||||
| // stop saving session if we run out of context | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| // Keep session tracking aligned with the shifted KV cache so future session saves/reuse | |
| // operate on the same logical token sequence as the current context. | |
| var sessionDiscardStart = Math.Min(tokensToKeep, _session_tokens.Count); | |
| var sessionDiscardEnd = Math.Min(tokensToKeep + n_discard, _session_tokens.Count); | |
| var removedSessionTokens = sessionDiscardEnd - sessionDiscardStart; | |
| if (removedSessionTokens > 0) | |
| { | |
| _session_tokens.RemoveRange(sessionDiscardStart, removedSessionTokens); | |
| if (_n_session_consumed > sessionDiscardStart) | |
| { | |
| _n_session_consumed = _n_session_consumed >= sessionDiscardEnd | |
| ? _n_session_consumed - removedSessionTokens | |
| : sessionDiscardStart; | |
| } | |
| } | |
| if (_n_session_consumed > _session_tokens.Count) | |
| { | |
| _n_session_consumed = _session_tokens.Count; | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK! We will stop saving the session if we run out of context.
Uh oh!
There was an error while loading. Please reload this page.