From 556a401e3825550ca7418a87ad85c2fe6a69244d Mon Sep 17 00:00:00 2001 From: Zoli Somogyi Date: Sun, 19 Apr 2026 08:30:18 +0200 Subject: [PATCH 1/5] Context Overflow Fix --- LLama.Unittest/StatelessExecutorTest.cs | 85 +++++++++++++++++++- LLama/Abstractions/IInferenceParams.cs | 19 ++++- LLama/Common/ContextOverflowStrategy.cs | 26 ++++++ LLama/Common/InferenceParams.cs | 15 ++++ LLama/Exceptions/ContextOverflowException.cs | 14 ++++ LLama/LLamaExecutorBase.cs | 51 +++++++++--- LLama/LLamaInstructExecutor.cs | 2 +- LLama/LLamaInteractExecutor.cs | 2 +- LLama/LLamaStatelessExecutor.cs | 53 ++++++++++-- 9 files changed, 242 insertions(+), 25 deletions(-) create mode 100644 LLama/Common/ContextOverflowStrategy.cs create mode 100644 LLama/Exceptions/ContextOverflowException.cs diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index d2c646e99..1fae35a79 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -57,18 +57,21 @@ public async Task Stateless() } [Fact(Skip = "Very very slow in CI")] - public async Task OutOfContext() + public async Task OutOfContext_WithTruncateStrategy_SuccessfullyGenerates() { var executor = new StatelessExecutor(_weights, _params); const string question = " Question. cats or dogs?\nAnswer:"; // The context size is set to 60. Generate more than that, forcing it to generate a coherent response - // with a modified context + // with a modified context. + // We explicitly set the strategy to TruncateAndReprefill to test the new fallback logic. var @params = new InferenceParams() { MaxTokens = 65, TokensKeep = question.Length, + OverflowStrategy = ContextOverflowStrategy.TruncateAndReprefill, + ContextTruncationPercentage = 0.2f // Drop 20% of tokens when full }; var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); @@ -79,5 +82,83 @@ public async Task OutOfContext() // Check that it produced the exact same result both times Assert.Equal(result1, result2); } + + [Fact] + public async Task OutOfContext_WithDefaultStrategy_ThrowsException() + { + var executor = new StatelessExecutor(_weights, _params); + using var context = _weights.CreateContext(_params); + + // Read the ACTUAL context size allocated by the native engine + uint actualContextSize = context.ContextSize; + + string question = "Cats and dogs are great pets. "; + + // Fast pad for the bulk of it + while (context.Tokenize(question, special: true).Length < actualContextSize - 20) + { + question += "Cats and dogs are great pets. "; + } + + // Slow pad by single words to precisely hit actualContextSize - 2 + while (context.Tokenize(question, special: true).Length < actualContextSize - 2) + { + question += "pet "; + } + + var finalLength = context.Tokenize(question, special: true).Length; + _testOutputHelper.WriteLine($"[DEBUG] Actual ContextSize: {actualContextSize}, Prompt length: {finalLength}"); + + // Sanity check to ensure we didn't overshoot + Assert.True(finalLength < actualContextSize, "Prompt exceeded context size during prefill!"); + + var @params = new InferenceParams() + { + MaxTokens = 10, + TokensKeep = 5, + }; + + var exception = await Assert.ThrowsAsync(async () => + { + await executor.InferAsync(question, @params).ToListAsync(); + }); + + _testOutputHelper.WriteLine($"Successfully caught expected exception: {exception.Message}"); + } + + [Fact] + public async Task OutOfContext_WithDefaultStrategy_2_ThrowsException() + { + using var context = _weights.CreateContext(_params); + var executor = new InstructExecutor(context); + + uint actualContextSize = context.ContextSize; + string instruction = "Cats or dogs? "; + + // Fast pad safely below limit (InstructExecutor adds hidden prefix/suffix) + while (context.Tokenize(instruction, special: true).Length < actualContextSize - 30) + { + instruction += "Cats or dogs? "; + } + + // Slow pad + while (context.Tokenize(instruction, special: true).Length < actualContextSize - 15) + { + instruction += "pet "; + } + + var @params = new InferenceParams() + { + MaxTokens = 20, + TokensKeep = 5, + }; + + var exception = await Assert.ThrowsAsync(async () => + { + await executor.InferAsync(instruction, @params).ToListAsync(); + }); + + _testOutputHelper.WriteLine($"Successfully caught expected exception in InstructExecutor: {exception.Message}"); + } } } \ No newline at end of file diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index fd6ce5b47..2afbcdcd1 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -1,5 +1,6 @@ -using System.Collections.Generic; +using LLama.Common; using LLama.Sampling; +using System.Collections.Generic; namespace LLama.Abstractions { @@ -36,5 +37,19 @@ public interface IInferenceParams /// Controls the behavior of decoders like /// public bool DecodeSpecialTokens { get; set; } - } + + /// + /// Defines the strategy the executor should use when the context window is full + /// and the model architecture (e.g., models with 2D RoPE embeddings) does not + /// support native memory shifting. + /// + ContextOverflowStrategy OverflowStrategy { get; set; } + + /// + /// The percentage of past tokens to discard when + /// is set to . + /// For example, 0.1f represents dropping the oldest 10% of the conversational context. + /// + float ContextTruncationPercentage { get; set; } + } } \ No newline at end of file diff --git a/LLama/Common/ContextOverflowStrategy.cs b/LLama/Common/ContextOverflowStrategy.cs new file mode 100644 index 000000000..cf52b8bd0 --- /dev/null +++ b/LLama/Common/ContextOverflowStrategy.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Common +{ + /// + /// Defines how the executor should behave when the context window fills up + /// on a model that does not support native memory shifting (e.g., 2D RoPE models). + /// + public enum ContextOverflowStrategy + { + /// + /// The engine will throw a ContextOverflowException. + /// Use this to manually manage context pruning in your application layer. + /// (Equivalent to llama-cli's --no-context-shift). + /// + ThrowException, + + /// + /// The engine will silently drop a percentage of the oldest tokens + /// (preserving the system prompt) and completely re-prefill the context. + /// + TruncateAndReprefill + } +} diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index 64d95ee3e..37cbeb034 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -33,6 +33,21 @@ public record InferenceParams /// public bool DecodeSpecialTokens { get; set; } + + /// + /// Defines the strategy the executor should use when the context window is full + /// and the model architecture does not support native memory shifting. + /// Defaults to to prevent + /// unintended data loss and latency spikes. + /// + public ContextOverflowStrategy OverflowStrategy { get; set; } = ContextOverflowStrategy.ThrowException; + + /// + /// The percentage of past tokens to discard when + /// is set to . + /// Defaults to 0.1f (10%). Valid range is typically between 0.01f and 0.99f. + /// + public float ContextTruncationPercentage { get; set; } = 0.1f; } /// diff --git a/LLama/Exceptions/ContextOverflowException.cs b/LLama/Exceptions/ContextOverflowException.cs new file mode 100644 index 000000000..872750c97 --- /dev/null +++ b/LLama/Exceptions/ContextOverflowException.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Exceptions +{ + /// + /// Thrown when the KV cache context is full and the model architecture + /// cannot mathematically support native memory shifting. + /// + public class ContextOverflowException(string message) : Exception(message) + { + } +} \ No newline at end of file diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 54499e5e1..f0ee9959e 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -198,21 +198,50 @@ public void SaveSessionFile(string filename) /// /// After running out of the context, take some tokens from the original prompt and recompute the logits in batches. /// - /// - protected virtual void HandleRunOutOfContext(int tokensToKeep) + /// The number of tokens from the initial prompt to preserve (e.g., system prompt). + /// The parameters controlling the inference and overflow strategy. + /// Thrown when the overflow strategy is set to ThrowException or truncation bounds are invalid. + /// Thrown if the native context decoding fails during the re-prefill phase. + protected virtual async Task HandleRunOutOfContext(int tokensToKeep, IInferenceParams inferenceParams) { - // if we run out of context: - // - take the tokensToKeep first tokens from the original prompt (via n_past) - // - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches + // 1. Fast Fail if not configured to auto-truncate + if (inferenceParams.OverflowStrategy == ContextOverflowStrategy.ThrowException) + { + throw new ContextOverflowException( + "The context window is full and the current model architecture does not support native memory shifting. " + + "To automatically truncate and re-prefill the context, set InferenceParams.OverflowStrategy to ContextOverflowStrategy.TruncateAndReprefill." + ); + } + + // 2. Calculate tokens safely var n_left = _pastTokensCount - tokensToKeep; - var n_discard = n_left / 2; + if (n_left <= 0) + { + throw new ContextOverflowException("Cannot truncate context: tokensToKeep exceeds or equals the current context size."); + } + + // Clamp the percentage between 1% and 99% to prevent math errors or total wipeouts + var percentage = Math.Max(0.01f, Math.Min(0.99f, inferenceParams.ContextTruncationPercentage)); + var n_discard = (int)(n_left * percentage); - Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensToKeep, tokensToKeep + n_discard); - Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard); + // Sanity check: always discard at least 1 token if we are truncating + if (n_discard < 1) n_discard = 1; - _pastTokensCount -= n_discard; - // stop saving session if we run out of context - _pathSession = string.Empty; + // 3. Remove the oldest non-kept tokens. + // If tokensToKeep is 10, we keep indexes 0-9, and start removing from index 10. + int startIndex = Math.Max(0, tokensToKeep); + _session_tokens.RemoveRange(startIndex, n_discard); + + // 4. Clear the native KV Cache and perform a complete re-prefill + LLamaBatch batch = new(); + Context.NativeHandle.MemoryClear(); + + (var result, _, _pastTokensCount) = await Context.DecodeAsync(_session_tokens, LLamaSeqId.Zero, batch, 0); + + if (result != DecodeResult.Ok) + { + throw new LLamaDecodeError(result); + } } /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index e7724cb65..e77f55b6f 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -233,7 +233,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 // Instruct always uses input token size. var tokensToKeep = _embed_inps.Count; - HandleRunOutOfContext(tokensToKeep); + await HandleRunOutOfContext(tokensToKeep, inferenceParams); } TryReuseMatchingPrefix(); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 91e1d4b7f..0fd512be7 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -232,7 +232,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In tokensToKeep += Convert.ToInt32(Context.Vocab.ShouldAddBOS); // always keep the BOS token } - HandleRunOutOfContext(tokensToKeep); + await HandleRunOutOfContext(tokensToKeep, inferenceParams); } if (MtmdChunks is null) diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index a895054d4..5f9ff0f6e 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -106,6 +106,9 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams // Tokenize the prompt var tokens = Context.Tokenize(prompt, special: true).ToList(); + // We must track the history of all tokens in this session in case we need to re-prefill the context + var all_tokens = new List(tokens); + // Evaluate the prompt, in chunks smaller than the max batch size var n_past = 0; var (r, _, past) = await Context.DecodeAsync(tokens, LLamaSeqId.Zero, _batch, n_past); @@ -138,17 +141,23 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams tokens.Add(id); // when run out of context - // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 if (n_past + tokens.Count >= Context.ContextSize) { + if (inferenceParams.OverflowStrategy == ContextOverflowStrategy.ThrowException) + { + throw new ContextOverflowException( + "The context window is full and the current strategy is set to ThrowException. " + + "To automatically truncate and manage context, set InferenceParams.OverflowStrategy to ContextOverflowStrategy.TruncateAndReprefill." + ); + } + var canAddBos = Context.Vocab.ShouldAddBOS; var tokensKeep = inferenceParams.TokensKeep; // number of tokens to keep when resetting context - // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 - if (tokensKeep < 0 || tokensKeep > tokens.Count) + if (tokensKeep < 0 || tokensKeep > all_tokens.Count) { - tokensKeep = tokens.Count; + tokensKeep = all_tokens.Count; } else { @@ -156,14 +165,42 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams } var n_left = n_past - tokensKeep; - var n_discard = n_left / 2; - Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensKeep, tokensKeep + n_discard); - Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensKeep + n_discard, n_past, -n_discard); + // Safely calculate discard amount using our configured percentage + var percentage = Math.Max(0.01f, Math.Min(0.99f, inferenceParams.ContextTruncationPercentage)); + var n_discard = (int)(n_left * percentage); + if (n_discard < 1) n_discard = 1; - n_past -= n_discard; + try + { + // First, attempt the fast native memory shift (works for standard models like Llama 2/3) + Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensKeep, tokensKeep + n_discard); + Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensKeep + n_discard, n_past, -n_discard); + n_past -= n_discard; + all_tokens.RemoveRange(tokensKeep, n_discard); + } + catch (Exception ex) when (ex.Message.Contains("MemoryCanShift")) + { + // Fallback: The model does not support native shifting (e.g., 2D RoPE models). + // We must clear the cache and perform a full context re-prefill. + _logger?.LogInformation("Model does not support native memory shifting. Falling back to context re-prefill."); + + all_tokens.RemoveRange(tokensKeep, n_discard); + + _batch.Clear(); + Context.NativeHandle.MemoryClear(); + + var (rReprefill, _, pastReprefill) = await Context.DecodeAsync(all_tokens, LLamaSeqId.Zero, _batch, 0); + if (rReprefill != DecodeResult.Ok) + throw new LLamaDecodeError(rReprefill); + + n_past = pastReprefill; + } } + // Add the new token to our historical tracker + all_tokens.Add(id); + // Evaluate with this new token _batch.Clear(); _batch.Add(id, n_past++, LLamaSeqId.Zero, true); From 7f54d5c7f9ef0594c82a198d120643ae7cb945f9 Mon Sep 17 00:00:00 2001 From: Zoli Somogyi Date: Sun, 19 Apr 2026 08:41:24 +0200 Subject: [PATCH 2/5] Context Overflow Fix update to LLama.Web.Common.InferenceOptions --- LLama.Web/Common/InferenceOptions.cs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs index 1b32a996c..11dccd678 100644 --- a/LLama.Web/Common/InferenceOptions.cs +++ b/LLama.Web/Common/InferenceOptions.cs @@ -24,5 +24,20 @@ public class InferenceOptions /// public bool DecodeSpecialTokens { get; set; } + + /// + /// Defines the strategy the executor should use when the context window is full + /// and the model architecture does not support native memory shifting. + /// Defaults to to prevent + /// unintended data loss and latency spikes. + /// + public ContextOverflowStrategy OverflowStrategy { get; set; } = ContextOverflowStrategy.ThrowException; + + /// + /// The percentage of past tokens to discard when + /// is set to . + /// Defaults to 0.1f (10%). Valid range is typically between 0.01f and 0.99f. + /// + public float ContextTruncationPercentage { get; set; } = 0.1f; } } From 6b66a79c39c9248b82f3d7f7192ff0630c1bff4a Mon Sep 17 00:00:00 2001 From: Zoli Somogyi Date: Thu, 23 Apr 2026 19:07:52 +0200 Subject: [PATCH 3/5] Context Overflow Fix --- LLama/Exceptions/ContextOverflowException.cs | 7 ++-- LLama/LLamaExecutorBase.cs | 41 ++++++++++++-------- LLama/LLamaStatelessExecutor.cs | 11 ++---- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/LLama/Exceptions/ContextOverflowException.cs b/LLama/Exceptions/ContextOverflowException.cs index 872750c97..0893e50c2 100644 --- a/LLama/Exceptions/ContextOverflowException.cs +++ b/LLama/Exceptions/ContextOverflowException.cs @@ -1,14 +1,13 @@ using System; -using System.Collections.Generic; -using System.Text; namespace LLama.Exceptions { /// /// Thrown when the KV cache context is full and the model architecture - /// cannot mathematically support native memory shifting. + /// cannot mathematically support native memory shifting, or when the + /// ContextOverflowStrategy.ThrowException is used. /// - public class ContextOverflowException(string message) : Exception(message) + public class ContextOverflowException() : Exception("The context window is full and the current strategy is set to ThrowException. To automatically truncate and manage context, set InferenceParams.OverflowStrategy to ContextOverflowStrategy.TruncateAndReprefill.") { } } \ No newline at end of file diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index f0ee9959e..6e645c659 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -200,47 +200,56 @@ public void SaveSessionFile(string filename) /// /// The number of tokens from the initial prompt to preserve (e.g., system prompt). /// The parameters controlling the inference and overflow strategy. - /// Thrown when the overflow strategy is set to ThrowException or truncation bounds are invalid. + /// Thrown when the overflow strategy is set to ThrowException. + /// Thrown when tokensToKeep is invalid. /// Thrown if the native context decoding fails during the re-prefill phase. protected virtual async Task HandleRunOutOfContext(int tokensToKeep, IInferenceParams inferenceParams) { // 1. Fast Fail if not configured to auto-truncate if (inferenceParams.OverflowStrategy == ContextOverflowStrategy.ThrowException) { - throw new ContextOverflowException( - "The context window is full and the current model architecture does not support native memory shifting. " + - "To automatically truncate and re-prefill the context, set InferenceParams.OverflowStrategy to ContextOverflowStrategy.TruncateAndReprefill." - ); + throw new ContextOverflowException(); } // 2. Calculate tokens safely var n_left = _pastTokensCount - tokensToKeep; if (n_left <= 0) { - throw new ContextOverflowException("Cannot truncate context: tokensToKeep exceeds or equals the current context size."); + throw new ArgumentOutOfRangeException(nameof(tokensToKeep), "Cannot truncate context: tokensToKeep exceeds or equals the current context size."); } // Clamp the percentage between 1% and 99% to prevent math errors or total wipeouts var percentage = Math.Max(0.01f, Math.Min(0.99f, inferenceParams.ContextTruncationPercentage)); var n_discard = (int)(n_left * percentage); - // Sanity check: always discard at least 1 token if we are truncating - if (n_discard < 1) n_discard = 1; + // Sanity check: always discard at least 1 token, but never more than we have available. + n_discard = Math.Max(1, Math.Min(n_discard, n_left)); // 3. Remove the oldest non-kept tokens. - // If tokensToKeep is 10, we keep indexes 0-9, and start removing from index 10. int startIndex = Math.Max(0, tokensToKeep); _session_tokens.RemoveRange(startIndex, n_discard); - // 4. Clear the native KV Cache and perform a complete re-prefill - LLamaBatch batch = new(); - Context.NativeHandle.MemoryClear(); + if (Context.NativeHandle.MemoryCanShift) + { + // Fast path: attempt the fast native memory shift + Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensToKeep, tokensToKeep + n_discard); + Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard); + _pastTokensCount -= n_discard; + } + else + { + // 4. Fallback: Clear the native KV Cache and perform a complete re-prefill + _logger?.LogInformation("Model does not support native memory shifting. Falling back to context re-prefill."); - (var result, _, _pastTokensCount) = await Context.DecodeAsync(_session_tokens, LLamaSeqId.Zero, batch, 0); + LLamaBatch batch = new(); + Context.NativeHandle.MemoryClear(); - if (result != DecodeResult.Ok) - { - throw new LLamaDecodeError(result); + (var result, _, _pastTokensCount) = await Context.DecodeAsync(_session_tokens, LLamaSeqId.Zero, batch, 0); + + if (result != DecodeResult.Ok) + { + throw new LLamaDecodeError(result); + } } } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 5f9ff0f6e..6f295bf12 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -145,10 +145,7 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams { if (inferenceParams.OverflowStrategy == ContextOverflowStrategy.ThrowException) { - throw new ContextOverflowException( - "The context window is full and the current strategy is set to ThrowException. " + - "To automatically truncate and manage context, set InferenceParams.OverflowStrategy to ContextOverflowStrategy.TruncateAndReprefill." - ); + throw new ContextOverflowException(); } var canAddBos = Context.Vocab.ShouldAddBOS; @@ -171,15 +168,15 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams var n_discard = (int)(n_left * percentage); if (n_discard < 1) n_discard = 1; - try + if (Context.NativeHandle.MemoryCanShift) { - // First, attempt the fast native memory shift (works for standard models like Llama 2/3) + // Fast path: Attempt the fast native memory shift (works for standard models like Llama 2/3) Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensKeep, tokensKeep + n_discard); Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensKeep + n_discard, n_past, -n_discard); n_past -= n_discard; all_tokens.RemoveRange(tokensKeep, n_discard); } - catch (Exception ex) when (ex.Message.Contains("MemoryCanShift")) + else { // Fallback: The model does not support native shifting (e.g., 2D RoPE models). // We must clear the cache and perform a full context re-prefill. From 5d13abdc17d3e8a81908e308c59c060b30b0b07a Mon Sep 17 00:00:00 2001 From: Zoli Somogyi Date: Thu, 23 Apr 2026 20:50:12 +0200 Subject: [PATCH 4/5] Context Overflow Fix --- LLama/Exceptions/ContextOverflowException.cs | 28 +++++++++++- LLama/LLamaExecutorBase.cs | 47 ++++++++------------ LLama/LLamaStatelessExecutor.cs | 9 +++- 3 files changed, 53 insertions(+), 31 deletions(-) diff --git a/LLama/Exceptions/ContextOverflowException.cs b/LLama/Exceptions/ContextOverflowException.cs index 0893e50c2..2e753e523 100644 --- a/LLama/Exceptions/ContextOverflowException.cs +++ b/LLama/Exceptions/ContextOverflowException.cs @@ -7,7 +7,33 @@ namespace LLama.Exceptions /// cannot mathematically support native memory shifting, or when the /// ContextOverflowStrategy.ThrowException is used. /// - public class ContextOverflowException() : Exception("The context window is full and the current strategy is set to ThrowException. To automatically truncate and manage context, set InferenceParams.OverflowStrategy to ContextOverflowStrategy.TruncateAndReprefill.") + public class ContextOverflowException : Exception { + private const string DefaultMessage = "The context window is full and the current strategy is set to ThrowException. To automatically truncate and manage context, set InferenceParams.OverflowStrategy to ContextOverflowStrategy.TruncateAndReprefill."; + + /// + /// Initializes a new instance of the ContextOverflowException class with a default error message. + /// + public ContextOverflowException() : base(DefaultMessage) + { + } + + /// + /// Initializes a new instance of the ContextOverflowException class with a specified error message. + /// + /// The message that describes the error. + public ContextOverflowException(string message) : base(message) + { + } + + /// + /// Initializes a new instance of the ContextOverflowException class with a specified error message + /// and a reference to the inner exception that is the cause of this exception. + /// + /// The message that describes the error. + /// The exception that is the cause of the current exception. + public ContextOverflowException(string message, Exception innerException) : base(message, innerException) + { + } } } \ No newline at end of file diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 6e645c659..428fc3459 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -200,10 +200,9 @@ public void SaveSessionFile(string filename) /// /// The number of tokens from the initial prompt to preserve (e.g., system prompt). /// The parameters controlling the inference and overflow strategy. - /// Thrown when the overflow strategy is set to ThrowException. + /// Thrown when the overflow strategy is set to ThrowException, or if the model does not support native shifting. /// Thrown when tokensToKeep is invalid. - /// Thrown if the native context decoding fails during the re-prefill phase. - protected virtual async Task HandleRunOutOfContext(int tokensToKeep, IInferenceParams inferenceParams) + protected virtual Task HandleRunOutOfContext(int tokensToKeep, IInferenceParams inferenceParams) { // 1. Fast Fail if not configured to auto-truncate if (inferenceParams.OverflowStrategy == ContextOverflowStrategy.ThrowException) @@ -211,7 +210,17 @@ protected virtual async Task HandleRunOutOfContext(int tokensToKeep, IInferenceP throw new ContextOverflowException(); } - // 2. Calculate tokens safely + // 2. Guard: Stateful executors currently require native shifting to truncate. + // TODO (Future Improvement): To support truncation on models where MemoryCanShift == false, + // StatefulExecutorBase needs an unconditional `List _history_tokens` to track + // all ingested/generated tokens so we can clear the KV cache and perform a full re-prefill. + if (!Context.NativeHandle.MemoryCanShift) + { + _logger?.LogError("Model does not support native memory shifting. Stateful truncation requires MemoryCanShift = true."); + throw new ContextOverflowException("Model does not support native memory shifting. Context overflowed."); + } + + // 3. Calculate tokens safely var n_left = _pastTokensCount - tokensToKeep; if (n_left <= 0) { @@ -225,32 +234,12 @@ protected virtual async Task HandleRunOutOfContext(int tokensToKeep, IInferenceP // Sanity check: always discard at least 1 token, but never more than we have available. n_discard = Math.Max(1, Math.Min(n_discard, n_left)); - // 3. Remove the oldest non-kept tokens. - int startIndex = Math.Max(0, tokensToKeep); - _session_tokens.RemoveRange(startIndex, n_discard); - - if (Context.NativeHandle.MemoryCanShift) - { - // Fast path: attempt the fast native memory shift - Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensToKeep, tokensToKeep + n_discard); - Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard); - _pastTokensCount -= n_discard; - } - else - { - // 4. Fallback: Clear the native KV Cache and perform a complete re-prefill - _logger?.LogInformation("Model does not support native memory shifting. Falling back to context re-prefill."); - - LLamaBatch batch = new(); - Context.NativeHandle.MemoryClear(); + // 4. Fast path: attempt the fast native memory shift + Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensToKeep, tokensToKeep + n_discard); + Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard); + _pastTokensCount -= n_discard; - (var result, _, _pastTokensCount) = await Context.DecodeAsync(_session_tokens, LLamaSeqId.Zero, batch, 0); - - if (result != DecodeResult.Ok) - { - throw new LLamaDecodeError(result); - } - } + return Task.CompletedTask; } /// diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 6f295bf12..8802a4881 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -163,10 +163,17 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams var n_left = n_past - tokensKeep; + if (n_left <= 0) + { + throw new ArgumentOutOfRangeException(nameof(inferenceParams), "Cannot truncate context: TokensKeep exceeds or equals the current context size."); + } + // Safely calculate discard amount using our configured percentage var percentage = Math.Max(0.01f, Math.Min(0.99f, inferenceParams.ContextTruncationPercentage)); var n_discard = (int)(n_left * percentage); - if (n_discard < 1) n_discard = 1; + + // Clamp between 1 and n_left + n_discard = Math.Max(1, Math.Min(n_discard, n_left)); if (Context.NativeHandle.MemoryCanShift) { From ec97967e63f4e557bbe5ce3897893f1c1f98e4cd Mon Sep 17 00:00:00 2001 From: Zoli Somogyi Date: Sun, 26 Apr 2026 14:54:52 +0200 Subject: [PATCH 5/5] Open Context Overflow Fix --- LLama/LLamaExecutorBase.cs | 7 +++++++ LLama/LLamaStatelessExecutor.cs | 7 +++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 428fc3459..1124b99f4 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -239,6 +239,13 @@ protected virtual Task HandleRunOutOfContext(int tokensToKeep, IInferenceParams Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard); _pastTokensCount -= n_discard; + // Stop saving the session if we run out of context. + // Note: A more advanced (but riskier and more complex) solution would be to physically trim + // the _session_tokens list and adjust _n_session_consumed to perfectly match the newly + // shifted native memory. This would allow session saving to continue safely, but requires + // precise index tracking to avoid off-by-one errors. For now, we abort saving to prevent corruption. + _pathSession = string.Empty; + return Task.CompletedTask; } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 8802a4881..9a9a1906d 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -106,6 +106,9 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams // Tokenize the prompt var tokens = Context.Tokenize(prompt, special: true).ToList(); + // Capture the initial prompt length + var initialPromptLength = tokens.Count; + // We must track the history of all tokens in this session in case we need to re-prefill the context var all_tokens = new List(tokens); @@ -152,9 +155,9 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams var tokensKeep = inferenceParams.TokensKeep; // number of tokens to keep when resetting context - if (tokensKeep < 0 || tokensKeep > all_tokens.Count) + if (tokensKeep < 0 || tokensKeep > initialPromptLength) { - tokensKeep = all_tokens.Count; + tokensKeep = initialPromptLength; } else {