From a296b55b9d2d8a368a0b9ae16f7bc7caf2064bb7 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 11 Jan 2025 21:32:41 +0000 Subject: [PATCH 1/2] Initial investigation into potential issue with batched executor (or batching in general) --- .../Examples/BatchedExecutorSimple.cs | 11 +++++--- LLama/Native/LLamaBatch.cs | 27 ++++++++++++++++--- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorSimple.cs b/LLama.Examples/Examples/BatchedExecutorSimple.cs index 4b0bf352b..a574ef9f1 100644 --- a/LLama.Examples/Examples/BatchedExecutorSimple.cs +++ b/LLama.Examples/Examples/BatchedExecutorSimple.cs @@ -84,12 +84,15 @@ await AnsiConsole.Live(table).StartAsync(async ctx => foreach (var conversationData in conversations.Where(c => c.IsComplete == false)) { - if (conversationData.Conversation.RequiresSampling == false) continue; + if (conversationData.Conversation.RequiresSampling == false) + continue; // sample a single token for the executor, passing the sample index of the conversation + var sampleIndex = conversationData.Conversation.GetSampleIndex(); var token = conversationData.Sampler.Sample( - executor.Context.NativeHandle, - conversationData.Conversation.GetSampleIndex()); + executor.Context, + sampleIndex + ); if (modelTokens.IsEndOfGeneration(token)) { @@ -99,7 +102,7 @@ await AnsiConsole.Live(table).StartAsync(async ctx => { // it isn't the end of generation, so add this token to the decoder and then add that to our tracked data conversationData.Decoder.Add(token); - conversationData.AppendAnswer(conversationData.Decoder.Read().ReplaceLineEndings(" ")); + todo: conversationData.AppendAnswer(conversationData.Decoder.Read().ReplaceLineEndings(" ")); // add the token to the conversation conversationData.Conversation.Prompt(token); diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index c66bd9277..a58e67a8c 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; namespace LLama.Native; @@ -105,6 +106,25 @@ private void GrowMaxSequences(int atLeast) internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) { + // Sanity checking +#if DEBUG + // Check every output logit position is actually generating logits for exactly one sequence + foreach (var (seq, idx) in _logitPositions) + { + Debug.Assert(_logits[idx] != 0); + Debug.Assert(_sequenceIdCount[idx] == 1); + Debug.Assert(_sequenceIds[idx][0] == seq); + } + + // Check the reverse + for (var i = 0; i < _logits.Length; i++) + { + var actual = _logitPositions.FindIndex(x => x.Item2 == i) >= 0; + var expected = _logits[i] != 0; + Debug.Assert(actual == expected); + } +#endif + // This group holds all of the memory pins var group = new GroupDisposable(); @@ -146,6 +166,7 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) /// The index that the token was added at. Use this for GetLogitsIth public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits) { + // todo: token sharing in batch is broken? // Try to find this (token, position) combo somewhere in the batch to re-use it by adding this // sequence ID to the list. // Do **not** do this if this token wants logits, to prevent logits being shared between sequences. @@ -171,9 +192,9 @@ public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequence if (sequences.Length > SequenceCapacity) GrowMaxSequences(sequences.Length); - // Store the position in the index, so it can be found later. - // We need to check that it's not already there in case we skipped the check above (because logits is true). - if (!_index.ContainsKey((token, pos))) + // Store the position in the index, so it can be found later. We don't want to share tokens when logits are being generated so + // do not add to the index in that case. + if (!logits && !_index.ContainsKey((token, pos))) _index.Add((token, pos), TokenCount); // Add the items to the arrays From 59cd4873713f36fb37262184f68b9cfd0a87a698 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 17 Jan 2025 03:07:51 +0000 Subject: [PATCH 2/2] - Removed the automatic shared token mechanism in LLamaBatch, it was causing problems with tokens that should never have been shared. - Added a helper to simplify sampling a conversation with a sampling pipeline - Added many more comments explaining BatchedExecutorSimple in detail --- .../Examples/BatchedExecutorSimple.cs | 68 ++++++++++++------- LLama/Batched/ConversationExtensions.cs | 13 ++++ LLama/LLamaTemplate.cs | 2 +- LLama/Native/LLamaBatch.cs | 42 ++---------- 4 files changed, 65 insertions(+), 60 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorSimple.cs b/LLama.Examples/Examples/BatchedExecutorSimple.cs index a574ef9f1..fcdf95e2d 100644 --- a/LLama.Examples/Examples/BatchedExecutorSimple.cs +++ b/LLama.Examples/Examples/BatchedExecutorSimple.cs @@ -1,4 +1,3 @@ -using System.Diagnostics.CodeAnalysis; using System.Text; using LLama.Batched; using LLama.Common; @@ -34,6 +33,7 @@ public static async Task Run() var name = model.Metadata.GetValueOrDefault("general.name", "unknown model name"); Console.WriteLine($"Created executor with model: {name}"); + // A set of questions to evaluate all at once var messages = new[] { "What's 2+2?", @@ -46,8 +46,10 @@ public static async Task Run() "I have two sons, Bert and Ernie. What should I name my daughter?", "What day comes after Friday?", "What color shoes should I wear with dark blue pants?", + "Wy ae cts btr tn dgs?" }; + // Create a "Conversation" for each question var conversations = new List(); foreach (var message in messages) { @@ -57,11 +59,14 @@ public static async Task Run() template.Add("user", message); template.AddAssistant = true; var templatedMessage = Encoding.UTF8.GetString(template.Apply()); - + // create a new conversation and prompt it. include special and bos because we are using the template + // - BOS is the "Beginning of Sequence" token and should be included at the start of any prompt + // - Special tokens are special non-text tokens which an LLM is trained to understand (e.g. BOS). The templated text may contains special tokens. var conversation = executor.Create(); conversation.Prompt(executor.Context.Tokenize(templatedMessage, addBos: true, special: true)); + // Store everything we need to process this conversation conversations.Add(new ConversationData { Prompt = message, Conversation = conversation, @@ -73,50 +78,64 @@ public static async Task Run() var table = BuildTable(conversations); await AnsiConsole.Live(table).StartAsync(async ctx => { + // Enter a loop generating tokens for (var i = 0; i < TokenCount; i++) { // Run inference for all conversations in the batch which have pending tokens. var decodeResult = await executor.Infer(); + + // Inference can fail, always check the return value! + // NoKvSlot is not a fatal error, it just means that there's not enough memory available in the KV cache to process everything. You can force + // this to happen by setting a small value for ContextSize in the ModelParams at the top of this file (e.g. 512). + // In this case it's handled by ending a conversation (which will free up some space) and trying again. You could also handle this by + // saving the conversation to disk and loading it up again later once some other conversations have finished. if (decodeResult == DecodeResult.NoKvSlot) - throw new Exception("Could not find a KV slot for the batch. Try reducing the size of the batch or increase the context."); + { + conversations.FirstOrDefault(a => !a.IsComplete)?.MarkComplete(failed:true); + continue; + } + + // A generic error, this is fatal and the batch can no longer be used. This should never occur and generally indicates + // a bug in LLamaSharp, llama.cpp or a hardware error. if (decodeResult == DecodeResult.Error) throw new Exception("Unknown error occurred while inferring."); - foreach (var conversationData in conversations.Where(c => c.IsComplete == false)) + // After inference all of the conversations must be sampled before running inference again. + foreach (var conversationData in conversations) { - if (conversationData.Conversation.RequiresSampling == false) + // Completed conversations don't need sampling. + if (conversationData.IsComplete) continue; - - // sample a single token for the executor, passing the sample index of the conversation - var sampleIndex = conversationData.Conversation.GetSampleIndex(); - var token = conversationData.Sampler.Sample( - executor.Context, - sampleIndex - ); - + + // If the conversation wasn't prompted before the last call to Infer then it won't need sampling. + if (!conversationData.Conversation.RequiresSampling) + continue; + + // Use the sampling pipeline to choose a single token for this conversation. + var token = conversationData.Conversation.Sample(conversationData.Sampler); + + // Some special tokens indicate that this sequence has ended. Check if that's what has been chosen by the sampling pipeline. if (modelTokens.IsEndOfGeneration(token)) { conversationData.MarkComplete(); } else { - // it isn't the end of generation, so add this token to the decoder and then add that to our tracked data + // It isn't the end of generation, so add this token to the decoder and then add that to our tracked data conversationData.Decoder.Add(token); - todo: conversationData.AppendAnswer(conversationData.Decoder.Read().ReplaceLineEndings(" ")); + conversationData.AppendAnswer(conversationData.Decoder.Read().ReplaceLineEndings(" ")); - // add the token to the conversation + // Prompt the conversation with this token, ready for the next round of inference to generate another token conversationData.Conversation.Prompt(token); } } - // render the current state + // Render the current state table = BuildTable(conversations); ctx.UpdateTarget(table); if (conversations.All(c => c.IsComplete)) - { break; - } } // if we ran out of tokens before completing just mark them as complete for rendering purposes. @@ -155,20 +174,23 @@ public class ConversationData public required BaseSamplingPipeline Sampler { get; init; } public required StreamingTokenDecoder Decoder { get; init; } - public string AnswerMarkdown => IsComplete - ? $"[green]{_inProgressAnswer.Message.EscapeMarkup()}{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]" - : $"[grey]{_inProgressAnswer.Message.EscapeMarkup()}[/][white]{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]"; + public string AnswerMarkdown => + IsComplete + ? $"[{(IsFailed ? "red" : "green")}]{_inProgressAnswer.Message.EscapeMarkup()}{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]" + : $"[grey]{_inProgressAnswer.Message.EscapeMarkup()}[/][white]{_inProgressAnswer.LatestToken.EscapeMarkup()}[/]"; public bool IsComplete { get; private set; } + public bool IsFailed { get; private set; } // we are only keeping track of the answer in two parts to render them differently. private (string Message, string LatestToken) _inProgressAnswer = (string.Empty, string.Empty); public void AppendAnswer(string newText) => _inProgressAnswer = (_inProgressAnswer.Message + _inProgressAnswer.LatestToken, newText); - public void MarkComplete() + public void MarkComplete(bool failed = false) { IsComplete = true; + IsFailed = failed; if (Conversation.IsDisposed == false) { // clean up the conversation and sampler to release more memory for inference. diff --git a/LLama/Batched/ConversationExtensions.cs b/LLama/Batched/ConversationExtensions.cs index b6b0d9eb1..eb0192061 100644 --- a/LLama/Batched/ConversationExtensions.cs +++ b/LLama/Batched/ConversationExtensions.cs @@ -1,5 +1,6 @@ using System; using LLama.Native; +using LLama.Sampling; namespace LLama.Batched; @@ -20,6 +21,18 @@ public static LLamaToken Sample(this Conversation conversation, SafeLLamaSampler return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset)); } + /// + /// Sample a token from this conversation using the given sampling pipeline + /// + /// to sample from + /// + /// Offset from the end of the conversation to the logits to sample, see for more details + /// + public static LLamaToken Sample(this Conversation conversation, ISamplingPipeline sampler, int offset = 0) + { + return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset)); + } + /// /// Rewind a back to an earlier state by removing tokens from the end /// diff --git a/LLama/LLamaTemplate.cs b/LLama/LLamaTemplate.cs index e82cbccb4..0fc19168e 100644 --- a/LLama/LLamaTemplate.cs +++ b/LLama/LLamaTemplate.cs @@ -210,7 +210,7 @@ public void Clear() #endregion /// - /// Apply the template to the messages and write it into the output buffer + /// Apply the template to the messages and return a span containing the results /// /// A span over the buffer that holds the applied template public ReadOnlySpan Apply() diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index a58e67a8c..76ef26c3c 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Linq; namespace LLama.Native; @@ -18,11 +19,6 @@ public class LLamaBatch private LLamaSeqId[][] _sequenceIds; private IntPtr[] _sequenceIdsPtrs; - /// - /// Keep track of the index of existing token/position combos in the batch - /// - private readonly Dictionary<(LLamaToken, LLamaPos), int> _index = new(); - /// /// Keep a list of where logits can be sampled from /// @@ -108,7 +104,7 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) { // Sanity checking #if DEBUG - // Check every output logit position is actually generating logits for exactly one sequence + // Check every output logit position is generating logits for exactly one sequence foreach (var (seq, idx) in _logitPositions) { Debug.Assert(_logits[idx] != 0); @@ -116,10 +112,10 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) Debug.Assert(_sequenceIds[idx][0] == seq); } - // Check the reverse + // Check every index, if it's generating logits it must be in the _logitPositions list. Otherwise it must not. for (var i = 0; i < _logits.Length; i++) { - var actual = _logitPositions.FindIndex(x => x.Item2 == i) >= 0; + var actual = _logitPositions.Any(x => x.Item2 == i); var expected = _logits[i] != 0; Debug.Assert(actual == expected); } @@ -166,37 +162,12 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) /// The index that the token was added at. Use this for GetLogitsIth public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits) { - // todo: token sharing in batch is broken? - // Try to find this (token, position) combo somewhere in the batch to re-use it by adding this - // sequence ID to the list. - // Do **not** do this if this token wants logits, to prevent logits being shared between sequences. - if (!logits && _index.TryGetValue((token, pos), out var existingIndex)) - { - if (_sequenceIdCount[existingIndex] + sequences.Length > SequenceCapacity) - GrowMaxSequences(_sequenceIdCount[existingIndex] + sequences.Length); - - foreach (var sequence in sequences) - { - _sequenceIds[existingIndex][_sequenceIdCount[existingIndex]] = sequence; - _sequenceIdCount[existingIndex]++; - } - - return existingIndex; - } - - // Couldn't find this token/position combo anywhere in the batch. Add a new item. - // Grow capacity as necessary if (TokenCount == TokenCapacity) GrowTokenCapacity(); if (sequences.Length > SequenceCapacity) GrowMaxSequences(sequences.Length); - // Store the position in the index, so it can be found later. We don't want to share tokens when logits are being generated so - // do not add to the index in that case. - if (!logits && !_index.ContainsKey((token, pos))) - _index.Add((token, pos), TokenCount); - // Add the items to the arrays _tokens[TokenCount] = token; _positions[TokenCount] = pos; @@ -234,7 +205,7 @@ public int Add(LLamaToken token, LLamaPos pos, List sequences, bool // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't // avoid the copying. - var rented = System.Buffers.ArrayPool.Shared.Rent(sequences.Count); + var rented = ArrayPool.Shared.Rent(sequences.Count); try { sequences.CopyTo(rented, 0); @@ -242,7 +213,7 @@ public int Add(LLamaToken token, LLamaPos pos, List sequences, bool } finally { - System.Buffers.ArrayPool.Shared.Return(rented); + ArrayPool.Shared.Return(rented); } #endif } @@ -294,7 +265,6 @@ public void Clear() { TokenCount = 0; - _index.Clear(); _logitPositions.Clear(); }