From 15b59e67620d73fef0c4c9ae9ffdfb10b5bcae6f Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 16 Apr 2026 15:15:59 -0400 Subject: [PATCH] Add runtime header options across SDKs Expose provider headers and per-message requestHeaders across Node, Python, Go, and .NET, and add focused tests covering create, resume, and send request forwarding. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dotnet/src/Session.cs | 4 +- dotnet/src/Types.cs | 13 ++++ dotnet/test/SerializationTests.cs | 98 +++++++++++++++++++++++++++++++ go/session.go | 13 ++-- go/types.go | 17 ++++-- go/types_test.go | 57 ++++++++++++++++++ nodejs/src/session.ts | 1 + nodejs/src/types.ts | 10 ++++ nodejs/test/client.test.ts | 88 +++++++++++++++++++++++++++ python/copilot/client.py | 2 + python/copilot/session.py | 14 ++++- python/test_client.py | 97 ++++++++++++++++++++++++++++++ 12 files changed, 400 insertions(+), 14 deletions(-) diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 733b94a71..20d6525b8 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -192,7 +192,8 @@ public async Task SendAsync(MessageOptions options, CancellationToken ca Attachments = options.Attachments, Mode = options.Mode, Traceparent = traceparent, - Tracestate = tracestate + Tracestate = tracestate, + RequestHeaders = options.RequestHeaders, }; var response = await InvokeRpcAsync( @@ -1223,6 +1224,7 @@ internal record SendMessageRequest public string? Mode { get; init; } public string? Traceparent { get; init; } public string? Tracestate { get; init; } + public IDictionary? RequestHeaders { get; init; } } internal record SendMessageResponse diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 978defcfb..1fd8afa39 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -1483,6 +1483,12 @@ public class ProviderConfig /// [JsonPropertyName("azure")] public AzureOptions? Azure { get; set; } + + /// + /// Custom HTTP headers to include in outbound provider requests. + /// + [JsonPropertyName("headers")] + public IDictionary? Headers { get; set; } } /// @@ -2157,6 +2163,9 @@ protected MessageOptions(MessageOptions? other) Attachments = other.Attachments is not null ? [.. other.Attachments] : null; Mode = other.Mode; Prompt = other.Prompt; + RequestHeaders = other.RequestHeaders is not null + ? new Dictionary(other.RequestHeaders) + : null; } /// @@ -2171,6 +2180,10 @@ protected MessageOptions(MessageOptions? other) /// Interaction mode for the message (e.g., "plan", "edit"). /// public string? Mode { get; set; } + /// + /// Custom per-turn HTTP headers for outbound model requests. + /// + public IDictionary? RequestHeaders { get; set; } /// /// Creates a shallow clone of this instance. diff --git a/dotnet/test/SerializationTests.cs b/dotnet/test/SerializationTests.cs index 6fb266be1..4a976d2bc 100644 --- a/dotnet/test/SerializationTests.cs +++ b/dotnet/test/SerializationTests.cs @@ -67,6 +67,74 @@ public void SerializerOptions_CanResolveRequestIdTypeInfo() Assert.Equal(typeof(RequestId), typeInfo.Type); } + [Fact] + public void ProviderConfig_CanSerializeHeaders_WithSdkOptions() + { + var options = GetSerializerOptions(); + var original = new ProviderConfig + { + BaseUrl = "https://example.com/provider", + Headers = new Dictionary { ["Authorization"] = "Bearer provider-token" } + }; + + var json = JsonSerializer.Serialize(original, options); + using var document = JsonDocument.Parse(json); + var root = document.RootElement; + Assert.Equal("https://example.com/provider", root.GetProperty("baseUrl").GetString()); + Assert.Equal("Bearer provider-token", root.GetProperty("headers").GetProperty("Authorization").GetString()); + + var deserialized = JsonSerializer.Deserialize(json, options); + Assert.NotNull(deserialized); + Assert.Equal("https://example.com/provider", deserialized.BaseUrl); + Assert.Equal("Bearer provider-token", deserialized.Headers!["Authorization"]); + } + + [Fact] + public void MessageOptions_CanSerializeRequestHeaders_WithSdkOptions() + { + var options = GetSerializerOptions(); + var original = new MessageOptions + { + Prompt = "real prompt", + Mode = "plan", + RequestHeaders = new Dictionary { ["X-Trace"] = "trace-value" } + }; + + var json = JsonSerializer.Serialize(original, options); + using var document = JsonDocument.Parse(json); + var root = document.RootElement; + Assert.Equal("real prompt", root.GetProperty("prompt").GetString()); + Assert.Equal("plan", root.GetProperty("mode").GetString()); + Assert.Equal("trace-value", root.GetProperty("requestHeaders").GetProperty("X-Trace").GetString()); + + var deserialized = JsonSerializer.Deserialize(json, options); + Assert.NotNull(deserialized); + Assert.Equal("real prompt", deserialized.Prompt); + Assert.Equal("plan", deserialized.Mode); + Assert.Equal("trace-value", deserialized.RequestHeaders!["X-Trace"]); + } + + [Fact] + public void SendMessageRequest_CanSerializeRequestHeaders_WithSdkOptions() + { + var options = GetSerializerOptions(); + var requestType = GetNestedType(typeof(CopilotSession), "SendMessageRequest"); + var request = CreateInternalRequest( + requestType, + ("SessionId", "session-id"), + ("Prompt", "real prompt"), + ("Mode", "plan"), + ("RequestHeaders", new Dictionary { ["X-Trace"] = "trace-value" })); + + var json = JsonSerializer.Serialize(request, requestType, options); + using var document = JsonDocument.Parse(json); + var root = document.RootElement; + Assert.Equal("session-id", root.GetProperty("sessionId").GetString()); + Assert.Equal("real prompt", root.GetProperty("prompt").GetString()); + Assert.Equal("plan", root.GetProperty("mode").GetString()); + Assert.Equal("trace-value", root.GetProperty("requestHeaders").GetProperty("X-Trace").GetString()); + } + private static JsonSerializerOptions GetSerializerOptions() { var prop = typeof(CopilotClient) @@ -77,4 +145,34 @@ private static JsonSerializerOptions GetSerializerOptions() Assert.NotNull(options); return options; } + + private static Type GetNestedType(Type containingType, string name) + { + var type = containingType.GetNestedType(name, System.Reflection.BindingFlags.NonPublic); + Assert.NotNull(type); + return type!; + } + + private static object CreateInternalRequest(Type type, params (string Name, object? Value)[] properties) + { + var instance = System.Runtime.CompilerServices.RuntimeHelpers.GetUninitializedObject(type); + + foreach (var (name, value) in properties) + { + var property = type.GetProperty(name, System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.NonPublic); + Assert.NotNull(property); + + if (property!.SetMethod is not null) + { + property.SetValue(instance, value); + continue; + } + + var field = type.GetField($"<{name}>k__BackingField", System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic); + Assert.NotNull(field); + field!.SetValue(instance, value); + } + + return instance; + } } diff --git a/go/session.go b/go/session.go index a2e52e72c..be8c78e2b 100644 --- a/go/session.go +++ b/go/session.go @@ -132,12 +132,13 @@ func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) func (s *Session) Send(ctx context.Context, options MessageOptions) (string, error) { traceparent, tracestate := getTraceContext(ctx) req := sessionSendRequest{ - SessionID: s.SessionID, - Prompt: options.Prompt, - Attachments: options.Attachments, - Mode: options.Mode, - Traceparent: traceparent, - Tracestate: tracestate, + SessionID: s.SessionID, + Prompt: options.Prompt, + Attachments: options.Attachments, + Mode: options.Mode, + Traceparent: traceparent, + Tracestate: tracestate, + RequestHeaders: options.RequestHeaders, } result, err := s.client.Request("session.send", req) diff --git a/go/types.go b/go/types.go index d609ce00a..f889d3e2a 100644 --- a/go/types.go +++ b/go/types.go @@ -783,6 +783,8 @@ type ProviderConfig struct { BearerToken string `json:"bearerToken,omitempty"` // Azure contains Azure-specific options Azure *AzureProviderOptions `json:"azure,omitempty"` + // Headers are custom HTTP headers included in outbound provider requests. + Headers map[string]string `json:"headers,omitempty"` } // AzureProviderOptions contains Azure-specific provider configuration @@ -807,6 +809,8 @@ type MessageOptions struct { Attachments []Attachment // Mode is the message delivery mode (default: "enqueue") Mode string + // RequestHeaders are custom per-turn HTTP headers for outbound model requests. + RequestHeaders map[string]string } // SessionEventHandler is a callback for session events @@ -1142,12 +1146,13 @@ type sessionAbortRequest struct { } type sessionSendRequest struct { - SessionID string `json:"sessionId"` - Prompt string `json:"prompt"` - Attachments []Attachment `json:"attachments,omitempty"` - Mode string `json:"mode,omitempty"` - Traceparent string `json:"traceparent,omitempty"` - Tracestate string `json:"tracestate,omitempty"` + SessionID string `json:"sessionId"` + Prompt string `json:"prompt"` + Attachments []Attachment `json:"attachments,omitempty"` + Mode string `json:"mode,omitempty"` + Traceparent string `json:"traceparent,omitempty"` + Tracestate string `json:"tracestate,omitempty"` + RequestHeaders map[string]string `json:"requestHeaders,omitempty"` } // sessionSendResponse is the response from session.send diff --git a/go/types_test.go b/go/types_test.go index 80b0cc545..b37e94f15 100644 --- a/go/types_test.go +++ b/go/types_test.go @@ -91,3 +91,60 @@ func TestPermissionRequestResult_JSONSerialize(t *testing.T) { t.Errorf("expected %s, got %s", expected, string(data)) } } + +func TestProviderConfig_JSONIncludesHeaders(t *testing.T) { + config := ProviderConfig{ + BaseURL: "https://example.com/provider", + Headers: map[string]string{"Authorization": "Bearer provider-token"}, + } + + data, err := json.Marshal(config) + if err != nil { + t.Fatalf("failed to marshal provider config: %v", err) + } + + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal provider config: %v", err) + } + + if decoded["baseUrl"] != "https://example.com/provider" { + t.Fatalf("expected baseUrl to round-trip, got %v", decoded["baseUrl"]) + } + headers, ok := decoded["headers"].(map[string]any) + if !ok { + t.Fatalf("expected headers object, got %T", decoded["headers"]) + } + if headers["Authorization"] != "Bearer provider-token" { + t.Fatalf("expected Authorization header, got %v", headers["Authorization"]) + } +} + +func TestSessionSendRequest_JSONIncludesRequestHeaders(t *testing.T) { + req := sessionSendRequest{ + SessionID: "session-1", + Prompt: "hello", + RequestHeaders: map[string]string{"Authorization": "Bearer turn-token"}, + } + + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("failed to marshal session send request: %v", err) + } + + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal session send request: %v", err) + } + + if decoded["prompt"] != "hello" { + t.Fatalf("expected prompt to round-trip, got %v", decoded["prompt"]) + } + headers, ok := decoded["requestHeaders"].(map[string]any) + if !ok { + t.Fatalf("expected requestHeaders object, got %T", decoded["requestHeaders"]) + } + if headers["Authorization"] != "Bearer turn-token" { + t.Fatalf("expected Authorization header, got %v", headers["Authorization"]) + } +} diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index ffb2c045a..eae4cab94 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -184,6 +184,7 @@ export class CopilotSession { prompt: options.prompt, attachments: options.attachments, mode: options.mode, + requestHeaders: options.requestHeaders, }); return (response as { messageId: string }).messageId; diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index a4cb77fa2..0c901f989 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -1403,6 +1403,11 @@ export interface ProviderConfig { */ apiVersion?: string; }; + + /** + * Custom HTTP headers to include in outbound provider requests. + */ + headers?: Record; } /** @@ -1452,6 +1457,11 @@ export interface MessageOptions { * - "immediate": Send immediately */ mode?: "enqueue" | "immediate"; + + /** + * Custom HTTP headers to include in outbound model requests for this turn. + */ + requestHeaders?: Record; } /** diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index 0c0611df8..870ccb1ed 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -98,6 +98,67 @@ describe("CopilotClient", () => { spy.mockRestore(); }); + it("forwards provider headers in session.create request", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.create") return { sessionId: params.sessionId }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.createSession({ + onPermissionRequest: approveAll, + provider: { + baseUrl: "https://example.com/provider", + headers: { Authorization: "Bearer provider-token" }, + }, + }); + + const payload = spy.mock.calls.find(([method]) => method === "session.create")![1] as any; + expect(payload.provider).toEqual( + expect.objectContaining({ + baseUrl: "https://example.com/provider", + headers: { Authorization: "Bearer provider-token" }, + }) + ); + spy.mockRestore(); + }); + + it("forwards provider headers in session.resume request", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const session = await client.createSession({ onPermissionRequest: approveAll }); + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.resume") return { sessionId: params.sessionId }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.resumeSession(session.sessionId, { + onPermissionRequest: approveAll, + provider: { + baseUrl: "https://example.com/provider", + headers: { Authorization: "Bearer resume-token" }, + }, + }); + + const payload = spy.mock.calls.find(([method]) => method === "session.resume")![1] as any; + expect(payload.provider).toEqual( + expect.objectContaining({ + baseUrl: "https://example.com/provider", + headers: { Authorization: "Bearer resume-token" }, + }) + ); + spy.mockRestore(); + }); + it("does not request permissions on session.resume when using the default joinSession handler", async () => { const client = new CopilotClient(); await client.start(); @@ -720,6 +781,33 @@ describe("CopilotClient", () => { ); }); + it("forwards requestHeaders in session.send request", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const session = await client.createSession({ onPermissionRequest: approveAll }); + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string) => { + if (method === "session.send") return { messageId: "m1" }; + throw new Error(`Unexpected method: ${method}`); + }); + + await session.send({ + prompt: "hello", + requestHeaders: { Authorization: "Bearer turn-token" }, + }); + + expect(spy).toHaveBeenCalledWith( + "session.send", + expect.objectContaining({ + prompt: "hello", + requestHeaders: { Authorization: "Bearer turn-token" }, + }) + ); + }); + it("does not include trace context when no callback is provided", async () => { const client = new CopilotClient(); await client.start(); diff --git a/python/copilot/client.py b/python/copilot/client.py index 407ad1673..5d62db301 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -2124,6 +2124,8 @@ def _convert_provider_to_wire_format( wire_provider["wireApi"] = provider["wire_api"] if "bearer_token" in provider: wire_provider["bearerToken"] = provider["bearer_token"] + if "headers" in provider: + wire_provider["headers"] = provider["headers"] if "azure" in provider: azure = provider["azure"] wire_azure: dict[str, Any] = {} diff --git a/python/copilot/session.py b/python/copilot/session.py index 9fd9f79bd..9552f75b6 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -825,6 +825,7 @@ class ProviderConfig(TypedDict, total=False): # Takes precedence over api_key when both are set. bearer_token: str azure: AzureProviderOptions # Azure-specific options + headers: dict[str, str] class SessionConfig(TypedDict, total=False): @@ -1066,6 +1067,7 @@ async def send( *, attachments: list[Attachment] | None = None, mode: Literal["enqueue", "immediate"] | None = None, + request_headers: dict[str, str] | None = None, ) -> str: """ Send a message to this session. @@ -1078,6 +1080,7 @@ async def send( prompt: The message text to send. attachments: Optional file, directory, or selection attachments. mode: Message delivery mode (``"enqueue"`` or ``"immediate"``). + request_headers: Optional per-turn HTTP headers for outbound model requests. Returns: The message ID assigned by the server, which can be used to correlate events. @@ -1099,6 +1102,8 @@ async def send( params["attachments"] = attachments if mode is not None: params["mode"] = mode + if request_headers is not None: + params["requestHeaders"] = request_headers params.update(get_trace_context()) response = await self._client.request("session.send", params) @@ -1110,6 +1115,7 @@ async def send_and_wait( *, attachments: list[Attachment] | None = None, mode: Literal["enqueue", "immediate"] | None = None, + request_headers: dict[str, str] | None = None, timeout: float = 60.0, ) -> SessionEvent | None: """ @@ -1125,6 +1131,7 @@ async def send_and_wait( prompt: The message text to send. attachments: Optional file, directory, or selection attachments. mode: Message delivery mode (``"enqueue"`` or ``"immediate"``). + request_headers: Optional per-turn HTTP headers for outbound model requests. timeout: Timeout in seconds (default: 60). Controls how long to wait; does not abort in-flight agent work. @@ -1160,7 +1167,12 @@ def handler(event: SessionEventTypeAlias) -> None: unsubscribe = self.on(handler) try: - await self.send(prompt, attachments=attachments, mode=mode) + await self.send( + prompt, + attachments=attachments, + mode=mode, + request_headers=request_headers, + ) await asyncio.wait_for(idle_event.wait(), timeout=timeout) if error_event: raise error_event diff --git a/python/test_client.py b/python/test_client.py index 5d0dc868e..0896b54e2 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -444,6 +444,103 @@ async def mock_request(method, params): finally: await client.force_stop() + @pytest.mark.asyncio + async def test_create_session_forwards_provider_headers(self): + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + captured = {} + original_request = client._client.request + + async def mock_request(method, params): + captured[method] = params + if method == "session.create": + return {"sessionId": params["sessionId"]} + return await original_request(method, params) + + client._client.request = mock_request + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + provider={ + "base_url": "https://example.com/provider", + "headers": {"Authorization": "Bearer provider-token"}, + }, + ) + + provider = captured["session.create"]["provider"] + assert provider["baseUrl"] == "https://example.com/provider" + assert provider["headers"] == {"Authorization": "Bearer provider-token"} + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_resume_session_forwards_provider_headers(self): + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + + captured = {} + original_request = client._client.request + + async def mock_request(method, params): + captured[method] = params + if method == "session.resume": + return {"sessionId": session.session_id} + return await original_request(method, params) + + client._client.request = mock_request + await client.resume_session( + session.session_id, + on_permission_request=PermissionHandler.approve_all, + provider={ + "base_url": "https://example.com/provider", + "headers": {"Authorization": "Bearer resume-token"}, + }, + ) + + provider = captured["session.resume"]["provider"] + assert provider["baseUrl"] == "https://example.com/provider" + assert provider["headers"] == {"Authorization": "Bearer resume-token"} + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_session_send_forwards_request_headers(self): + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + + captured = {} + original_request = client._client.request + + async def mock_request(method, params): + captured[method] = params + if method == "session.send": + return {"messageId": "msg-1"} + return await original_request(method, params) + + client._client.request = mock_request + await session.send( + "hello", + request_headers={"Authorization": "Bearer turn-token"}, + ) + + assert captured["session.send"]["prompt"] == "hello" + assert captured["session.send"]["requestHeaders"] == { + "Authorization": "Bearer turn-token" + } + finally: + await client.force_stop() + @pytest.mark.asyncio async def test_create_session_forwards_agent(self): client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH))