Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 83 additions & 2 deletions LLama.Unittest/StatelessExecutorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,21 @@ public async Task Stateless()
}

[Fact(Skip = "Very very slow in CI")]
public async Task OutOfContext()
public async Task OutOfContext_WithTruncateStrategy_SuccessfullyGenerates()
{
var executor = new StatelessExecutor(_weights, _params);

const string question = " Question. cats or dogs?\nAnswer:";

// The context size is set to 60. Generate more than that, forcing it to generate a coherent response
// with a modified context
// with a modified context.
// We explicitly set the strategy to TruncateAndReprefill to test the new fallback logic.
var @params = new InferenceParams()
{
MaxTokens = 65,
TokensKeep = question.Length,
OverflowStrategy = ContextOverflowStrategy.TruncateAndReprefill,
ContextTruncationPercentage = 0.2f // Drop 20% of tokens when full
};

var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
Expand All @@ -79,5 +82,83 @@ public async Task OutOfContext()
// Check that it produced the exact same result both times
Assert.Equal(result1, result2);
}

[Fact]
public async Task OutOfContext_WithDefaultStrategy_ThrowsException()
{
var executor = new StatelessExecutor(_weights, _params);
using var context = _weights.CreateContext(_params);

// Read the ACTUAL context size allocated by the native engine
uint actualContextSize = context.ContextSize;

string question = "Cats and dogs are great pets. ";

// Fast pad for the bulk of it
while (context.Tokenize(question, special: true).Length < actualContextSize - 20)
{
Comment thread
martindevans marked this conversation as resolved.
question += "Cats and dogs are great pets. ";
}

// Slow pad by single words to precisely hit actualContextSize - 2
while (context.Tokenize(question, special: true).Length < actualContextSize - 2)
{
question += "pet ";
}

var finalLength = context.Tokenize(question, special: true).Length;
_testOutputHelper.WriteLine($"[DEBUG] Actual ContextSize: {actualContextSize}, Prompt length: {finalLength}");

// Sanity check to ensure we didn't overshoot
Assert.True(finalLength < actualContextSize, "Prompt exceeded context size during prefill!");

var @params = new InferenceParams()
{
MaxTokens = 10,
TokensKeep = 5,
};

var exception = await Assert.ThrowsAsync<Exceptions.ContextOverflowException>(async () =>
{
await executor.InferAsync(question, @params).ToListAsync();
});
Comment thread
martindevans marked this conversation as resolved.

_testOutputHelper.WriteLine($"Successfully caught expected exception: {exception.Message}");
}

[Fact]
public async Task OutOfContext_WithDefaultStrategy_2_ThrowsException()
{
using var context = _weights.CreateContext(_params);
var executor = new InstructExecutor(context);

uint actualContextSize = context.ContextSize;
string instruction = "Cats or dogs? ";

// Fast pad safely below limit (InstructExecutor adds hidden prefix/suffix)
while (context.Tokenize(instruction, special: true).Length < actualContextSize - 30)
{
instruction += "Cats or dogs? ";
}

// Slow pad
while (context.Tokenize(instruction, special: true).Length < actualContextSize - 15)
{
Comment thread
martindevans marked this conversation as resolved.
instruction += "pet ";
}

var @params = new InferenceParams()
{
MaxTokens = 20,
TokensKeep = 5,
};

var exception = await Assert.ThrowsAsync<Exceptions.ContextOverflowException>(async () =>
{
await executor.InferAsync(instruction, @params).ToListAsync();
});
Comment thread
martindevans marked this conversation as resolved.

_testOutputHelper.WriteLine($"Successfully caught expected exception in InstructExecutor: {exception.Message}");
}
}
}
15 changes: 15 additions & 0 deletions LLama.Web/Common/InferenceOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,20 @@ public class InferenceOptions

/// <inheritdoc />
public bool DecodeSpecialTokens { get; set; }

/// <summary>
/// Defines the strategy the executor should use when the context window is full
/// and the model architecture does not support native memory shifting.
/// Defaults to <see cref="ContextOverflowStrategy.ThrowException"/> to prevent
/// unintended data loss and latency spikes.
/// </summary>
public ContextOverflowStrategy OverflowStrategy { get; set; } = ContextOverflowStrategy.ThrowException;

