Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fe6c77c
Migration POC
crickman Sep 6, 2023
3dca171
Merge branch 'feature-semantic-memory' into feature-semantic-memory-m…
crickman Sep 6, 2023
b9d9455
Typo cleanup
crickman Sep 6, 2023
348fe22
Merge branch 'feature-semantic-memory-migration' of https://github.co…
crickman Sep 6, 2023
9ed3b86
Namespace cleanup (dotnet-format)
crickman Sep 6, 2023
9285c96
Clean-up and consolidation
crickman Sep 6, 2023
834802b
Spell-check
crickman Sep 6, 2023
9230fc8
Update migration route
crickman Sep 6, 2023
0b3172c
Merge from feature branch
crickman Sep 12, 2023
8d73a0e
Merged to general maintenance hook
crickman Sep 12, 2023
29dd9ec
Merge/refactor check-point
crickman Sep 12, 2023
c1f3ea3
Merge branch 'feature-semantic-memory' into feature-semantic-memory-m…
crickman Sep 12, 2023
ac34017
Checkpoint
crickman Sep 15, 2023
b8a8063
Merged
crickman Sep 15, 2023
1d4a307
Spelling
crickman Sep 15, 2023
7e8e655
Update from feature branch
crickman Sep 15, 2023
30cd475
Typos and rename
crickman Sep 15, 2023
d51f5dc
Refactor kernel initialization
crickman Sep 15, 2023
c241714
Spelling
crickman Sep 15, 2023
51119c6
Namespace clean-up
crickman Sep 15, 2023
b59796d
Namespace order
crickman Sep 15, 2023
81dc4ed
Bot Fix
crickman Sep 16, 2023
7b1d81f
Fixes and tweaks
crickman Sep 16, 2023
721cef7
Functional testing
crickman Sep 19, 2023
de044ad
Merged from feature branch
crickman Sep 19, 2023
1cf4bf4
Clean-up
crickman Sep 19, 2023
0221254
Typo
crickman Sep 19, 2023
e5fb8b2
Typo
crickman Sep 19, 2023
3a1c694
Whitespace
crickman Sep 19, 2023
45e2591
Merge from main/feature branch
crickman Sep 19, 2023
8b07967
Whitespace
crickman Sep 19, 2023
0b4a960
Merge from feature branch
crickman Sep 19, 2023
1dd2dd4
Merge from feature branch
crickman Sep 19, 2023
daae03f
Remove edit
crickman Sep 19, 2023
dd59889
Remove "ConfigureAwait"
crickman Sep 19, 2023
d113308
Comment
crickman Sep 19, 2023
bf51fc0
Remove try/catch
crickman Sep 19, 2023
f3b116c
Namespace
crickman Sep 19, 2023
5832041
Merge branch 'feature-semantic-memory' into feature-semantic-memory-m…
crickman Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions webapi/Controllers/MaintenanceController.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Threading;
using System.Threading.Tasks;
using CopilotChat.WebApi.Auth;
using CopilotChat.WebApi.Hubs;
using CopilotChat.WebApi.Models.Response;
using CopilotChat.WebApi.Options;
using CopilotChat.WebApi.Services.MemoryMigration;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel;

namespace CopilotChat.WebApi.Controllers;

