Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions LLama.Unittest/SamplingTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
using LLama.Common;
using LLama.Native;

using System.Numerics.Tensors;
using System.Runtime.InteropServices;
using System.Text;

using Xunit.Abstractions;

namespace LLama.Unittest
{
public class SamplingTests : IDisposable
{
private readonly ITestOutputHelper _testOutputHelper;
private readonly LLamaWeights _model;
private readonly ModelParams _params;

private readonly LLamaBatch _batch;
private readonly StreamingTokenDecoder _decoder;

public void Dispose() => _model.Dispose();

public SamplingTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
_params = new ModelParams(Constants.GenerativeModelPath) {
ContextSize = 200,
BatchSize = 200,
GpuLayerCount = Constants.CIGpuLayerCount,
};
_model = LLamaWeights.LoadFromFile(_params);
_batch = new LLamaBatch();
_decoder = new(Encoding.UTF8, _model);
}


[Fact]
public void Sampling()
{
using var context = new LLamaContext(_model, _params);
var tokens = _model.NativeHandle.Tokenize("I will repeat this phrase forever.\n", false, false, Encoding.UTF8);
var logitBias = tokens.Select(x => new LLamaLogitBias() { Token = x, Bias = -1000 }).ToArray();

// Add "I will repeat this phrase forever.\nI will", without requesting any logits.
for (int i = 0; i < tokens.Length; i++) { _batch.Add(token: tokens[i], pos: i, sequence: LLamaSeqId.Zero, logits: false); }
for (int i = 0; i < 2; i++) { _batch.Add(token: tokens[i], pos: tokens.Length + i, sequence: LLamaSeqId.Zero, logits: false); }

// Add " repeat" and test whether next tokens will be "this phrase forever.".
for (int i = 0; i < 4; i++)
{
_batch.Add(token: tokens[i + 2], pos: tokens.Length + i + 2, sequence: LLamaSeqId.Zero, logits: true);
DecodeAndClear(context);

var expected = tokens[i + 3];
var logits = context.NativeHandle.GetLogits(numTokens: 1);

// Test raw sampling
Assert.Equal(expected, TensorPrimitives.IndexOfMax(logits));

// Test native sampling with `LLamaTokenDataArrayNative`.
var array = LLamaTokenDataArray.Create(logits);
{
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
var rawLogits = new float[_model.VocabCount];
for (int j = 0; j < cur_p.Data.Length; j++)
{
rawLogits[(int) cur_p.Data[j].ID] = cur_p.Data[j].Logit;
}
Assert.Equal(expected, TensorPrimitives.IndexOfMax(rawLogits));
}

// Test sampling chain
{
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
using var chain = CreateChain(context.NativeHandle);
chain.Apply(ref cur_p);
Assert.Equal(expected, cur_p.Data[(int) cur_p.Selected].ID);
}

// Test logit bias
{
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
using var chain = CreateChain(context.NativeHandle, logitBias);
chain.Apply(ref cur_p);
Assert.NotEqual(expected, cur_p.Data[(int) cur_p.Selected].ID);
}
}
}


[Fact]
public void BatchedSampling()
{
const int batch_count = 4;
using var context = new LLamaContext(_model, _params);
var tokens = _model.NativeHandle.Tokenize("I will repeat this phrase forever.\n", false, false, Encoding.UTF8);
var logitBias = tokens.Select(x => new LLamaLogitBias() { Token = x, Bias = -1000 }).ToArray();

// Add "I will repeat this phrase forever.\nI will", without requesting any logits.
for (int i = 0; i < tokens.Length + 2; i++)
{
for (int b = 0; b < batch_count; b++)
{
_batch.Add(token: tokens[i % tokens.Length], pos: i, sequence: (LLamaSeqId) b, logits: false);
}
}

// Add " repeat" and test whether next tokens will be "this phrase forever.".
for (int i = 0; i < 4; i++)
{
for (int b = 0; b < batch_count; b++)
{
_batch.Add(token: tokens[i + 2], pos: tokens.Length + i + 2, sequence: (LLamaSeqId) b, logits: true);
}
DecodeAndClear(context);

var expected = tokens[i + 3];
var all_logits = context.NativeHandle.GetLogits(numTokens: batch_count);

for (int b = 0; b < batch_count; b++)
{
var logits = all_logits.Slice(b * _model.VocabCount, _model.VocabCount);

// Test raw sampling
Assert.Equal(expected, TensorPrimitives.IndexOfMax(logits));

// Test native sampling with `LLamaTokenDataArrayNative`.
var array = LLamaTokenDataArray.Create(logits);
{
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
var rawLogits = new float[_model.VocabCount];
for (int j = 0; j < cur_p.Data.Length; j++)
{
rawLogits[(int) cur_p.Data[j].ID] = cur_p.Data[j].Logit;
}
Assert.Equal(expected, TensorPrimitives.IndexOfMax(rawLogits));
}

// Test sampling chain
{
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
using var chain = CreateChain(context.NativeHandle);
chain.Apply(ref cur_p);
Assert.Equal(expected, cur_p.Data[(int) cur_p.Selected].ID);
}

// Test logit bias
{
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
using var chain = CreateChain(context.NativeHandle, logitBias);
chain.Apply(ref cur_p);
Assert.NotEqual(expected, cur_p.Data[(int) cur_p.Selected].ID);
}
}
}
}


private void DecodeAndClear(LLamaContext context)
{
context.Decode(_batch);
_batch.Clear();
}

private static SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context, LLamaLogitBias[]? logit_bias = null)
{
var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());

