Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions LLama/Batched/Conversation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using CommunityToolkit.HighPerformance.Buffers;
using LLama.Native;

namespace LLama.Batched;
Expand Down Expand Up @@ -224,18 +225,13 @@ public void Prompt(List<LLamaToken> tokens, bool allLogits = false)
Prompt(span, allLogits);
#else
// Borrow an array and copy tokens into it
var arr = ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
try
{
for (var i = 0; i < tokens.Count; i++)
arr[i] = tokens[i];
using var span = SpanOwner<LLamaToken>.Allocate(tokens.Count);

for (var i = 0; i < tokens.Count; i++)
span.Span[i] = tokens[i];

Prompt(span.Span);

Prompt(arr.AsSpan());
}
finally
{
ArrayPool<LLamaToken>.Shared.Return(arr);
}
#endif
}

Expand Down
25 changes: 10 additions & 15 deletions LLama/Extensions/IReadOnlyListExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Text;
using CommunityToolkit.HighPerformance.Buffers;
using LLama.Native;

namespace LLama.Extensions
Expand Down Expand Up @@ -49,23 +50,17 @@ internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tok
longest = Math.Max(longest, candidate.Length);

// Rent an array to detokenize into
var builderArray = ArrayPool<char>.Shared.Rent(longest);
try
{
// Convert as many tokens as possible into the builderArray
var characters = model.TokensToSpan(tokens, builderArray.AsSpan(0, longest), encoding);
using var builderArray = SpanOwner<char>.Allocate(longest);

// Check every query to see if it's present
foreach (var query in queries)
if (characters.EndsWith(query.AsSpan()))
return true;
// Convert as many tokens as possible into the builderArray
var characters = model.TokensToSpan(tokens, builderArray.Span, encoding);

return false;
}
finally
{
ArrayPool<char>.Shared.Return(builderArray);
}
// Check every query to see if it's present
foreach (var query in queries)
if (characters.EndsWith(query.AsSpan()))
return true;

return false;
}

/// <summary>
Expand Down
13 changes: 12 additions & 1 deletion LLama/Extensions/ListExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Collections.Generic;

namespace LLama.Extensions
Expand All @@ -20,5 +20,16 @@ public static void AddSpan<T>(this List<T> list, ReadOnlySpan<T> items)
for (var i = 0; i < items.Length; i++)
list.Add(items[i]);
}

#if !NET6_0_OR_GREATER
public static void CopyTo<T>(this List<T> list, Span<T> dest)
{
if (dest.Length < list.Count)
throw new ArgumentException($"dest is too small ({dest.Length} < {list.Count})");

for (var i = 0; i < list.Count; i++)
dest[i] = list[i];
}
#endif
}
}
1 change: 1 addition & 0 deletions LLama/LLamaSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="CommunityToolkit.HighPerformance" Version="8.4.0" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="9.0.2" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.3.0-preview.1.25114.11" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="9.0.0" />
Expand Down
15 changes: 5 additions & 10 deletions LLama/Native/LLamaBatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using CommunityToolkit.HighPerformance.Buffers;

namespace LLama.Native;

Expand Down Expand Up @@ -205,16 +206,10 @@ public int Add(LLamaToken token, LLamaPos pos, List<LLamaSeqId> sequences, bool
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
// avoid the copying.

var rented = ArrayPool<LLamaSeqId>.Shared.Rent(sequences.Count);
try
{
sequences.CopyTo(rented, 0);
return Add(token, pos, rented.AsSpan(0, sequences.Count), logits);
}
finally
{
ArrayPool<LLamaSeqId>.Shared.Return(rented);
}
using var rented = SpanOwner<LLamaSeqId>.Allocate(sequences.Count);
sequences.CopyTo(rented.Span);
return Add(token, pos, rented.Span, logits);

#endif
}

Expand Down
29 changes: 12 additions & 17 deletions LLama/Native/LLamaTokenDataArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Numerics.Tensors;
using System.Runtime.CompilerServices;
using CommunityToolkit.HighPerformance.Buffers;

