diff --git a/LLama.Unittest/SamplingTests.cs b/LLama.Unittest/SamplingTests.cs new file mode 100644 index 000000000..f322bc250 --- /dev/null +++ b/LLama.Unittest/SamplingTests.cs @@ -0,0 +1,186 @@ +using LLama.Common; +using LLama.Native; + +using System.Numerics.Tensors; +using System.Runtime.InteropServices; +using System.Text; + +using Xunit.Abstractions; + +namespace LLama.Unittest +{ + public class SamplingTests : IDisposable + { + private readonly ITestOutputHelper _testOutputHelper; + private readonly LLamaWeights _model; + private readonly ModelParams _params; + + private readonly LLamaBatch _batch; + private readonly StreamingTokenDecoder _decoder; + + public void Dispose() => _model.Dispose(); + + public SamplingTests(ITestOutputHelper testOutputHelper) + { + _testOutputHelper = testOutputHelper; + _params = new ModelParams(Constants.GenerativeModelPath) { + ContextSize = 200, + BatchSize = 200, + GpuLayerCount = Constants.CIGpuLayerCount, + }; + _model = LLamaWeights.LoadFromFile(_params); + _batch = new LLamaBatch(); + _decoder = new(Encoding.UTF8, _model); + } + + + [Fact] + public void Sampling() + { + using var context = new LLamaContext(_model, _params); + var tokens = _model.NativeHandle.Tokenize("I will repeat this phrase forever.\n", false, false, Encoding.UTF8); + var logitBias = tokens.Select(x => new LLamaLogitBias() { Token = x, Bias = -1000 }).ToArray(); + + // Add "I will repeat this phrase forever.\nI will", without requesting any logits. + for (int i = 0; i < tokens.Length; i++) { _batch.Add(token: tokens[i], pos: i, sequence: LLamaSeqId.Zero, logits: false); } + for (int i = 0; i < 2; i++) { _batch.Add(token: tokens[i], pos: tokens.Length + i, sequence: LLamaSeqId.Zero, logits: false); } + + // Add " repeat" and test whether next tokens will be "this phrase forever.". + for (int i = 0; i < 4; i++) + { + _batch.Add(token: tokens[i + 2], pos: tokens.Length + i + 2, sequence: LLamaSeqId.Zero, logits: true); + DecodeAndClear(context); + + var expected = tokens[i + 3]; + var logits = context.NativeHandle.GetLogits(numTokens: 1); + + // Test raw sampling + Assert.Equal(expected, TensorPrimitives.IndexOfMax(logits)); + + // Test native sampling with `LLamaTokenDataArrayNative`. + var array = LLamaTokenDataArray.Create(logits); + { + using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p); + var rawLogits = new float[_model.VocabCount]; + for (int j = 0; j < cur_p.Data.Length; j++) + { + rawLogits[(int) cur_p.Data[j].ID] = cur_p.Data[j].Logit; + } + Assert.Equal(expected, TensorPrimitives.IndexOfMax(rawLogits)); + } + + // Test sampling chain + { + using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p); + using var chain = CreateChain(context.NativeHandle); + chain.Apply(ref cur_p); + Assert.Equal(expected, cur_p.Data[(int) cur_p.Selected].ID); + } + + // Test logit bias + { + using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p); + using var chain = CreateChain(context.NativeHandle, logitBias); + chain.Apply(ref cur_p); + Assert.NotEqual(expected, cur_p.Data[(int) cur_p.Selected].ID); + } + } + } + + + [Fact] + public void BatchedSampling() + { + const int batch_count = 4; + using var context = new LLamaContext(_model, _params); + var tokens = _model.NativeHandle.Tokenize("I will repeat this phrase forever.\n", false, false, Encoding.UTF8); + var logitBias = tokens.Select(x => new LLamaLogitBias() { Token = x, Bias = -1000 }).ToArray(); + + // Add "I will repeat this phrase forever.\nI will", without requesting any logits. + for (int i = 0; i < tokens.Length + 2; i++) + { + for (int b = 0; b < batch_count; b++) + { + _batch.Add(token: tokens[i % tokens.Length], pos: i, sequence: (LLamaSeqId) b, logits: false); + } + } + + // Add " repeat" and test whether next tokens will be "this phrase forever.". + for (int i = 0; i < 4; i++) + { + for (int b = 0; b < batch_count; b++) + { + _batch.Add(token: tokens[i + 2], pos: tokens.Length + i + 2, sequence: (LLamaSeqId) b, logits: true); + } + DecodeAndClear(context); + + var expected = tokens[i + 3]; + var all_logits = context.NativeHandle.GetLogits(numTokens: batch_count); + + for (int b = 0; b < batch_count; b++) + { + var logits = all_logits.Slice(b * _model.VocabCount, _model.VocabCount); + + // Test raw sampling + Assert.Equal(expected, TensorPrimitives.IndexOfMax(logits)); + + // Test native sampling with `LLamaTokenDataArrayNative`. + var array = LLamaTokenDataArray.Create(logits); + { + using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p); + var rawLogits = new float[_model.VocabCount]; + for (int j = 0; j < cur_p.Data.Length; j++) + { + rawLogits[(int) cur_p.Data[j].ID] = cur_p.Data[j].Logit; + } + Assert.Equal(expected, TensorPrimitives.IndexOfMax(rawLogits)); + } + + // Test sampling chain + { + using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p); + using var chain = CreateChain(context.NativeHandle); + chain.Apply(ref cur_p); + Assert.Equal(expected, cur_p.Data[(int) cur_p.Selected].ID); + } + + // Test logit bias + { + using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p); + using var chain = CreateChain(context.NativeHandle, logitBias); + chain.Apply(ref cur_p); + Assert.NotEqual(expected, cur_p.Data[(int) cur_p.Selected].ID); + } + } + } + } + + + private void DecodeAndClear(LLamaContext context) + { + context.Decode(_batch); + _batch.Clear(); + } + + private static SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context, LLamaLogitBias[]? logit_bias = null) + { + var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default()); + + chain.AddPenalties( + vocabSize: context.VocabCount, + eos: context.ModelHandle.Tokens.EOS, + newline: context.ModelHandle.Tokens.Newline ?? 0, + penaltyCount: 60, repeat: 1, freq: 0, presence: 0, + penalizeNewline: false, ignoreEOS: false + ); + + if (logit_bias != null) { chain.AddLogitBias(context.VocabCount, logit_bias); } + + chain.AddTopK(10); + chain.AddTemperature(0.1f); + chain.AddDistributionSampler(seed: 42); + + return chain; + } + } +} diff --git a/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs b/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs index 2084298e6..8be978019 100644 --- a/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs +++ b/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs @@ -60,14 +60,14 @@ public void ToModelPrompt_FormatsCorrectly() // Call once with empty array to discover length var templateResult = PromptTemplateTransformer.ToModelPrompt(templater); - const string expected = "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>" - + "<|start_header_id|>user<|end_header_id|>\n\nworld<|eot_id|>" - + "<|start_header_id|>assistant<|end_header_id|>\n\n111<|eot_id|>" - + "<|start_header_id|>user<|end_header_id|>\n\naaa<|eot_id|>" - + "<|start_header_id|>assistant<|end_header_id|>\n\n222<|eot_id|>" - + "<|start_header_id|>user<|end_header_id|>\n\nbbb<|eot_id|>" - + "<|start_header_id|>assistant<|end_header_id|>\n\n333<|eot_id|>" - + "<|start_header_id|>user<|end_header_id|>\n\nccc<|eot_id|>" + const string expected = "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>user<|end_header_id|>\n\nworld<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\n111<|eot_id|>" + + "<|start_header_id|>user<|end_header_id|>\n\naaa<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\n222<|eot_id|>" + + "<|start_header_id|>user<|end_header_id|>\n\nbbb<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\n333<|eot_id|>" + + "<|start_header_id|>user<|end_header_id|>\n\nccc<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n"; Assert.Equal(expected, templateResult); diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 8caff8d5f..450f4998a 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -431,21 +431,27 @@ public void ClearLoraAdapters() #region GetLogits /// - /// Token logits obtained from the last call to llama_decode - /// The logits for the last token are stored in the last row + /// Token logits obtained from the last call to llama_decode. + /// The logits for the last token are stored in the last row. + /// Only tokens with `logits = true` requested are present.
/// Can be mutated in order to change the probabilities of the next token.
/// Rows: n_tokens
/// Cols: n_vocab ///
+ /// + /// The amount of tokens whose logits should be retrieved, in [numTokens X n_vocab] format.
+ /// Tokens' order is based on their order in the LlamaBatch (so, first tokens are first, etc).
+ /// This is helpful when requesting logits for many tokens in a sequence, or want to decode multiple sequences in one go. + /// /// - public Span GetLogits() + public Span GetLogits(int numTokens = 1) { var model = ThrowIfDisposed(); unsafe { var logits = llama_get_logits(this); - return new Span(logits, model.VocabCount); + return new Span(logits, model.VocabCount * numTokens); } }