From f4e2a7d0ccf50150a50f2e1bf6edd3bf2f861664 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Sun, 16 Feb 2025 23:29:32 -0500 Subject: [PATCH] Update Microsoft.Extensions.AI to 9.3.0-preview.1.25114.11 --- LLama.Unittest/LLamaEmbedderTests.cs | 8 ++--- LLama/Extensions/LLamaExecutorExtensions.cs | 21 ++++++------- LLama/LLamaEmbedder.EmbeddingGenerator.cs | 35 ++++++++++++++------- LLama/LLamaSharp.csproj | 2 +- 4 files changed, 39 insertions(+), 27 deletions(-) diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index 28e7427f4..d975a168d 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -43,10 +43,10 @@ private async Task CompareEmbeddings(string modelPath) Assert.DoesNotContain(float.NaN, spoon); var generator = (IEmbeddingGenerator>)embedder; - Assert.NotNull(generator.Metadata); - Assert.Equal(nameof(LLamaEmbedder), generator.Metadata.ProviderName); - Assert.NotNull(generator.Metadata.ModelId); - Assert.NotEmpty(generator.Metadata.ModelId); + Assert.NotNull(generator.GetService()); + Assert.Equal(nameof(LLamaEmbedder), generator.GetService()?.ProviderName); + Assert.NotNull(generator.GetService()?.ModelId); + Assert.NotEmpty(generator.GetService()?.ModelId!); Assert.Same(embedder, generator.GetService()); Assert.Same(generator, generator.GetService>>()); Assert.Null(generator.GetService()); diff --git a/LLama/Extensions/LLamaExecutorExtensions.cs b/LLama/Extensions/LLamaExecutorExtensions.cs index e38ccf98d..cf9231624 100644 --- a/LLama/Extensions/LLamaExecutorExtensions.cs +++ b/LLama/Extensions/LLamaExecutorExtensions.cs @@ -36,6 +36,7 @@ private sealed class LLamaExecutorChatClient( IHistoryTransform? historyTransform = null, ITextStreamTransform? outputTransform = null) : IChatClient { + private static readonly ChatClientMetadata s_metadata = new(nameof(LLamaExecutorChatClient)); private static readonly InferenceParams s_defaultParams = new(); private static readonly DefaultSamplingPipeline s_defaultPipeline = new(); private static readonly string[] s_antiPrompts = ["User:", "Assistant:", "System:"]; @@ -47,21 +48,19 @@ private sealed class LLamaExecutorChatClient( private readonly ITextStreamTransform _outputTransform = outputTransform ?? new LLamaTransforms.KeywordTextOutputStreamTransform(s_antiPrompts); - /// - public ChatClientMetadata Metadata { get; } = new(nameof(LLamaExecutorChatClient)); - /// public void Dispose() { } /// - public object? GetService(Type serviceType, object? key = null) => - key is not null ? null : + public object? GetService(Type serviceType, object? serviceKey = null) => + serviceKey is not null ? null : + serviceType == typeof(ChatClientMetadata) ? s_metadata : serviceType?.IsInstanceOfType(_executor) is true ? _executor : serviceType?.IsInstanceOfType(this) is true ? this : null; /// - public async Task CompleteAsync( + public async Task GetResponseAsync( IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) { var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken); @@ -79,7 +78,7 @@ public async Task CompleteAsync( } /// - public async IAsyncEnumerable CompleteStreamingAsync( + public async IAsyncEnumerable GetStreamingResponseAsync( IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken); @@ -142,8 +141,8 @@ private string CreatePrompt(IList messages) MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit SamplingPipeline = new DefaultSamplingPipeline() { - FrequencyPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.FrequencyPenalty), out float af) is true ? af : s_defaultPipeline.FrequencyPenalty, - PresencePenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PresencePenalty), out float ap) is true ? ap : s_defaultPipeline.PresencePenalty, + FrequencyPenalty = options?.FrequencyPenalty ?? s_defaultPipeline.FrequencyPenalty, + PresencePenalty = options?.PresencePenalty ?? s_defaultPipeline.PresencePenalty, PreventEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PreventEOS), out bool eos) is true ? eos : s_defaultPipeline.PreventEOS, PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pnl) is true ? pnl : s_defaultPipeline.PenalizeNewline, RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty, @@ -152,8 +151,8 @@ private string CreatePrompt(IList messages) MinKeep = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinKeep), out int mk) is true ? mk : s_defaultPipeline.MinKeep, MinP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinP), out float mp) is true ? mp : s_defaultPipeline.MinP, Seed = options?.Seed is long seed ? (uint)seed : (uint)(t_random ??= new()).Next(), - Temperature = options?.Temperature ?? 0, - TopP = options?.TopP ?? 0, + Temperature = options?.Temperature ?? s_defaultPipeline.Temperature, + TopP = options?.TopP ?? s_defaultPipeline.TopP, TopK = options?.TopK ?? s_defaultPipeline.TopK, TypicalP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TypicalP), out float tp) is true ? tp : s_defaultPipeline.TypicalP, }, diff --git a/LLama/LLamaEmbedder.EmbeddingGenerator.cs b/LLama/LLamaEmbedder.EmbeddingGenerator.cs index 62a6d1940..2ee86fcc6 100644 --- a/LLama/LLamaEmbedder.EmbeddingGenerator.cs +++ b/LLama/LLamaEmbedder.EmbeddingGenerator.cs @@ -14,18 +14,31 @@ public partial class LLamaEmbedder private EmbeddingGeneratorMetadata? _metadata; /// - EmbeddingGeneratorMetadata IEmbeddingGenerator>.Metadata => - _metadata ??= new( - nameof(LLamaEmbedder), - modelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null, - dimensions: EmbeddingSize); + object? IEmbeddingGenerator>.GetService(Type serviceType, object? serviceKey) + { + if (serviceKey is null) + { + if (serviceType == typeof(EmbeddingGeneratorMetadata)) + { + return _metadata ??= new( + nameof(LLamaEmbedder), + modelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null, + dimensions: EmbeddingSize); + } - /// - object? IEmbeddingGenerator>.GetService(Type serviceType, object? key) => - key is not null ? null : - serviceType?.IsInstanceOfType(Context) is true ? Context : - serviceType?.IsInstanceOfType(this) is true ? this : - null; + if (serviceType?.IsInstanceOfType(Context) is true) + { + return Context; + } + + if (serviceType?.IsInstanceOfType(this) is true) + { + return this; + } + } + + return null; + } /// async Task>> IEmbeddingGenerator>.GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken) diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index d0be38f74..86ed11c10 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -50,7 +50,7 @@ - +