Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dotnet/src/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ public async Task<string> SendAsync(MessageOptions options, CancellationToken ca
Attachments = options.Attachments,
Mode = options.Mode,
Traceparent = traceparent,
Tracestate = tracestate
Tracestate = tracestate,
RequestHeaders = options.RequestHeaders,
};

var response = await InvokeRpcAsync<SendMessageResponse>(
Expand Down Expand Up @@ -1223,6 +1224,7 @@ internal record SendMessageRequest
public string? Mode { get; init; }
public string? Traceparent { get; init; }
public string? Tracestate { get; init; }
public IDictionary<string, string>? RequestHeaders { get; init; }
}

internal record SendMessageResponse
Expand Down
13 changes: 13 additions & 0 deletions dotnet/src/Types.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,12 @@ public class ProviderConfig
/// </summary>
[JsonPropertyName("azure")]
public AzureOptions? Azure { get; set; }

/// <summary>
/// Custom HTTP headers to include in outbound provider requests.
/// </summary>
[JsonPropertyName("headers")]
public IDictionary<string, string>? Headers { get; set; }
}

/// <summary>
Expand Down Expand Up @@ -2157,6 +2163,9 @@ protected MessageOptions(MessageOptions? other)
Attachments = other.Attachments is not null ? [.. other.Attachments] : null;
Mode = other.Mode;
Prompt = other.Prompt;
RequestHeaders = other.RequestHeaders is not null
? new Dictionary<string, string>(other.RequestHeaders)
: null;
}

/// <summary>
Expand All @@ -2171,6 +2180,10 @@ protected MessageOptions(MessageOptions? other)
/// Interaction mode for the message (e.g., "plan", "edit").
/// </summary>
public string? Mode { get; set; }
/// <summary>
/// Custom per-turn HTTP headers for outbound model requests.
/// </summary>
public IDictionary<string, string>? RequestHeaders { get; set; }

/// <summary>
/// Creates a shallow clone of this <see cref="MessageOptions"/> instance.
Expand Down
98 changes: 98 additions & 0 deletions dotnet/test/SerializationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,74 @@ public void SerializerOptions_CanResolveRequestIdTypeInfo()
Assert.Equal(typeof(RequestId), typeInfo.Type);
}

[Fact]
public void ProviderConfig_CanSerializeHeaders_WithSdkOptions()
{
var options = GetSerializerOptions();
var original = new ProviderConfig
{
BaseUrl = "https://example.com/provider",
Headers = new Dictionary<string, string> { ["Authorization"] = "Bearer provider-token" }
};

var json = JsonSerializer.Serialize(original, options);
using var document = JsonDocument.Parse(json);
var root = document.RootElement;
Assert.Equal("https://example.com/provider", root.GetProperty("baseUrl").GetString());
Assert.Equal("Bearer provider-token", root.GetProperty("headers").GetProperty("Authorization").GetString());

var deserialized = JsonSerializer.Deserialize<ProviderConfig>(json, options);
Assert.NotNull(deserialized);
Assert.Equal("https://example.com/provider", deserialized.BaseUrl);
Assert.Equal("Bearer provider-token", deserialized.Headers!["Authorization"]);
}

[Fact]
public void MessageOptions_CanSerializeRequestHeaders_WithSdkOptions()
{
var options = GetSerializerOptions();
var original = new MessageOptions
{
Prompt = "real prompt",
Mode = "plan",
RequestHeaders = new Dictionary<string, string> { ["X-Trace"] = "trace-value" }
};

var json = JsonSerializer.Serialize(original, options);
using var document = JsonDocument.Parse(json);
var root = document.RootElement;
Assert.Equal("real prompt", root.GetProperty("prompt").GetString());
Assert.Equal("plan", root.GetProperty("mode").GetString());
Assert.Equal("trace-value", root.GetProperty("requestHeaders").GetProperty("X-Trace").GetString());

var deserialized = JsonSerializer.Deserialize<MessageOptions>(json, options);
Assert.NotNull(deserialized);
Assert.Equal("real prompt", deserialized.Prompt);
Assert.Equal("plan", deserialized.Mode);
Assert.Equal("trace-value", deserialized.RequestHeaders!["X-Trace"]);
}

[Fact]
public void SendMessageRequest_CanSerializeRequestHeaders_WithSdkOptions()
{
var options = GetSerializerOptions();
var requestType = GetNestedType(typeof(CopilotSession), "SendMessageRequest");
var request = CreateInternalRequest(
requestType,
("SessionId", "session-id"),
("Prompt", "real prompt"),
("Mode", "plan"),
("RequestHeaders", new Dictionary<string, string> { ["X-Trace"] = "trace-value" }));

var json = JsonSerializer.Serialize(request, requestType, options);
using var document = JsonDocument.Parse(json);
var root = document.RootElement;
Assert.Equal("session-id", root.GetProperty("sessionId").GetString());
Assert.Equal("real prompt", root.GetProperty("prompt").GetString());
Assert.Equal("plan", root.GetProperty("mode").GetString());
Assert.Equal("trace-value", root.GetProperty("requestHeaders").GetProperty("X-Trace").GetString());
}

private static JsonSerializerOptions GetSerializerOptions()
{
var prop = typeof(CopilotClient)
Expand All @@ -77,4 +145,34 @@ private static JsonSerializerOptions GetSerializerOptions()
Assert.NotNull(options);
return options;
}

private static Type GetNestedType(Type containingType, string name)
{
var type = containingType.GetNestedType(name, System.Reflection.BindingFlags.NonPublic);
Assert.NotNull(type);
return type!;
}

