Skip to content

Commit 4f3f0ed

Browse files
committed
Totally rewritten the LLamaEmbedder based on https://github.com/ggerganov/llama.cpp/tree/master/examples/embedding. New embedder properly handles pooling, either returning one embedding for the whole sequence or one per token.
- Added `Encode` methods to `LLamaContext` - Moved some native methods from `NativeApi` to `SafeLLamaContextHandle` and wrapped them properly - Added `HasDecoder` property to `SafeLlamaModelHandle`. This function doesn't exist in the current version of llama.cpp, will need to be hooked up in the next binary update - Added some normalization methods as extensions on span/array. This required adding a dependency on `System.Numerics.Tensors`
1 parent df8cc71 commit 4f3f0ed

12 files changed

Lines changed: 408 additions & 155 deletions

File tree

LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using LLama;
22
using LLama.Common;
3+
using LLama.Native;
34
using Microsoft.KernelMemory;
45
using Microsoft.KernelMemory.AI;
56

@@ -35,7 +36,8 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config)
3536
GpuLayerCount = config.GpuLayerCount ?? 20,
3637
Embeddings = true,
3738
MainGpu = config.MainGpu,
38-
SplitMode = config.SplitMode
39+
SplitMode = config.SplitMode,
40+
PoolingType = LLamaPoolingType.Mean,
3941
};
4042
_weights = LLamaWeights.LoadFromFile(@params);
4143
_embedder = new LLamaEmbedder(_weights, @params);
@@ -59,7 +61,8 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we
5961
GpuLayerCount = config.GpuLayerCount ?? 20,
6062
Embeddings = true,
6163
MainGpu = config.MainGpu,
62-
SplitMode = config.SplitMode
64+
SplitMode = config.SplitMode,
65+
PoolingType = LLamaPoolingType.Mean,
6366
};
6467
_weights = weights;
6568
_embedder = new LLamaEmbedder(_weights, @params);
@@ -92,7 +95,7 @@ public void Dispose()
9295
public async Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
9396
{
9497
var embeddings = await _embedder.GetEmbeddings(text, cancellationToken);
95-
return new Embedding(embeddings);
98+
return new Embedding(embeddings.First());
9699
}
97100

98101
/// <inheritdoc/>

LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
namespace LLamaSharp.SemanticKernel.TextEmbedding;
66

7-
public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService
7+
public sealed class LLamaSharpEmbeddingGeneration
8+
: ITextEmbeddingGenerationService
89
{
910
private readonly LLamaEmbedder _embedder;
1011

@@ -23,7 +24,7 @@ public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<st
2324
var result = new List<ReadOnlyMemory<float>>();
2425

2526
foreach (var item in data)
26-
result.Add(await _embedder.GetEmbeddings(item, cancellationToken));
27+
result.Add((await _embedder.GetEmbeddings(item, cancellationToken)).First());
2728

2829
return result;
2930
}

LLama.Unittest/LLamaEmbedderTests.cs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using LLama.Common;
2+
using LLama.Extensions;
3+
using LLama.Native;
24
using Xunit.Abstractions;
35

46
namespace LLama.Unittest;
@@ -26,17 +28,18 @@ private async Task CompareEmbeddings(string modelPath)
2628
Threads = 4,
2729
Embeddings = true,
2830
GpuLayerCount = Constants.CIGpuLayerCount,
31+
PoolingType = LLamaPoolingType.Mean,
2932
};
3033
using var weights = LLamaWeights.LoadFromFile(@params);
3134
using var embedder = new LLamaEmbedder(weights, @params);
3235

33-
var cat = await embedder.GetEmbeddings("The cat is cute");
36+
var cat = (await embedder.GetEmbeddings("The cat is cute")).Single().EuclideanNormalization();
3437
Assert.DoesNotContain(float.NaN, cat);
3538

36-
var kitten = await embedder.GetEmbeddings("The kitten is kawaii");
39+
var kitten = (await embedder.GetEmbeddings("The kitten is cute")).Single().EuclideanNormalization();
3740
Assert.DoesNotContain(float.NaN, kitten);
3841

39-
var spoon = await embedder.GetEmbeddings("The spoon is not real");
42+
var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization();
4043
Assert.DoesNotContain(float.NaN, spoon);
4144