Expand Down Expand Up @@ -46,14 +47,34 @@ public MaintenanceController(
[HttpGet]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
public ActionResult<MaintenanceResult?> GetMaintenanceStatus(
[FromServices] IKernel kernel,
public async Task<ActionResult<MaintenanceResult?>> GetMaintenanceStatusAsync(
[FromServices] IChatMigrationMonitor migrationMonitor,
[FromServices] IHubContext<MessageRelayHub> messageRelayHubContext,
CancellationToken cancellationToken = default)
{
MaintenanceResult? result = null;

var migrationStatus = await migrationMonitor.GetCurrentStatusAsync(cancellationToken);

if (migrationStatus != ChatMigrationStatus.None)
{
result =
new MaintenanceResult
{
Title = "Migrating Chat Memory",
Message = "An upgrade requires that all non-document memories be migrated. This may take several minutes...",
Note = "Note: All document memories will need to be re-imported.",
};
}

if (this._serviceOptions.Value.InMaintenance)
{
return this.Ok(new MaintenanceResult());
result = new MaintenanceResult(); // Default maintenance message
}

if (result != null)
{
return this.Ok(result);
}

return this.Ok();
Expand Down
24 changes: 17 additions & 7 deletions webapi/Extensions/ISemanticMemoryClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ await memoryClient.SearchAsync(
indexName,
filter,
resultCount,
cancellationToken)
.ConfigureAwait(false);
cancellationToken);

return searchResult;
}
Expand Down Expand Up @@ -129,11 +128,23 @@ public static async Task StoreDocumentAsync(
await memoryClient.ImportDocumentAsync(uploadRequest, cancellationToken);
}

public static Task StoreMemoryAsync(
this ISemanticMemoryClient memoryClient,
string indexName,
string chatId,
string memoryName,
string memory,
CancellationToken cancellationToken = default)
{
return memoryClient.StoreMemoryAsync(indexName, chatId, memoryName, memoryId: Guid.NewGuid().ToString(), memory, cancellationToken);
}

public static async Task StoreMemoryAsync(
this ISemanticMemoryClient memoryClient,
string indexName,
string chatId,
string memoryName,
string memoryId,
string memory,
CancellationToken cancellationToken = default)
{
Expand All @@ -143,10 +154,9 @@ public static async Task StoreMemoryAsync(
await writer.FlushAsync();
stream.Position = 0;

var id = Guid.NewGuid().ToString();
var uploadRequest = new DocumentUploadRequest
{
DocumentId = id,
DocumentId = memoryId,
Index = indexName,
Files =
new()
Expand All @@ -166,12 +176,12 @@ public static async Task RemoveChatMemoriesAsync(
this ISemanticMemoryClient memoryClient,
string indexName,
string chatId,
CancellationToken cancelToken = default)
CancellationToken cancellationToken = default)
{
var memories = await memoryClient.SearchMemoryAsync(indexName, "*", 0.0F, chatId, cancellationToken: cancelToken);
var memories = await memoryClient.SearchMemoryAsync(indexName, "*", 0.0F, chatId, cancellationToken: cancellationToken);
foreach (var memory in memories.Results)
{
await memoryClient.DeleteDocumentAsync(indexName, memory.Link, cancelToken);
await memoryClient.DeleteDocumentAsync(indexName, memory.Link, cancellationToken);
}
}
}
97 changes: 13 additions & 84 deletions webapi/Extensions/SemanticKernelExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Skills.Core;
using Microsoft.SemanticMemory;
Expand All @@ -39,14 +37,14 @@ internal static class SemanticKernelExtensions
/// </summary>
public static WebApplicationBuilder AddSemanticKernelServices(this WebApplicationBuilder builder)
{
builder.InitializeKernelProvider();

// Semantic Kernel
builder.Services.AddScoped<IKernel>(
sp =>
{
var kernel = Kernel.Builder
.WithLoggerFactory(sp.GetRequiredService<ILoggerFactory>())
.WithCompletionBackend(sp, builder.Configuration)
.Build();
var provider = sp.GetRequiredService<SemanticKernelProvider>();
var kernel = provider.GetCompletionKernel();

sp.GetRequiredService<RegisterSkillsWithKernel>()(sp, kernel);
return kernel;
Expand All @@ -66,16 +64,16 @@ public static WebApplicationBuilder AddSemanticKernelServices(this WebApplicatio
/// </summary>
public static WebApplicationBuilder AddPlannerServices(this WebApplicationBuilder builder)
{
builder.InitializeKernelProvider();

builder.Services.AddScoped<CopilotChatPlanner>(sp =>
{
sp.WithBotConfig(builder.Configuration);

var plannerOptions = sp.GetRequiredService<IOptions<PlannerOptions>>();

var plannerKernel = Kernel.Builder
.WithLoggerFactory(sp.GetRequiredService<ILoggerFactory>())
.WithPlannerBackend(sp, builder.Configuration)
.Build();
var provider = sp.GetRequiredService<SemanticKernelProvider>();
var plannerKernel = provider.GetPlannerKernel();

return new CopilotChatPlanner(plannerKernel, plannerOptions?.Value, sp.GetRequiredService<ILogger<CopilotChatPlanner>>());
});
Expand Down Expand Up @@ -139,6 +137,11 @@ public static async Task<T> SafeInvokeAsync<T>(Func<Task<T>> callback, string fu
}
}

private static void InitializeKernelProvider(this WebApplicationBuilder builder)
{
builder.Services.AddSingleton(sp => new SemanticKernelProvider(sp, builder.Configuration));
}

/// <summary>
/// Register the skills with the kernel.
/// </summary>
Expand Down Expand Up @@ -185,80 +188,6 @@ internal static void AddContentSafety(this IServiceCollection services)
}
}

/// <summary>
/// Add the completion backend to the kernel config
/// </summary>
private static KernelBuilder WithCompletionBackend(this KernelBuilder kernelBuilder, IServiceProvider provider, IConfiguration configuration)
{
var memoryOptions = provider.GetRequiredService<IOptions<SemanticMemoryConfig>>().Value;

switch (memoryOptions.TextGeneratorType)
{
case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase):
case string y when y.Equals("AzureOpenAIText", StringComparison.OrdinalIgnoreCase):
var azureAIOptions = memoryOptions.GetServiceConfig<AzureOpenAIConfig>(configuration, "AzureOpenAIText");
return kernelBuilder.WithAzureChatCompletionService(azureAIOptions.Deployment, azureAIOptions.Endpoint, azureAIOptions.APIKey);

case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase):
var openAIOptions = memoryOptions.GetServiceConfig<OpenAIConfig>(configuration, "OpenAI");
return kernelBuilder.WithOpenAIChatCompletionService(openAIOptions.TextModel, openAIOptions.APIKey);

default:
throw new ArgumentException($"Invalid {nameof(memoryOptions.TextGeneratorType)} value in 'SemanticMemory' settings.");
}
}

/// <summary>
/// Add the completion backend to the kernel config for the planner.
/// </summary>
private static KernelBuilder WithPlannerBackend(this KernelBuilder kernelBuilder, IServiceProvider provider, IConfiguration configuration)
{
var memoryOptions = provider.GetRequiredService<IOptions<SemanticMemoryConfig>>().Value;
var plannerOptions = provider.GetRequiredService<IOptions<PlannerOptions>>().Value;

switch (memoryOptions.TextGeneratorType)
{
case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase):
case string y when y.Equals("AzureOpenAIText", StringComparison.OrdinalIgnoreCase):
var azureAIOptions = memoryOptions.GetServiceConfig<AzureOpenAIConfig>(configuration, "AzureOpenAIText");
return kernelBuilder.WithAzureChatCompletionService(plannerOptions.Model, azureAIOptions.Endpoint, azureAIOptions.APIKey);

case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase):
var openAIOptions = memoryOptions.GetServiceConfig<OpenAIConfig>(configuration, "OpenAI");
return kernelBuilder.WithOpenAIChatCompletionService(plannerOptions.Model, openAIOptions.APIKey);

default:
throw new ArgumentException($"Invalid {nameof(memoryOptions.TextGeneratorType)} value in 'SemanticMemory' settings.");
}
}