private static object CreateInternalRequest(Type type, params (string Name, object? Value)[] properties)
{
var instance = System.Runtime.CompilerServices.RuntimeHelpers.GetUninitializedObject(type);

foreach (var (name, value) in properties)
{
var property = type.GetProperty(name, System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.NonPublic);
Assert.NotNull(property);

if (property!.SetMethod is not null)
{
property.SetValue(instance, value);
continue;
}

var field = type.GetField($"<{name}>k__BackingField", System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic);
Assert.NotNull(field);
field!.SetValue(instance, value);
}

return instance;
}
}
13 changes: 7 additions & 6 deletions go/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,13 @@ func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string)
func (s *Session) Send(ctx context.Context, options MessageOptions) (string, error) {
traceparent, tracestate := getTraceContext(ctx)
req := sessionSendRequest{
SessionID: s.SessionID,
Prompt: options.Prompt,
Attachments: options.Attachments,
Mode: options.Mode,
Traceparent: traceparent,
Tracestate: tracestate,
SessionID: s.SessionID,
Prompt: options.Prompt,
Attachments: options.Attachments,
Mode: options.Mode,
Traceparent: traceparent,
Tracestate: tracestate,
RequestHeaders: options.RequestHeaders,
}

result, err := s.client.Request("session.send", req)
Expand Down
17 changes: 11 additions & 6 deletions go/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,8 @@ type ProviderConfig struct {
BearerToken string `json:"bearerToken,omitempty"`
// Azure contains Azure-specific options
Azure *AzureProviderOptions `json:"azure,omitempty"`
// Headers are custom HTTP headers included in outbound provider requests.
Headers map[string]string `json:"headers,omitempty"`
}

// AzureProviderOptions contains Azure-specific provider configuration
Expand All @@ -807,6 +809,8 @@ type MessageOptions struct {
Attachments []Attachment
// Mode is the message delivery mode (default: "enqueue")
Mode string
// RequestHeaders are custom per-turn HTTP headers for outbound model requests.
RequestHeaders map[string]string
}

// SessionEventHandler is a callback for session events
Expand Down Expand Up @@ -1142,12 +1146,13 @@ type sessionAbortRequest struct {
}

type sessionSendRequest struct {
SessionID string `json:"sessionId"`
Prompt string `json:"prompt"`
Attachments []Attachment `json:"attachments,omitempty"`
Mode string `json:"mode,omitempty"`
Traceparent string `json:"traceparent,omitempty"`
Tracestate string `json:"tracestate,omitempty"`
SessionID string `json:"sessionId"`
Prompt string `json:"prompt"`
Attachments []Attachment `json:"attachments,omitempty"`
Mode string `json:"mode,omitempty"`
Traceparent string `json:"traceparent,omitempty"`
Tracestate string `json:"tracestate,omitempty"`
RequestHeaders map[string]string `json:"requestHeaders,omitempty"`
}

// sessionSendResponse is the response from session.send
Expand Down
57 changes: 57 additions & 0 deletions go/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,60 @@ func TestPermissionRequestResult_JSONSerialize(t *testing.T) {
t.Errorf("expected %s, got %s", expected, string(data))
}
}

func TestProviderConfig_JSONIncludesHeaders(t *testing.T) {
config := ProviderConfig{
BaseURL: "https://example.com/provider",
Headers: map[string]string{"Authorization": "Bearer provider-token"},
}

data, err := json.Marshal(config)
if err != nil {
t.Fatalf("failed to marshal provider config: %v", err)
}

var decoded map[string]any
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal provider config: %v", err)
}

if decoded["baseUrl"] != "https://example.com/provider" {
t.Fatalf("expected baseUrl to round-trip, got %v", decoded["baseUrl"])
}
headers, ok := decoded["headers"].(map[string]any)
if !ok {
t.Fatalf("expected headers object, got %T", decoded["headers"])
}
if headers["Authorization"] != "Bearer provider-token" {
t.Fatalf("expected Authorization header, got %v", headers["Authorization"])
}
}

func TestSessionSendRequest_JSONIncludesRequestHeaders(t *testing.T) {
req := sessionSendRequest{
SessionID: "session-1",
Prompt: "hello",
RequestHeaders: map[string]string{"Authorization": "Bearer turn-token"},
}

data, err := json.Marshal(req)
if err != nil {
t.Fatalf("failed to marshal session send request: %v", err)
}

var decoded map[string]any
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal session send request: %v", err)
}

if decoded["prompt"] != "hello" {
t.Fatalf("expected prompt to round-trip, got %v", decoded["prompt"])
}
headers, ok := decoded["requestHeaders"].(map[string]any)
if !ok {
t.Fatalf("expected requestHeaders object, got %T", decoded["requestHeaders"])
}
if headers["Authorization"] != "Bearer turn-token" {
t.Fatalf("expected Authorization header, got %v", headers["Authorization"])
}
}
1 change: 1 addition & 0 deletions nodejs/src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ export class CopilotSession {
prompt: options.prompt,
attachments: options.attachments,
mode: options.mode,
requestHeaders: options.requestHeaders,
});

return (response as { messageId: string }).messageId;
Expand Down
10 changes: 10 additions & 0 deletions nodejs/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,11 @@ export interface ProviderConfig {
*/
apiVersion?: string;
};

/**
* Custom HTTP headers to include in outbound provider requests.
*/
headers?: Record<string, string>;
}

/**
Expand Down Expand Up @@ -1452,6 +1457,11 @@ export interface MessageOptions {
* - "immediate": Send immediately
*/
mode?: "enqueue" | "immediate";

/**
* Custom HTTP headers to include in outbound model requests for this turn.
*/
requestHeaders?: Record<string, string>;
}

/**
Expand Down
Loading
Loading