4245
_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
@@ -64,4 +67,34 @@ public async Task EmbedCompareGenerateModel()
6467
{
6568
await CompareEmbeddings(Constants.GenerativeModelPath);
6669
}
70+
71+
private async Task NonPooledEmbeddings(string modelPath)
72+
{
73+
var @params = new ModelParams(modelPath)
74+
{
75+
ContextSize = 8,
76+
Threads = 4,
77+
Embeddings = true,
78+
GpuLayerCount = Constants.CIGpuLayerCount,
79+
PoolingType = LLamaPoolingType.None,
80+
};
81+
using var weights = LLamaWeights.LoadFromFile(@params);
82+
using var embedder = new LLamaEmbedder(weights, @params);
83+
84+
var kitten = await embedder.GetEmbeddings("the kitten is kawaii");
85+
foreach (var embd in kitten)
86+
Assert.DoesNotContain(float.NaN, embd);
87+
}
88+
89+
[Fact]
90+
public async Task EmbeddingModelNonPooledEmbeddings()
91+
{
92+
await NonPooledEmbeddings(Constants.EmbeddingModelPath);
93+
}
94+
95+
[Fact]
96+
public async Task GenerativeModelNonPooledEmbeddings()
97+
{
98+
await NonPooledEmbeddings(Constants.GenerativeModelPath);
99+
}
67100
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
using System;
2+
using System.Numerics.Tensors;
3+
4+
namespace LLama.Extensions;
5+
6+
/// <summary>
7+
/// Extensions to span which apply <b>in-place</b> normalization
8+
/// </summary>
9+
public static class SpanNormalizationExtensions
10+
{
11+
/// <summary>
12+
/// <b>In-place</b> multiple every element by 32760 and divide every element in the span by the max absolute value in the span
13+
/// </summary>
14+
/// <param name="vector"></param>
15+
/// <returns>The same array</returns>
16+
public static float[] MaxAbsoluteNormalization(this float[] vector)
17+
{
18+
vector.AsSpan().MaxAbsoluteNormalization();
19+
return vector;
20+
}
21+
22+
/// <summary>
23+
/// <b>In-place</b> multiple every element by 32760 and divide every element in the span by the max absolute value in the span
24+
/// </summary>
25+
/// <param name="vector"></param>
26+
/// <returns>The same span</returns>
27+
public static Span<float> MaxAbsoluteNormalization(this Span<float> vector)
28+
{
29+
var factor = 32760 / TensorPrimitives.MaxMagnitude(vector);
30+
TensorPrimitives.Multiply(vector, factor, vector);
31+
return vector;
32+
}
33+
34+
/// <summary>
35+
/// <b>In-place</b> divide every element in the array by the sum of absolute values in the array
36+
/// </summary>
37+
/// <remarks>Also known as "Manhattan normalization".</remarks>
38+
/// <param name="vector"></param>
39+
/// <returns>The same array</returns>
40+
public static float[] TaxicabNormalization(this float[] vector)
41+
{
42+
vector.AsSpan().TaxicabNormalization();
43+
return vector;
44+
}
45+
46+
/// <summary>
47+
/// <b>In-place</b> divide every element in the span by the sum of absolute values in the span
48+
/// </summary>
49+
/// <remarks>Also known as "Manhattan normalization".</remarks>
50+
/// <param name="vector"></param>
51+
/// <returns>The same span</returns>
52+
public static Span<float> TaxicabNormalization(this Span<float> vector)
53+
{
54+
var sumAbs = TensorPrimitives.SumOfMagnitudes(vector);
55+
TensorPrimitives.Divide(vector, sumAbs, vector);
56+
return vector;
57+
}
58+
59+
/// <summary>
60+
/// <b>In-place</b> divide every element by the euclidean length of the vector
61+
/// </summary>
62+
/// <remarks>Also known as "L2 normalization".</remarks>
63+
/// <param name="vector"></param>
64+
/// <returns>The same array</returns>
65+
public static float[] EuclideanNormalization(this float[] vector)
66+
{
67+
vector.AsSpan().EuclideanNormalization();
68+
return vector;
69+
}
70+
71+
/// <summary>
72+
/// <b>In-place</b> divide every element by the euclidean length of the vector
73+
/// </summary>
74+
/// <remarks>Also known as "L2 normalization".</remarks>
75+
/// <param name="vector"></param>
76+
/// <returns>The same span</returns>
77+
public static Span<float> EuclideanNormalization(this Span<float> vector)
78+
{
79+
var norm = TensorPrimitives.Norm(vector);
80+
TensorPrimitives.Divide(vector, norm, vector);
81+
return vector;
82+
}
83+
84+
/// <summary>
85+
/// <b>In-place</b> apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
86+
/// <list type="bullet">
87+
/// <item>For p = 1, this is taxicab normalization</item>
88+
/// <item>For p = 2, this is euclidean normalization</item>
89+
/// <item>As p => infinity, this approaches infinity norm or maximum norm</item>
90+
/// </list>
91+
/// </summary>
92+
/// <param name="vector"></param>
93+
/// <param name="p"></param>
94+
/// <returns>The same array</returns>
95+
public static float[] PNormalization(this float[] vector, int p)
96+
{
97+
vector.AsSpan().PNormalization(p);
98+
return vector;
99+
}
100+
101+
/// <summary>
102+
/// <b>In-place</b> apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
103+
/// <list type="bullet">
104+
/// <item>For p = 1, this is taxicab normalization</item>
105+
/// <item>For p = 2, this is euclidean normalization</item>
106+
/// <item>As p => infinity, this approaches infinity norm or maximum norm</item>
107+
/// </list>
108+
/// </summary>
109+
/// <param name="vector"></param>
110+
/// <param name="p"></param>
111+
/// <returns>The same span</returns>
112+
public static Span<float> PNormalization(this Span<float> vector, int p)
113+
{
114+
if (p == 2)
115+
return vector.EuclideanNormalization();
116+
117+
var sum = 0.0;
118+
for (var i = 0; i < vector.Length; i++)
119+
sum += MathF.Pow(vector[i], p);
120+
var divisor = (float)Math.Pow(sum, 1.0 / p);
121+
122+
TensorPrimitives.Divide(vector, divisor, vector);
123+
124+
return vector;
125+
}
126+
}

LLama/LLamaContext.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,28 @@ public bool ShouldAddBosToken()
379379
}
380380

381381
#region eval overloads
382+
/// <summary>
383+
/// </summary>
384+
/// <param name="batch"></param>
385+
public EncodeResult Encode(LLamaBatch batch)
386+
{
387+
if (batch.TokenCount == 0)
388+
return 0;
389+
if (batch.TokenCount > BatchSize)
390+
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));
391+
392+
return (EncodeResult)NativeHandle.Encode(batch);
393+
}
394+
395+
/// <summary>
396+
/// </summary>
397+
/// <param name="batch"></param>
398+
/// <param name="cancellationToken"></param>
399+
public Task<EncodeResult> EncodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
400+
{
401+
return Task.Run(() => Encode(batch), cancellationToken);
402+
}
403+
382404
/// <summary>
383405
/// </summary>
384406
/// <param name="batch"></param>

0 commit comments

Comments
 (0)