namespace LLama.Native
{
Expand Down Expand Up @@ -101,25 +102,19 @@ public void Softmax()

// Calculate softmax. Using TensorPrimitives is very fast (it uses SIMD etc) and is
// definitely correct! So just copy to a temp and use that.
var tempLogits = ArrayPool<float>.Shared.Rent(data.Length);
var tempLogitsSpan = tempLogits.AsSpan(0, data.Length);
try
{
// Copy to temporary
for (var i = 0; i < data.Length; i++)
tempLogitsSpan[i] = data[i].Logit;
using var tempLogitsSpan = SpanOwner<float>.Allocate(data.Length);

// Softmax
TensorPrimitives.SoftMax(tempLogitsSpan, tempLogitsSpan);
// Copy to temporary
for (var i = 0; i < data.Length; i++)
tempLogitsSpan.Span[i] = data[i].Logit;

// Softmax
TensorPrimitives.SoftMax(tempLogitsSpan.Span, tempLogitsSpan.Span);

// Copy back
for (var i = 0; i < data.Length; i++)
data[i].Probability = tempLogitsSpan.Span[i];

// Copy back
for (var i = 0; i < data.Length; i++)
data[i].Probability = tempLogitsSpan[i];
}
finally
{
ArrayPool<float>.Shared.Return(tempLogits, true);
}
}

private struct LLamaTokenDataLogitComparerDescending
Expand Down
2 changes: 0 additions & 2 deletions LLama/Native/NativeApi.Grammar.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using System;

namespace LLama.Native
{
public static partial class NativeApi
Expand Down
43 changes: 19 additions & 24 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Diagnostics;
using System.IO;
using System.Text;
using CommunityToolkit.HighPerformance.Buffers;
using LLama.Exceptions;

namespace LLama.Native
Expand Down Expand Up @@ -247,12 +248,12 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model,
private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string key, Span<byte> dest)
{
var bytesCount = Encoding.UTF8.GetByteCount(key);
var bytes = ArrayPool<byte>.Shared.Rent(bytesCount);
using var bytes = SpanOwner<byte>.Allocate(bytesCount);

unsafe
{
fixed (char* keyPtr = key)
fixed (byte* bytesPtr = bytes)
fixed (byte* bytesPtr = bytes.Span)
fixed (byte* destPtr = dest)
{
// Convert text into bytes
Expand Down Expand Up @@ -471,34 +472,28 @@ public LLamaToken[] Tokenize(string text, bool addBos, bool special, Encoding en

// Convert string to bytes, adding one extra byte to the end (null terminator)
var bytesCount = encoding.GetByteCount(text);
var bytes = ArrayPool<byte>.Shared.Rent(bytesCount + 1);
try
using var bytes = SpanOwner<byte>.Allocate(bytesCount + 1, AllocationMode.Clear);

unsafe
{
unsafe
fixed (char* textPtr = text)
fixed (byte* bytesPtr = bytes.Span)
{
fixed (char* textPtr = text)
fixed (byte* bytesPtr = bytes)
// Convert text into bytes
encoding.GetBytes(textPtr, text.Length, bytesPtr, bytes.Length);

// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize(llama_model_get_vocab(this), bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, addBos, special);

// Tokenize again, this time outputting into an array of exactly the right size
var tokens = new LLamaToken[count];
fixed (LLamaToken* tokensPtr = tokens)
{
// Convert text into bytes
encoding.GetBytes(textPtr, text.Length, bytesPtr, bytes.Length);

// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize(llama_model_get_vocab(this), bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, addBos, special);

// Tokenize again, this time outputting into an array of exactly the right size
var tokens = new LLamaToken[count];
fixed (LLamaToken* tokensPtr = tokens)
{
_ = NativeApi.llama_tokenize(llama_model_get_vocab(this), bytesPtr, bytesCount, tokensPtr, count, addBos, special);
return tokens;
}
_ = NativeApi.llama_tokenize(llama_model_get_vocab(this), bytesPtr, bytesCount, tokensPtr, count, addBos, special);
return tokens;
}
}
}
finally
{
ArrayPool<byte>.Shared.Return(bytes, true);
}
}
#endregion

Expand Down
Loading