From d29cba7dd495be3127fbdfd1d25a33fd8740f471 Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Fri, 2 Jan 2026 18:10:06 -0300 Subject: [PATCH] Add support for collection search results citations --- readme.md | 44 ++++ src/xAI.Tests/ChatClientTests.cs | 23 +- ....cs => CollectionSearchToolCallContent.cs} | 6 +- ...s => CollectionSearchToolResultContent.cs} | 8 +- src/xAI/GrokChatClient.cs | 116 ++-------- src/xAI/GrokProtocolExtensions.cs | 208 +++++++++++++++++- 6 files changed, 294 insertions(+), 111 deletions(-) rename src/xAI/{HostedToolCallContent.cs => CollectionSearchToolCallContent.cs} (56%) rename src/xAI/{HostedToolResultContent.cs => CollectionSearchToolResultContent.cs} (60%) diff --git a/readme.md b/readme.md index 5b0e219..cb981a3 100644 --- a/readme.md +++ b/readme.md @@ -158,6 +158,50 @@ var options = new ChatOptions }; ``` +To receive the actual search results and file references, include `CollectionsSearchCallOutput` in the options: + +```csharp +var options = new GrokChatOptions +{ + Include = [IncludeOption.CollectionsSearchCallOutput], + Tools = [new HostedFileSearchTool { + Inputs = [new HostedVectorStoreContent("[collection_id]")] + }] +}; + +var response = await grok.GetResponseAsync(messages, options); + +// Access the search results with file references +var results = response.Messages + .SelectMany(x => x.Contents) + .OfType(); + +foreach (var result in results) +{ + // Each result contains files that were found and referenced + var files = result.Outputs?.OfType(); + foreach (var file in files ?? []) + { + Console.WriteLine($"File: {file.Name} (ID: {file.FileId})"); + + // Files include citation annotations with snippets + foreach (var citation in file.Annotations?.OfType() ?? []) + { + Console.WriteLine($" Title: {citation.Title}"); + Console.WriteLine($" Snippet: {citation.Snippet}"); + Console.WriteLine($" URL: {citation.Url}"); // collections://[collection_id]/files/[file_id] + } + } +} +``` + +Citations from collection search include: +- **Title**: Extracted from the first line of the chunk content (if available), typically the file name or heading +- **Snippet**: The relevant text excerpt from the document +- **FileId**: Identifier of the source file in the collection +- **Url**: A `collections://` URI pointing to the specific file within the collection +- **ToolName**: Always set to `"collections_search"` + Learn more about [collection search](https://docs.x.ai/docs/guides/tools/collections-search-tool). ## Remote MCP diff --git a/src/xAI.Tests/ChatClientTests.cs b/src/xAI.Tests/ChatClientTests.cs index 92b48ff..cf3b3ab 100644 --- a/src/xAI.Tests/ChatClientTests.cs +++ b/src/xAI.Tests/ChatClientTests.cs @@ -96,8 +96,7 @@ public async Task GrokInvokesToolAndSearch() Assert.Equal(options.ModelId, response.ModelId); var calls = response.Messages - .SelectMany(x => x.Contents.OfType()) - .Select(x => x.RawRepresentation as xAI.Protocol.ToolCall) + .SelectMany(x => x.Contents.Select(x => x.RawRepresentation as xAI.Protocol.ToolCall)) .Where(x => x is not null) .ToList(); @@ -317,7 +316,6 @@ public async Task GrokInvokesHostedCollectionSearch() var options = new GrokChatOptions { - Include = { IncludeOption.CollectionsSearchCallOutput }, Tools = [new HostedFileSearchTool { Inputs = [new HostedVectorStoreContent("collection_91559d9b-a55d-42fe-b2ad-ecf8904d9049")] }] @@ -329,9 +327,21 @@ public async Task GrokInvokesHostedCollectionSearch() Assert.Contains("11,74", text); Assert.Contains(response.Messages .SelectMany(x => x.Contents) - .OfType() + .OfType() .Select(x => x.RawRepresentation as xAI.Protocol.ToolCall), x => x?.Type == xAI.Protocol.ToolCallType.CollectionsSearchTool); + // No actual search results content since we didn't specify it in Include + Assert.Empty(response.Messages.SelectMany(x => x.Contents).OfType()); + + options.Include = [IncludeOption.CollectionsSearchCallOutput]; + response = await grok.GetResponseAsync(messages, options); + + // Now it also contains the file reference as result content + Assert.Contains(response.Messages + .SelectMany(x => x.Contents) + .OfType() + .SelectMany(x => (x.Outputs ?? []).OfType()), + x => x.Name == "LNS0004592.txt"); } [SecretsFact("XAI_API_KEY", "GITHUB_TOKEN")] @@ -458,9 +468,8 @@ public async Task GrokStreamsUpdatesFromAllTools() .OfType()); Assert.Contains(response.Messages - .SelectMany(x => x.Contents) - .OfType() - .Select(x => x.RawRepresentation as xAI.Protocol.ToolCall), + .SelectMany(x => x.Contents.Select(x => x.RawRepresentation as xAI.Protocol.ToolCall)) + .Where(x => x != null), x => x?.Type == xAI.Protocol.ToolCallType.WebSearchTool); Assert.Equal(1, getDateCalls); diff --git a/src/xAI/HostedToolCallContent.cs b/src/xAI/CollectionSearchToolCallContent.cs similarity index 56% rename from src/xAI/HostedToolCallContent.cs rename to src/xAI/CollectionSearchToolCallContent.cs index 450fd23..f80f6ee 100644 --- a/src/xAI/HostedToolCallContent.cs +++ b/src/xAI/CollectionSearchToolCallContent.cs @@ -1,11 +1,9 @@ -using System.Diagnostics.CodeAnalysis; -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; namespace xAI; /// Represents a hosted tool agentic call. -[Experimental("xAI001")] -public class HostedToolCallContent : AIContent +public class CollectionSearchToolCallContent : AIContent { /// Gets or sets the tool call ID. public virtual string? CallId { get; set; } diff --git a/src/xAI/HostedToolResultContent.cs b/src/xAI/CollectionSearchToolResultContent.cs similarity index 60% rename from src/xAI/HostedToolResultContent.cs rename to src/xAI/CollectionSearchToolResultContent.cs index 4c8694a..ac490b3 100644 --- a/src/xAI/HostedToolResultContent.cs +++ b/src/xAI/CollectionSearchToolResultContent.cs @@ -1,13 +1,9 @@ -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; namespace xAI; /// Represents a hosted tool agentic call. -[DebuggerDisplay("{DebuggerDisplay,nq}")] -[Experimental("xAI001")] -public class HostedToolResultContent : AIContent +public class CollectionSearchToolResultContent : AIContent { /// Gets or sets the tool call ID. public virtual string? CallId { get; set; } diff --git a/src/xAI/GrokChatClient.cs b/src/xAI/GrokChatClient.cs index d396c65..6cd35a7 100644 --- a/src/xAI/GrokChatClient.cs +++ b/src/xAI/GrokChatClient.cs @@ -1,4 +1,5 @@ using System.Text.Json; +using Google.Protobuf; using Grpc.Core; using Grpc.Net.Client; using Microsoft.Extensions.AI; @@ -39,72 +40,7 @@ public async Task GetResponseAsync(IEnumerable messag var response = await client.GetCompletionAsync(request, cancellationToken: cancellationToken); var lastOutput = response.Outputs.OrderByDescending(x => x.Index).FirstOrDefault(); - if (lastOutput == null) - { - return new ChatResponse() - { - ResponseId = response.Id, - ModelId = response.Model, - CreatedAt = response.Created.ToDateTimeOffset(), - Usage = MapToUsage(response.Usage), - }; - } - - var message = new ChatMessage(MapRole(lastOutput.Message.Role), default(string)); - var citations = response.Citations?.Distinct().Select(MapCitation).ToList(); - - foreach (var output in response.Outputs.OrderBy(x => x.Index)) - { - if (output.Message.Content is { Length: > 0 } text) - { - // Special-case output from tools - if (output.Message.Role == MessageRole.RoleTool && - output.Message.ToolCalls.Count == 1 && - output.Message.ToolCalls[0] is { } toolCall) - { - if (toolCall.Type == ToolCallType.McpTool) - { - message.Contents.Add(new McpServerToolCallContent(toolCall.Id, toolCall.Function.Name, null) - { - RawRepresentation = toolCall - }); - message.Contents.Add(new McpServerToolResultContent(toolCall.Id) - { - RawRepresentation = toolCall, - Output = [new TextContent(text)] - }); - continue; - } - else if (toolCall.Type == ToolCallType.CodeExecutionTool) - { - message.Contents.Add(new CodeInterpreterToolCallContent() - { - CallId = toolCall.Id, - RawRepresentation = toolCall - }); - message.Contents.Add(new CodeInterpreterToolResultContent() - { - CallId = toolCall.Id, - RawRepresentation = toolCall, - Outputs = [new TextContent(text)] - }); - continue; - } - } - - var content = new TextContent(text) { Annotations = citations }; - - foreach (var citation in output.Message.Citations) - (content.Annotations ??= []).Add(MapInlineCitation(citation)); - - message.Contents.Add(content); - } - - foreach (var toolCall in output.Message.ToolCalls) - message.Contents.Add(MapToolCall(toolCall)); - } - - return new ChatResponse(message) + var result = new ChatResponse() { ResponseId = response.Id, ModelId = response.Model, @@ -112,9 +48,15 @@ public async Task GetResponseAsync(IEnumerable messag FinishReason = lastOutput != null ? MapFinishReason(lastOutput.FinishReason) : null, Usage = MapToUsage(response.Usage), }; + + var citations = response.Citations?.Distinct().Select(MapCitation).ToList(); + + ((List)result.Messages).AddRange(response.Outputs.AsChatMessages(citations)); + + return result; } - AIContent MapToolCall(ToolCall toolCall) => toolCall.Type switch + AIContent? MapToolCall(ToolCall toolCall) => toolCall.Type switch { ToolCallType.ClientSideTool => new FunctionCallContent( toolCall.Id, @@ -134,11 +76,12 @@ public async Task GetResponseAsync(IEnumerable messag CallId = toolCall.Id, RawRepresentation = toolCall }, - _ => new HostedToolCallContent() + ToolCallType.CollectionsSearchTool => new CollectionSearchToolCallContent() { CallId = toolCall.Id, RawRepresentation = toolCall - } + }, + _ => null }; public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) @@ -161,10 +104,12 @@ async IAsyncEnumerable CompleteChatStreamingCore(IEnumerable ResponseId = chunk.Id, ModelId = chunk.Model, CreatedAt = chunk.Created?.ToDateTimeOffset(), + RawRepresentation = chunk, FinishReason = output.FinishReason != FinishReason.ReasonInvalid ? MapFinishReason(output.FinishReason) : null, }; - if (chunk.Citations is { Count: > 0 } citations) + var citations = chunk.Citations?.Distinct().Select(MapCitation).ToList(); + if (citations?.Count > 0) { var textContent = update.Contents.OfType().FirstOrDefault(); if (textContent == null) @@ -172,13 +117,10 @@ async IAsyncEnumerable CompleteChatStreamingCore(IEnumerable textContent = new TextContent(string.Empty); update.Contents.Add(textContent); } - - foreach (var citation in citations.Distinct()) - (textContent.Annotations ??= []).Add(MapCitation(citation)); + ((List)(textContent.Annotations ??= [])).AddRange(citations); } - foreach (var toolCall in output.Delta.ToolCalls) - update.Contents.Add(MapToolCall(toolCall)); + ((List)update.Contents).AddRange(output.Delta.ToolCalls.AsContents(text, citations)); if (update.Contents.Any()) yield return update; @@ -186,19 +128,6 @@ async IAsyncEnumerable CompleteChatStreamingCore(IEnumerable } } - static CitationAnnotation MapInlineCitation(InlineCitation citation) => citation.CitationCase switch - { - InlineCitation.CitationOneofCase.WebCitation => new CitationAnnotation { Url = new(citation.WebCitation.Url) }, - InlineCitation.CitationOneofCase.XCitation => new CitationAnnotation { Url = new(citation.XCitation.Url) }, - InlineCitation.CitationOneofCase.CollectionsCitation => new CitationAnnotation - { - FileId = citation.CollectionsCitation.FileId, - Snippet = citation.CollectionsCitation.ChunkContent, - ToolName = "file_search", - }, - _ => new CitationAnnotation() - }; - static CitationAnnotation MapCitation(string citation) { var url = new Uri(citation); @@ -210,12 +139,13 @@ static CitationAnnotation MapCitation(string citation) var file = url.AbsolutePath[7..]; return new CitationAnnotation { - ToolName = "collections_search", - FileId = file, AdditionalProperties = new AdditionalPropertiesDictionary - { - { "collection_id", collection } - } + { + { "collection_id", collection } + }, + FileId = file, + ToolName = "collections_search", + Url = new Uri($"collections://{collection}/files/{file}"), }; } diff --git a/src/xAI/GrokProtocolExtensions.cs b/src/xAI/GrokProtocolExtensions.cs index 43d52d6..2aaabd8 100644 --- a/src/xAI/GrokProtocolExtensions.cs +++ b/src/xAI/GrokProtocolExtensions.cs @@ -4,15 +4,18 @@ using System.Linq; using System.Text; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading.Tasks; +using Google.Protobuf; using Microsoft.Extensions.AI; using xAI.Protocol; +using static Google.Protobuf.Reflection.GeneratedCodeInfo.Types; namespace xAI; /// Provides extension methods for working with xAI protocol types. [EditorBrowsable(EditorBrowsableState.Never)] -public static class GrokProtocolExtensions +public static partial class GrokProtocolExtensions { /// Creates an xAI protocol from an . /// The tool to convert. @@ -20,6 +23,18 @@ public static class GrokProtocolExtensions /// is . public static Tool? AsProtocolTool(this AITool tool, ChatOptions? options = null) => ToProtocolTool(Throw.IfNull(tool), options); + /// Creates a sequence of instances from the specified protocol outputs. + /// The output messages to convert. + /// A sequence of instances. + /// is . + public static IEnumerable AsChatMessages(this IEnumerable outputs, List? citations = default) => ToChatMessages(Throw.IfNull(outputs).Select(x => x.Message), citations); + + /// Creates a sequence of instances from the specified protocol messages. + /// The messages to convert. + /// A sequence of instances. + /// is . + public static IEnumerable AsChatMessages(this IEnumerable messages, List? citations = default) => ToChatMessages(Throw.IfNull(messages), citations); + static Tool? ToProtocolTool(AITool tool, ChatOptions? options = null) { switch (tool) @@ -122,4 +137,195 @@ public static class GrokProtocolExtensions return null; } } + + static IEnumerable ToChatMessages(IEnumerable messages, List? citations = default) + { + ChatMessage? message = null; + + foreach (var completion in messages) + { + message ??= new(ChatRole.Assistant, (string?)null); + var annotations = citations; + if (completion.Citations.Count > 0) + { + annotations ??= []; + foreach (var citation in completion.Citations) + annotations.AddRange(AsCitations(citation)); + } + + var content = string.IsNullOrEmpty(completion.Content) ? null : completion.Content; + + ((List)message.Contents).AddRange(AsContents(completion.ToolCalls, content, annotations)); + + if (!string.IsNullOrEmpty(completion.ReasoningContent)) + { + message.Contents.Add(new TextReasoningContent(completion.ReasoningContent) + { + Annotations = annotations, + RawRepresentation = completion, + ProtectedData = completion.EncryptedContent, + }); + } + else if (!string.IsNullOrEmpty(completion.EncryptedContent)) + { + message.Contents.Add(new TextReasoningContent(null) + { + Annotations = annotations, + ProtectedData = completion.EncryptedContent, + RawRepresentation = completion + }); + } + + if (completion.Role != MessageRole.RoleTool && completion.Content is { Length: > 0 } text) + { + message.Contents.Add(new TextContent(text) + { + Annotations = annotations, + RawRepresentation = completion + }); + } + } + + if (message is not null) + yield return message; + } + + internal static IEnumerable AsContents(this IEnumerable toolCalls, string? content = default, List? annotations = default) + { + foreach (var toolCall in toolCalls) + { + switch (toolCall.Type) + { + case ToolCallType.ClientSideTool: + yield return new FunctionCallContent( + toolCall.Id, + toolCall.Function.Name, + !string.IsNullOrEmpty(toolCall.Function.Arguments) + ? JsonSerializer.Deserialize>(toolCall.Function.Arguments) + : null) + { + Annotations = annotations, + RawRepresentation = toolCall, + }; + break; + + case ToolCallType.McpTool: + yield return new McpServerToolCallContent(toolCall.Id, toolCall.Function.Name, null) + { + Annotations = annotations, + RawRepresentation = toolCall + }; + if (content is not null) + yield return new McpServerToolResultContent(toolCall.Id) + { + Annotations = annotations, + RawRepresentation = toolCall, + Output = [new TextContent(content)] + }; + break; + + case ToolCallType.CodeExecutionTool: + yield return new CodeInterpreterToolCallContent() + { + Annotations = annotations, + RawRepresentation = toolCall, + CallId = toolCall.Id, + }; + if (content is not null) + yield return new CodeInterpreterToolResultContent() + { + Annotations = annotations, + RawRepresentation = toolCall, + CallId = toolCall.Id, + Outputs = [new TextContent(content)] + }; + break; + + case ToolCallType.CollectionsSearchTool: + if (content is not null) + { + var reader = new Utf8JsonReader(Encoding.UTF8.GetBytes(content)); + if (JsonDocument.TryParseValue(ref reader, out var doc) && + doc.RootElement.TryGetProperty("search_matches", out var matchesElement) && + matchesElement.ValueKind == JsonValueKind.Array && + JsonSerializer.Deserialize(matchesElement, CollectionSearchJsonContext.Default.Options) is { Length: > 0 } matches) + { + var result = new CollectionSearchToolResultContent + { + Annotations = annotations, + RawRepresentation = toolCall, + CallId = toolCall.Id, + }; + var outputs = new List(); + foreach (var file in matches.GroupBy(x => x.FileId)) + { + var fileCitations = file.SelectMany(AsCitations).ToArray(); + outputs.Add(new HostedFileContent(file.Key) + { + Annotations = fileCitations, + RawRepresentation = toolCall, + Name = fileCitations.Select(x => x.Title).Where(x => x != null).FirstOrDefault(), + }); + } + result.Outputs = outputs; + yield return result; + } + } + else + { + yield return new CollectionSearchToolCallContent + { + Annotations = annotations, + RawRepresentation = toolCall, + CallId = toolCall.Id, + }; + } + break; + + default: + yield return new() { Annotations = annotations, RawRepresentation = toolCall }; + break; + } + } + } + + static IEnumerable AsCitations(CollectionSearchItem item) + { + var newline = item.ChunkContent.IndexOf('\n'); + var title = newline >= 0 ? item.ChunkContent[..newline] : null; + foreach (var collectionId in item.CollectionIds) + { + yield return new CitationAnnotation + { + Title = title, + FileId = item.FileId, + Snippet = newline >= 0 ? item.ChunkContent[++newline..] : item.ChunkContent, + ToolName = "collections_search", + Url = new Uri($"collections://{collectionId}/files/{item.FileId}"), + RawRepresentation = item, + }; + } + } + + static IEnumerable AsCitations(InlineCitation citation) => citation.CitationCase switch + { + InlineCitation.CitationOneofCase.WebCitation => [new CitationAnnotation { Url = new(citation.WebCitation.Url), RawRepresentation = citation }], + InlineCitation.CitationOneofCase.XCitation => [new CitationAnnotation { Url = new(citation.XCitation.Url), RawRepresentation = citation }], + InlineCitation.CitationOneofCase.CollectionsCitation => AsCitations(new CollectionSearchItem( + citation.CollectionsCitation.FileId, citation.CollectionsCitation.ChunkId, citation.CollectionsCitation.ChunkContent, citation.CollectionsCitation.Score, [.. citation.CollectionsCitation.CollectionIds])), + _ => [new CitationAnnotation { RawRepresentation = citation }] + }; + + [JsonSourceGenerationOptions(JsonSerializerDefaults.Web, + UseStringEnumConverter = true, + UnmappedMemberHandling = JsonUnmappedMemberHandling.Skip, + PropertyNameCaseInsensitive = true, + PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower, + WriteIndented = true + )] + [JsonSerializable(typeof(CollectionSearchItem[]))] + [JsonSerializable(typeof(CollectionSearchItem))] + partial class CollectionSearchJsonContext : JsonSerializerContext { } + + record CollectionSearchItem(string FileId, string ChunkId, string ChunkContent, float Score, string[] CollectionIds); }