diff --git a/webapi/Controllers/MaintenanceController.cs b/webapi/Controllers/MaintenanceController.cs index 9ca21e69f..f3e54dd41 100644 --- a/webapi/Controllers/MaintenanceController.cs +++ b/webapi/Controllers/MaintenanceController.cs @@ -1,16 +1,17 @@ // Copyright (c) Microsoft. All rights reserved. using System.Threading; +using System.Threading.Tasks; using CopilotChat.WebApi.Auth; using CopilotChat.WebApi.Hubs; using CopilotChat.WebApi.Models.Response; using CopilotChat.WebApi.Options; +using CopilotChat.WebApi.Services.MemoryMigration; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.SignalR; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using Microsoft.SemanticKernel; namespace CopilotChat.WebApi.Controllers; @@ -46,14 +47,34 @@ public MaintenanceController( [HttpGet] [ProducesResponseType(StatusCodes.Status200OK)] [ProducesResponseType(StatusCodes.Status400BadRequest)] - public ActionResult GetMaintenanceStatus( - [FromServices] IKernel kernel, + public async Task> GetMaintenanceStatusAsync( + [FromServices] IChatMigrationMonitor migrationMonitor, [FromServices] IHubContext messageRelayHubContext, CancellationToken cancellationToken = default) { + MaintenanceResult? result = null; + + var migrationStatus = await migrationMonitor.GetCurrentStatusAsync(cancellationToken); + + if (migrationStatus != ChatMigrationStatus.None) + { + result = + new MaintenanceResult + { + Title = "Migrating Chat Memory", + Message = "An upgrade requires that all non-document memories be migrated. This may take several minutes...", + Note = "Note: All document memories will need to be re-imported.", + }; + } + if (this._serviceOptions.Value.InMaintenance) { - return this.Ok(new MaintenanceResult()); + result = new MaintenanceResult(); // Default maintenance message + } + + if (result != null) + { + return this.Ok(result); } return this.Ok(); diff --git a/webapi/Extensions/ISemanticMemoryClientExtensions.cs b/webapi/Extensions/ISemanticMemoryClientExtensions.cs index fe12f24ac..c083dc2ae 100644 --- a/webapi/Extensions/ISemanticMemoryClientExtensions.cs +++ b/webapi/Extensions/ISemanticMemoryClientExtensions.cs @@ -99,8 +99,7 @@ await memoryClient.SearchAsync( indexName, filter, resultCount, - cancellationToken) - .ConfigureAwait(false); + cancellationToken); return searchResult; } @@ -129,11 +128,23 @@ public static async Task StoreDocumentAsync( await memoryClient.ImportDocumentAsync(uploadRequest, cancellationToken); } + public static Task StoreMemoryAsync( + this ISemanticMemoryClient memoryClient, + string indexName, + string chatId, + string memoryName, + string memory, + CancellationToken cancellationToken = default) + { + return memoryClient.StoreMemoryAsync(indexName, chatId, memoryName, memoryId: Guid.NewGuid().ToString(), memory, cancellationToken); + } + public static async Task StoreMemoryAsync( this ISemanticMemoryClient memoryClient, string indexName, string chatId, string memoryName, + string memoryId, string memory, CancellationToken cancellationToken = default) { @@ -143,10 +154,9 @@ public static async Task StoreMemoryAsync( await writer.FlushAsync(); stream.Position = 0; - var id = Guid.NewGuid().ToString(); var uploadRequest = new DocumentUploadRequest { - DocumentId = id, + DocumentId = memoryId, Index = indexName, Files = new() @@ -166,12 +176,12 @@ public static async Task RemoveChatMemoriesAsync( this ISemanticMemoryClient memoryClient, string indexName, string chatId, - CancellationToken cancelToken = default) + CancellationToken cancellationToken = default) { - var memories = await memoryClient.SearchMemoryAsync(indexName, "*", 0.0F, chatId, cancellationToken: cancelToken); + var memories = await memoryClient.SearchMemoryAsync(indexName, "*", 0.0F, chatId, cancellationToken: cancellationToken); foreach (var memory in memories.Results) { - await memoryClient.DeleteDocumentAsync(indexName, memory.Link, cancelToken); + await memoryClient.DeleteDocumentAsync(indexName, memory.Link, cancellationToken); } } } diff --git a/webapi/Extensions/SemanticKernelExtensions.cs b/webapi/Extensions/SemanticKernelExtensions.cs index 7392a6b1a..c4bc893f7 100644 --- a/webapi/Extensions/SemanticKernelExtensions.cs +++ b/webapi/Extensions/SemanticKernelExtensions.cs @@ -16,8 +16,6 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.AI.Embeddings; -using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding; using Microsoft.SemanticKernel.Diagnostics; using Microsoft.SemanticKernel.Skills.Core; using Microsoft.SemanticMemory; @@ -39,14 +37,14 @@ internal static class SemanticKernelExtensions /// public static WebApplicationBuilder AddSemanticKernelServices(this WebApplicationBuilder builder) { + builder.InitializeKernelProvider(); + // Semantic Kernel builder.Services.AddScoped( sp => { - var kernel = Kernel.Builder - .WithLoggerFactory(sp.GetRequiredService()) - .WithCompletionBackend(sp, builder.Configuration) - .Build(); + var provider = sp.GetRequiredService(); + var kernel = provider.GetCompletionKernel(); sp.GetRequiredService()(sp, kernel); return kernel; @@ -66,16 +64,16 @@ public static WebApplicationBuilder AddSemanticKernelServices(this WebApplicatio /// public static WebApplicationBuilder AddPlannerServices(this WebApplicationBuilder builder) { + builder.InitializeKernelProvider(); + builder.Services.AddScoped(sp => { sp.WithBotConfig(builder.Configuration); var plannerOptions = sp.GetRequiredService>(); - var plannerKernel = Kernel.Builder - .WithLoggerFactory(sp.GetRequiredService()) - .WithPlannerBackend(sp, builder.Configuration) - .Build(); + var provider = sp.GetRequiredService(); + var plannerKernel = provider.GetPlannerKernel(); return new CopilotChatPlanner(plannerKernel, plannerOptions?.Value, sp.GetRequiredService>()); }); @@ -139,6 +137,11 @@ public static async Task SafeInvokeAsync(Func> callback, string fu } } + private static void InitializeKernelProvider(this WebApplicationBuilder builder) + { + builder.Services.AddSingleton(sp => new SemanticKernelProvider(sp, builder.Configuration)); + } + /// /// Register the skills with the kernel. /// @@ -185,80 +188,6 @@ internal static void AddContentSafety(this IServiceCollection services) } } - /// - /// Add the completion backend to the kernel config - /// - private static KernelBuilder WithCompletionBackend(this KernelBuilder kernelBuilder, IServiceProvider provider, IConfiguration configuration) - { - var memoryOptions = provider.GetRequiredService>().Value; - - switch (memoryOptions.TextGeneratorType) - { - case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): - case string y when y.Equals("AzureOpenAIText", StringComparison.OrdinalIgnoreCase): - var azureAIOptions = memoryOptions.GetServiceConfig(configuration, "AzureOpenAIText"); - return kernelBuilder.WithAzureChatCompletionService(azureAIOptions.Deployment, azureAIOptions.Endpoint, azureAIOptions.APIKey); - - case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): - var openAIOptions = memoryOptions.GetServiceConfig(configuration, "OpenAI"); - return kernelBuilder.WithOpenAIChatCompletionService(openAIOptions.TextModel, openAIOptions.APIKey); - - default: - throw new ArgumentException($"Invalid {nameof(memoryOptions.TextGeneratorType)} value in 'SemanticMemory' settings."); - } - } - - /// - /// Add the completion backend to the kernel config for the planner. - /// - private static KernelBuilder WithPlannerBackend(this KernelBuilder kernelBuilder, IServiceProvider provider, IConfiguration configuration) - { - var memoryOptions = provider.GetRequiredService>().Value; - var plannerOptions = provider.GetRequiredService>().Value; - - switch (memoryOptions.TextGeneratorType) - { - case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): - case string y when y.Equals("AzureOpenAIText", StringComparison.OrdinalIgnoreCase): - var azureAIOptions = memoryOptions.GetServiceConfig(configuration, "AzureOpenAIText"); - return kernelBuilder.WithAzureChatCompletionService(plannerOptions.Model, azureAIOptions.Endpoint, azureAIOptions.APIKey); - - case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): - var openAIOptions = memoryOptions.GetServiceConfig(configuration, "OpenAI"); - return kernelBuilder.WithOpenAIChatCompletionService(plannerOptions.Model, openAIOptions.APIKey); - - default: - throw new ArgumentException($"Invalid {nameof(memoryOptions.TextGeneratorType)} value in 'SemanticMemory' settings."); - } - } - - /// - /// Construct IEmbeddingGeneration from - /// - private static ITextEmbeddingGeneration ToTextEmbeddingsService( - this IServiceProvider provider, - IConfiguration configuration, - ILoggerFactory? loggerFactory = null) - { - var logger = provider.GetRequiredService>(); - var memoryOptions = provider.GetRequiredService>().Value; - - switch (memoryOptions.Retrieval.EmbeddingGeneratorType) - { - case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): - case string y when y.Equals("AzureOpenAIEmbedding", StringComparison.OrdinalIgnoreCase): - var azureAIOptions = memoryOptions.GetServiceConfig(configuration, "AzureOpenAIEmbedding"); - return new AzureTextEmbeddingGeneration(azureAIOptions.Deployment, azureAIOptions.Endpoint, azureAIOptions.APIKey, httpClient: null, loggerFactory); - - case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): - var openAIOptions = memoryOptions.GetServiceConfig(configuration, "OpenAI"); - return new OpenAITextEmbeddingGeneration(openAIOptions.EmbeddingModel, openAIOptions.APIKey, organization: null, httpClient: null, loggerFactory); - - default: - throw new ArgumentException($"Invalid {nameof(memoryOptions.Retrieval.EmbeddingGeneratorType)} value in 'SemanticMemory' settings."); - } - } - /// /// Get the embedding model from the configuration. /// diff --git a/webapi/Extensions/ServiceExtensions.cs b/webapi/Extensions/ServiceExtensions.cs index 54f4d691b..f693aa3aa 100644 --- a/webapi/Extensions/ServiceExtensions.cs +++ b/webapi/Extensions/ServiceExtensions.cs @@ -7,6 +7,8 @@ using CopilotChat.WebApi.Auth; using CopilotChat.WebApi.Models.Storage; using CopilotChat.WebApi.Options; +using CopilotChat.WebApi.Services; +using CopilotChat.WebApi.Services.MemoryMigration; using CopilotChat.WebApi.Storage; using CopilotChat.WebApi.Utilities; using Microsoft.AspNetCore.Authentication; @@ -87,6 +89,25 @@ internal static IServiceCollection AddUtilities(this IServiceCollection services return services.AddScoped(); } + internal static IServiceCollection AddMainetnanceServices(this IServiceCollection services) + { + // Inject migration services + services.AddSingleton(); + services.AddSingleton(); + + // Inject actions so they can be part of the action-list. + services.AddSingleton(); + services.AddSingleton>( + sp => + (IReadOnlyList) + new[] + { + sp.GetRequiredService(), + }); + + return services; + } + /// /// Add CORS settings. /// diff --git a/webapi/Models/Response/MigrationResult.cs b/webapi/Models/Response/MaintenanceResult.cs similarity index 100% rename from webapi/Models/Response/MigrationResult.cs rename to webapi/Models/Response/MaintenanceResult.cs diff --git a/webapi/Program.cs b/webapi/Program.cs index 693e02408..67027336d 100644 --- a/webapi/Program.cs +++ b/webapi/Program.cs @@ -69,6 +69,7 @@ public static async Task Main(string[] args) // Add in the rest of the services. builder.Services + .AddMainetnanceServices() .AddEndpointsApiExplorer() .AddSwaggerGen() .AddCorsPolicy(builder.Configuration) diff --git a/webapi/Services/IMaintenanceAction.cs b/webapi/Services/IMaintenanceAction.cs new file mode 100644 index 000000000..a4848c85c --- /dev/null +++ b/webapi/Services/IMaintenanceAction.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading; +using System.Threading.Tasks; + +namespace CopilotChat.WebApi.Services; + +/// +/// Defines discrete maintenance action responsible for both inspecting state +/// and performing maintenance. +/// +public interface IMaintenanceAction +{ + /// + /// Calling site to initiate maintenance action. + /// + /// true if maintenance needed or in progress + Task InvokeAsync(CancellationToken cancellation = default); +} diff --git a/webapi/Services/MaintenanceMiddleware.cs b/webapi/Services/MaintenanceMiddleware.cs index c697a06a4..3f101e6fa 100644 --- a/webapi/Services/MaintenanceMiddleware.cs +++ b/webapi/Services/MaintenanceMiddleware.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; using System.Threading.Tasks; using CopilotChat.WebApi.Controllers; using CopilotChat.WebApi.Hubs; @@ -18,18 +19,23 @@ namespace CopilotChat.WebApi.Services; public class MaintenanceMiddleware { private readonly RequestDelegate _next; + private readonly IReadOnlyList _actions; private readonly IOptions _serviceOptions; private readonly IHubContext _messageRelayHubContext; private readonly ILogger _logger; + private bool? _isInMaintenance; + public MaintenanceMiddleware( RequestDelegate next, + IReadOnlyList actions, IOptions servicetOptions, IHubContext messageRelayHubContext, ILogger logger) { this._next = next; + this._actions = actions; this._serviceOptions = servicetOptions; this._messageRelayHubContext = messageRelayHubContext; this._logger = logger; @@ -37,11 +43,31 @@ public MaintenanceMiddleware( public async Task Invoke(HttpContext ctx, IKernel kernel) { + // Skip inspection if _isInMaintenance explicitly false. + if (this._isInMaintenance == null || this._isInMaintenance.Value) + { + // Maintenance never false => true; always true => false or just false; + this._isInMaintenance = await this.InspectMaintenanceActionAsync(); + } + + // In maintenance if actions say so or explicitly configured. if (this._serviceOptions.Value.InMaintenance) { - await this._messageRelayHubContext.Clients.All.SendAsync(MaintenanceController.GlobalSiteMaintenance, "Site undergoing maintenance...").ConfigureAwait(false); + await this._messageRelayHubContext.Clients.All.SendAsync(MaintenanceController.GlobalSiteMaintenance, "Site undergoing maintenance..."); } await this._next(ctx); } + + private async Task InspectMaintenanceActionAsync() + { + bool inMaintenance = false; + + foreach (var action in this._actions) + { + inMaintenance |= await action.InvokeAsync(); + } + + return inMaintenance; + } } diff --git a/webapi/Services/MemoryMigration/ChatMemoryMigrationService.cs b/webapi/Services/MemoryMigration/ChatMemoryMigrationService.cs new file mode 100644 index 000000000..bce9677d6 --- /dev/null +++ b/webapi/Services/MemoryMigration/ChatMemoryMigrationService.cs @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using CopilotChat.WebApi.Extensions; +using CopilotChat.WebApi.Options; +using CopilotChat.WebApi.Storage; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.SemanticKernel.Memory; +using Microsoft.SemanticMemory; + +namespace CopilotChat.WebApi.Services.MemoryMigration; + +/// +/// Service implementation of . +/// +public class ChatMemoryMigrationService : IChatMemoryMigrationService +{ + private readonly ILogger _logger; + private readonly ISemanticTextMemory _memory; + private readonly ISemanticMemoryClient _memoryClient; + private readonly ChatSessionRepository _chatSessionRepository; + private readonly ChatMemorySourceRepository _memorySourceRepository; + private readonly string _globalIndex; + private readonly PromptsOptions _promptOptions; + + /// + /// Initializes a new instance of the class. + /// + public ChatMemoryMigrationService( + ILogger logger, + IOptions documentMemoryOptions, + IOptions promptOptions, + ISemanticMemoryClient memoryClient, + ChatSessionRepository chatSessionRepository, + ChatMemorySourceRepository memorySourceRepository, + SemanticKernelProvider provider) + { + this._logger = logger; + this._promptOptions = promptOptions.Value; + this._memoryClient = memoryClient; + this._chatSessionRepository = chatSessionRepository; + this._memorySourceRepository = memorySourceRepository; + this._globalIndex = documentMemoryOptions.Value.GlobalDocumentCollectionName; + var kernel = provider.GetMigrationKernel(); + this._memory = kernel.Memory; + } + + /// + public async Task MigrateAsync(CancellationToken cancellationToken = default) + { + try + { + await this.InternalMigrateAsync(cancellationToken); + } + catch (Exception exception) when (!exception.IsCriticalException()) + { + this._logger.LogError(exception, "Error migrating chat memories"); + } + } + + private async Task InternalMigrateAsync(CancellationToken cancellationToken = default) + { + var collectionNames = (await this._memory.GetCollectionsAsync(cancellationToken)).ToHashSet(StringComparer.OrdinalIgnoreCase); + + var tokenMemory = await GetTokenMemory(cancellationToken); + if (tokenMemory != null) + { + // Create memory token already exists + return; + } + + // Create memory token + var token = Guid.NewGuid().ToString(); + await SetTokenMemory(token, cancellationToken); + + await RemoveMemorySourcesAsync(); + + bool needsZombie = true; + // Extract and store memories, using the original id to avoid duplication should a retry be required. + await foreach ((string chatId, string memoryName, string memoryId, string memoryText) in QueryMemoriesAsync()) + { + await this._memoryClient.StoreMemoryAsync(this._promptOptions.MemoryIndexName, chatId, memoryName, memoryId, memoryText, cancellationToken); + needsZombie = false; + } + + // Store "Zombie" memory in order to create the index since zero writes have occurred. Won't affect any chats. + if (needsZombie) + { + await this._memoryClient.StoreMemoryAsync(this._promptOptions.MemoryIndexName, Guid.Empty.ToString(), "zombie", Guid.NewGuid().ToString(), "Initialized", cancellationToken); + } + + await SetTokenMemory(ChatMigrationMonitor.MigrationCompletionToken, cancellationToken); + + // Inline function to extract all memories for a given chat and memory type. + async IAsyncEnumerable<(string chatId, string memoryName, string memoryId, string memoryText)> QueryMemoriesAsync() + { + var chats = await this._chatSessionRepository.GetAllChatsAsync(); + foreach (var chat in chats) + { + foreach (var memoryType in this._promptOptions.MemoryMap.Keys) + { + var indexName = $"{chat.Id}-{memoryType}"; + if (collectionNames.Contains(indexName)) + { + var memories = await this._memory.SearchAsync(indexName, "*", limit: 10000, minRelevanceScore: 0, withEmbeddings: false, cancellationToken).ToArrayAsync(cancellationToken); + + foreach (var memory in memories) + { + yield return (chat.Id, memoryType, memory.Metadata.Id, memory.Metadata.Text); + } + } + } + } + } + + // Inline function to read the token memory + async Task GetTokenMemory(CancellationToken cancellationToken) + { + try + { + return await this._memory.GetAsync(this._globalIndex, ChatMigrationMonitor.MigrationKey, withEmbedding: false, cancellationToken); + } + catch (Exception ex) when (!ex.IsCriticalException()) + { + return null; + } + } + + // Inline function to write the token memory + async Task SetTokenMemory(string token, CancellationToken cancellationToken) + { + await this._memory.SaveInformationAsync(this._globalIndex, token, ChatMigrationMonitor.MigrationKey, description: null, additionalMetadata: null, cancellationToken); + } + + async Task RemoveMemorySourcesAsync() + { + var documentMemories = await this._memorySourceRepository.GetAllAsync(); + + await Task.WhenAll(documentMemories.Select(memory => this._memorySourceRepository.DeleteAsync(memory))); + } + } +} diff --git a/webapi/Services/MemoryMigration/ChatMigrationMaintenanceAction.cs b/webapi/Services/MemoryMigration/ChatMigrationMaintenanceAction.cs new file mode 100644 index 000000000..798dc06e7 --- /dev/null +++ b/webapi/Services/MemoryMigration/ChatMigrationMaintenanceAction.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace CopilotChat.WebApi.Services.MemoryMigration; + +/// +/// Middleware action to handle memory migration maintenance. +/// +public class ChatMigrationMaintenanceAction : IMaintenanceAction +{ + private readonly IChatMigrationMonitor _migrationMonitor; + private readonly IChatMemoryMigrationService _migrationService; + private readonly ILogger _logger; + + public ChatMigrationMaintenanceAction( + IChatMigrationMonitor migrationMonitor, + IChatMemoryMigrationService migrationService, + ILogger logger) + + { + this._migrationMonitor = migrationMonitor; + this._migrationService = migrationService; + this._logger = logger; + } + + public async Task InvokeAsync(CancellationToken cancellation = default) + { + var migrationStatus = await this._migrationMonitor.GetCurrentStatusAsync(cancellation); + + switch (migrationStatus) + { + case ChatMigrationStatus s when (s == ChatMigrationStatus.RequiresUpgrade): + // Migrate all chats to single index (in background) + var task = this._migrationService.MigrateAsync(cancellation); + return true; // In maintenance + + case ChatMigrationStatus s when (s == ChatMigrationStatus.Upgrading): + return true; // In maintenance + + case ChatMigrationStatus s when (s == ChatMigrationStatus.None): + default: + return false; // No maintenance + } + } +} diff --git a/webapi/Services/MemoryMigration/ChatMigrationMonitor.cs b/webapi/Services/MemoryMigration/ChatMigrationMonitor.cs new file mode 100644 index 000000000..af5871de6 --- /dev/null +++ b/webapi/Services/MemoryMigration/ChatMigrationMonitor.cs @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using CopilotChat.WebApi.Options; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.SemanticKernel.Memory; + +namespace CopilotChat.WebApi.Services.MemoryMigration; + +/// +/// Service implementation of . +/// +/// +/// Migration is fundamentally determined by presence of the new consolidated index. +/// That is, if the new index exists then migration was considered to have occurred. +/// A tracking record is created in the historical global-document index: +/// to managed race condition during the migration process (having migration triggered a second time while in progress). +/// In the event that somehow two migration processes are initiated in parallel, no duplication will result...only extraneous processing. +/// If the desire exists to reset/re-execute migration, simply delete the new index. +/// +public class ChatMigrationMonitor : IChatMigrationMonitor +{ + internal const string MigrationCompletionToken = "DONE"; + internal const string MigrationKey = "migrate-00000000-0000-0000-0000-000000000000"; + + private static ChatMigrationStatus? _cachedStatus; + private static bool? _hasCurrentIndex; + + private readonly ILogger _logger; + private readonly string _indexNameGlobalDocs; + private readonly string _indexNameAllMemory; + private readonly ISemanticTextMemory _memory; + + /// + /// Initializes a new instance of the class. + /// + public ChatMigrationMonitor( + ILogger logger, + IOptions docOptions, + IOptions promptOptions, + SemanticKernelProvider provider) + { + this._logger = logger; + this._indexNameGlobalDocs = docOptions.Value.GlobalDocumentCollectionName; + this._indexNameAllMemory = promptOptions.Value.MemoryIndexName; + var kernel = provider.GetMigrationKernel(); + this._memory = kernel.Memory; + } + + /// + public async Task GetCurrentStatusAsync(CancellationToken cancellationToken = default) + { + if (_cachedStatus == null) + { + // Attempt to determine migration status looking at index existence. (Once) + Interlocked.CompareExchange( + ref _cachedStatus, + await QueryCollectionAsync(), + null); + + if (_cachedStatus == null) + { + // Attempt to determine migration status looking at index state. + _cachedStatus = await QueryStatusAsync(); + } + } + else + { + // Refresh status if we have a cached value for any state other than: ChatVersionStatus.None. + switch (_cachedStatus) + { + case ChatMigrationStatus s when s != ChatMigrationStatus.None: + _cachedStatus = await QueryStatusAsync(); + break; + + default: // ChatVersionStatus.None + break; + } + } + + return _cachedStatus ?? ChatMigrationStatus.None; + + // Reports and caches migration state as either: None or null depending on existence of the target index. + async Task QueryCollectionAsync() + { + if (_hasCurrentIndex == null) + { + try + { + // Cache "found" index state to reduce query count and avoid handling truth mutation. + var collections = await this._memory.GetCollectionsAsync(cancellationToken); + + // Does the new "target" index already exist? + _hasCurrentIndex = collections.Any(c => c.Equals(this._indexNameAllMemory, StringComparison.OrdinalIgnoreCase)); + + return (_hasCurrentIndex ?? false) ? ChatMigrationStatus.None : null; + } + catch (Exception exception) when (!exception.IsCriticalException()) + { + this._logger.LogError(exception, "Unable to search collections"); + } + } + + return (_hasCurrentIndex ?? false) ? ChatMigrationStatus.None : null; + } + + // Note: Only called once determined that target index does not exist. + async Task QueryStatusAsync() + { + try + { + var result = + await this._memory.SearchAsync( + this._indexNameGlobalDocs, + MigrationKey, + limit: 1, + minRelevanceScore: -1, + withEmbeddings: false, + cancellationToken) + .SingleOrDefaultAsync(cancellationToken) + ; + + if (result == null) + { + // No migration token + return ChatMigrationStatus.RequiresUpgrade; + } + + var isDone = MigrationCompletionToken.Equals(result.Metadata.Text, StringComparison.OrdinalIgnoreCase); + + return isDone ? ChatMigrationStatus.None : ChatMigrationStatus.Upgrading; + } + catch (Exception exception) when (!exception.IsCriticalException()) + { + this._logger.LogWarning("Failure searching collections: {0}\n{1}", this._indexNameGlobalDocs, exception.Message); + return ChatMigrationStatus.RequiresUpgrade; + } + } + } +} diff --git a/webapi/Services/MemoryMigration/ChatMigrationStatus.cs b/webapi/Services/MemoryMigration/ChatMigrationStatus.cs new file mode 100644 index 000000000..190505768 --- /dev/null +++ b/webapi/Services/MemoryMigration/ChatMigrationStatus.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace CopilotChat.WebApi.Services.MemoryMigration; + +/// +/// Set of migration states/status for chat memory migration. +/// +/// +/// Interlocked.CompareExchange doesn't work with enums. +/// +public sealed class ChatMigrationStatus +{ + /// + /// Represents state where no migration is required or in progress. + /// + public static ChatMigrationStatus None { get; } = new ChatMigrationStatus(nameof(None)); + + /// + /// Represents state where no migration is required. + /// + public static ChatMigrationStatus RequiresUpgrade { get; } = new ChatMigrationStatus(nameof(RequiresUpgrade)); + + /// + /// Represents state where no migration is in progress. + /// + public static ChatMigrationStatus Upgrading { get; } = new ChatMigrationStatus(nameof(Upgrading)); + + /// + /// The state label (no functional impact, but helps debugging). + /// + public string Label { get; } + + private ChatMigrationStatus(string label) + { + this.Label = label; + } +} diff --git a/webapi/Services/MemoryMigration/IChatMemoryMigrationService.cs b/webapi/Services/MemoryMigration/IChatMemoryMigrationService.cs new file mode 100644 index 000000000..0601b3914 --- /dev/null +++ b/webapi/Services/MemoryMigration/IChatMemoryMigrationService.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading; +using System.Threading.Tasks; + +namespace CopilotChat.WebApi.Services.MemoryMigration; + +/// +/// Defines contract for migrating chat memory. +/// +public interface IChatMemoryMigrationService +{ + /// + /// Migrates all non-document memory to the semantic-memory index. + /// Subsequent/redunant migration is non-destructive/no-impact to migrated index. + /// + Task MigrateAsync(CancellationToken cancellationToken = default); +} diff --git a/webapi/Services/MemoryMigration/IChatMigrationMonitor.cs b/webapi/Services/MemoryMigration/IChatMigrationMonitor.cs new file mode 100644 index 000000000..391f7558e --- /dev/null +++ b/webapi/Services/MemoryMigration/IChatMigrationMonitor.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading; +using System.Threading.Tasks; + +namespace CopilotChat.WebApi.Services.MemoryMigration; + +/// +/// Contract for monitoring the status of chat memory migration. +/// +public interface IChatMigrationMonitor +{ + /// + /// Inspects the current state of affairs to determine the chat migration status. + /// + Task GetCurrentStatusAsync(CancellationToken cancellationToken = default); +} diff --git a/webapi/Services/SemanticKernelProvider.cs b/webapi/Services/SemanticKernelProvider.cs new file mode 100644 index 000000000..be4dc682d --- /dev/null +++ b/webapi/Services/SemanticKernelProvider.cs @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using CopilotChat.WebApi.Options; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.AI.Embeddings; +using Microsoft.SemanticKernel.Connectors.Memory.AzureCognitiveSearch; +using Microsoft.SemanticKernel.Connectors.Memory.Qdrant; +using Microsoft.SemanticKernel.Memory; +using Microsoft.SemanticMemory; +using Microsoft.SemanticMemory.MemoryStorage.Qdrant; + +namespace CopilotChat.WebApi.Services; + +/// +/// Extension methods for registering Semantic Kernel related services. +/// +public sealed class SemanticKernelProvider +{ + private readonly IServiceProvider _serviceProvider; + private readonly IConfiguration _configuration; + + public SemanticKernelProvider(IServiceProvider serviceProvider, IConfiguration configuration) + { + this._serviceProvider = serviceProvider; + this._configuration = configuration; + } + + /// + /// Produce semantic-kernel with only completion services for chat. + /// + public IKernel GetCompletionKernel() + { + var builder = Kernel.Builder.WithLoggerFactory(this._serviceProvider.GetRequiredService()); + + this.WithCompletionBackend(builder); + + return builder.Build(); + } + + /// + /// Produce semantic-kernel with only completion services for planner. + /// + public IKernel GetPlannerKernel() + { + var builder = Kernel.Builder.WithLoggerFactory(this._serviceProvider.GetRequiredService()); + + this.WithPlannerBackend(builder); + + return builder.Build(); + } + + /// + /// Produce semantic-kernel with semantic-memory. + /// + public IKernel GetMigrationKernel() + { + var builder = Kernel.Builder.WithLoggerFactory(this._serviceProvider.GetRequiredService()); + + this.WithEmbeddingBackend(builder); + this.WithSemanticTextMemory(builder); + + return builder.Build(); + } + + /// + /// Add the completion backend to the kernel config + /// + private KernelBuilder WithCompletionBackend(KernelBuilder kernelBuilder) + { + var memoryOptions = this._serviceProvider.GetRequiredService>().Value; + + switch (memoryOptions.TextGeneratorType) + { + case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): + case string y when y.Equals("AzureOpenAIText", StringComparison.OrdinalIgnoreCase): + var azureAIOptions = memoryOptions.GetServiceConfig(this._configuration, "AzureOpenAIText"); + return kernelBuilder.WithAzureChatCompletionService(azureAIOptions.Deployment, azureAIOptions.Endpoint, azureAIOptions.APIKey); + + case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): + var openAIOptions = memoryOptions.GetServiceConfig(this._configuration, "OpenAI"); + return kernelBuilder.WithOpenAIChatCompletionService(openAIOptions.TextModel, openAIOptions.APIKey); + + default: + throw new ArgumentException($"Invalid {nameof(memoryOptions.TextGeneratorType)} value in 'SemanticMemory' settings."); + } + } + + /// + /// Add the completion backend to the kernel config for the planner. + /// + private KernelBuilder WithPlannerBackend(KernelBuilder kernelBuilder) + { + var memoryOptions = this._serviceProvider.GetRequiredService>().Value; + var plannerOptions = this._serviceProvider.GetRequiredService>().Value; + + switch (memoryOptions.TextGeneratorType) + { + case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): + case string y when y.Equals("AzureOpenAIText", StringComparison.OrdinalIgnoreCase): + var azureAIOptions = memoryOptions.GetServiceConfig(this._configuration, "AzureOpenAIText"); + return kernelBuilder.WithAzureChatCompletionService(plannerOptions.Model, azureAIOptions.Endpoint, azureAIOptions.APIKey); + + case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): + var openAIOptions = memoryOptions.GetServiceConfig(this._configuration, "OpenAI"); + return kernelBuilder.WithOpenAIChatCompletionService(plannerOptions.Model, openAIOptions.APIKey); + + default: + throw new ArgumentException($"Invalid {nameof(memoryOptions.TextGeneratorType)} value in 'SemanticMemory' settings."); + } + } + + /// + /// Add the embedding backend to the kernel config + /// + private KernelBuilder WithEmbeddingBackend(KernelBuilder kernelBuilder) + { + var memoryOptions = this._serviceProvider.GetRequiredService>().Value; + + switch (memoryOptions.Retrieval.EmbeddingGeneratorType) + { + case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase): + case string y when y.Equals("AzureOpenAIEmbedding", StringComparison.OrdinalIgnoreCase): + var azureAIOptions = memoryOptions.GetServiceConfig(this._configuration, "AzureOpenAIEmbedding"); + return kernelBuilder.WithAzureTextEmbeddingGenerationService(azureAIOptions.Deployment, azureAIOptions.Endpoint, azureAIOptions.APIKey); + + case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase): + var openAIOptions = memoryOptions.GetServiceConfig(this._configuration, "OpenAI"); + return kernelBuilder.WithOpenAITextEmbeddingGenerationService(openAIOptions.EmbeddingModel, openAIOptions.APIKey); + + default: + throw new ArgumentException($"Invalid {nameof(memoryOptions.Retrieval.EmbeddingGeneratorType)} value in 'SemanticMemory' settings."); + } + } + + /// + /// Add the semantic text memory. + /// + private void WithSemanticTextMemory(KernelBuilder builder) + { + var memoryOptions = this._serviceProvider.GetRequiredService>().Value; + + IMemoryStore memoryStore = CreateMemoryStore(); + +#pragma warning disable CA2000 // Ownership passed to kernel + builder.WithMemory( + new SemanticTextMemory( + memoryStore, + this._serviceProvider.GetRequiredService())); +#pragma warning restore CA2000 // Ownership passed to kernel + + IMemoryStore CreateMemoryStore() + { + switch (memoryOptions.Retrieval.VectorDbType) + { + case string x when x.Equals("SimpleVectorDb", StringComparison.OrdinalIgnoreCase): + return new VolatileMemoryStore(); + + case string x when x.Equals("Qdrant", StringComparison.OrdinalIgnoreCase): + var qdrantConfig = memoryOptions.GetServiceConfig(this._configuration, "Qdrant"); + +#pragma warning disable CA2000 // Ownership passed to QdrantMemoryStore + HttpClient httpClient = new(new HttpClientHandler { CheckCertificateRevocationList = true }); +#pragma warning restore CA2000 // Ownership passed to QdrantMemoryStore + if (!string.IsNullOrWhiteSpace(qdrantConfig.APIKey)) + { + httpClient.DefaultRequestHeaders.Add("api-key", qdrantConfig.APIKey); + } + + return + new QdrantMemoryStore( + httpClient: httpClient, + 1536, + qdrantConfig.Endpoint, + loggerFactory: this._serviceProvider.GetRequiredService()); + + case string x when x.Equals("AzureCognitiveSearch", StringComparison.OrdinalIgnoreCase): + var acsConfig = memoryOptions.GetServiceConfig(this._configuration, "AzureCognitiveSearch"); + return new AzureCognitiveSearchMemoryStore(acsConfig.Endpoint, acsConfig.APIKey); + + default: + throw new InvalidOperationException($"Invalid 'VectorDbType' type '{memoryOptions.Retrieval.VectorDbType}'."); + } + } + } +} diff --git a/webapi/Storage/ChatMemorySourceRepository.cs b/webapi/Storage/ChatMemorySourceRepository.cs index 7444edc2c..427a701e6 100644 --- a/webapi/Storage/ChatMemorySourceRepository.cs +++ b/webapi/Storage/ChatMemorySourceRepository.cs @@ -41,4 +41,13 @@ public Task> FindByNameAsync(string name) { return base.StorageContext.QueryEntitiesAsync(e => e.Name.Equals(name, StringComparison.OrdinalIgnoreCase)); } + + /// + /// Retrieves all memory sources. + /// + /// A list of memory sources. + public Task> GetAllAsync() + { + return base.StorageContext.QueryEntitiesAsync(e => true); + } } diff --git a/webapi/Storage/ChatSessionRepository.cs b/webapi/Storage/ChatSessionRepository.cs index 3875fb051..726b50d2e 100644 --- a/webapi/Storage/ChatSessionRepository.cs +++ b/webapi/Storage/ChatSessionRepository.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; +using System.Threading.Tasks; using CopilotChat.WebApi.Models.Storage; namespace CopilotChat.WebApi.Storage; @@ -17,4 +19,13 @@ public ChatSessionRepository(IStorageContext storageContext) : base(storageContext) { } + + /// + /// Retrieves all chat sessions. + /// + /// A list of ChatMessages. + public Task> GetAllChatsAsync() + { + return base.StorageContext.QueryEntitiesAsync(e => true); + } }