From 50a0c1c12e7fa24f923a4537268b06727b2d1773 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 26 Feb 2025 23:53:46 +0000 Subject: [PATCH 1/3] Added a new MemoryRental.cs system, to prevent common bugs when using `ArrayPool` --- LLama/Pooling/MemoryRental.cs | 92 +++++++++++++++ LLama/Sampling/DefaultSamplingPipeline.cs | 132 ++++++++++------------ 2 files changed, 150 insertions(+), 74 deletions(-) create mode 100644 LLama/Pooling/MemoryRental.cs diff --git a/LLama/Pooling/MemoryRental.cs b/LLama/Pooling/MemoryRental.cs new file mode 100644 index 000000000..7ca39d45d --- /dev/null +++ b/LLama/Pooling/MemoryRental.cs @@ -0,0 +1,92 @@ +using System; + +namespace LLama.Pooling; + +/// +/// A memory rental which can be stored on the heap +/// +/// +internal readonly struct LongMemoryRental + : IDisposable +{ + public readonly Memory Memory; + private readonly T[] _arr; + + private LongMemoryRental(T[] arr, Memory mem) + { + _arr = arr; + Memory = mem; + } + + /// + /// Borrow a slice of memory which is the given length + /// + /// + /// + public static LongMemoryRental Rent(int length) + { + return Rent(length, out _); + } + + /// + /// Borrow a slice of memory which is the given length + /// + /// + /// + public static LongMemoryRental Rent(int length, out Memory memory) + { + var arr = ArrayPool.Shared.Rent(length); + memory = arr.AsMemory(0, length); + + return new(arr, memory); + } + + public void Dispose() + { + ArrayPool.Shared.Return(_arr); + } +} + +/// +/// A memory rental in a ref struct, cannot be stored on the heap +/// +/// +internal readonly ref struct MemoryRental +{ + public readonly Memory Memory; + private readonly T[] _arr; + + private MemoryRental(T[] arr, Memory mem) + { + _arr = arr; + Memory = mem; + } + + /// + /// Borrow a slice of memory which is the given length + /// + /// + /// + public static MemoryRental Rent(int length) + { + return Rent(length, out _); + } + + /// + /// Borrow a slice of memory which is the given length + /// + /// + /// + public static MemoryRental Rent(int length, out Memory memory) + { + var arr = ArrayPool.Shared.Rent(length); + memory = arr.AsMemory(0, length); + + return new MemoryRental(arr, memory); + } + + public void Dispose() + { + ArrayPool.Shared.Return(_arr); + } +} \ No newline at end of file diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index cd8f57f27..63b52adf9 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using LLama.Native; +using LLama.Pooling; namespace LLama.Sampling; @@ -219,99 +220,82 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) _grammarChain ??= CreateGrammarChain(ctx); // Rent some buffers to use later - var rentedBufferVocabSizeArr = ArrayPool.Shared.Rent(ctx.ModelHandle.Vocab.Count); - var rentedBufferVocabSize = rentedBufferVocabSizeArr.AsMemory(0, ctx.ModelHandle.Vocab.Count); - var rentedBufferSingleItemArr = ArrayPool.Shared.Rent(1); - var rentedBufferSingleItem = rentedBufferSingleItemArr.AsMemory(0, 1); + using var rbvs = MemoryRental.Rent(ctx.ModelHandle.Vocab.Count, out var rentedBufferVocabSize); + using var rbsi = MemoryRental.Rent(1, out var rentedBufferSingleItem); - try + // Handle grammar optimization modes + if (GrammarOptimization != GrammarOptimizationMode.None) { - // Handle grammar optimization modes - if (GrammarOptimization != GrammarOptimizationMode.None) + // Basic optimization : Apply the grammar to the selected token and check if it's valid + using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) { - // Basic optimization : Apply the grammar to the selected token and check if it's valid - using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) - { - // Apply the chain without the grammar to select one token which may or may not be valid - Apply(ctx, ref nativeAll); + // Apply the chain without the grammar to select one token which may or may not be valid + Apply(ctx, ref nativeAll); - // Select the candidate token - var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID; + // Select the candidate token + var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID; - // Now create another token data array with just that one token - rentedBufferSingleItem.Span[0] = new LLamaTokenData(candidateToken, 1, 0); - using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferSingleItem, true), out var nativeSingleCandidate)) - { - // Apply the grammar chain to the single candidate - _grammarChain.Apply(ref nativeSingleCandidate); + // Now create another token data array with just that one token + rentedBufferSingleItem.Span[0] = new LLamaTokenData(candidateToken, 1, 0); + using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferSingleItem, true), out var nativeSingleCandidate)) + { + // Apply the grammar chain to the single candidate + _grammarChain.Apply(ref nativeSingleCandidate); - // Check if the token passes the grammar - if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit)) - { - Accept(candidateToken); - return candidateToken; - } + // Check if the token passes the grammar + if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit)) + { + Accept(candidateToken); + return candidateToken; } + } - // Extended optimization : Apply the grammar to the TopK tokens and check if the selected token is valid - if (GrammarOptimization == GrammarOptimizationMode.Extended) - { - // Calculate a safe TopK value - var safeTopK = Math.Min(TopK, nativeAll.Data.Length); + // Extended optimization : Apply the grammar to the TopK tokens and check if the selected token is valid + if (GrammarOptimization == GrammarOptimizationMode.Extended) + { + // Calculate a safe TopK value + var safeTopK = Math.Min(TopK, nativeAll.Data.Length); - // Rent a buffer for the TopK candidates - var rentedBufferTopKArr = ArrayPool.Shared.Rent(safeTopK); - var rentedBufferTopK = rentedBufferTopKArr.AsMemory(0, safeTopK); - try - { - // Copy only the TopK tokens from the existing candidate pool to the new buffer - nativeAll.Data.Slice(0, safeTopK).CopyTo(rentedBufferTopK.Span); - - // Create a native array with the TopK tokens - using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferTopK, true), out var nativeTopK)) - { - // Apply the grammar chain to the TopK candidates - _grammarChain.Apply(ref nativeTopK); + // Rent a buffer for the TopK candidates + using var rbtk = MemoryRental.Rent(safeTopK, out var rentedBufferTopK); + + // Copy only the TopK tokens from the existing candidate pool to the new buffer + nativeAll.Data.Slice(0, safeTopK).CopyTo(rentedBufferTopK.Span); + + // Create a native array with the TopK tokens + using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferTopK, true), out var nativeTopK)) + { + // Apply the grammar chain to the TopK candidates + _grammarChain.Apply(ref nativeTopK); - // Select the candidate token - var candidateTokenTopK = nativeTopK.Data[checked((int)nativeTopK.Selected)]; + // Select the candidate token + var candidateTokenTopK = nativeTopK.Data[checked((int)nativeTopK.Selected)]; - // Check if the token passes the grammar - if (!float.IsNegativeInfinity(candidateTokenTopK.Logit)) - { - // Accept and return the token - Accept(candidateTokenTopK.ID); - return candidateTokenTopK.ID; - } - } - } - finally + // Check if the token passes the grammar + if (!float.IsNegativeInfinity(candidateTokenTopK.Logit)) { - ArrayPool.Shared.Return(rentedBufferTopKArr); + // Accept and return the token + Accept(candidateTokenTopK.ID); + return candidateTokenTopK.ID; } } } } + } - // If we get here the grammar rejected the token - using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) - { - // Apply the grammar _first_. This is slower (since it has to work on the entire vocab), but guaranteed to work - _grammarChain.Apply(ref nativeAll); + // If we get here the grammar rejected the token + using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) + { + // Apply the grammar _first_. This is slower (since it has to work on the entire vocab), but guaranteed to work + _grammarChain.Apply(ref nativeAll); - // Now apply the rest of the pipeline - Apply(ctx, ref nativeAll); + // Now apply the rest of the pipeline + Apply(ctx, ref nativeAll); - // Take the selected token - var token = nativeAll.Data[checked((int)nativeAll.Selected)].ID; - Accept(token); - return token; - } - } - finally - { - ArrayPool.Shared.Return(rentedBufferVocabSizeArr); - ArrayPool.Shared.Return(rentedBufferSingleItemArr); + // Take the selected token + var token = nativeAll.Data[checked((int)nativeAll.Selected)].ID; + Accept(token); + return token; } } From 8d3aa04f26b9ef3fceacea36de54a420dc9b8041 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 27 Feb 2025 00:13:43 +0000 Subject: [PATCH 2/3] Replaced most uses of ArrayPool with `SpanRental`/`MemoryRental` --- LLama/Batched/Conversation.cs | 18 ++++----- LLama/Extensions/IReadOnlyListExtensions.cs | 25 +++++-------- LLama/Extensions/ListExtensions.cs | 13 ++++++- LLama/Native/LLamaBatch.cs | 15 +++----- LLama/Native/LLamaTokenDataArray.cs | 29 ++++++--------- LLama/Native/SafeLlamaModelHandle.cs | 41 +++++++++------------ LLama/Pooling/MemoryRental.cs | 37 +++++++++++-------- LLama/Sampling/DefaultSamplingPipeline.cs | 32 +++++++--------- 8 files changed, 100 insertions(+), 110 deletions(-) diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index c9a374549..a50d30467 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Text.Json; using LLama.Native; +using LLama.Pooling; namespace LLama.Batched; @@ -224,18 +225,13 @@ public void Prompt(List tokens, bool allLogits = false) Prompt(span, allLogits); #else // Borrow an array and copy tokens into it - var arr = ArrayPool.Shared.Rent(tokens.Count); - try - { - for (var i = 0; i < tokens.Count; i++) - arr[i] = tokens[i]; + using var span = SpanRental.Rent(tokens.Count); + + for (var i = 0; i < tokens.Count; i++) + span.Span[i] = tokens[i]; + + Prompt(span.Span); - Prompt(arr.AsSpan()); - } - finally - { - ArrayPool.Shared.Return(arr); - } #endif } diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs index afbd97a90..f4fdf3ab3 100644 --- a/LLama/Extensions/IReadOnlyListExtensions.cs +++ b/LLama/Extensions/IReadOnlyListExtensions.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Text; using LLama.Native; +using LLama.Pooling; namespace LLama.Extensions { @@ -49,23 +50,17 @@ internal static bool TokensEndsWithAnyString(this TTokens tok longest = Math.Max(longest, candidate.Length); // Rent an array to detokenize into - var builderArray = ArrayPool.Shared.Rent(longest); - try - { - // Convert as many tokens as possible into the builderArray - var characters = model.TokensToSpan(tokens, builderArray.AsSpan(0, longest), encoding); + using var builderArray = SpanRental.Rent(longest); - // Check every query to see if it's present - foreach (var query in queries) - if (characters.EndsWith(query.AsSpan())) - return true; + // Convert as many tokens as possible into the builderArray + var characters = model.TokensToSpan(tokens, builderArray.Span, encoding); - return false; - } - finally - { - ArrayPool.Shared.Return(builderArray); - } + // Check every query to see if it's present + foreach (var query in queries) + if (characters.EndsWith(query.AsSpan())) + return true; + + return false; } /// diff --git a/LLama/Extensions/ListExtensions.cs b/LLama/Extensions/ListExtensions.cs index eb30a07a0..0261bcd52 100644 --- a/LLama/Extensions/ListExtensions.cs +++ b/LLama/Extensions/ListExtensions.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; namespace LLama.Extensions @@ -20,5 +20,16 @@ public static void AddSpan(this List list, ReadOnlySpan items) for (var i = 0; i < items.Length; i++) list.Add(items[i]); } + +#if !NET6_0_OR_GREATER + public static void CopyTo(this List list, Span dest) + { + if (dest.Length < list.Count) + throw new ArgumentException($"dest is too small ({dest.Length} < {list.Count})"); + + for (var i = 0; i < list.Count; i++) + dest[i] = list[i]; + } +#endif } } diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index d58b60a90..cce8a14fb 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using LLama.Pooling; namespace LLama.Native; @@ -205,16 +206,10 @@ public int Add(LLamaToken token, LLamaPos pos, List sequences, bool // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't // avoid the copying. - var rented = ArrayPool.Shared.Rent(sequences.Count); - try - { - sequences.CopyTo(rented, 0); - return Add(token, pos, rented.AsSpan(0, sequences.Count), logits); - } - finally - { - ArrayPool.Shared.Return(rented); - } + var rented = SpanRental.Rent(sequences.Count); + sequences.CopyTo(rented.Span); + return Add(token, pos, rented.Span, logits); + #endif } diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 32d2cccf5..285cda0e2 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Numerics.Tensors; using System.Runtime.CompilerServices; +using LLama.Pooling; namespace LLama.Native { @@ -101,25 +102,19 @@ public void Softmax() // Calculate softmax. Using TensorPrimitives is very fast (it uses SIMD etc) and is // definitely correct! So just copy to a temp and use that. - var tempLogits = ArrayPool.Shared.Rent(data.Length); - var tempLogitsSpan = tempLogits.AsSpan(0, data.Length); - try - { - // Copy to temporary - for (var i = 0; i < data.Length; i++) - tempLogitsSpan[i] = data[i].Logit; + using var rental = SpanRental.Rent(data.Length, out var tempLogitsSpan); - // Softmax - TensorPrimitives.SoftMax(tempLogitsSpan, tempLogitsSpan); + // Copy to temporary + for (var i = 0; i < data.Length; i++) + tempLogitsSpan[i] = data[i].Logit; + + // Softmax + TensorPrimitives.SoftMax(tempLogitsSpan, tempLogitsSpan); + + // Copy back + for (var i = 0; i < data.Length; i++) + data[i].Probability = tempLogitsSpan[i]; - // Copy back - for (var i = 0; i < data.Length; i++) - data[i].Probability = tempLogitsSpan[i]; - } - finally - { - ArrayPool.Shared.Return(tempLogits, true); - } } private struct LLamaTokenDataLogitComparerDescending diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 9439c2bb3..09663f5dd 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -4,6 +4,7 @@ using System.IO; using System.Text; using LLama.Exceptions; +using LLama.Pooling; namespace LLama.Native { @@ -247,7 +248,7 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string key, Span dest) { var bytesCount = Encoding.UTF8.GetByteCount(key); - var bytes = ArrayPool.Shared.Rent(bytesCount); + using var rental = SpanRental.Rent(bytesCount, out var bytes); unsafe { @@ -471,34 +472,28 @@ public LLamaToken[] Tokenize(string text, bool addBos, bool special, Encoding en // Convert string to bytes, adding one extra byte to the end (null terminator) var bytesCount = encoding.GetByteCount(text); - var bytes = ArrayPool.Shared.Rent(bytesCount + 1); - try + using var rental = SpanRental.Rent(bytesCount + 1, out var bytes, clear:true); + + unsafe { - unsafe + fixed (char* textPtr = text) + fixed (byte* bytesPtr = bytes) { - fixed (char* textPtr = text) - fixed (byte* bytesPtr = bytes) + // Convert text into bytes + encoding.GetBytes(textPtr, text.Length, bytesPtr, bytes.Length); + + // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) + var count = -NativeApi.llama_tokenize(llama_model_get_vocab(this), bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, addBos, special); + + // Tokenize again, this time outputting into an array of exactly the right size + var tokens = new LLamaToken[count]; + fixed (LLamaToken* tokensPtr = tokens) { - // Convert text into bytes - encoding.GetBytes(textPtr, text.Length, bytesPtr, bytes.Length); - - // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) - var count = -NativeApi.llama_tokenize(llama_model_get_vocab(this), bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, addBos, special); - - // Tokenize again, this time outputting into an array of exactly the right size - var tokens = new LLamaToken[count]; - fixed (LLamaToken* tokensPtr = tokens) - { - _ = NativeApi.llama_tokenize(llama_model_get_vocab(this), bytesPtr, bytesCount, tokensPtr, count, addBos, special); - return tokens; - } + _ = NativeApi.llama_tokenize(llama_model_get_vocab(this), bytesPtr, bytesCount, tokensPtr, count, addBos, special); + return tokens; } } } - finally - { - ArrayPool.Shared.Return(bytes, true); - } } #endregion diff --git a/LLama/Pooling/MemoryRental.cs b/LLama/Pooling/MemoryRental.cs index 7ca39d45d..43e706d97 100644 --- a/LLama/Pooling/MemoryRental.cs +++ b/LLama/Pooling/MemoryRental.cs @@ -6,13 +6,13 @@ namespace LLama.Pooling; /// A memory rental which can be stored on the heap /// /// -internal readonly struct LongMemoryRental +internal readonly struct MemoryRental : IDisposable { public readonly Memory Memory; private readonly T[] _arr; - private LongMemoryRental(T[] arr, Memory mem) + private MemoryRental(T[] arr, Memory mem) { _arr = arr; Memory = mem; @@ -23,7 +23,7 @@ private LongMemoryRental(T[] arr, Memory mem) /// /// /// - public static LongMemoryRental Rent(int length) + public static MemoryRental Rent(int length) { return Rent(length, out _); } @@ -32,8 +32,9 @@ public static LongMemoryRental Rent(int length) /// Borrow a slice of memory which is the given length /// /// + /// /// - public static LongMemoryRental Rent(int length, out Memory memory) + public static MemoryRental Rent(int length, out Memory memory) { var arr = ArrayPool.Shared.Rent(length); memory = arr.AsMemory(0, length); @@ -48,45 +49,51 @@ public void Dispose() } /// -/// A memory rental in a ref struct, cannot be stored on the heap +/// A rented span, cannot be stored on the heap. Use if the +/// rental must be stored on the heap /// /// -internal readonly ref struct MemoryRental +internal readonly ref struct SpanRental { - public readonly Memory Memory; + public readonly Span Span; + private readonly bool _clear; private readonly T[] _arr; - private MemoryRental(T[] arr, Memory mem) + private SpanRental(T[] arr, Span span, bool clear) { _arr = arr; - Memory = mem; + Span = span; + _clear = clear; } /// /// Borrow a slice of memory which is the given length /// /// + /// /// - public static MemoryRental Rent(int length) + public static SpanRental Rent(int length, bool clear = false) { - return Rent(length, out _); + return Rent(length, out _, clear); } /// /// Borrow a slice of memory which is the given length /// /// + /// + /// /// - public static MemoryRental Rent(int length, out Memory memory) + public static SpanRental Rent(int length, out Span span, bool clear = false) { var arr = ArrayPool.Shared.Rent(length); - memory = arr.AsMemory(0, length); + span = arr.AsSpan(0, length); - return new MemoryRental(arr, memory); + return new SpanRental(arr, span, clear); } public void Dispose() { - ArrayPool.Shared.Return(_arr); + ArrayPool.Shared.Return(_arr, _clear); } } \ No newline at end of file diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index 63b52adf9..20f2e63ca 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -174,27 +174,23 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl if (LogitBias.Count > 0) { - // Rent a temporary array and copy the biases into it - var biases = ArrayPool.Shared.Rent(LogitBias.Count); - try - { - var index = 0; - foreach (var bias in LogitBias) - { - biases[index++] = new LLamaLogitBias - { - Token = bias.Key, - Bias = bias.Value - }; - } + // Rent a temporary array + using var rental = SpanRental.Rent(LogitBias.Count, out var biases); - // Add the biases to the sampler - chain.AddLogitBias(context.Vocab.Count, biases.AsSpan(0, LogitBias.Count)); - } - finally + // copy the biases into it + var index = 0; + foreach (var bias in LogitBias) { - ArrayPool.Shared.Return(biases); + biases[index++] = new LLamaLogitBias + { + Token = bias.Key, + Bias = bias.Value + }; } + + // Add the biases to the sampler + chain.AddLogitBias(context.Vocab.Count, biases); + } chain.AddPenalties(PenaltyCount, RepeatPenalty, FrequencyPenalty, PresencePenalty); From 5581a7fdc9d18edaa41ed989af0db5e8d9c7a31e Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sun, 9 Mar 2025 00:16:31 +0000 Subject: [PATCH 3/3] Using `CommunityToolkit.HighPerformance` instead of a custom buffer renting system --- LLama/Batched/Conversation.cs | 4 +- LLama/Extensions/IReadOnlyListExtensions.cs | 4 +- LLama/LLamaSharp.csproj | 1 + LLama/Native/LLamaBatch.cs | 4 +- LLama/Native/LLamaTokenDataArray.cs | 10 +-- LLama/Native/NativeApi.Grammar.cs | 2 - LLama/Native/SafeLlamaModelHandle.cs | 10 +-- LLama/Pooling/MemoryRental.cs | 99 --------------------- LLama/Sampling/DefaultSamplingPipeline.cs | 27 +++--- 9 files changed, 30 insertions(+), 131 deletions(-) delete mode 100644 LLama/Pooling/MemoryRental.cs diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index a50d30467..b816fe778 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -2,8 +2,8 @@ using System.Collections.Generic; using System.Linq; using System.Text.Json; +using CommunityToolkit.HighPerformance.Buffers; using LLama.Native; -using LLama.Pooling; namespace LLama.Batched; @@ -225,7 +225,7 @@ public void Prompt(List tokens, bool allLogits = false) Prompt(span, allLogits); #else // Borrow an array and copy tokens into it - using var span = SpanRental.Rent(tokens.Count); + using var span = SpanOwner.Allocate(tokens.Count); for (var i = 0; i < tokens.Count; i++) span.Span[i] = tokens[i]; diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs index f4fdf3ab3..c87239df7 100644 --- a/LLama/Extensions/IReadOnlyListExtensions.cs +++ b/LLama/Extensions/IReadOnlyListExtensions.cs @@ -2,8 +2,8 @@ using System.Collections; using System.Collections.Generic; using System.Text; +using CommunityToolkit.HighPerformance.Buffers; using LLama.Native; -using LLama.Pooling; namespace LLama.Extensions { @@ -50,7 +50,7 @@ internal static bool TokensEndsWithAnyString(this TTokens tok longest = Math.Max(longest, candidate.Length); // Rent an array to detokenize into - using var builderArray = SpanRental.Rent(longest); + using var builderArray = SpanOwner.Allocate(longest); // Convert as many tokens as possible into the builderArray var characters = model.TokensToSpan(tokens, builderArray.Span, encoding); diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index 86ed11c10..a3b45c4f5 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -49,6 +49,7 @@ + diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index cce8a14fb..2f1f3f54a 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; -using LLama.Pooling; +using CommunityToolkit.HighPerformance.Buffers; namespace LLama.Native; @@ -206,7 +206,7 @@ public int Add(LLamaToken token, LLamaPos pos, List sequences, bool // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't // avoid the copying. - var rented = SpanRental.Rent(sequences.Count); + using var rented = SpanOwner.Allocate(sequences.Count); sequences.CopyTo(rented.Span); return Add(token, pos, rented.Span, logits); diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 285cda0e2..1e4e2a476 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Numerics.Tensors; using System.Runtime.CompilerServices; -using LLama.Pooling; +using CommunityToolkit.HighPerformance.Buffers; namespace LLama.Native { @@ -102,18 +102,18 @@ public void Softmax() // Calculate softmax. Using TensorPrimitives is very fast (it uses SIMD etc) and is // definitely correct! So just copy to a temp and use that. - using var rental = SpanRental.Rent(data.Length, out var tempLogitsSpan); + using var tempLogitsSpan = SpanOwner.Allocate(data.Length); // Copy to temporary for (var i = 0; i < data.Length; i++) - tempLogitsSpan[i] = data[i].Logit; + tempLogitsSpan.Span[i] = data[i].Logit; // Softmax - TensorPrimitives.SoftMax(tempLogitsSpan, tempLogitsSpan); + TensorPrimitives.SoftMax(tempLogitsSpan.Span, tempLogitsSpan.Span); // Copy back for (var i = 0; i < data.Length; i++) - data[i].Probability = tempLogitsSpan[i]; + data[i].Probability = tempLogitsSpan.Span[i]; } diff --git a/LLama/Native/NativeApi.Grammar.cs b/LLama/Native/NativeApi.Grammar.cs index d7397d9cc..450167288 100644 --- a/LLama/Native/NativeApi.Grammar.cs +++ b/LLama/Native/NativeApi.Grammar.cs @@ -1,5 +1,3 @@ -using System; - namespace LLama.Native { public static partial class NativeApi diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 09663f5dd..4156fe355 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -3,8 +3,8 @@ using System.Diagnostics; using System.IO; using System.Text; +using CommunityToolkit.HighPerformance.Buffers; using LLama.Exceptions; -using LLama.Pooling; namespace LLama.Native { @@ -248,12 +248,12 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string key, Span dest) { var bytesCount = Encoding.UTF8.GetByteCount(key); - using var rental = SpanRental.Rent(bytesCount, out var bytes); + using var bytes = SpanOwner.Allocate(bytesCount); unsafe { fixed (char* keyPtr = key) - fixed (byte* bytesPtr = bytes) + fixed (byte* bytesPtr = bytes.Span) fixed (byte* destPtr = dest) { // Convert text into bytes @@ -472,12 +472,12 @@ public LLamaToken[] Tokenize(string text, bool addBos, bool special, Encoding en // Convert string to bytes, adding one extra byte to the end (null terminator) var bytesCount = encoding.GetByteCount(text); - using var rental = SpanRental.Rent(bytesCount + 1, out var bytes, clear:true); + using var bytes = SpanOwner.Allocate(bytesCount + 1, AllocationMode.Clear); unsafe { fixed (char* textPtr = text) - fixed (byte* bytesPtr = bytes) + fixed (byte* bytesPtr = bytes.Span) { // Convert text into bytes encoding.GetBytes(textPtr, text.Length, bytesPtr, bytes.Length); diff --git a/LLama/Pooling/MemoryRental.cs b/LLama/Pooling/MemoryRental.cs deleted file mode 100644 index 43e706d97..000000000 --- a/LLama/Pooling/MemoryRental.cs +++ /dev/null @@ -1,99 +0,0 @@ -using System; - -namespace LLama.Pooling; - -/// -/// A memory rental which can be stored on the heap -/// -/// -internal readonly struct MemoryRental - : IDisposable -{ - public readonly Memory Memory; - private readonly T[] _arr; - - private MemoryRental(T[] arr, Memory mem) - { - _arr = arr; - Memory = mem; - } - - /// - /// Borrow a slice of memory which is the given length - /// - /// - /// - public static MemoryRental Rent(int length) - { - return Rent(length, out _); - } - - /// - /// Borrow a slice of memory which is the given length - /// - /// - /// - /// - public static MemoryRental Rent(int length, out Memory memory) - { - var arr = ArrayPool.Shared.Rent(length); - memory = arr.AsMemory(0, length); - - return new(arr, memory); - } - - public void Dispose() - { - ArrayPool.Shared.Return(_arr); - } -} - -/// -/// A rented span, cannot be stored on the heap. Use if the -/// rental must be stored on the heap -/// -/// -internal readonly ref struct SpanRental -{ - public readonly Span Span; - private readonly bool _clear; - private readonly T[] _arr; - - private SpanRental(T[] arr, Span span, bool clear) - { - _arr = arr; - Span = span; - _clear = clear; - } - - /// - /// Borrow a slice of memory which is the given length - /// - /// - /// - /// - public static SpanRental Rent(int length, bool clear = false) - { - return Rent(length, out _, clear); - } - - /// - /// Borrow a slice of memory which is the given length - /// - /// - /// - /// - /// - public static SpanRental Rent(int length, out Span span, bool clear = false) - { - var arr = ArrayPool.Shared.Rent(length); - span = arr.AsSpan(0, length); - - return new SpanRental(arr, span, clear); - } - - public void Dispose() - { - ArrayPool.Shared.Return(_arr, _clear); - } -} \ No newline at end of file diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index 20f2e63ca..9330764f3 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -1,7 +1,7 @@ using System; using System.Collections.Generic; +using CommunityToolkit.HighPerformance.Buffers; using LLama.Native; -using LLama.Pooling; namespace LLama.Sampling; @@ -174,14 +174,13 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl if (LogitBias.Count > 0) { - // Rent a temporary array - using var rental = SpanRental.Rent(LogitBias.Count, out var biases); + using var biases = SpanOwner.Allocate(LogitBias.Count); // copy the biases into it var index = 0; foreach (var bias in LogitBias) { - biases[index++] = new LLamaLogitBias + biases.Span[index++] = new LLamaLogitBias { Token = bias.Key, Bias = bias.Value @@ -189,7 +188,7 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl } // Add the biases to the sampler - chain.AddLogitBias(context.Vocab.Count, biases); + chain.AddLogitBias(context.Vocab.Count, biases.Span); } @@ -216,14 +215,14 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) _grammarChain ??= CreateGrammarChain(ctx); // Rent some buffers to use later - using var rbvs = MemoryRental.Rent(ctx.ModelHandle.Vocab.Count, out var rentedBufferVocabSize); - using var rbsi = MemoryRental.Rent(1, out var rentedBufferSingleItem); + using var bufferVocabSize = MemoryOwner.Allocate(ctx.ModelHandle.Vocab.Count); + using var bufferSingleItem = MemoryOwner.Allocate(1); // Handle grammar optimization modes if (GrammarOptimization != GrammarOptimizationMode.None) { // Basic optimization : Apply the grammar to the selected token and check if it's valid - using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) + using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), bufferVocabSize.Memory), out var nativeAll)) { // Apply the chain without the grammar to select one token which may or may not be valid Apply(ctx, ref nativeAll); @@ -232,8 +231,8 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID; // Now create another token data array with just that one token - rentedBufferSingleItem.Span[0] = new LLamaTokenData(candidateToken, 1, 0); - using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferSingleItem, true), out var nativeSingleCandidate)) + bufferSingleItem.Span[0] = new LLamaTokenData(candidateToken, 1, 0); + using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(bufferSingleItem.Memory, true), out var nativeSingleCandidate)) { // Apply the grammar chain to the single candidate _grammarChain.Apply(ref nativeSingleCandidate); @@ -253,13 +252,13 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) var safeTopK = Math.Min(TopK, nativeAll.Data.Length); // Rent a buffer for the TopK candidates - using var rbtk = MemoryRental.Rent(safeTopK, out var rentedBufferTopK); + using var bufferTopK = MemoryOwner.Allocate(safeTopK); // Copy only the TopK tokens from the existing candidate pool to the new buffer - nativeAll.Data.Slice(0, safeTopK).CopyTo(rentedBufferTopK.Span); + nativeAll.Data.Slice(0, safeTopK).CopyTo(bufferTopK.Span); // Create a native array with the TopK tokens - using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferTopK, true), out var nativeTopK)) + using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(bufferTopK.Memory, true), out var nativeTopK)) { // Apply the grammar chain to the TopK candidates _grammarChain.Apply(ref nativeTopK); @@ -280,7 +279,7 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) } // If we get here the grammar rejected the token - using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) + using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), bufferVocabSize.Memory), out var nativeAll)) { // Apply the grammar _first_. This is slower (since it has to work on the entire vocab), but guaranteed to work _grammarChain.Apply(ref nativeAll);