From 520d0d6a7b95840e9d1fb494a03fb261ee53ea0c Mon Sep 17 00:00:00 2001 From: Gil LaHaye Date: Thu, 16 Mar 2023 17:32:29 -0700 Subject: [PATCH] Update memory skill with latest changes and fixes --- .../CoreSkills/TextMemorySkill.cs | 113 +++++++++++++++--- .../Memory/ISemanticTextMemory.cs | 10 ++ .../src/SemanticKernel/Memory/NullMemory.cs | 9 ++ .../Memory/SemanticTextMemory.cs | 9 ++ .../src/components/ServiceConfig.tsx | 2 +- .../Example15_MemorySkill.cs | 82 ++++++++++--- .../dotnet/6-memory-and-embeddings.ipynb | 2 +- 7 files changed, 187 insertions(+), 40 deletions(-) diff --git a/dotnet/src/SemanticKernel/CoreSkills/TextMemorySkill.cs b/dotnet/src/SemanticKernel/CoreSkills/TextMemorySkill.cs index b8805105ab4f..c09359615d62 100644 --- a/dotnet/src/SemanticKernel/CoreSkills/TextMemorySkill.cs +++ b/dotnet/src/SemanticKernel/CoreSkills/TextMemorySkill.cs @@ -2,10 +2,10 @@ using System.Globalization; using System.Linq; +using System.Text.Json; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.Diagnostics; -using Microsoft.SemanticKernel.Memory; using Microsoft.SemanticKernel.Orchestration; using Microsoft.SemanticKernel.SkillDefinition; @@ -37,49 +37,98 @@ public class TextMemorySkill /// public const string KeyParam = "key"; + /// + /// Name of the context variable used to specify the number of memories to recall + /// + public const string LimitParam = "limit"; + private const string DefaultCollection = "generic"; private const string DefaultRelevance = "0.75"; + private const string DefaultLimit = "1"; /// - /// Recall a fact from the long term memory + /// Key-based lookup for a specific memory + /// + /// + /// SKContext[TextMemorySkill.KeyParam] = "countryInfo1" + /// {{memory.retrieve }} + /// + /// Contains the 'collection' containing the memory to retrieve and the `key` associated with it. + [SKFunction("Key-based lookup for a specific memory")] + [SKFunctionName("Retrieve")] + [SKFunctionContextParameter(Name = CollectionParam, Description = "Memories collection associated with the memory to retrieve", + DefaultValue = DefaultCollection)] + [SKFunctionContextParameter(Name = KeyParam, Description = "The key associated with the memory to retrieve")] + public async Task RetrieveAsync(SKContext context) + { + var collection = context.Variables.ContainsKey(CollectionParam) ? context[CollectionParam] : DefaultCollection; + Verify.NotEmpty(collection, "Memory collection not defined"); + + var key = context.Variables.ContainsKey(KeyParam) ? context[KeyParam] : string.Empty; + Verify.NotEmpty(key, "Memory key not defined"); + + context.Log.LogTrace("Recalling memory with key '{0}' from collection '{1}'", key, collection); + + var memory = await context.Memory.GetAsync(collection, key); + + return memory?.Text ?? string.Empty; + } + + /// + /// Semantic search and return up to N memories related to the input text /// /// /// SKContext["input"] = "what is the capital of France?" /// {{memory.recall $input }} => "Paris" /// - /// The information to retrieve - /// Contains the 'collection' to search for information and 'relevance' score - [SKFunction("Recall a fact from the long term memory")] + /// The input text to find related memories for + /// Contains the 'collection' to search for the topic and 'relevance' score + [SKFunction("Semantic search and return up to N memories related to the input text")] [SKFunctionName("Recall")] - [SKFunctionInput(Description = "The information to retrieve")] - [SKFunctionContextParameter(Name = CollectionParam, Description = "Memories collection where to search for information", DefaultValue = DefaultCollection)] + [SKFunctionInput(Description = "The input text to find related memories for")] + [SKFunctionContextParameter(Name = CollectionParam, Description = "Memories collection to search", DefaultValue = DefaultCollection)] [SKFunctionContextParameter(Name = RelevanceParam, Description = "The relevance score, from 0.0 to 1.0, where 1.0 means perfect match", DefaultValue = DefaultRelevance)] - public async Task RecallAsync(string ask, SKContext context) + [SKFunctionContextParameter(Name = LimitParam, Description = "The maximum number of relevant memories to recall", DefaultValue = DefaultLimit)] + public string Recall(string text, SKContext context) { var collection = context.Variables.ContainsKey(CollectionParam) ? context[CollectionParam] : DefaultCollection; - Verify.NotEmpty(collection, "Memory collection not defined"); + Verify.NotEmpty(collection, "Memories collection not defined"); var relevance = context.Variables.ContainsKey(RelevanceParam) ? context[RelevanceParam] : DefaultRelevance; if (string.IsNullOrWhiteSpace(relevance)) { relevance = DefaultRelevance; } - context.Log.LogTrace("Searching memory for '{0}', collection '{1}', relevance '{2}'", ask, collection, relevance); + var limit = context.Variables.ContainsKey(LimitParam) ? context[LimitParam] : DefaultLimit; + if (string.IsNullOrWhiteSpace(limit)) { relevance = DefaultLimit; } + + context.Log.LogTrace("Searching memories in collection '{0}', relevance '{1}'", collection, relevance); // TODO: support locales, e.g. "0.7" and "0,7" must both work - MemoryQueryResult? memory = await context.Memory - .SearchAsync(collection, ask, limit: 1, minRelevanceScore: float.Parse(relevance, CultureInfo.InvariantCulture)) - .FirstOrDefaultAsync(); + int limitInt = int.Parse(limit, CultureInfo.InvariantCulture); + var memories = context.Memory + .SearchAsync(collection, text, limitInt, minRelevanceScore: float.Parse(relevance, CultureInfo.InvariantCulture)) + .ToEnumerable(); - if (memory == null) + context.Log.LogTrace("Done looking for memories in collection '{0}')", collection); + + string resultString; + + if (limitInt == 1) { - context.Log.LogWarning("Memory not found in collection: {0}", collection); + var memory = memories.FirstOrDefault(); + resultString = (memory != null) ? memory.Text : string.Empty; } else { - context.Log.LogTrace("Memory found (collection: {0})", collection); + resultString = JsonSerializer.Serialize(memories.Select(x => x.Text)); + } + + if (resultString.Length == 0) + { + context.Log.LogWarning("Memories not found in collection: {0}", collection); } - return memory != null ? memory.Text : string.Empty; + return resultString; } /// @@ -95,8 +144,8 @@ public async Task RecallAsync(string ask, SKContext context) [SKFunction("Save information to semantic memory")] [SKFunctionName("Save")] [SKFunctionInput(Description = "The information to save")] - [SKFunctionContextParameter(Name = CollectionParam, Description = "Memories collection where to save the information", DefaultValue = DefaultCollection)] - [SKFunctionContextParameter(Name = KeyParam, Description = "The key to save the information")] + [SKFunctionContextParameter(Name = CollectionParam, Description = "Memories collection associated with the information to save", DefaultValue = DefaultCollection)] + [SKFunctionContextParameter(Name = KeyParam, Description = "The key associated with the information to save")] public async Task SaveAsync(string text, SKContext context) { var collection = context.Variables.ContainsKey(CollectionParam) ? context[CollectionParam] : DefaultCollection; @@ -109,4 +158,30 @@ public async Task SaveAsync(string text, SKContext context) await context.Memory.SaveInformationAsync(collection, text: text, id: key); } + + /// + /// Remove specific memory + /// + /// + /// SKContext[TextMemorySkill.KeyParam] = "countryInfo1" + /// {{memory.remove }} + /// + /// Contains the 'collection' containing the memory to remove. + [SKFunction("Remove specific memory")] + [SKFunctionName("Remove")] + [SKFunctionContextParameter(Name = CollectionParam, Description = "Memories collection associated with the memory to remove", + DefaultValue = DefaultCollection)] + [SKFunctionContextParameter(Name = KeyParam, Description = "The key associated with the memory to remove")] + public async Task RemoveAsync(SKContext context) + { + var collection = context.Variables.ContainsKey(CollectionParam) ? context[CollectionParam] : DefaultCollection; + Verify.NotEmpty(collection, "Memory collection not defined"); + + var key = context.Variables.ContainsKey(KeyParam) ? context[KeyParam] : string.Empty; + Verify.NotEmpty(key, "Memory key not defined"); + + context.Log.LogTrace("Removing memory from collection '{0}'", collection); + + await context.Memory.RemoveAsync(collection, key); + } } diff --git a/dotnet/src/SemanticKernel/Memory/ISemanticTextMemory.cs b/dotnet/src/SemanticKernel/Memory/ISemanticTextMemory.cs index 76560cc63919..b7f3a09cb5b1 100644 --- a/dotnet/src/SemanticKernel/Memory/ISemanticTextMemory.cs +++ b/dotnet/src/SemanticKernel/Memory/ISemanticTextMemory.cs @@ -54,6 +54,16 @@ public Task SaveReferenceAsync( /// Memory record, or null when nothing is found public Task GetAsync(string collection, string key, CancellationToken cancel = default); + /// + /// Remove a memory by key. + /// For local memories the key is the "id" used when saving the record. + /// For external reference, the key is the "URI" used when saving the record. + /// + /// Collection to search + /// Unique memory record identifier + /// Cancellation token + public Task RemoveAsync(string collection, string key, CancellationToken cancel = default); + /// /// Find some information in memory /// diff --git a/dotnet/src/SemanticKernel/Memory/NullMemory.cs b/dotnet/src/SemanticKernel/Memory/NullMemory.cs index f577f3d31d8f..afce6c1fe49b 100644 --- a/dotnet/src/SemanticKernel/Memory/NullMemory.cs +++ b/dotnet/src/SemanticKernel/Memory/NullMemory.cs @@ -49,6 +49,15 @@ public Task SaveReferenceAsync( return Task.FromResult(null as MemoryQueryResult); } + /// + public Task RemoveAsync( + string collection, + string key, + CancellationToken cancel = default) + { + return Task.CompletedTask; + } + /// public IAsyncEnumerable SearchAsync( string collection, diff --git a/dotnet/src/SemanticKernel/Memory/SemanticTextMemory.cs b/dotnet/src/SemanticKernel/Memory/SemanticTextMemory.cs index f37e253a484b..de37eab0612d 100644 --- a/dotnet/src/SemanticKernel/Memory/SemanticTextMemory.cs +++ b/dotnet/src/SemanticKernel/Memory/SemanticTextMemory.cs @@ -71,6 +71,15 @@ public async Task SaveReferenceAsync( return MemoryQueryResult.FromMemoryRecord(result, 1); } + /// + public async Task RemoveAsync( + string collection, + string key, + CancellationToken cancel = default) + { + await this._storage.RemoveAsync(collection, key, cancel); + } + /// public async IAsyncEnumerable SearchAsync( string collection, diff --git a/samples/apps/github-qna-webapp-react/src/components/ServiceConfig.tsx b/samples/apps/github-qna-webapp-react/src/components/ServiceConfig.tsx index 19aa08a6b17e..5b10fb7e2701 100644 --- a/samples/apps/github-qna-webapp-react/src/components/ServiceConfig.tsx +++ b/samples/apps/github-qna-webapp-react/src/components/ServiceConfig.tsx @@ -174,7 +174,7 @@ const ServiceConfig: FC = ({ uri, onConfigComplete }) => { label: d.value } }) - }} placeholder='Enter your deployment id here, ie: my-deployment' /> + }} placeholder='Enter the embeddings model id here, ie: text-embedding-ada-002' />