Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,121 +21,113 @@
namespace CopilotChat.WebApi.Controllers;

[ApiController]
public class BotController : ControllerBase
public class ChatArchiveController : ControllerBase
{
private readonly ILogger<BotController> _logger;
private readonly ILogger<ChatArchiveController> _logger;
private readonly ISemanticMemoryClient _memoryClient;
private readonly ChatSessionRepository _chatRepository;
private readonly ChatMessageRepository _chatMessageRepository;
private readonly ChatParticipantRepository _chatParticipantRepository;
private readonly BotEmbeddingConfig _embeddingConfig;
private readonly BotSchemaOptions _botSchemaOptions;
private readonly ChatArchiveEmbeddingConfig _embeddingConfig;
private readonly PromptsOptions _promptOptions;

/// <summary>
/// The constructor of BotController.
/// Constructor.
/// </summary>
/// <param name="memoryClient">Memory client.</param>
/// <param name="chatRepository">The chat session repository.</param>
/// <param name="chatMessageRepository">The chat message repository.</param>
/// <param name="chatParticipantRepository">The chat participant repository.</param>
/// <param name="botSchemaOptions">The bot schema options.</param>
/// <param name="promptOptions">The document memory options.</param>
/// <param name="logger">The logger.</param>
public BotController(
public ChatArchiveController(
ISemanticMemoryClient memoryClient,
ChatSessionRepository chatRepository,
ChatMessageRepository chatMessageRepository,
ChatParticipantRepository chatParticipantRepository,
BotEmbeddingConfig embeddingConfig,
IOptions<BotSchemaOptions> botSchemaOptions,
ChatArchiveEmbeddingConfig embeddingConfig,
IOptions<PromptsOptions> promptOptions,
ILogger<BotController> logger)
ILogger<ChatArchiveController> logger)
{
this._memoryClient = memoryClient;
this._logger = logger;
this._chatRepository = chatRepository;
this._chatMessageRepository = chatMessageRepository;
this._chatParticipantRepository = chatParticipantRepository;
this._embeddingConfig = embeddingConfig;
this._botSchemaOptions = botSchemaOptions.Value;
this._promptOptions = promptOptions.Value;
}

/// <summary>
/// Download a bot.
/// Download a chat archive.
/// </summary>
/// <param name="kernel">The Semantic Kernel instance.</param>
/// <param name="chatId">The chat id to be downloaded.</param>
/// <returns>The serialized Bot object of the chat id.</returns>
/// <param name="chatId">The ID of chat to be downloaded.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>The serialized chat archive object of the chat id.</returns>
[HttpGet]
[ActionName("DownloadAsync")]
[Route("bot/download/{chatId:guid}")]
[Route("chats/{chatId:guid}/archive")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
[Authorize(Policy = AuthPolicyName.RequireChatParticipant)]
public async Task<ActionResult<Bot?>> DownloadAsync(Guid chatId, CancellationToken cancellationToken = default)
public async Task<ActionResult<ChatArchive?>> DownloadAsync(Guid chatId, CancellationToken cancellationToken = default)
{
this._logger.LogDebug("Received call to download a bot");
this._logger.LogDebug("Received call to download a chat archive");

var memory = await this.CreateBotAsync(chatId, cancellationToken);
var chatArchive = await this.CreateChatArchiveAsync(chatId, cancellationToken);

return this.Ok(memory);
return this.Ok(chatArchive);
}

/// <summary>
/// Prepare the bot information of a given chat.
/// Prepare a chat archive.
/// </summary>
/// <param name="kernel">The semantic kernel object.</param>
/// <param name="chatId">The chat id of the bot</param>
/// <returns>A Bot object that represents the chat session.</returns>
private async Task<Bot> CreateBotAsync(Guid chatId, CancellationToken cancellationToken)
/// <param name="chatId">The chat id of the chat archive</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>A ChatArchive object that represents the chat session.</returns>
private async Task<ChatArchive> CreateChatArchiveAsync(Guid chatId, CancellationToken cancellationToken)
{
var chatIdString = chatId.ToString();
var bot = new Bot
var chatArchive = new ChatArchive
{
// get the bot schema version
Schema = this._botSchemaOptions,

// get the embedding configuration
// Get embedding configuration
EmbeddingConfigurations = this._embeddingConfig,
};

// get the chat title
ChatSession chat = await this._chatRepository.FindByIdAsync(chatIdString);
bot.ChatTitle = chat.Title;
chatArchive.ChatTitle = chat.Title;

// get the system description
bot.SystemDescription = chat.SystemDescription;
chatArchive.SystemDescription = chat.SystemDescription;

// get the chat history
bot.ChatHistory = await this.GetAllChatMessagesAsync(chatIdString);
chatArchive.ChatHistory = await this.GetAllChatMessagesAsync(chatIdString);

foreach (var memory in this._promptOptions.MemoryMap.Keys)
{
bot.Embeddings.Add(
chatArchive.Embeddings.Add(
memory,
await this.GetMemoryRecordsAndAppendToEmbeddingsAsync(chatIdString, memory, cancellationToken));
}

// get the document memory collection names (global scope)
bot.DocumentEmbeddings.Add(
chatArchive.DocumentEmbeddings.Add(
"GlobalDocuments",
await this.GetMemoryRecordsAndAppendToEmbeddingsAsync(
Guid.Empty.ToString(),
this._promptOptions.DocumentMemoryName,
cancellationToken));

// get the document memory collection names (user scope)
bot.DocumentEmbeddings.Add(
chatArchive.DocumentEmbeddings.Add(
"ChatDocuments",
await this.GetMemoryRecordsAndAppendToEmbeddingsAsync(
chatIdString,
this._promptOptions.DocumentMemoryName,
cancellationToken));

return bot;
return chatArchive;
}

/// <summary>
Expand Down
31 changes: 17 additions & 14 deletions webapi/Controllers/ChatController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ public ChatController(
/// <param name="chatParticipantRepository">Repository of chat participants.</param>
/// <param name="authInfo">Auth info for the current request.</param>
/// <param name="ask">Prompt along with its parameters.</param>
/// <param name="chatId">Chat ID.</param>
/// <returns>Results containing the response from the model.</returns>
[Route("chat")]
[Route("chats/{chatId:guid}/messages")]
[HttpPost]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
Expand All @@ -98,10 +99,12 @@ public async Task<IActionResult> ChatAsync(
[FromServices] ChatSessionRepository chatSessionRepository,
[FromServices] ChatParticipantRepository chatParticipantRepository,
[FromServices] IAuthInfo authInfo,
[FromBody] Ask ask)
[FromBody] Ask ask,
[FromRoute] Guid chatId)
{
this._logger.LogDebug("/chat request received.");
return await this.HandleRequest(ChatFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask);
this._logger.LogDebug("Chat message received.");

return await this.HandleRequest(ChatFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString());
}

/// <summary>
Expand All @@ -115,8 +118,9 @@ public async Task<IActionResult> ChatAsync(
/// <param name="chatParticipantRepository">Repository of chat participants.</param>
/// <param name="authInfo">Auth info for the current request.</param>
/// <param name="ask">Prompt along with its parameters.</param>
/// <param name="chatId">Chat ID.</param>
/// <returns>Results containing the response from the model.</returns>
[Route("processplan")]
[Route("chats/{chatId:guid}/plan")]
[HttpPost]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
Expand All @@ -131,10 +135,12 @@ public async Task<IActionResult> ProcessPlanAsync(
[FromServices] ChatSessionRepository chatSessionRepository,
[FromServices] ChatParticipantRepository chatParticipantRepository,
[FromServices] IAuthInfo authInfo,
[FromBody] ExecutePlanParameters ask)
[FromBody] ExecutePlanParameters ask,
Comment thread
glahaye marked this conversation as resolved.
[FromRoute] Guid chatId)
{
this._logger.LogDebug("/processplan request received.");
return await this.HandleRequest(ProcessPlanFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask);
this._logger.LogDebug("plan request received.");

return await this.HandleRequest(ProcessPlanFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString());
}

#region Private Methods
Expand All @@ -151,6 +157,7 @@ public async Task<IActionResult> ProcessPlanAsync(
/// <param name="chatParticipantRepository">Repository of chat participants.</param>
/// <param name="authInfo">Auth info for the current request.</param>
/// <param name="ask">Prompt along with its parameters.</param>
/// <param name="chatId"Chat ID.</>
/// <returns>Results containing the response from the model.</returns>
private async Task<IActionResult> HandleRequest(
string functionName,
Expand All @@ -161,17 +168,13 @@ private async Task<IActionResult> HandleRequest(
ChatSessionRepository chatSessionRepository,
ChatParticipantRepository chatParticipantRepository,
IAuthInfo authInfo,
Ask ask)
Ask ask,
string chatId)
{
// Put ask's variables in the context we will use.
var contextVariables = askConverter.GetContextVariables(ask);

// Verify that the chat exists and that the user has access to it.
if (!contextVariables.TryGetValue("chatId", out string? chatId))
{
return this.BadRequest("ChatId not specified.");
}

ChatSession? chat = null;
#pragma warning disable CA1508 // Avoid dead conditional code. It's giving out false positives on chat == null.
if (!(await chatSessionRepository.TryFindByIdAsync(chatId, callback: c => chat = c)) || chat == null)
Expand Down
59 changes: 22 additions & 37 deletions webapi/Controllers/ChatHistoryController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class ChatHistoryController : ControllerBase
{
private const string ChatEditedClientCall = "ChatEdited";
private const string ChatDeletedClientCall = "ChatDeleted";
private const string GetChatRoute = "GetChatRoute";

private readonly ILogger<ChatHistoryController> _logger;
private readonly ISemanticMemoryClient _memoryClient;
Expand Down Expand Up @@ -81,7 +82,7 @@ public ChatHistoryController(
/// <param name="chatParameter">Contains the title of the chat.</param>
/// <returns>The HTTP action result.</returns>
[HttpPost]
[Route("chatSession/create")]
[Route("chats")]
[ProducesResponseType(StatusCodes.Status201Created)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
public async Task<IActionResult> CreateChatSessionAsync(
Expand Down Expand Up @@ -109,19 +110,16 @@ public async Task<IActionResult> CreateChatSessionAsync(
await this._participantRepository.CreateAsync(new ChatParticipant(this._authInfo.UserId, newChat.Id));

this._logger.LogDebug("Created chat session with id {0}.", newChat.Id);
return this.CreatedAtAction(
nameof(this.GetChatSessionByIdAsync),
new { chatId = newChat.Id },
new CreateChatResponse(newChat, chatMessage));

return this.CreatedAtRoute(GetChatRoute, new { chatId = newChat.Id }, new CreateChatResponse(newChat, chatMessage));
}

/// <summary>
/// Get a chat session by id.
/// </summary>
/// <param name="chatId">The chat id.</param>
[HttpGet]
[ActionName("GetChatSessionByIdAsync")]
[Route("chatSession/getChat/{chatId:guid}")]
[Route("chats/{chatId:guid}", Name = GetChatRoute)]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status403Forbidden)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
Expand All @@ -143,11 +141,11 @@ public async Task<IActionResult> GetChatSessionByIdAsync(Guid chatId)
/// <param name="userId">The user id.</param>
/// <returns>A list of chat sessions. An empty list if the user is not in any chat session.</returns>
[HttpGet]
[Route("chatSession/getAllChats/{userId}")]
[Route("chats")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status403Forbidden)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
public async Task<IActionResult> GetAllChatSessionsAsync(string userId)
public async Task<IActionResult> GetAllChatSessionsAsync()
{
// Get all participants that belong to the user.
// Then get all the chats from the list of participants.
Expand All @@ -163,8 +161,7 @@ public async Task<IActionResult> GetAllChatSessionsAsync(string userId)
}
else
{
this._logger.LogDebug(
"Failed to find chat session with id {0}", chatParticipant.ChatId);
this._logger.LogDebug("Failed to find chat session with id {0}", chatParticipant.ChatId);
}
}

Expand All @@ -179,13 +176,13 @@ public async Task<IActionResult> GetAllChatSessionsAsync(string userId)
/// <param name="startIdx">The start index at which the first message will be returned.</param>
/// <param name="count">The number of messages to return. -1 will return all messages starting from startIdx.</param>
[HttpGet]
[Route("chatSession/getChatMessages/{chatId:guid}")]
[Route("chats/{chatId:guid}/messages")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status403Forbidden)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
[Authorize(Policy = AuthPolicyName.RequireChatParticipant)]
public async Task<IActionResult> GetChatMessagesAsync(
Guid chatId,
[FromRoute] Guid chatId,
[FromQuery] int startIdx = 0,
[FromQuery] int count = -1)
{
Expand All @@ -206,49 +203,36 @@ public async Task<IActionResult> GetChatMessagesAsync(
/// Edit a chat session.
/// </summary>
/// <param name="chatParameters">Object that contains the parameters to edit the chat.</param>
[HttpPost]
[Route("chatSession/edit")]
[HttpPatch]
[Route("chats/{chatId:guid}")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status403Forbidden)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
[Authorize(Policy = AuthPolicyName.RequireChatParticipant)]
public async Task<IActionResult> EditChatSessionAsync(
[FromServices] IHubContext<MessageRelayHub> messageRelayHubContext,
[FromBody] EditChatParameters chatParameters)
[FromBody] EditChatParameters chatParameters,
[FromRoute] Guid chatId)
{
string? chatId = chatParameters.Id;

if (chatId == null)
{
return this.BadRequest("Chat id must be specified.");
}

// Verify access to chat session
// TODO: [Issue #141] This can be removed when route is updated to include chatId, so that we can leverage RequireChatParticipant policy.
bool isUserInChat = await this._participantRepository.IsUserInChatAsync(this._authInfo.UserId, chatId);
if (!isUserInChat)
{
return this.Forbid("User does not have access to the specified chat.");
}

ChatSession? chat = null;
if (await this._sessionRepository.TryFindByIdAsync(chatId, callback: v => chat = v))
if (await this._sessionRepository.TryFindByIdAsync(chatId.ToString(), callback: v => chat = v))
{
chat!.Title = chatParameters.Title ?? chat!.Title;
chat!.SystemDescription = chatParameters.SystemDescription ?? chat!.SystemDescription;
chat!.MemoryBalance = chatParameters.MemoryBalance ?? chat!.MemoryBalance;
await this._sessionRepository.UpsertAsync(chat);
await messageRelayHubContext.Clients.Group(chatId).SendAsync(ChatEditedClientCall, chat);
await messageRelayHubContext.Clients.Group(chatId.ToString()).SendAsync(ChatEditedClientCall, chat);

return this.Ok(chat);
}

return this.NotFound($"No chat session found for chat id '{chatId}'.");
}

/// <summary>
/// Service API to get a list of imported sources.
/// Gets list of imported documents for a given chat.
/// </summary>
[Route("chatSession/{chatId:guid}/sources")]
[Route("chats/{chatId:guid}/documents")]
[HttpGet]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
Expand All @@ -261,7 +245,8 @@ public async Task<ActionResult<IEnumerable<MemorySource>>> GetSourcesAsync(Guid

if (await this._sessionRepository.TryFindByIdAsync(chatId.ToString()))
{
var sources = await this._sourceRepository.FindByChatIdAsync(chatId.ToString());
IEnumerable<MemorySource> sources = await this._sourceRepository.FindByChatIdAsync(chatId.ToString());

return this.Ok(sources);
}

Expand All @@ -273,7 +258,7 @@ public async Task<ActionResult<IEnumerable<MemorySource>>> GetSourcesAsync(Guid
/// </summary>
/// <param name="chatId">The chat id.</param>
[HttpDelete]
[Route("chatSession/{chatId:guid}")]
[Route("chats/{chatId:guid}")]
[ProducesResponseType(StatusCodes.Status204NoContent)]
[ProducesResponseType(StatusCodes.Status403Forbidden)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
Expand Down
Loading