/// <summary>
/// The percentage of past tokens to discard when <see cref="OverflowStrategy"/>
/// is set to <see cref="ContextOverflowStrategy.TruncateAndReprefill"/>.
/// Defaults to 0.1f (10%). Valid range is typically between 0.01f and 0.99f.
/// </summary>
public float ContextTruncationPercentage { get; set; } = 0.1f;
}
}
19 changes: 17 additions & 2 deletions LLama/Abstractions/IInferenceParams.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Sampling;
using System.Collections.Generic;

namespace LLama.Abstractions
{
Expand Down Expand Up @@ -36,5 +37,19 @@ public interface IInferenceParams
/// Controls the behavior of decoders like <see cref="StreamingTokenDecoder" />
/// </remark>
public bool DecodeSpecialTokens { get; set; }
}

/// <summary>
/// Defines the strategy the executor should use when the context window is full
/// and the model architecture (e.g., models with 2D RoPE embeddings) does not
/// support native memory shifting.
/// </summary>
ContextOverflowStrategy OverflowStrategy { get; set; }

/// <summary>
/// The percentage of past tokens to discard when <see cref="OverflowStrategy"/>
/// is set to <see cref="ContextOverflowStrategy.TruncateAndReprefill"/>.
/// For example, 0.1f represents dropping the oldest 10% of the conversational context.
Comment on lines +42 to +51
Copy link

Copilot AI Apr 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs for OverflowStrategy describe it as applying only when the model “does not support native memory shifting”, but executors also use it to force a fast-fail even when MemoryCanShift is true. Consider rewording to “when the context window is full” and documenting that ThrowException disables any automatic shifting/truncation regardless of model support.

Suggested change
/// Defines the strategy the executor should use when the context window is full
/// and the model architecture (e.g., models with 2D RoPE embeddings) does not
/// support native memory shifting.
/// </summary>
ContextOverflowStrategy OverflowStrategy { get; set; }
/// <summary>
/// The percentage of past tokens to discard when <see cref="OverflowStrategy"/>
/// is set to <see cref="ContextOverflowStrategy.TruncateAndReprefill"/>.
/// For example, 0.1f represents dropping the oldest 10% of the conversational context.
/// Defines the strategy the executor should use when the context window is full.
/// </summary>
/// <remarks>
/// This setting applies even for models that support native memory shifting.
/// Setting <see cref="ContextOverflowStrategy.ThrowException"/> disables automatic
/// shifting or truncation and causes the executor to fail immediately on overflow.
/// </remarks>
ContextOverflowStrategy OverflowStrategy { get; set; }
/// <summary>
/// The percentage of past tokens to discard when <see cref="OverflowStrategy"/>
/// is set to <see cref="ContextOverflowStrategy.TruncateAndReprefill"/> to recover
/// from a full context window. For example, 0.1f represents dropping the oldest
/// 10% of the conversational context.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

@zsogitbe zsogitbe Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not doing this now.

/// </summary>
float ContextTruncationPercentage { get; set; }
}
}
26 changes: 26 additions & 0 deletions LLama/Common/ContextOverflowStrategy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Common
{
/// <summary>
/// Defines how the executor should behave when the context window fills up
/// on a model that does not support native memory shifting (e.g., 2D RoPE models).
/// </summary>
public enum ContextOverflowStrategy
{
/// <summary>
/// The engine will throw a ContextOverflowException.
/// Use this to manually manage context pruning in your application layer.
/// (Equivalent to llama-cli's --no-context-shift).
/// </summary>
ThrowException,

/// <summary>
/// The engine will silently drop a percentage of the oldest tokens
/// (preserving the system prompt) and completely re-prefill the context.
/// </summary>
TruncateAndReprefill
}
}
15 changes: 15 additions & 0 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ public record InferenceParams

/// <inheritdoc />
public bool DecodeSpecialTokens { get; set; }