chain.AddPenalties(
vocabSize: context.VocabCount,
eos: context.ModelHandle.Tokens.EOS,
newline: context.ModelHandle.Tokens.Newline ?? 0,
penaltyCount: 60, repeat: 1, freq: 0, presence: 0,
penalizeNewline: false, ignoreEOS: false
);

if (logit_bias != null) { chain.AddLogitBias(context.VocabCount, logit_bias); }

chain.AddTopK(10);
chain.AddTemperature(0.1f);
chain.AddDistributionSampler(seed: 42);

return chain;
}
}
}
16 changes: 8 additions & 8 deletions LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ public void ToModelPrompt_FormatsCorrectly()

// Call once with empty array to discover length
var templateResult = PromptTemplateTransformer.ToModelPrompt(templater);
const string expected = "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nworld<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n111<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\naaa<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n222<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nbbb<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n333<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nccc<|eot_id|>"
const string expected = "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nworld<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n111<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\naaa<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n222<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nbbb<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n333<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nccc<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n";

Assert.Equal(expected, templateResult);
Expand Down
14 changes: 10 additions & 4 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -431,21 +431,27 @@ public void ClearLoraAdapters()

#region GetLogits
/// <summary>
/// Token logits obtained from the last call to llama_decode
/// The logits for the last token are stored in the last row
/// Token logits obtained from the last call to llama_decode.
/// The logits for the last token are stored in the last row.
/// Only tokens with `logits = true` requested are present.<br/>
/// Can be mutated in order to change the probabilities of the next token.<br />
/// Rows: n_tokens<br />
/// Cols: n_vocab
/// </summary>
/// <param name="numTokens">
/// The amount of tokens whose logits should be retrieved, in <b>[numTokens X n_vocab]</b> format.<br/>
/// Tokens' order is based on their order in the LlamaBatch (so, first tokens are first, etc).<br/>
/// This is helpful when requesting logits for many tokens in a sequence, or want to decode multiple sequences in one go.
/// </param>
/// <returns></returns>
public Span<float> GetLogits()
public Span<float> GetLogits(int numTokens = 1)
{
var model = ThrowIfDisposed();

unsafe
{
var logits = llama_get_logits(this);
return new Span<float>(logits, model.VocabCount);
return new Span<float>(logits, model.VocabCount * numTokens);
}
}

Expand Down