Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using Azure.AI.OpenAI;
using Azure.Core;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using Microsoft.SemanticKernel.TextGeneration;

namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Extensions;

/// <summary>
/// Unit tests for <see cref="AzureOpenAIServiceCollectionExtensions"/> class.
/// </summary>
public sealed class AzureOpenAIServiceCollectionExtensionsTests
{
#region Chat completion

[Theory]
[InlineData(InitializationType.ApiKey)]
[InlineData(InitializationType.TokenCredential)]
[InlineData(InitializationType.OpenAIClientInline)]
[InlineData(InitializationType.OpenAIClientInServiceProvider)]
public void ServiceCollectionAddAzureOpenAIChatCompletionAddsValidService(InitializationType type)
{
// Arrange
var credentials = DelegatedTokenCredential.Create((_, _) => new AccessToken());
var client = new AzureOpenAIClient(new Uri("http://localhost"), "key");
var builder = Kernel.CreateBuilder();

builder.Services.AddSingleton(client);

// Act
IServiceCollection collection = type switch
{
InitializationType.ApiKey => builder.Services.AddAzureOpenAIChatCompletion("deployment-name", "https://endpoint", "api-key"),
InitializationType.TokenCredential => builder.Services.AddAzureOpenAIChatCompletion("deployment-name", "https://endpoint", credentials),
InitializationType.OpenAIClientInline => builder.Services.AddAzureOpenAIChatCompletion("deployment-name", client),
InitializationType.OpenAIClientInServiceProvider => builder.Services.AddAzureOpenAIChatCompletion("deployment-name"),
_ => builder.Services
};

// Assert
var chatCompletionService = builder.Build().GetRequiredService<IChatCompletionService>();
Assert.True(chatCompletionService is AzureOpenAIChatCompletionService);

var textGenerationService = builder.Build().GetRequiredService<ITextGenerationService>();
Assert.True(textGenerationService is AzureOpenAIChatCompletionService);
}

#endregion

public enum InitializationType
{
ApiKey,
TokenCredential,
OpenAIClientInline,
OpenAIClientInServiceProvider,
OpenAIClientEndpoint,
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using Azure.AI.OpenAI;
using Azure.Core;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using Microsoft.SemanticKernel.TextGeneration;

namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Extensions;

/// <summary>
/// Unit tests for <see cref="AzureOpenAIServiceKernelBuilderExtensions"/> class.
/// </summary>
public sealed class AzureOpenAIServiceKernelBuilderExtensionsTests
{
#region Chat completion

[Theory]
[InlineData(InitializationType.ApiKey)]
[InlineData(InitializationType.TokenCredential)]
[InlineData(InitializationType.OpenAIClientInline)]
[InlineData(InitializationType.OpenAIClientInServiceProvider)]
public void KernelBuilderAddAzureOpenAIChatCompletionAddsValidService(InitializationType type)
{
// Arrange
var credentials = DelegatedTokenCredential.Create((_, _) => new AccessToken());
var client = new AzureOpenAIClient(new Uri("http://localhost"), "key");
var builder = Kernel.CreateBuilder();

builder.Services.AddSingleton(client);

// Act
builder = type switch
{
InitializationType.ApiKey => builder.AddAzureOpenAIChatCompletion("deployment-name", "https://endpoint", "api-key"),
InitializationType.TokenCredential => builder.AddAzureOpenAIChatCompletion("deployment-name", "https://endpoint", credentials),
InitializationType.OpenAIClientInline => builder.AddAzureOpenAIChatCompletion("deployment-name", client),
InitializationType.OpenAIClientInServiceProvider => builder.AddAzureOpenAIChatCompletion("deployment-name"),
_ => builder
};

// Assert
var chatCompletionService = builder.Build().GetRequiredService<IChatCompletionService>();
Assert.True(chatCompletionService is AzureOpenAIChatCompletionService);

var textGenerationService = builder.Build().GetRequiredService<ITextGenerationService>();
Assert.True(textGenerationService is AzureOpenAIChatCompletionService);
}

#endregion

public enum InitializationType
{
ApiKey,
TokenCredential,
OpenAIClientInline,
OpenAIClientInServiceProvider,
OpenAIClientEndpoint,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;

namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests;
namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Extensions;
public class ChatHistoryExtensionsTests
{
[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ internal ClientCore(ILogger? logger = null)
private static Dictionary<string, object?> GetChatChoiceMetadata(OpenAIChatCompletion completions)
{
#pragma warning disable AOAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
return new Dictionary<string, object?>(12)
return new Dictionary<string, object?>(8)
{
{ nameof(completions.Id), completions.Id },
{ nameof(completions.CreatedAt), completions.CreatedAt },
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Net.Http;
using Azure;
using Azure.AI.OpenAI;
using Azure.Core;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using Microsoft.SemanticKernel.Http;
using Microsoft.SemanticKernel.TextGeneration;

#pragma warning disable IDE0039 // Use local function

namespace Microsoft.SemanticKernel;

/// <summary>
/// Provides extension methods for <see cref="IServiceCollection"/> and related classes to configure Azure OpenAI connectors.
/// </summary>
public static class AzureOpenAIServiceCollectionExtensions
{
#region Chat Completion

/// <summary>
/// Adds the Azure OpenAI chat completion service to the list.
/// </summary>
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="apiKey">Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="modelId">Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="httpClient">The HttpClient to use with this service.</param>
/// <returns>The same instance as <paramref name="builder"/>.</returns>
public static IKernelBuilder AddAzureOpenAIChatCompletion(
this IKernelBuilder builder,
string deploymentName,
string endpoint,
string apiKey,
string? serviceId = null,
string? modelId = null,
HttpClient? httpClient = null)
{
Verify.NotNull(builder);
Verify.NotNullOrWhiteSpace(deploymentName);
Verify.NotNullOrWhiteSpace(endpoint);
Verify.NotNullOrWhiteSpace(apiKey);

Func<IServiceProvider, object?, AzureOpenAIChatCompletionService> factory = (serviceProvider, _) =>
{
AzureOpenAIClient client = CreateAzureOpenAIClient(
endpoint,
new AzureKeyCredential(apiKey),
HttpClientProvider.GetHttpClient(httpClient, serviceProvider));

return new(deploymentName, client, modelId, serviceProvider.GetService<ILoggerFactory>());
};

builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, factory);
builder.Services.AddKeyedSingleton<ITextGenerationService>(serviceId, factory);

return builder;
}

/// <summary>
/// Adds the Azure OpenAI chat completion service to the list.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="apiKey">Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="modelId">Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
public static IServiceCollection AddAzureOpenAIChatCompletion(
this IServiceCollection services,
string deploymentName,
string endpoint,
string apiKey,
string? serviceId = null,
string? modelId = null)
{
Verify.NotNull(services);
Verify.NotNullOrWhiteSpace(deploymentName);
Verify.NotNullOrWhiteSpace(endpoint);
Verify.NotNullOrWhiteSpace(apiKey);

Func<IServiceProvider, object?, AzureOpenAIChatCompletionService> factory = (serviceProvider, _) =>
{
AzureOpenAIClient client = CreateAzureOpenAIClient(
endpoint,
new AzureKeyCredential(apiKey),
HttpClientProvider.GetHttpClient(serviceProvider));

return new(deploymentName, client, modelId, serviceProvider.GetService<ILoggerFactory>());
};

services.AddKeyedSingleton<IChatCompletionService>(serviceId, factory);
services.AddKeyedSingleton<ITextGenerationService>(serviceId, factory);

return services;
}

/// <summary>
/// Adds the Azure OpenAI chat completion service to the list.
/// </summary>
/// <param name="builder">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="credentials">Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="modelId">Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="httpClient">The HttpClient to use with this service.</param>
/// <returns>The same instance as <paramref name="builder"/>.</returns>
public static IKernelBuilder AddAzureOpenAIChatCompletion(
this IKernelBuilder builder,
string deploymentName,
string endpoint,
TokenCredential credentials,
string? serviceId = null,
string? modelId = null,
HttpClient? httpClient = null)
{
Verify.NotNull(builder);
Verify.NotNullOrWhiteSpace(deploymentName);
Verify.NotNullOrWhiteSpace(endpoint);
Verify.NotNull(credentials);

Func<IServiceProvider, object?, AzureOpenAIChatCompletionService> factory = (serviceProvider, _) =>
{
AzureOpenAIClient client = CreateAzureOpenAIClient(
endpoint,
credentials,
HttpClientProvider.GetHttpClient(httpClient, serviceProvider));

return new(deploymentName, client, modelId, serviceProvider.GetService<ILoggerFactory>());
};

builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, factory);
builder.Services.AddKeyedSingleton<ITextGenerationService>(serviceId, factory);

return builder;
}

/// <summary>
/// Adds the Azure OpenAI chat completion service to the list.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="credentials">Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="modelId">Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
public static IServiceCollection AddAzureOpenAIChatCompletion(
this IServiceCollection services,
string deploymentName,
string endpoint,
TokenCredential credentials,
string? serviceId = null,
string? modelId = null)
{
Verify.NotNull(services);
Verify.NotNullOrWhiteSpace(deploymentName);
Verify.NotNullOrWhiteSpace(endpoint);
Verify.NotNull(credentials);

Func<IServiceProvider, object?, AzureOpenAIChatCompletionService> factory = (serviceProvider, _) =>
{
AzureOpenAIClient client = CreateAzureOpenAIClient(
endpoint,
credentials,
HttpClientProvider.GetHttpClient(serviceProvider));

return new(deploymentName, client, modelId, serviceProvider.GetService<ILoggerFactory>());
};

services.AddKeyedSingleton<IChatCompletionService>(serviceId, factory);
services.AddKeyedSingleton<ITextGenerationService>(serviceId, factory);

return services;
}

/// <summary>
/// Adds the Azure OpenAI chat completion service to the list.
/// </summary>
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="azureOpenAIClient"><see cref="AzureOpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="modelId">Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <returns>The same instance as <paramref name="builder"/>.</returns>
public static IKernelBuilder AddAzureOpenAIChatCompletion(
this IKernelBuilder builder,
string deploymentName,
AzureOpenAIClient? azureOpenAIClient = null,
string? serviceId = null,
string? modelId = null)
{
Verify.NotNull(builder);
Verify.NotNullOrWhiteSpace(deploymentName);

Func<IServiceProvider, object?, AzureOpenAIChatCompletionService> factory = (serviceProvider, _) =>
new(deploymentName, azureOpenAIClient ?? serviceProvider.GetRequiredService<AzureOpenAIClient>(), modelId, serviceProvider.GetService<ILoggerFactory>());

builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, factory);
builder.Services.AddKeyedSingleton<ITextGenerationService>(serviceId, factory);

return builder;
}

/// <summary>
/// Adds the Azure OpenAI chat completion service to the list.
/// </summary>
/// <param name="services">The <see cref="IServiceCollection"/> instance to augment.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="azureOpenAIClient"><see cref="AzureOpenAIClient"/> to use for the service. If null, one must be available in the service provider when this service is resolved.</param>
/// <param name="serviceId">A local identifier for the given AI service</param>
/// <param name="modelId">Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <returns>The same instance as <paramref name="services"/>.</returns>
public static IServiceCollection AddAzureOpenAIChatCompletion(
this IServiceCollection services,
string deploymentName,
AzureOpenAIClient? azureOpenAIClient = null,
string? serviceId = null,
string? modelId = null)
{
Verify.NotNull(services);
Verify.NotNullOrWhiteSpace(deploymentName);

Func<IServiceProvider, object?, AzureOpenAIChatCompletionService> factory = (serviceProvider, _) =>
new(deploymentName, azureOpenAIClient ?? serviceProvider.GetRequiredService<AzureOpenAIClient>(), modelId, serviceProvider.GetService<ILoggerFactory>());

services.AddKeyedSingleton<IChatCompletionService>(serviceId, factory);
services.AddKeyedSingleton<ITextGenerationService>(serviceId, factory);

return services;
}

#endregion

private static AzureOpenAIClient CreateAzureOpenAIClient(string endpoint, AzureKeyCredential credentials, HttpClient? httpClient) =>
new(new Uri(endpoint), credentials, ClientCore.GetOpenAIClientOptions(httpClient));

private static AzureOpenAIClient CreateAzureOpenAIClient(string endpoint, TokenCredential credentials, HttpClient? httpClient) =>
new(new Uri(endpoint), credentials, ClientCore.GetOpenAIClientOptions(httpClient));
}
Loading