Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 113 additions & 26 deletions dotnet/src/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace GitHub.Copilot.SDK;
/// await session.SendAsync(new MessageOptions { Prompt = "Hello!" });
/// </code>
/// </example>
public class CopilotClient : IDisposable, IAsyncDisposable
public partial class CopilotClient : IDisposable, IAsyncDisposable
{
private readonly ConcurrentDictionary<string, CopilotSession> _sessions = new();
private readonly CopilotClientOptions _options;
Expand Down Expand Up @@ -461,7 +461,7 @@ public async Task<PingResponse> PingAsync(string? message = null, CancellationTo
var connection = await EnsureConnectedAsync(cancellationToken);

return await connection.Rpc.InvokeWithCancellationAsync<PingResponse>(
"ping", [new { message }], cancellationToken);
"ping", [new PingRequest { Message = message }], cancellationToken);
}

/// <summary>
Expand Down Expand Up @@ -554,7 +554,7 @@ public async Task DeleteSessionAsync(string sessionId, CancellationToken cancell
var connection = await EnsureConnectedAsync(cancellationToken);

var response = await connection.Rpc.InvokeWithCancellationAsync<DeleteSessionResponse>(
"session.delete", [new { sessionId }], cancellationToken);
"session.delete", [new DeleteSessionRequest(sessionId)], cancellationToken);