/// <summary>
/// Construct IEmbeddingGeneration from <see cref="AIServiceOptions"/>
/// </summary>
private static ITextEmbeddingGeneration ToTextEmbeddingsService(
this IServiceProvider provider,
IConfiguration configuration,
ILoggerFactory? loggerFactory = null)
{
var logger = provider.GetRequiredService<ILogger<ITextEmbeddingGeneration>>();
var memoryOptions = provider.GetRequiredService<IOptions<SemanticMemoryConfig>>().Value;

switch (memoryOptions.Retrieval.EmbeddingGeneratorType)
{
case string x when x.Equals("AzureOpenAI", StringComparison.OrdinalIgnoreCase):
case string y when y.Equals("AzureOpenAIEmbedding", StringComparison.OrdinalIgnoreCase):
var azureAIOptions = memoryOptions.GetServiceConfig<AzureOpenAIConfig>(configuration, "AzureOpenAIEmbedding");
return new AzureTextEmbeddingGeneration(azureAIOptions.Deployment, azureAIOptions.Endpoint, azureAIOptions.APIKey, httpClient: null, loggerFactory);

case string x when x.Equals("OpenAI", StringComparison.OrdinalIgnoreCase):
var openAIOptions = memoryOptions.GetServiceConfig<OpenAIConfig>(configuration, "OpenAI");
return new OpenAITextEmbeddingGeneration(openAIOptions.EmbeddingModel, openAIOptions.APIKey, organization: null, httpClient: null, loggerFactory);

default:
throw new ArgumentException($"Invalid {nameof(memoryOptions.Retrieval.EmbeddingGeneratorType)} value in 'SemanticMemory' settings.");
}
}