/// <summary>
/// Defines the strategy the executor should use when the context window is full
/// and the model architecture does not support native memory shifting.
/// Defaults to <see cref="ContextOverflowStrategy.ThrowException"/> to prevent
/// unintended data loss and latency spikes.
/// </summary>
public ContextOverflowStrategy OverflowStrategy { get; set; } = ContextOverflowStrategy.ThrowException;

/// <summary>
/// The percentage of past tokens to discard when <see cref="OverflowStrategy"/>
/// is set to <see cref="ContextOverflowStrategy.TruncateAndReprefill"/>.
/// Defaults to 0.1f (10%). Valid range is typically between 0.01f and 0.99f.
/// </summary>
public float ContextTruncationPercentage { get; set; } = 0.1f;
}

/// <summary>
Expand Down
39 changes: 39 additions & 0 deletions LLama/Exceptions/ContextOverflowException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using System;

namespace LLama.Exceptions
{
/// <summary>
/// Thrown when the KV cache context is full and the model architecture
/// cannot mathematically support native memory shifting, or when the
/// ContextOverflowStrategy.ThrowException is used.
/// </summary>
public class ContextOverflowException : Exception
{
private const string DefaultMessage = "The context window is full and the current strategy is set to ThrowException. To automatically truncate and manage context, set InferenceParams.OverflowStrategy to ContextOverflowStrategy.TruncateAndReprefill.";

/// <summary>
/// Initializes a new instance of the ContextOverflowException class with a default error message.
/// </summary>
public ContextOverflowException() : base(DefaultMessage)
{
}

/// <summary>
/// Initializes a new instance of the ContextOverflowException class with a specified error message.
/// </summary>
/// <param name="message">The message that describes the error.</param>
public ContextOverflowException(string message) : base(message)
{
}

/// <summary>
/// Initializes a new instance of the ContextOverflowException class with a specified error message
/// and a reference to the inner exception that is the cause of this exception.
/// </summary>
/// <param name="message">The message that describes the error.</param>
/// <param name="innerException">The exception that is the cause of the current exception.</param>
public ContextOverflowException(string message, Exception innerException) : base(message, innerException)
{
}
}
}
50 changes: 42 additions & 8 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -198,21 +198,55 @@ public void SaveSessionFile(string filename)
/// <summary>
/// After running out of the context, take some tokens from the original prompt and recompute the logits in batches.
/// </summary>
/// <param name="tokensToKeep"></param>
protected virtual void HandleRunOutOfContext(int tokensToKeep)
/// <param name="tokensToKeep">The number of tokens from the initial prompt to preserve (e.g., system prompt).</param>
/// <param name="inferenceParams">The parameters controlling the inference and overflow strategy.</param>
/// <exception cref="ContextOverflowException">Thrown when the overflow strategy is set to ThrowException, or if the model does not support native shifting.</exception>
/// <exception cref="ArgumentOutOfRangeException">Thrown when tokensToKeep is invalid.</exception>
protected virtual Task HandleRunOutOfContext(int tokensToKeep, IInferenceParams inferenceParams)
Comment thread
martindevans marked this conversation as resolved.
{
// if we run out of context:
// - take the tokensToKeep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches
// 1. Fast Fail if not configured to auto-truncate
if (inferenceParams.OverflowStrategy == ContextOverflowStrategy.ThrowException)
{
throw new ContextOverflowException();
}

// 2. Guard: Stateful executors currently require native shifting to truncate.
// TODO (Future Improvement): To support truncation on models where MemoryCanShift == false,
// StatefulExecutorBase needs an unconditional `List<LLamaToken> _history_tokens` to track
// all ingested/generated tokens so we can clear the KV cache and perform a full re-prefill.
if (!Context.NativeHandle.MemoryCanShift)
{
_logger?.LogError("Model does not support native memory shifting. Stateful truncation requires MemoryCanShift = true.");
throw new ContextOverflowException("Model does not support native memory shifting. Context overflowed.");
}

// 3. Calculate tokens safely
var n_left = _pastTokensCount - tokensToKeep;
var n_discard = n_left / 2;
if (n_left <= 0)
{
throw new ArgumentOutOfRangeException(nameof(tokensToKeep), "Cannot truncate context: tokensToKeep exceeds or equals the current context size.");
}

// Clamp the percentage between 1% and 99% to prevent math errors or total wipeouts
var percentage = Math.Max(0.01f, Math.Min(0.99f, inferenceParams.ContextTruncationPercentage));
Comment thread
martindevans marked this conversation as resolved.
var n_discard = (int)(n_left * percentage);

// Sanity check: always discard at least 1 token, but never more than we have available.
n_discard = Math.Max(1, Math.Min(n_discard, n_left));

// 4. Fast path: attempt the fast native memory shift
Context.NativeHandle.MemorySequenceRemove(LLamaSeqId.Zero, tokensToKeep, tokensToKeep + n_discard);
Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard);

