From 86d8ac688418859ea81f3ed2c94b345984a8f9e1 Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Wed, 31 Dec 2025 15:25:42 -0300 Subject: [PATCH] Improve tool conversion and make API public Added comprehensive tests for existing conversions. Add support for specifying additional headers for MCP tools. Add support for setting Instructions on the collection search API in the HostedFileSearchTool. Validate that both AllowedDomains/ExcludedDomains and AllowedXHandles/ExcludedXHandles aren't used simultaneously (not supported in the API). --- src/xAI.Tests/GrokConversionTests.cs | 230 +++++++++++++++++++++++++++ src/xAI/Extensions/ChatExtensions.cs | 36 ++++- src/xAI/GrokChatClient.cs | 78 +-------- src/xAI/GrokProtocolExtensions.cs | 125 +++++++++++++++ 4 files changed, 392 insertions(+), 77 deletions(-) create mode 100644 src/xAI.Tests/GrokConversionTests.cs create mode 100644 src/xAI/GrokProtocolExtensions.cs diff --git a/src/xAI.Tests/GrokConversionTests.cs b/src/xAI.Tests/GrokConversionTests.cs new file mode 100644 index 0000000..34b7ce3 --- /dev/null +++ b/src/xAI.Tests/GrokConversionTests.cs @@ -0,0 +1,230 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Google.Protobuf.WellKnownTypes; +using Microsoft.Extensions.AI; +using OpenAI.Responses; +using xAI.Protocol; + +namespace xAI; + +public class GrokConversionTests +{ + [Fact] + public void AsTool_WithWebSearch() + { + var webSearch = new HostedWebSearchTool(); + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.WebSearch); + } + + [Fact] + public void AsTool_WithWebSearch_ThrowsIfAllowedAndExcluded() + { + var webSearch = new GrokSearchTool + { + AllowedDomains = ["Foo"], + ExcludedDomains = ["Bar"] + }; + + Assert.Throws(() => webSearch.AsProtocolTool()); + } + + [Fact] + public void AsTool_WithWebSearch_AllowedDomains() + { + var webSearch = new GrokSearchTool + { + AllowedDomains = ["foo.com", "bar.com"], + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.WebSearch); + Assert.Equal(["foo.com", "bar.com"], tool.WebSearch.AllowedDomains); + } + + [Fact] + public void AsTool_WithWebSearch_ExcludedDomains() + { + var webSearch = new GrokSearchTool + { + ExcludedDomains = ["foo.com", "bar.com"], + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.WebSearch); + Assert.Equal(["foo.com", "bar.com"], tool.WebSearch.ExcludedDomains); + } + + [Fact] + public void AsTool_WithWebSearch_ImageUnderstanding() + { + var webSearch = new GrokSearchTool + { + EnableImageUnderstanding = true + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.WebSearch); + Assert.True(tool.WebSearch.EnableImageUnderstanding); + } + + [Fact] + public void AsTool_WithXSearch_ThrowsIfAllowedAndExcluded() + { + var webSearch = new GrokXSearchTool + { + AllowedHandles = ["Foo"], + ExcludedHandles = ["Bar"] + }; + + Assert.Throws(() => webSearch.AsProtocolTool()); + } + + [Fact] + public void AsTool_WithXSearch_AllowedHandles() + { + var webSearch = new GrokXSearchTool + { + AllowedHandles = ["foo", "bar"], + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.XSearch); + Assert.Equal(["foo", "bar"], tool.XSearch.AllowedXHandles); + } + + [Fact] + public void AsTool_WithXSearch_ExcludedDomains() + { + var webSearch = new GrokXSearchTool + { + ExcludedHandles = ["foo", "bar"], + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.XSearch); + Assert.Equal(["foo", "bar"], tool.XSearch.ExcludedXHandles); + } + + [Fact] + public void AsTool_WithXSearch_ImageUnderstanding() + { + var webSearch = new GrokXSearchTool + { + EnableImageUnderstanding = true + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.XSearch); + Assert.True(tool.XSearch.EnableImageUnderstanding); + } + + [Fact] + public void AsTool_WithXSearch_VideoUnderstanding() + { + var webSearch = new GrokXSearchTool + { + EnableVideoUnderstanding = true + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.XSearch); + Assert.True(tool.XSearch.EnableVideoUnderstanding); + } + + [Fact] + public void AsTool_WithXSearch_FromTo() + { + var webSearch = new GrokXSearchTool + { + FromDate = DateOnly.FromDateTime(DateTime.UtcNow.Subtract(TimeSpan.FromDays(1))), + ToDate = DateOnly.FromDateTime(DateTime.UtcNow) + }; + + var tool = webSearch.AsProtocolTool(); + + Assert.NotNull(tool?.XSearch); + Assert.Equal(tool.XSearch.FromDate, Timestamp.FromDateTime(webSearch.FromDate.Value.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc))); + Assert.Equal(tool.XSearch.ToDate, Timestamp.FromDateTime(webSearch.ToDate.Value.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc))); + } + + [Fact] + public void AsTool_WithFunctionTool() + { + var functionTool = AIFunctionFactory.Create(() => "", "Name", "Description"); + + var tool = functionTool.AsProtocolTool(); + + Assert.NotNull(tool?.Function); + Assert.Equal("Name", tool.Function.Name); + Assert.Equal("Description", tool.Function.Description); + } + + [Fact] + public void AsTool_WithCodeExecution() + { + var codeTool = new HostedCodeInterpreterTool(); + + var tool = codeTool.AsProtocolTool(); + + Assert.NotNull(tool?.CodeExecution); + } + + [Fact] + public void AsTool_WithHostedFileSearchTool() + { + var collectionId = Guid.NewGuid().ToString(); + var instructions = "Return N/A if no results found"; + var fileSearch = new HostedFileSearchTool() + { + MaximumResultCount = 50, + Inputs = [new HostedVectorStoreContent(collectionId)] + }.WithInstructions(instructions); + + var tool = fileSearch.AsProtocolTool(); + + Assert.NotNull(tool?.CollectionsSearch); + Assert.Contains(collectionId, tool.CollectionsSearch.CollectionIds); + Assert.Equal(50, tool.CollectionsSearch.Limit); + Assert.Equal(instructions, tool.CollectionsSearch.Instructions); + } + + [Fact] + public void AsTool_WithHostedMcpTool() + { + var accessToken = Guid.NewGuid().ToString(); + var headers = new Dictionary + { + ["foo"] = "baz" + }; + var mcpTool = new HostedMcpServerTool("foo", "foo.com", new Dictionary + { + ["x-extra"] = "bar", + [nameof(MCP.ExtraHeaders)] = headers + }) + { + AllowedTools = ["list"], + AuthorizationToken = accessToken, + }; + + var tool = mcpTool.AsProtocolTool(); + + Assert.NotNull(tool?.Mcp); + Assert.Equal("foo", tool.Mcp.ServerLabel); + Assert.Equal("foo.com", tool.Mcp.ServerUrl); + Assert.Contains("list", tool.Mcp.AllowedToolNames); + Assert.Equal(accessToken, tool.Mcp.Authorization); + Assert.Contains(KeyValuePair.Create("x-extra", "bar"), tool.Mcp.ExtraHeaders); + Assert.Contains(KeyValuePair.Create("foo", "baz"), tool.Mcp.ExtraHeaders); + } +} diff --git a/src/xAI/Extensions/ChatExtensions.cs b/src/xAI/Extensions/ChatExtensions.cs index 1b2e874..967c516 100644 --- a/src/xAI/Extensions/ChatExtensions.cs +++ b/src/xAI/Extensions/ChatExtensions.cs @@ -1,8 +1,12 @@ -using Microsoft.Extensions.AI; +using System.ComponentModel; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Options; +using xAI.Protocol; namespace xAI; /// Extensions for . +[EditorBrowsable(EditorBrowsableState.Never)] public static partial class ChatOptionsExtensions { extension(ChatOptions options) @@ -14,4 +18,34 @@ public string? EndUserId set => (options.AdditionalProperties ??= [])["EndUserId"] = value; } } +} + +/// Grok-specific extensions for . +[EditorBrowsable(EditorBrowsableState.Never)] +public static partial class HostedFileSearchToolExtensions +{ + extension(HostedFileSearchTool tool) + { + /// + /// User-defined instructions to be included in the search query. Defaults to generic search + /// instructions used by the collections search backend if unset. + /// + public HostedFileSearchTool WithInstructions(string instructions) => new(new Dictionary + { + [nameof(CollectionsSearch.Instructions)] = Throw.IfNullOrEmpty(instructions) + }) + { + Inputs = tool.Inputs, + MaximumResultCount = tool.MaximumResultCount, + }; + } +} + +static partial class AIToolExtensions +{ + extension(AITool tool) + { + public T? GetProperty(string name) => + tool.AdditionalProperties?.TryGetValue(name, out var value) is true && value is T typed ? typed : default; + } } \ No newline at end of file diff --git a/src/xAI/GrokChatClient.cs b/src/xAI/GrokChatClient.cs index 20315e2..d396c65 100644 --- a/src/xAI/GrokChatClient.cs +++ b/src/xAI/GrokChatClient.cs @@ -328,82 +328,8 @@ codeResult.RawRepresentation is ToolCall codeToolCall && if (options?.Tools is not null) { - foreach (var tool in options.Tools) - { - if (tool is AIFunction functionTool) - { - var function = new Function - { - Name = functionTool.Name, - Description = functionTool.Description, - Parameters = JsonSerializer.Serialize(functionTool.JsonSchema) - }; - request.Tools.Add(new Tool { Function = function }); - } - else if (tool is HostedWebSearchTool webSearchTool) - { - if (webSearchTool is GrokXSearchTool xSearch) - { - var toolProto = new XSearch - { - EnableImageUnderstanding = xSearch.EnableImageUnderstanding, - EnableVideoUnderstanding = xSearch.EnableVideoUnderstanding, - }; - - if (xSearch.AllowedHandles is { } allowed) toolProto.AllowedXHandles.AddRange(allowed); - if (xSearch.ExcludedHandles is { } excluded) toolProto.ExcludedXHandles.AddRange(excluded); - if (xSearch.FromDate is { } from) toolProto.FromDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(from.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)); - if (xSearch.ToDate is { } to) toolProto.ToDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(to.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)); - - request.Tools.Add(new Tool { XSearch = toolProto }); - } - else if (webSearchTool is GrokSearchTool grokSearch) - { - var toolProto = new WebSearch - { - EnableImageUnderstanding = grokSearch.EnableImageUnderstanding, - }; - - if (grokSearch.AllowedDomains is { } allowed) toolProto.AllowedDomains.AddRange(allowed); - if (grokSearch.ExcludedDomains is { } excluded) toolProto.ExcludedDomains.AddRange(excluded); - - request.Tools.Add(new Tool { WebSearch = toolProto }); - } - else - { - request.Tools.Add(new Tool { WebSearch = new WebSearch() }); - } - } - else if (tool is HostedCodeInterpreterTool) - { - request.Tools.Add(new Tool { CodeExecution = new CodeExecution { } }); - } - else if (tool is HostedFileSearchTool fileSearch) - { - var toolProto = new CollectionsSearch(); - - if (fileSearch.Inputs?.OfType() is { } vectorStores) - toolProto.CollectionIds.AddRange(vectorStores.Select(x => x.VectorStoreId).Distinct()); - - if (fileSearch.MaximumResultCount is { } maxResults) - toolProto.Limit = maxResults; - - request.Tools.Add(new Tool { CollectionsSearch = toolProto }); - } - else if (tool is HostedMcpServerTool mcpTool) - { - request.Tools.Add(new Tool - { - Mcp = new MCP - { - Authorization = mcpTool.AuthorizationToken, - ServerLabel = mcpTool.ServerName, - ServerUrl = mcpTool.ServerAddress, - AllowedToolNames = { mcpTool.AllowedTools ?? Array.Empty() } - } - }); - } - } + foreach (var tool in options.Tools.Select(x => x.AsProtocolTool(options))) + if (tool is not null) request.Tools.Add(tool); } if (options?.ResponseFormat is ChatResponseFormatJson) diff --git a/src/xAI/GrokProtocolExtensions.cs b/src/xAI/GrokProtocolExtensions.cs new file mode 100644 index 0000000..43d52d6 --- /dev/null +++ b/src/xAI/GrokProtocolExtensions.cs @@ -0,0 +1,125 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using xAI.Protocol; + +namespace xAI; + +/// Provides extension methods for working with xAI protocol types. +[EditorBrowsable(EditorBrowsableState.Never)] +public static class GrokProtocolExtensions +{ + /// Creates an xAI protocol from an . + /// The tool to convert. + /// An xAI protocol representing or if there is no mapping. + /// is . + public static Tool? AsProtocolTool(this AITool tool, ChatOptions? options = null) => ToProtocolTool(Throw.IfNull(tool), options); + + static Tool? ToProtocolTool(AITool tool, ChatOptions? options = null) + { + switch (tool) + { + case AIFunction functionTool: + return new Tool + { + Function = new Function + { + Name = functionTool.Name, + Description = functionTool.Description, + Parameters = JsonSerializer.Serialize(functionTool.JsonSchema) + } + }; + + case HostedWebSearchTool webSearchTool: + if (webSearchTool is GrokXSearchTool xSearchTool) + { + var xsearch = new XSearch + { + EnableImageUnderstanding = xSearchTool.EnableImageUnderstanding, + EnableVideoUnderstanding = xSearchTool.EnableVideoUnderstanding, + }; + + if (xSearchTool.AllowedHandles is { Count: > 0 } && + xSearchTool.ExcludedHandles is { Count: > 0 }) + throw new NotSupportedException($"Cannot use {nameof(GrokXSearchTool.AllowedHandles)} and {nameof(GrokXSearchTool.ExcludedHandles)} together in the same request."); + + if (xSearchTool.AllowedHandles is { } allowed) + xsearch.AllowedXHandles.AddRange(allowed); + if (xSearchTool.ExcludedHandles is { } excluded) + xsearch.ExcludedXHandles.AddRange(excluded); + if (xSearchTool.FromDate is { } from) + xsearch.FromDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(from.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)); + if (xSearchTool.ToDate is { } to) + xsearch.ToDate = Google.Protobuf.WellKnownTypes.Timestamp.FromDateTimeOffset(to.ToDateTime(TimeOnly.MinValue, DateTimeKind.Utc)); + + return new Tool { XSearch = xsearch }; + } + else if (webSearchTool is GrokSearchTool grokSearch) + { + var websearch = new WebSearch + { + EnableImageUnderstanding = grokSearch.EnableImageUnderstanding, + }; + + if (grokSearch.AllowedDomains is { Count: > 0 } && + grokSearch.ExcludedDomains is { Count: > 0 }) + throw new NotSupportedException($"Cannot use {nameof(GrokSearchTool.AllowedDomains)} and {nameof(GrokSearchTool.ExcludedDomains)} together in the same request."); + + if (grokSearch.AllowedDomains is { } allowed) + websearch.AllowedDomains.AddRange(allowed); + if (grokSearch.ExcludedDomains is { } excluded) + websearch.ExcludedDomains.AddRange(excluded); + + return new Tool { WebSearch = websearch }; + } + else + { + return new Tool { WebSearch = new WebSearch() }; + } + + case HostedCodeInterpreterTool: + return new Tool { CodeExecution = new CodeExecution { } }; + + case HostedFileSearchTool fileSearch: + var collectionTool = new CollectionsSearch(); + + if (fileSearch.Inputs?.OfType() is { } vectorStores) + collectionTool.CollectionIds.AddRange(vectorStores.Select(x => x.VectorStoreId).Distinct()); + + if (fileSearch.MaximumResultCount is { } maxResults) + collectionTool.Limit = maxResults; + if (fileSearch.GetProperty(nameof(CollectionsSearch.Instructions)) is { } instructions) + collectionTool.Instructions = instructions; + + return new Tool { CollectionsSearch = collectionTool }; + + case HostedMcpServerTool mcpTool: + var mcp = new MCP + { + Authorization = mcpTool.AuthorizationToken, + ServerLabel = mcpTool.ServerName, + ServerUrl = mcpTool.ServerAddress, + AllowedToolNames = { mcpTool.AllowedTools ?? Array.Empty() }, + }; + + // We can set an entire dictionary with a specific key + if (mcpTool.GetProperty>(nameof(MCP.ExtraHeaders)) is { } headers) + mcp.ExtraHeaders.Add(headers); + + // Or also the more intuitive mapping of additional properties directly. + foreach (var kv in mcpTool.AdditionalProperties) + if (kv.Value is string value) + mcp.ExtraHeaders.Add(kv.Key, value); + + return new Tool { Mcp = mcp }; + + default: + return null; + } + } +}