diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index d2c646e99..1fae35a79 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -57,18 +57,21 @@ public async Task Stateless() } [Fact(Skip = "Very very slow in CI")] - public async Task OutOfContext() + public async Task OutOfContext_WithTruncateStrategy_SuccessfullyGenerates() { var executor = new StatelessExecutor(_weights, _params); const string question = " Question. cats or dogs?\nAnswer:"; // The context size is set to 60. Generate more than that, forcing it to generate a coherent response - // with a modified context + // with a modified context. + // We explicitly set the strategy to TruncateAndReprefill to test the new fallback logic. var @params = new InferenceParams() { MaxTokens = 65, TokensKeep = question.Length, + OverflowStrategy = ContextOverflowStrategy.TruncateAndReprefill, + ContextTruncationPercentage = 0.2f // Drop 20% of tokens when full }; var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); @@ -79,5 +82,83 @@ public async Task OutOfContext() // Check that it produced the exact same result both times Assert.Equal(result1, result2); } + + [Fact] + public async Task OutOfContext_WithDefaultStrategy_ThrowsException() + { + var executor = new StatelessExecutor(_weights, _params); + using var context = _weights.CreateContext(_params); + + // Read the ACTUAL context size allocated by the native engine + uint actualContextSize = context.ContextSize; + + string question = "Cats and dogs are great pets. "; + + // Fast pad for the bulk of it + while (context.Tokenize(question, special: true).Length < actualContextSize - 20) + { + question += "Cats and dogs are great pets. "; + } + + // Slow pad by single words to precisely hit actualContextSize - 2 + while (context.Tokenize(question, special: true).Length < actualContextSize - 2) + { + question += "pet "; + } + + var finalLength = context.Tokenize(question, special: true).Length; + _testOutputHelper.WriteLine($"[DEBUG] Actual ContextSize: {actualContextSize}, Prompt length: {finalLength}"); + + // Sanity check to ensure we didn't overshoot + Assert.True(finalLength < actualContextSize, "Prompt exceeded context size during prefill!"); + + var @params = new InferenceParams() + { + MaxTokens = 10, + TokensKeep = 5, + }; + + var exception = await Assert.ThrowsAsync(async () => + { + await executor.InferAsync(question, @params).ToListAsync(); + }); + + _testOutputHelper.WriteLine($"Successfully caught expected exception: {exception.Message}"); + } + + [Fact] + public async Task OutOfContext_WithDefaultStrategy_2_ThrowsException() + { + using var context = _weights.CreateContext(_params); + var executor = new InstructExecutor(context); + + uint actualContextSize = context.ContextSize; + string instruction = "Cats or dogs? "; + + // Fast pad safely below limit (InstructExecutor adds hidden prefix/suffix) + while (context.Tokenize(instruction, special: true).Length < actualContextSize - 30) + { + instruction += "Cats or dogs? "; + } + + // Slow pad + while (context.Tokenize(instruction, special: true).Length < actualContextSize - 15) + { + instruction += "pet "; + } + + var @params = new InferenceParams() + { + MaxTokens = 20, + TokensKeep = 5, + }; + + var exception = await Assert.ThrowsAsync(async () => + { + await executor.InferAsync(instruction, @params).ToListAsync(); + }); + + _testOutputHelper.WriteLine($"Successfully caught expected exception in InstructExecutor: {exception.Message}"); + } } } \ No newline at end of file diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs index 1b32a996c..11dccd678 100644 --- a/LLama.Web/Common/InferenceOptions.cs +++ b/LLama.Web/Common/InferenceOptions.cs @@ -24,5 +24,20 @@ public class InferenceOptions /// public bool DecodeSpecialTokens { get; set; } + + /// + /// Defines the strategy the executor should use when the context window is full + /// and the model architecture does not support native memory shifting. + /// Defaults to to prevent + /// unintended data loss and latency spikes. + /// + public ContextOverflowStrategy OverflowStrategy { get; set; } = ContextOverflowStrategy.ThrowException; + + /// + /// The percentage of past tokens to discard when + /// is set to . + /// Defaults to 0.1f (10%). Valid range is typically between 0.01f and 0.99f. + /// + public float ContextTruncationPercentage { get; set; } = 0.1f; } } diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index fd6ce5b47..2afbcdcd1 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -1,5 +1,6 @@ -using System.Collections.Generic; +using LLama.Common; using LLama.Sampling; +using System.Collections.Generic; namespace LLama.Abstractions { @@ -36,5 +37,19 @@ public interface IInferenceParams /// Controls the behavior of decoders like /// public bool DecodeSpecialTokens { get; set; } - } + + /// + /// Defines the strategy the executor should use when the context window is full + /// and the model architecture (e.g., models with 2D RoPE embeddings) does not + /// support native memory shifting. + /// + ContextOverflowStrategy OverflowStrategy { get; set; } + + /// + /// The percentage of past tokens to discard when + /// is set to . + /// For example, 0.1f represents dropping the oldest 10% of the conversational context. + /// + float ContextTruncationPercentage { get; set; } + } } \ No newline at end of file diff --git a/LLama/Common/ContextOverflowStrategy.cs b/LLama/Common/ContextOverflowStrategy.cs new file mode 100644 index 000000000..cf52b8bd0 --- /dev/null +++ b/LLama/Common/ContextOverflowStrategy.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Common +{ + /// + /// Defines how the executor should behave when the context window fills up + /// on a model that does not support native memory shifting (e.g., 2D RoPE models). + /// + public enum ContextOverflowStrategy + { + /// + /// The engine will throw a ContextOverflowException. + /// Use this to manually manage context pruning in your application layer. + /// (Equivalent to llama-cli's --no-context-shift). + /// + ThrowException, + + /// + /// The engine will silently drop a percentage of the oldest tokens + /// (preserving the system prompt) and completely re-prefill the context. + /// + TruncateAndReprefill + } +} diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index 64d95ee3e..37cbeb034 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -33,6 +33,21 @@ public record InferenceParams /// public bool DecodeSpecialTokens { get; set; } + + /// + /// Defines the strategy the executor should use when the context window is full + /// and the model architecture does not support native memory shifting. + /// Defaults to to prevent + /// unintended data loss and latency spikes. + /// + public ContextOverflowStrategy OverflowStrategy { get; set; } = ContextOverflowStrategy.ThrowException; + + /// + /// The percentage of past tokens to discard when + /// is set to . + /// Defaults to 0.1f (10%). Valid range is typically between 0.01f and 0.99f. + /// + public float ContextTruncationPercentage { get; set; } = 0.1f; } /// diff --git a/LLama/Exceptions/ContextOverflowException.cs b/LLama/Exceptions/ContextOverflowException.cs new file mode 100644 index 000000000..2e753e523 --- /dev/null +++ b/LLama/Exceptions/ContextOverflowException.cs @@ -0,0 +1,39 @@ +using System; + +namespace LLama.Exceptions +{ + /// + /// Thrown when the KV cache context is full and the model architecture + /// cannot mathematically support native memory shifting, or when the + /// ContextOverflowStrategy.ThrowException is used. + /// + public class ContextOverflowException : Exception + { + private const string DefaultMessage = "The context window is full and the current strategy is set to ThrowException. To automatically truncate and manage context, set InferenceParams.OverflowStrategy to ContextOverflowStrategy.TruncateAndReprefill."; + + /// + /// Initializes a new instance of the ContextOverflowException class with a default error message. + /// + public ContextOverflowException() : base(DefaultMessage) + { + } + + /// + /// Initializes a new instance of the ContextOverflowException class with a specified error message. + /// + /// The message that describes the error. + public ContextOverflowException(string message) : base(message) + { + } + + /// + /// Initializes a new instance of the ContextOverflowException class with a specified error message + /// and a reference to the inner exception that is the cause of this exception. + /// + /// The message that describes the error. + /// The exception that is the cause of the current exception. + public ContextOverflowException(string message, Exception innerException) : base(message, innerException) + { + } + } +} \ No newline at end of file diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 54499e5e1..1124b99f4 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -198,21 +198,55 @@ public void SaveSessionFile(string filename) /// /// After running out of the context, take some tokens from the original prompt and recompute the logits in batches. /// - /// - protected virtual void HandleRunOutOfContext(int tokensToKeep) + /// The number of tokens from the initial prompt to preserve (e.g., system prompt). + /// The parameters controlling the inference and overflow strategy. + /// Thrown when the overflow strategy is set to ThrowException, or if the model does not support native shifting. + /// Thrown when tokensToKeep is invalid. + protected virtual Task HandleRunOutOfContext(int tokensToKeep, IInferenceParams inferenceParams) { - // if we run out of context: - // - take the tokensToKeep first tokens from the original prompt (via n_past) - // - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches + // 1. Fast Fail if not configured to auto-truncate + if (inferenceParams.OverflowStrategy == ContextOverflowStrategy.ThrowException) + { + throw new ContextOverflowException(); + } + + // 2. Guard: Stateful executors currently require native shifting to truncate. + // TODO (Future Improvement): To support truncation on models where MemoryCanShift == false, + // StatefulExecutorBase needs an unconditional `List _history_tokens` to track + // all ingested/generated tokens so we can clear the KV cache and perform a full re-prefill. + if (!Context.NativeHandle.MemoryCanShift) + { + _logger?.LogError("Model does not support native memory shifting. Stateful truncation requires MemoryCanShift = true."); + throw new ContextOverflowException("Model does not support native memory shifting. Context overflowed."); + } + + // 3. Calculate tokens safely var n_left = _pastTokensCount - tokensToKeep; - var n_discard = n_left / 2; + if (n_left <= 0) + { + throw new ArgumentOutOfRangeException(nameof(tokensToKeep), "Cannot truncate context: tokensToKeep exceeds or equals the current context size."); + } + + // Clamp the percentage between 1% and 99% to prevent math errors or total wipeouts + var percentage = Math.Max(0.01f, Math.Min(0.99f, inferenceParams.ContextTruncationPercentage)); + var n_discard = (int)(n_left * percentage); + // Sanity check: always discard at least 1 token, but never more than we have available. + n_discard = Math.Max(1, Math.Min(n_discard, n_left)); + + // 4. Fast path: attempt the fast native memory shift Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensToKeep, tokensToKeep + n_discard); Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard); - _pastTokensCount -= n_discard; - // stop saving session if we run out of context + + // Stop saving the session if we run out of context. + // Note: A more advanced (but riskier and more complex) solution would be to physically trim + // the _session_tokens list and adjust _n_session_consumed to perfectly match the newly + // shifted native memory. This would allow session saving to continue safely, but requires + // precise index tracking to avoid off-by-one errors. For now, we abort saving to prevent corruption. _pathSession = string.Empty; + + return Task.CompletedTask; } /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index e7724cb65..e77f55b6f 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -233,7 +233,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 // Instruct always uses input token size. var tokensToKeep = _embed_inps.Count; - HandleRunOutOfContext(tokensToKeep); + await HandleRunOutOfContext(tokensToKeep, inferenceParams); } TryReuseMatchingPrefix(); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 91e1d4b7f..0fd512be7 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -232,7 +232,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In tokensToKeep += Convert.ToInt32(Context.Vocab.ShouldAddBOS); // always keep the BOS token } - HandleRunOutOfContext(tokensToKeep); + await HandleRunOutOfContext(tokensToKeep, inferenceParams); } if (MtmdChunks is null) diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index a895054d4..9a9a1906d 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -106,6 +106,12 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams // Tokenize the prompt var tokens = Context.Tokenize(prompt, special: true).ToList(); + // Capture the initial prompt length + var initialPromptLength = tokens.Count; + + // We must track the history of all tokens in this session in case we need to re-prefill the context + var all_tokens = new List(tokens); + // Evaluate the prompt, in chunks smaller than the max batch size var n_past = 0; var (r, _, past) = await Context.DecodeAsync(tokens, LLamaSeqId.Zero, _batch, n_past); @@ -138,17 +144,20 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams tokens.Add(id); // when run out of context - // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 if (n_past + tokens.Count >= Context.ContextSize) { + if (inferenceParams.OverflowStrategy == ContextOverflowStrategy.ThrowException) + { + throw new ContextOverflowException(); + } + var canAddBos = Context.Vocab.ShouldAddBOS; var tokensKeep = inferenceParams.TokensKeep; // number of tokens to keep when resetting context - // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 - if (tokensKeep < 0 || tokensKeep > tokens.Count) + if (tokensKeep < 0 || tokensKeep > initialPromptLength) { - tokensKeep = tokens.Count; + tokensKeep = initialPromptLength; } else { @@ -156,14 +165,49 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams } var n_left = n_past - tokensKeep; - var n_discard = n_left / 2; - Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensKeep, tokensKeep + n_discard); - Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensKeep + n_discard, n_past, -n_discard); + if (n_left <= 0) + { + throw new ArgumentOutOfRangeException(nameof(inferenceParams), "Cannot truncate context: TokensKeep exceeds or equals the current context size."); + } + + // Safely calculate discard amount using our configured percentage + var percentage = Math.Max(0.01f, Math.Min(0.99f, inferenceParams.ContextTruncationPercentage)); + var n_discard = (int)(n_left * percentage); - n_past -= n_discard; + // Clamp between 1 and n_left + n_discard = Math.Max(1, Math.Min(n_discard, n_left)); + + if (Context.NativeHandle.MemoryCanShift) + { + // Fast path: Attempt the fast native memory shift (works for standard models like Llama 2/3) + Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensKeep, tokensKeep + n_discard); + Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensKeep + n_discard, n_past, -n_discard); + n_past -= n_discard; + all_tokens.RemoveRange(tokensKeep, n_discard); + } + else + { + // Fallback: The model does not support native shifting (e.g., 2D RoPE models). + // We must clear the cache and perform a full context re-prefill. + _logger?.LogInformation("Model does not support native memory shifting. Falling back to context re-prefill."); + + all_tokens.RemoveRange(tokensKeep, n_discard); + + _batch.Clear(); + Context.NativeHandle.MemoryClear(); + + var (rReprefill, _, pastReprefill) = await Context.DecodeAsync(all_tokens, LLamaSeqId.Zero, _batch, 0); + if (rReprefill != DecodeResult.Ok) + throw new LLamaDecodeError(rReprefill); + + n_past = pastReprefill; + } } + // Add the new token to our historical tracker + all_tokens.Add(id); + // Evaluate with this new token _batch.Clear(); _batch.Add(id, n_past++, LLamaSeqId.Zero, true);