if (!response.Success)
{
Expand Down Expand Up @@ -604,7 +604,7 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio
{
var expectedVersion = SdkProtocolVersion.GetVersion();
var pingResponse = await connection.Rpc.InvokeWithCancellationAsync<PingResponse>(
"ping", [new { message = (string?)null }], cancellationToken);
"ping", [new PingRequest()], cancellationToken);

if (!pingResponse.ProtocolVersion.HasValue)
{
Expand Down Expand Up @@ -754,23 +754,45 @@ private async Task<Connection> ConnectToServerAsync(Process? cliProcess, string?
outputStream = networkStream;
}

var rpc = new JsonRpc(new HeaderDelimitedMessageHandler(outputStream, inputStream, CreateFormatter()));
rpc.AddLocalRpcTarget(new RpcHandler(this));
var rpc = new JsonRpc(new HeaderDelimitedMessageHandler(
outputStream,
inputStream,
CreateSystemTextJsonFormatter()))
{
TraceSource = new LoggerTraceSource(_logger),
};

var handler = new RpcHandler(this);
rpc.AddLocalRpcMethod("session.event", handler.OnSessionEvent);
rpc.AddLocalRpcMethod("tool.call", handler.OnToolCall);
rpc.AddLocalRpcMethod("permission.request", handler.OnPermissionRequest);
rpc.StartListening();
return new Connection(rpc, cliProcess, tcpClient, networkStream);
}

[UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Using the Json source generator.")]
[UnconditionalSuppressMessage("AOT", "IL3050", Justification = "Using the Json source generator.")]
static IJsonRpcMessageFormatter CreateFormatter()
[UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Using happy path from https://microsoft.github.io/vs-streamjsonrpc/docs/nativeAOT.html")]
[UnconditionalSuppressMessage("AOT", "IL3050", Justification = "Using happy path from https://microsoft.github.io/vs-streamjsonrpc/docs/nativeAOT.html")]
private static SystemTextJsonFormatter CreateSystemTextJsonFormatter() =>
new SystemTextJsonFormatter() { JsonSerializerOptions = SerializerOptionsForMessageFormatter };

private static JsonSerializerOptions SerializerOptionsForMessageFormatter { get; } = CreateSerializerOptions();

private static JsonSerializerOptions CreateSerializerOptions()
{
var options = new JsonSerializerOptions(JsonSerializerDefaults.Web)
{
AllowOutOfOrderMetadataProperties = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
};

return new SystemTextJsonFormatter() { JsonSerializerOptions = options };
options.TypeInfoResolverChain.Add(ClientJsonContext.Default);
options.TypeInfoResolverChain.Add(TypesJsonContext.Default);
options.TypeInfoResolverChain.Add(CopilotSession.SessionJsonContext.Default);
options.TypeInfoResolverChain.Add(SessionEventsJsonContext.Default);

options.MakeReadOnly();

return options;
}

internal CopilotSession? GetSession(string sessionId) =>
Expand Down Expand Up @@ -803,9 +825,7 @@ public async ValueTask DisposeAsync()

private class RpcHandler(CopilotClient client)
{
[JsonRpcMethod("session.event")]
public void OnSessionEvent(string sessionId,
JsonElement? @event)
public void OnSessionEvent(string sessionId, JsonElement? @event)
{
var session = client.GetSession(sessionId);
if (session != null && @event != null)
Expand All @@ -818,7 +838,6 @@ public void OnSessionEvent(string sessionId,
}
}

[JsonRpcMethod("tool.call")]
public async Task<ToolCallResponse> OnToolCall(string sessionId,
string toolCallId,
string toolName,
Expand Down Expand Up @@ -891,7 +910,7 @@ public async Task<ToolCallResponse> OnToolCall(string sessionId,
// something we don't control? an error?)
TextResultForLlm = result is JsonElement { ValueKind: JsonValueKind.String } je
? je.GetString()!
: JsonSerializer.Serialize(result, tool.JsonSerializerOptions),
: JsonSerializer.Serialize(result, tool.JsonSerializerOptions.GetTypeInfo(typeof(object))),
};
return new ToolCallResponse(toolResultObject);
}
Expand All @@ -908,7 +927,6 @@ public async Task<ToolCallResponse> OnToolCall(string sessionId,
}
}

[JsonRpcMethod("permission.request")]
public async Task<PermissionRequestResponse> OnPermissionRequest(string sessionId, JsonElement permissionRequest)
{
var session = client.GetSession(sessionId);
Expand Down Expand Up @@ -959,7 +977,7 @@ public static string Escape(string arg)
}

// Request/Response types for RPC
private record CreateSessionRequest(
internal record CreateSessionRequest(
string? Model,
string? SessionId,
List<ToolDefinition>? Tools,
Expand All @@ -975,7 +993,7 @@ private record CreateSessionRequest(
List<string>? SkillDirectories,
List<string>? DisabledSkills);

private record ToolDefinition(
internal record ToolDefinition(
string Name,
string? Description,
JsonElement Parameters /* JSON schema */)
Expand All @@ -984,10 +1002,10 @@ public static ToolDefinition FromAIFunction(AIFunction function)
=> new ToolDefinition(function.Name, function.Description, function.JsonSchema);
}

private record CreateSessionResponse(
internal record CreateSessionResponse(
string SessionId);

private record ResumeSessionRequest(
internal record ResumeSessionRequest(
string SessionId,
List<ToolDefinition>? Tools,
ProviderConfig? Provider,
Expand All @@ -998,24 +1016,93 @@ private record ResumeSessionRequest(
List<string>? SkillDirectories,
List<string>? DisabledSkills);

private record ResumeSessionResponse(
internal record ResumeSessionResponse(
string SessionId);

private record GetLastSessionIdResponse(
internal record GetLastSessionIdResponse(
string? SessionId);

private record DeleteSessionResponse(
internal record DeleteSessionRequest(
string SessionId);

internal record DeleteSessionResponse(
bool Success,
string? Error);

private record ListSessionsResponse(
internal record ListSessionsResponse(
List<SessionMetadata> Sessions);

private record ToolCallResponse(
internal record ToolCallResponse(
ToolResultObject? Result);

private record PermissionRequestResponse(
internal record PermissionRequestResponse(
PermissionRequestResult Result);

/// <summary>Trace source that forwards all logs to the ILogger.</summary>
internal sealed class LoggerTraceSource : TraceSource
{
public LoggerTraceSource(ILogger logger) : base(nameof(LoggerTraceSource), SourceLevels.All)
{
Listeners.Clear();
Listeners.Add(new LoggerTraceListener(logger));
}

private sealed class LoggerTraceListener(ILogger logger) : TraceListener
{
public override void TraceEvent(TraceEventCache? eventCache, string source, TraceEventType eventType, int id, string? message) =>
logger.Log(MapLevel(eventType), "[{Source}] {Message}", source, message);

public override void TraceEvent(TraceEventCache? eventCache, string source, TraceEventType eventType, int id, string? format, params object?[]? args) =>
logger.Log(MapLevel(eventType), "[{Source}] {Message}", source, args is null || args.Length == 0 ? format : string.Format(format ?? "", args));

public override void TraceData(TraceEventCache? eventCache, string source, TraceEventType eventType, int id, object? data) =>
logger.Log(MapLevel(eventType), "[{Source}] {Data}", source, data);

public override void TraceData(TraceEventCache? eventCache, string source, TraceEventType eventType, int id, params object?[]? data) =>
logger.Log(MapLevel(eventType), "[{Source}] {Data}", source, data is null ? null : string.Join(", ", data));

public override void Write(string? message) =>
logger.LogTrace("{Message}", message);

public override void WriteLine(string? message) =>
logger.LogTrace("{Message}", message);

private static LogLevel MapLevel(TraceEventType eventType) => eventType switch
{
TraceEventType.Critical => LogLevel.Critical,
TraceEventType.Error => LogLevel.Error,
TraceEventType.Warning => LogLevel.Warning,
TraceEventType.Information => LogLevel.Information,
TraceEventType.Verbose => LogLevel.Debug,
_ => LogLevel.Trace
};
}
}

[JsonSourceGenerationOptions(
JsonSerializerDefaults.Web,
AllowOutOfOrderMetadataProperties = true,
NumberHandling = JsonNumberHandling.AllowReadingFromString,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)]
[JsonSerializable(typeof(CreateSessionRequest))]
[JsonSerializable(typeof(CreateSessionResponse))]
[JsonSerializable(typeof(CustomAgentConfig))]
[JsonSerializable(typeof(DeleteSessionRequest))]
[JsonSerializable(typeof(DeleteSessionResponse))]
[JsonSerializable(typeof(GetLastSessionIdResponse))]
[JsonSerializable(typeof(ListSessionsResponse))]
[JsonSerializable(typeof(PermissionRequestResponse))]
[JsonSerializable(typeof(PermissionRequestResult))]
[JsonSerializable(typeof(ProviderConfig))]
[JsonSerializable(typeof(ResumeSessionRequest))]
[JsonSerializable(typeof(ResumeSessionResponse))]
[JsonSerializable(typeof(SessionMetadata))]
[JsonSerializable(typeof(SystemMessageConfig))]
[JsonSerializable(typeof(ToolCallResponse))]
[JsonSerializable(typeof(ToolDefinition))]
[JsonSerializable(typeof(ToolResultAIContent))]
[JsonSerializable(typeof(ToolResultObject))]
internal partial class ClientJsonContext : JsonSerializerContext;
}

// Must inherit from AIContent as a signal to MEAI to avoid JSON-serializing the
Expand Down
Loading
Loading