_pastTokensCount -= n_discard;
// stop saving session if we run out of context

Copy link

Copilot AI Apr 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HandleRunOutOfContext now shifts the KV cache but no longer disables session saving (previously _pathSession was cleared on overflow). Since _session_tokens is still appended to later, SaveSessionFile can produce a session token list that no longer matches the shifted KV cache. Either keep the old behavior (stop session saving after a shift) or update _session_tokens (and related counters) to reflect discarded tokens so saved sessions remain loadable/correct.

Suggested change
// Keep session tracking aligned with the shifted KV cache so future session saves/reuse
// operate on the same logical token sequence as the current context.
var sessionDiscardStart = Math.Min(tokensToKeep, _session_tokens.Count);
var sessionDiscardEnd = Math.Min(tokensToKeep + n_discard, _session_tokens.Count);
var removedSessionTokens = sessionDiscardEnd - sessionDiscardStart;
if (removedSessionTokens > 0)
{
_session_tokens.RemoveRange(sessionDiscardStart, removedSessionTokens);
if (_n_session_consumed > sessionDiscardStart)
{
_n_session_consumed = _n_session_consumed >= sessionDiscardEnd
? _n_session_consumed - removedSessionTokens
: sessionDiscardStart;
}
}
if (_n_session_consumed > _session_tokens.Count)
{
_n_session_consumed = _session_tokens.Count;
}

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK! We will stop saving the session if we run out of context.

// Stop saving the session if we run out of context.
// Note: A more advanced (but riskier and more complex) solution would be to physically trim
// the _session_tokens list and adjust _n_session_consumed to perfectly match the newly
// shifted native memory. This would allow session saving to continue safely, but requires
// precise index tracking to avoid off-by-one errors. For now, we abort saving to prevent corruption.
_pathSession = string.Empty;

return Task.CompletedTask;
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
_instructionSuffix = instructionSuffix;
}

public InstructExecutor(LLamaContext context,

Check warning on line 50 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Windows x64 CPU

Missing XML comment for publicly visible type or member 'InstructExecutor.InstructExecutor(LLamaContext, MtmdWeights, string, string, ILogger?)'

Check warning on line 50 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Linux ARM64 CPU

Missing XML comment for publicly visible type or member 'InstructExecutor.InstructExecutor(LLamaContext, MtmdWeights, string, string, ILogger?)'

Check warning on line 50 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Linux x64 CPU

Missing XML comment for publicly visible type or member 'InstructExecutor.InstructExecutor(LLamaContext, MtmdWeights, string, string, ILogger?)'

Check warning on line 50 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / macOS ARM64 Metal

Missing XML comment for publicly visible type or member 'InstructExecutor.InstructExecutor(LLamaContext, MtmdWeights, string, string, ILogger?)'
MtmdWeights clipModel,
string instructionPrefix = "\n\n### Instruction:\n\n",
string instructionSuffix = "\n\n### Response:\n\n",
Expand Down Expand Up @@ -233,7 +233,7 @@
// Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334
// Instruct always uses input token size.
var tokensToKeep = _embed_inps.Count;
HandleRunOutOfContext(tokensToKeep);
await HandleRunOutOfContext(tokensToKeep, inferenceParams);
}

TryReuseMatchingPrefix();
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
private bool _is_prompt_run = true;

// MTMD multimodal state
private SafeMtmdInputChunks? _mtmdChunks; // Pending chunk collection produced by the multimodal tokenizer.

