Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions dotnet/src/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable
private bool _disposed;
private readonly int? _optionsPort;
private readonly string? _optionsHost;
private readonly string? _effectiveConnectionToken;
private int? _actualPort;
private int? _negotiatedProtocolVersion;
private List<ModelInfo>? _modelsCache;
Expand Down Expand Up @@ -140,6 +141,22 @@ public CopilotClient(CopilotClientOptions? options = null)
throw new ArgumentException("GitHubToken and UseLoggedInUser cannot be used with CliUrl (external server manages its own auth)");
}

if (_options.TcpConnectionToken is not null)
{
if (_options.TcpConnectionToken.Length == 0)
{
throw new ArgumentException("TcpConnectionToken must be a non-empty string");
}
if (_options.UseStdio && string.IsNullOrEmpty(_options.CliUrl))
{
throw new ArgumentException("TcpConnectionToken cannot be used with UseStdio = true");
}
}

var sdkSpawnsCli = !_options.UseStdio && string.IsNullOrEmpty(_options.CliUrl);
_effectiveConnectionToken = _options.TcpConnectionToken
?? (sdkSpawnsCli ? Guid.NewGuid().ToString() : null);

_logger = _options.Logger ?? NullLogger.Instance;
_onListModels = _options.OnListModels;

Expand Down Expand Up @@ -216,7 +233,7 @@ async Task<Connection> StartCoreAsync(CancellationToken ct)
else
{
// Child process (stdio or TCP)
var (cliProcess, portOrNull, stderrBuffer) = await StartCliServerAsync(_options, _logger, ct);
var (cliProcess, portOrNull, stderrBuffer) = await StartCliServerAsync(_options, _effectiveConnectionToken, _logger, ct);
_actualPort = portOrNull;
result = ConnectToServerAsync(cliProcess, portOrNull is null ? null : "localhost", portOrNull, stderrBuffer, ct);
}
Expand Down Expand Up @@ -1122,30 +1139,42 @@ private void ConfigureSessionFsHandlers(CopilotSession session, Func<CopilotSess
private async Task VerifyProtocolVersionAsync(Connection connection, CancellationToken cancellationToken)
{
var maxVersion = SdkProtocolVersion.GetVersion();
var pingResponse = await InvokeRpcAsync<PingResponse>(
connection.Rpc, "ping", [new PingRequest()], connection.StderrBuffer, cancellationToken);
int? serverVersion;
try
{
var connectResponse = await InvokeRpcAsync<ConnectResult>(
connection.Rpc, "connect", [new ConnectRequest { Token = _effectiveConnectionToken }], connection.StderrBuffer, cancellationToken);
serverVersion = (int)connectResponse.ProtocolVersion;
}
catch (RemoteRpcException ex) when (ex.ErrorCode == RemoteRpcException.MethodNotFoundErrorCode)
{
// Legacy server without `connect`; fall back to `ping`. A token, if any,
// is silently dropped — the legacy server can't enforce one.
var pingResponse = await InvokeRpcAsync<PingResponse>(
connection.Rpc, "ping", [new PingRequest()], connection.StderrBuffer, cancellationToken);
serverVersion = pingResponse.ProtocolVersion;
}

if (!pingResponse.ProtocolVersion.HasValue)
if (!serverVersion.HasValue)
{
throw new InvalidOperationException(
$"SDK protocol version mismatch: SDK supports versions {MinProtocolVersion}-{maxVersion}, " +
$"but server does not report a protocol version. " +
$"Please update your server to ensure compatibility.");
}

var serverVersion = pingResponse.ProtocolVersion.Value;
if (serverVersion < MinProtocolVersion || serverVersion > maxVersion)
if (serverVersion.Value < MinProtocolVersion || serverVersion.Value > maxVersion)
{
throw new InvalidOperationException(
$"SDK protocol version mismatch: SDK supports versions {MinProtocolVersion}-{maxVersion}, " +
$"but server reports version {serverVersion}. " +
$"but server reports version {serverVersion.Value}. " +
$"Please update your SDK or server to ensure compatibility.");
}

_negotiatedProtocolVersion = serverVersion;
_negotiatedProtocolVersion = serverVersion.Value;
}

private static async Task<(Process Process, int? DetectedLocalhostTcpPort, StringBuilder StderrBuffer)> StartCliServerAsync(CopilotClientOptions options, ILogger logger, CancellationToken cancellationToken)
private static async Task<(Process Process, int? DetectedLocalhostTcpPort, StringBuilder StderrBuffer)> StartCliServerAsync(CopilotClientOptions options, string? connectionToken, ILogger logger, CancellationToken cancellationToken)
{
// Use explicit path, COPILOT_CLI_PATH env var (from options.Environment or process env), or bundled CLI - no PATH fallback
var envCliPath = options.Environment is not null && options.Environment.TryGetValue("COPILOT_CLI_PATH", out var envValue) ? envValue
Expand Down Expand Up @@ -1221,6 +1250,11 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio
startInfo.Environment["COPILOT_SDK_AUTH_TOKEN"] = options.GitHubToken;
}

if (!string.IsNullOrEmpty(connectionToken))
{
startInfo.Environment["COPILOT_CONNECTION_TOKEN"] = connectionToken;
}

// Set telemetry environment variables if configured
if (options.Telemetry is { } telemetry)
{
Expand Down
49 changes: 49 additions & 0 deletions dotnet/src/Generated/Rpc.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions dotnet/src/JsonRpc.cs
Original file line number Diff line number Diff line change
Expand Up @@ -831,5 +831,8 @@ internal sealed class ConnectionLostException() : IOException("The JSON-RPC conn
/// </summary>
internal sealed class RemoteRpcException(string message, int errorCode, Exception? innerException = null) : Exception(message, innerException)
{
/// <summary>JSON-RPC 2.0 reserved error code: requested method does not exist.</summary>
public const int MethodNotFoundErrorCode = -32601;

public int ErrorCode { get; } = errorCode;
}
8 changes: 8 additions & 0 deletions dotnet/src/Types.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ protected CopilotClientOptions(CopilotClientOptions? other)
OnListModels = other.OnListModels;
SessionFs = other.SessionFs;
SessionIdleTimeoutSeconds = other.SessionIdleTimeoutSeconds;
TcpConnectionToken = other.TcpConnectionToken;
}

/// <summary>
Expand Down Expand Up @@ -175,6 +176,13 @@ public string? GithubToken
/// </summary>
public int? SessionIdleTimeoutSeconds { get; set; }

/// <summary>
/// Connection token for the headless CLI server (TCP only). When the SDK spawns its own
/// CLI in TCP mode and this is omitted, a GUID is generated automatically so the loopback
/// listener is safe by default. Cannot be combined with <see cref="UseStdio"/> = true.
/// </summary>
public string? TcpConnectionToken { get; set; }

/// <summary>
/// Creates a shallow clone of this <see cref="CopilotClientOptions"/> instance.
/// </summary>
Expand Down
144 changes: 144 additions & 0 deletions dotnet/test/ConnectionTokenTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
*--------------------------------------------------------------------------------------------*/

using GitHub.Copilot.SDK.Test.Harness;
using Xunit;

namespace GitHub.Copilot.SDK.Test;

/// <summary>
/// Custom fixture that spawns a CLI in TCP mode with an explicit connection token, so
/// sibling clients can attempt to connect to the same port with the right/wrong/no token.
/// </summary>
public class ConnectionTokenTestFixture : IAsyncLifetime
{
public E2ETestContext Ctx { get; private set; } = null!;
public CopilotClient GoodClient { get; private set; } = null!;
public int Port { get; private set; }

public const string Token = "right-token";

public async Task InitializeAsync()
{
Ctx = await E2ETestContext.CreateAsync();
GoodClient = Ctx.CreateClient(useStdio: false, options: new CopilotClientOptions
{
TcpConnectionToken = Token,
});

await GoodClient.StartAsync();
Port = GoodClient.ActualPort
?? throw new InvalidOperationException("GoodClient is not using TCP mode; ActualPort is null");
}

public async Task DisposeAsync()
{
if (GoodClient is not null)
{
await GoodClient.ForceStopAsync();
}

await Ctx.DisposeAsync();
}
}

public class ConnectionTokenTests : IClassFixture<ConnectionTokenTestFixture>
{
private readonly ConnectionTokenTestFixture _fixture;

public ConnectionTokenTests(ConnectionTokenTestFixture fixture)
{
_fixture = fixture;
}

[Fact]
public async Task Connects_With_The_Matching_Token()
{
var pong = await _fixture.GoodClient.PingAsync("hi");
Assert.Equal("pong: hi", pong.Message);
}

[Fact]
public async Task Rejects_A_Wrong_Token()
{
var wrongClient = new CopilotClient(new CopilotClientOptions
{
CliUrl = $"localhost:{_fixture.Port}",
TcpConnectionToken = "wrong",
});

try
{
var ex = await Assert.ThrowsAnyAsync<Exception>(() => wrongClient.StartAsync());
Assert.Contains("AUTHENTICATION_FAILED", GetFullMessage(ex));
}
finally
{
try { await wrongClient.ForceStopAsync(); } catch { }
}
}

[Fact]
public async Task Rejects_A_Missing_Token_When_One_Is_Required()
{
var noTokenClient = new CopilotClient(new CopilotClientOptions
{
CliUrl = $"localhost:{_fixture.Port}",
});

try
{
var ex = await Assert.ThrowsAnyAsync<Exception>(() => noTokenClient.StartAsync());
Assert.Contains("AUTHENTICATION_FAILED", GetFullMessage(ex));
}
finally
{
try { await noTokenClient.ForceStopAsync(); } catch { }
}
}

private static string GetFullMessage(Exception ex)
{
var messages = new List<string>();
for (var cur = ex; cur is not null; cur = cur.InnerException)
{
messages.Add(cur.Message);
}
return string.Join(" | ", messages);
}
}

/// <summary>
/// When the SDK spawns its own CLI in TCP mode without an explicit token, it auto-generates
/// a GUID and round-trips it through the spawned CLI.
/// </summary>
public class ConnectionTokenAutoGeneratedTests : IAsyncLifetime
{
private E2ETestContext _ctx = null!;
private CopilotClient _client = null!;

public async Task InitializeAsync()
{
_ctx = await E2ETestContext.CreateAsync();
_client = _ctx.CreateClient(useStdio: false);
}

public async Task DisposeAsync()
{
if (_client is not null)
{
await _client.ForceStopAsync();
}

await _ctx.DisposeAsync();
}

[Fact]
public async Task The_SDK_Auto_Generated_Guid_Round_Trips_Through_The_Spawned_CLI()
{
await _client.StartAsync();
var pong = await _client.PingAsync("hi");
Assert.Equal("pong: hi", pong.Message);
}
}
Loading
Loading