/// <summary>
/// Get the embedding model from the configuration.
/// </summary>
Expand Down
21 changes: 21 additions & 0 deletions webapi/Extensions/ServiceExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
using CopilotChat.WebApi.Auth;
using CopilotChat.WebApi.Models.Storage;
using CopilotChat.WebApi.Options;
using CopilotChat.WebApi.Services;
using CopilotChat.WebApi.Services.MemoryMigration;
using CopilotChat.WebApi.Storage;
using CopilotChat.WebApi.Utilities;
using Microsoft.AspNetCore.Authentication;
Expand Down Expand Up @@ -87,6 +89,25 @@ internal static IServiceCollection AddUtilities(this IServiceCollection services
return services.AddScoped<AskConverter>();
}

internal static IServiceCollection AddMainetnanceServices(this IServiceCollection services)
{
// Inject migration services
services.AddSingleton<IChatMigrationMonitor, ChatMigrationMonitor>();
services.AddSingleton<IChatMemoryMigrationService, ChatMemoryMigrationService>();

// Inject actions so they can be part of the action-list.
services.AddSingleton<ChatMigrationMaintenanceAction>();
services.AddSingleton<IReadOnlyList<IMaintenanceAction>>(
sp =>
(IReadOnlyList<IMaintenanceAction>)
new[]
{
sp.GetRequiredService<ChatMigrationMaintenanceAction>(),
});

return services;
}

/// <summary>
/// Add CORS settings.
/// </summary>
Expand Down
1 change: 1 addition & 0 deletions webapi/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public static async Task Main(string[] args)

// Add in the rest of the services.
builder.Services
.AddMainetnanceServices()
.AddEndpointsApiExplorer()
.AddSwaggerGen()
.AddCorsPolicy(builder.Configuration)
Expand Down
19 changes: 19 additions & 0 deletions webapi/Services/IMaintenanceAction.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Threading;
using System.Threading.Tasks;

namespace CopilotChat.WebApi.Services;

/// <summary>
/// Defines discrete maintenance action responsible for both inspecting state
/// and performing maintenance.
/// </summary>
public interface IMaintenanceAction
{
/// <summary>
/// Calling site to initiate maintenance action.
/// </summary>
/// <returns>true if maintenance needed or in progress</returns>
Task<bool> InvokeAsync(CancellationToken cancellation = default);
}
28 changes: 27 additions & 1 deletion webapi/Services/MaintenanceMiddleware.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Threading.Tasks;
using CopilotChat.WebApi.Controllers;
using CopilotChat.WebApi.Hubs;
Expand All @@ -18,30 +19,55 @@ namespace CopilotChat.WebApi.Services;
public class MaintenanceMiddleware
{
private readonly RequestDelegate _next;
private readonly IReadOnlyList<IMaintenanceAction> _actions;
private readonly IOptions<ServiceOptions> _serviceOptions;
private readonly IHubContext<MessageRelayHub> _messageRelayHubContext;
private readonly ILogger<MaintenanceMiddleware> _logger;

private bool? _isInMaintenance;

public MaintenanceMiddleware(
RequestDelegate next,
IReadOnlyList<IMaintenanceAction> actions,
IOptions<ServiceOptions> servicetOptions,
IHubContext<MessageRelayHub> messageRelayHubContext,
ILogger<MaintenanceMiddleware> logger)

{
this._next = next;
this._actions = actions;
this._serviceOptions = servicetOptions;
this._messageRelayHubContext = messageRelayHubContext;
this._logger = logger;
}

public async Task Invoke(HttpContext ctx, IKernel kernel)
{
// Skip inspection if _isInMaintenance explicitly false.
if (this._isInMaintenance == null || this._isInMaintenance.Value)
{
// Maintenance never false => true; always true => false or just false;
this._isInMaintenance = await this.InspectMaintenanceActionAsync();
}

// In maintenance if actions say so or explicitly configured.
if (this._serviceOptions.Value.InMaintenance)
{
await this._messageRelayHubContext.Clients.All.SendAsync(MaintenanceController.GlobalSiteMaintenance, "Site undergoing maintenance...").ConfigureAwait(false);
await this._messageRelayHubContext.Clients.All.SendAsync(MaintenanceController.GlobalSiteMaintenance, "Site undergoing maintenance...");
}

await this._next(ctx);
}

private async Task<bool> InspectMaintenanceActionAsync()
{
bool inMaintenance = false;

foreach (var action in this._actions)
{
inMaintenance |= await action.InvokeAsync();
}

return inMaintenance;
}
}
Loading