Check warning on line 29 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Windows x64 CPU

The field 'InteractiveExecutor._mtmdChunks' is never used

Check warning on line 29 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Linux ARM64 CPU

The field 'InteractiveExecutor._mtmdChunks' is never used

Check warning on line 29 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Linux x64 CPU

The field 'InteractiveExecutor._mtmdChunks' is never used

Check warning on line 29 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / macOS ARM64 Metal

The field 'InteractiveExecutor._mtmdChunks' is never used
private string? _mtmdMarker; // Cached multimodal marker returned by the native helper.

Check warning on line 30 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Windows x64 CPU

The field 'InteractiveExecutor._mtmdMarker' is never used

Check warning on line 30 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Linux ARM64 CPU

The field 'InteractiveExecutor._mtmdMarker' is never used

Check warning on line 30 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Linux x64 CPU

The field 'InteractiveExecutor._mtmdMarker' is never used

Check warning on line 30 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / macOS ARM64 Metal

The field 'InteractiveExecutor._mtmdMarker' is never used


/// <summary>
Expand Down Expand Up @@ -118,7 +118,7 @@
/// </summary>
/// <param name="args">Mutable inference state.</param>
/// <returns><c>true</c> to keep generating; otherwise <c>false</c>.</returns>
protected override Task<bool> GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken)

Check warning on line 121 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Windows x64 CPU

Parameter 'cancellationToken' has no matching param tag in the XML comment for 'InteractiveExecutor.GetLoopCondition(StatefulExecutorBase.InferStateArgs, CancellationToken)' (but other parameters do)

Check warning on line 121 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Linux ARM64 CPU

Parameter 'cancellationToken' has no matching param tag in the XML comment for 'InteractiveExecutor.GetLoopCondition(StatefulExecutorBase.InferStateArgs, CancellationToken)' (but other parameters do)

Check warning on line 121 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Linux x64 CPU

Parameter 'cancellationToken' has no matching param tag in the XML comment for 'InteractiveExecutor.GetLoopCondition(StatefulExecutorBase.InferStateArgs, CancellationToken)' (but other parameters do)

Check warning on line 121 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / macOS ARM64 Metal

Parameter 'cancellationToken' has no matching param tag in the XML comment for 'InteractiveExecutor.GetLoopCondition(StatefulExecutorBase.InferStateArgs, CancellationToken)' (but other parameters do)
{
return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run);
}
Expand All @@ -128,7 +128,7 @@
/// </summary>
/// <param name="text">Prompt text or continuation provided by the caller.</param>
/// <param name="args">Mutable inference state.</param>
protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken)

Check warning on line 131 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Windows x64 CPU

Parameter 'cancellationToken' has no matching param tag in the XML comment for 'InteractiveExecutor.PreprocessInputs(string?, StatefulExecutorBase.InferStateArgs, CancellationToken)' (but other parameters do)

Check warning on line 131 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Linux ARM64 CPU

Parameter 'cancellationToken' has no matching param tag in the XML comment for 'InteractiveExecutor.PreprocessInputs(string?, StatefulExecutorBase.InferStateArgs, CancellationToken)' (but other parameters do)

Check warning on line 131 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Linux x64 CPU

Parameter 'cancellationToken' has no matching param tag in the XML comment for 'InteractiveExecutor.PreprocessInputs(string?, StatefulExecutorBase.InferStateArgs, CancellationToken)' (but other parameters do)

Check warning on line 131 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / macOS ARM64 Metal

Parameter 'cancellationToken' has no matching param tag in the XML comment for 'InteractiveExecutor.PreprocessInputs(string?, StatefulExecutorBase.InferStateArgs, CancellationToken)' (but other parameters do)
{
if (_is_prompt_run)
{
Expand Down Expand Up @@ -232,7 +232,7 @@
tokensToKeep += Convert.ToInt32(Context.Vocab.ShouldAddBOS); // always keep the BOS token
}

HandleRunOutOfContext(tokensToKeep);
await HandleRunOutOfContext(tokensToKeep, inferenceParams);
}

if (MtmdChunks is null)
Expand Down
Loading
Loading