diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index c9a374549..b816fe778 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Text.Json; +using CommunityToolkit.HighPerformance.Buffers; using LLama.Native; namespace LLama.Batched; @@ -224,18 +225,13 @@ public void Prompt(List tokens, bool allLogits = false) Prompt(span, allLogits); #else // Borrow an array and copy tokens into it - var arr = ArrayPool.Shared.Rent(tokens.Count); - try - { - for (var i = 0; i < tokens.Count; i++) - arr[i] = tokens[i]; + using var span = SpanOwner.Allocate(tokens.Count); + + for (var i = 0; i < tokens.Count; i++) + span.Span[i] = tokens[i]; + + Prompt(span.Span); - Prompt(arr.AsSpan()); - } - finally - { - ArrayPool.Shared.Return(arr); - } #endif } diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs index afbd97a90..c87239df7 100644 --- a/LLama/Extensions/IReadOnlyListExtensions.cs +++ b/LLama/Extensions/IReadOnlyListExtensions.cs @@ -2,6 +2,7 @@ using System.Collections; using System.Collections.Generic; using System.Text; +using CommunityToolkit.HighPerformance.Buffers; using LLama.Native; namespace LLama.Extensions @@ -49,23 +50,17 @@ internal static bool TokensEndsWithAnyString(this TTokens tok longest = Math.Max(longest, candidate.Length); // Rent an array to detokenize into - var builderArray = ArrayPool.Shared.Rent(longest); - try - { - // Convert as many tokens as possible into the builderArray - var characters = model.TokensToSpan(tokens, builderArray.AsSpan(0, longest), encoding); + using var builderArray = SpanOwner.Allocate(longest); - // Check every query to see if it's present - foreach (var query in queries) - if (characters.EndsWith(query.AsSpan())) - return true; + // Convert as many tokens as possible into the builderArray + var characters = model.TokensToSpan(tokens, builderArray.Span, encoding); - return false; - } - finally - { - ArrayPool.Shared.Return(builderArray); - } + // Check every query to see if it's present + foreach (var query in queries) + if (characters.EndsWith(query.AsSpan())) + return true; + + return false; } /// diff --git a/LLama/Extensions/ListExtensions.cs b/LLama/Extensions/ListExtensions.cs index eb30a07a0..0261bcd52 100644 --- a/LLama/Extensions/ListExtensions.cs +++ b/LLama/Extensions/ListExtensions.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; namespace LLama.Extensions @@ -20,5 +20,16 @@ public static void AddSpan(this List list, ReadOnlySpan items) for (var i = 0; i < items.Length; i++) list.Add(items[i]); } + +#if !NET6_0_OR_GREATER + public static void CopyTo(this List list, Span dest) + { + if (dest.Length < list.Count) + throw new ArgumentException($"dest is too small ({dest.Length} < {list.Count})"); + + for (var i = 0; i < list.Count; i++) + dest[i] = list[i]; + } +#endif } } diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index e960a414e..1c1b0f036 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -49,6 +49,7 @@ + diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index d58b60a90..2f1f3f54a 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using CommunityToolkit.HighPerformance.Buffers; namespace LLama.Native; @@ -205,16 +206,10 @@ public int Add(LLamaToken token, LLamaPos pos, List sequences, bool // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't // avoid the copying. - var rented = ArrayPool.Shared.Rent(sequences.Count); - try - { - sequences.CopyTo(rented, 0); - return Add(token, pos, rented.AsSpan(0, sequences.Count), logits); - } - finally - { - ArrayPool.Shared.Return(rented); - } + using var rented = SpanOwner.Allocate(sequences.Count); + sequences.CopyTo(rented.Span); + return Add(token, pos, rented.Span, logits); + #endif } diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 32d2cccf5..1e4e2a476 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Numerics.Tensors; using System.Runtime.CompilerServices; +using CommunityToolkit.HighPerformance.Buffers; namespace LLama.Native { @@ -101,25 +102,19 @@ public void Softmax() // Calculate softmax. Using TensorPrimitives is very fast (it uses SIMD etc) and is // definitely correct! So just copy to a temp and use that. - var tempLogits = ArrayPool.Shared.Rent(data.Length); - var tempLogitsSpan = tempLogits.AsSpan(0, data.Length); - try - { - // Copy to temporary - for (var i = 0; i < data.Length; i++) - tempLogitsSpan[i] = data[i].Logit; + using var tempLogitsSpan = SpanOwner.Allocate(data.Length); - // Softmax - TensorPrimitives.SoftMax(tempLogitsSpan, tempLogitsSpan); + // Copy to temporary + for (var i = 0; i < data.Length; i++) + tempLogitsSpan.Span[i] = data[i].Logit; + + // Softmax + TensorPrimitives.SoftMax(tempLogitsSpan.Span, tempLogitsSpan.Span); + + // Copy back + for (var i = 0; i < data.Length; i++) + data[i].Probability = tempLogitsSpan.Span[i]; - // Copy back - for (var i = 0; i < data.Length; i++) - data[i].Probability = tempLogitsSpan[i]; - } - finally - { - ArrayPool.Shared.Return(tempLogits, true); - } } private struct LLamaTokenDataLogitComparerDescending diff --git a/LLama/Native/NativeApi.Grammar.cs b/LLama/Native/NativeApi.Grammar.cs index d7397d9cc..450167288 100644 --- a/LLama/Native/NativeApi.Grammar.cs +++ b/LLama/Native/NativeApi.Grammar.cs @@ -1,5 +1,3 @@ -using System; - namespace LLama.Native { public static partial class NativeApi diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 0fd39176b..45eaf3e97 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.IO; using System.Text; +using CommunityToolkit.HighPerformance.Buffers; using LLama.Exceptions; namespace LLama.Native @@ -247,12 +248,12 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string key, Span dest) { var bytesCount = Encoding.UTF8.GetByteCount(key); - var bytes = ArrayPool.Shared.Rent(bytesCount); + using var bytes = SpanOwner.Allocate(bytesCount); unsafe { fixed (char* keyPtr = key) - fixed (byte* bytesPtr = bytes) + fixed (byte* bytesPtr = bytes.Span) fixed (byte* destPtr = dest) { // Convert text into bytes @@ -471,34 +472,28 @@ public LLamaToken[] Tokenize(string text, bool addBos, bool special, Encoding en // Convert string to bytes, adding one extra byte to the end (null terminator) var bytesCount = encoding.GetByteCount(text); - var bytes = ArrayPool.Shared.Rent(bytesCount + 1); - try + using var bytes = SpanOwner.Allocate(bytesCount + 1, AllocationMode.Clear); + + unsafe { - unsafe + fixed (char* textPtr = text) + fixed (byte* bytesPtr = bytes.Span) { - fixed (char* textPtr = text) - fixed (byte* bytesPtr = bytes) + // Convert text into bytes + encoding.GetBytes(textPtr, text.Length, bytesPtr, bytes.Length); + + // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) + var count = -NativeApi.llama_tokenize(llama_model_get_vocab(this), bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, addBos, special); + + // Tokenize again, this time outputting into an array of exactly the right size + var tokens = new LLamaToken[count]; + fixed (LLamaToken* tokensPtr = tokens) { - // Convert text into bytes - encoding.GetBytes(textPtr, text.Length, bytesPtr, bytes.Length); - - // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) - var count = -NativeApi.llama_tokenize(llama_model_get_vocab(this), bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, addBos, special); - - // Tokenize again, this time outputting into an array of exactly the right size - var tokens = new LLamaToken[count]; - fixed (LLamaToken* tokensPtr = tokens) - { - _ = NativeApi.llama_tokenize(llama_model_get_vocab(this), bytesPtr, bytesCount, tokensPtr, count, addBos, special); - return tokens; - } + _ = NativeApi.llama_tokenize(llama_model_get_vocab(this), bytesPtr, bytesCount, tokensPtr, count, addBos, special); + return tokens; } } } - finally - { - ArrayPool.Shared.Return(bytes, true); - } } #endregion diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index cd8f57f27..9330764f3 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using CommunityToolkit.HighPerformance.Buffers; using LLama.Native; namespace LLama.Sampling; @@ -173,27 +174,22 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl if (LogitBias.Count > 0) { - // Rent a temporary array and copy the biases into it - var biases = ArrayPool.Shared.Rent(LogitBias.Count); - try - { - var index = 0; - foreach (var bias in LogitBias) - { - biases[index++] = new LLamaLogitBias - { - Token = bias.Key, - Bias = bias.Value - }; - } + using var biases = SpanOwner.Allocate(LogitBias.Count); - // Add the biases to the sampler - chain.AddLogitBias(context.Vocab.Count, biases.AsSpan(0, LogitBias.Count)); - } - finally + // copy the biases into it + var index = 0; + foreach (var bias in LogitBias) { - ArrayPool.Shared.Return(biases); + biases.Span[index++] = new LLamaLogitBias + { + Token = bias.Key, + Bias = bias.Value + }; } + + // Add the biases to the sampler + chain.AddLogitBias(context.Vocab.Count, biases.Span); + } chain.AddPenalties(PenaltyCount, RepeatPenalty, FrequencyPenalty, PresencePenalty); @@ -219,99 +215,82 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) _grammarChain ??= CreateGrammarChain(ctx); // Rent some buffers to use later - var rentedBufferVocabSizeArr = ArrayPool.Shared.Rent(ctx.ModelHandle.Vocab.Count); - var rentedBufferVocabSize = rentedBufferVocabSizeArr.AsMemory(0, ctx.ModelHandle.Vocab.Count); - var rentedBufferSingleItemArr = ArrayPool.Shared.Rent(1); - var rentedBufferSingleItem = rentedBufferSingleItemArr.AsMemory(0, 1); + using var bufferVocabSize = MemoryOwner.Allocate(ctx.ModelHandle.Vocab.Count); + using var bufferSingleItem = MemoryOwner.Allocate(1); - try + // Handle grammar optimization modes + if (GrammarOptimization != GrammarOptimizationMode.None) { - // Handle grammar optimization modes - if (GrammarOptimization != GrammarOptimizationMode.None) + // Basic optimization : Apply the grammar to the selected token and check if it's valid + using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), bufferVocabSize.Memory), out var nativeAll)) { - // Basic optimization : Apply the grammar to the selected token and check if it's valid - using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) - { - // Apply the chain without the grammar to select one token which may or may not be valid - Apply(ctx, ref nativeAll); + // Apply the chain without the grammar to select one token which may or may not be valid + Apply(ctx, ref nativeAll); - // Select the candidate token - var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID; + // Select the candidate token + var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID; - // Now create another token data array with just that one token - rentedBufferSingleItem.Span[0] = new LLamaTokenData(candidateToken, 1, 0); - using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferSingleItem, true), out var nativeSingleCandidate)) - { - // Apply the grammar chain to the single candidate - _grammarChain.Apply(ref nativeSingleCandidate); + // Now create another token data array with just that one token + bufferSingleItem.Span[0] = new LLamaTokenData(candidateToken, 1, 0); + using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(bufferSingleItem.Memory, true), out var nativeSingleCandidate)) + { + // Apply the grammar chain to the single candidate + _grammarChain.Apply(ref nativeSingleCandidate); - // Check if the token passes the grammar - if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit)) - { - Accept(candidateToken); - return candidateToken; - } + // Check if the token passes the grammar + if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit)) + { + Accept(candidateToken); + return candidateToken; } + } - // Extended optimization : Apply the grammar to the TopK tokens and check if the selected token is valid - if (GrammarOptimization == GrammarOptimizationMode.Extended) - { - // Calculate a safe TopK value - var safeTopK = Math.Min(TopK, nativeAll.Data.Length); + // Extended optimization : Apply the grammar to the TopK tokens and check if the selected token is valid + if (GrammarOptimization == GrammarOptimizationMode.Extended) + { + // Calculate a safe TopK value + var safeTopK = Math.Min(TopK, nativeAll.Data.Length); - // Rent a buffer for the TopK candidates - var rentedBufferTopKArr = ArrayPool.Shared.Rent(safeTopK); - var rentedBufferTopK = rentedBufferTopKArr.AsMemory(0, safeTopK); - try - { - // Copy only the TopK tokens from the existing candidate pool to the new buffer - nativeAll.Data.Slice(0, safeTopK).CopyTo(rentedBufferTopK.Span); - - // Create a native array with the TopK tokens - using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferTopK, true), out var nativeTopK)) - { - // Apply the grammar chain to the TopK candidates - _grammarChain.Apply(ref nativeTopK); + // Rent a buffer for the TopK candidates + using var bufferTopK = MemoryOwner.Allocate(safeTopK); + + // Copy only the TopK tokens from the existing candidate pool to the new buffer + nativeAll.Data.Slice(0, safeTopK).CopyTo(bufferTopK.Span); + + // Create a native array with the TopK tokens + using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(bufferTopK.Memory, true), out var nativeTopK)) + { + // Apply the grammar chain to the TopK candidates + _grammarChain.Apply(ref nativeTopK); - // Select the candidate token - var candidateTokenTopK = nativeTopK.Data[checked((int)nativeTopK.Selected)]; + // Select the candidate token + var candidateTokenTopK = nativeTopK.Data[checked((int)nativeTopK.Selected)]; - // Check if the token passes the grammar - if (!float.IsNegativeInfinity(candidateTokenTopK.Logit)) - { - // Accept and return the token - Accept(candidateTokenTopK.ID); - return candidateTokenTopK.ID; - } - } - } - finally + // Check if the token passes the grammar + if (!float.IsNegativeInfinity(candidateTokenTopK.Logit)) { - ArrayPool.Shared.Return(rentedBufferTopKArr); + // Accept and return the token + Accept(candidateTokenTopK.ID); + return candidateTokenTopK.ID; } } } } + } - // If we get here the grammar rejected the token - using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) - { - // Apply the grammar _first_. This is slower (since it has to work on the entire vocab), but guaranteed to work - _grammarChain.Apply(ref nativeAll); + // If we get here the grammar rejected the token + using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), bufferVocabSize.Memory), out var nativeAll)) + { + // Apply the grammar _first_. This is slower (since it has to work on the entire vocab), but guaranteed to work + _grammarChain.Apply(ref nativeAll); - // Now apply the rest of the pipeline - Apply(ctx, ref nativeAll); + // Now apply the rest of the pipeline + Apply(ctx, ref nativeAll); - // Take the selected token - var token = nativeAll.Data[checked((int)nativeAll.Selected)].ID; - Accept(token); - return token; - } - } - finally - { - ArrayPool.Shared.Return(rentedBufferVocabSizeArr); - ArrayPool.Shared.Return(rentedBufferSingleItemArr); + // Take the selected token + var token = nativeAll.Data[checked((int)nativeAll.Selected)].ID; + Accept(token); + return token; } }