diff --git a/pkg/connectrpc/server.go b/pkg/connectrpc/server.go index 9d234e698..c6f75bb59 100644 --- a/pkg/connectrpc/server.go +++ b/pkg/connectrpc/server.go @@ -24,6 +24,7 @@ import ( "github.com/docker/cagent/pkg/server" "github.com/docker/cagent/pkg/session" "github.com/docker/cagent/pkg/tools" + "github.com/docker/cagent/pkg/upstream" ) // Server implements the Connect-RPC AgentService. @@ -44,7 +45,8 @@ func (s *Server) Handler() http.Handler { path, handler := cagentv1connect.NewAgentServiceHandler(s) mux.Handle(path, handler) - return h2c.NewHandler(mux, &http2.Server{}) + + return upstream.Handler(h2c.NewHandler(mux, &http2.Server{})) } // Serve starts the Connect-RPC server on the given listener. diff --git a/pkg/server/server.go b/pkg/server/server.go index 2bfbb4c7e..c8a47a9f3 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -17,6 +17,7 @@ import ( "github.com/docker/cagent/pkg/api" "github.com/docker/cagent/pkg/config" "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/upstream" ) type Server struct { @@ -27,6 +28,7 @@ type Server struct { func New(ctx context.Context, sessionStore session.Store, runConfig *config.RuntimeConfig, refreshInterval time.Duration, agentSources config.Sources) (*Server, error) { e := echo.New() e.Use(middleware.RequestLogger()) + e.Use(echo.WrapMiddleware(upstream.Handler)) s := &Server{ e: e, diff --git a/pkg/tools/a2a/a2a.go b/pkg/tools/a2a/a2a.go index 2cd1df2cd..0f180e5b8 100644 --- a/pkg/tools/a2a/a2a.go +++ b/pkg/tools/a2a/a2a.go @@ -16,6 +16,7 @@ import ( "github.com/docker/cagent/pkg/httpclient" "github.com/docker/cagent/pkg/tools" + "github.com/docker/cagent/pkg/upstream" ) // Toolset implements tools.ToolSet for A2A remote agents. @@ -121,7 +122,8 @@ func (t *Toolset) Start(ctx context.Context) error { // Use a longer timeout for the HTTP client since LLM responses can take a while. // The default a2a-go HTTP client has only a 5-second timeout which is too short. - httpClient := httpclient.NewHTTPClient(httpclient.WithHeaders(t.headers)) + httpClient := httpclient.NewHTTPClient() + httpClient.Transport = upstream.NewHeaderTransport(httpClient.Transport, t.headers) client, err := a2aclient.NewFromCard(ctx, card, a2aclient.WithJSONRPCTransport(httpClient)) if err != nil { diff --git a/pkg/tools/builtin/api.go b/pkg/tools/builtin/api.go index b8ee47711..386833007 100644 --- a/pkg/tools/builtin/api.go +++ b/pkg/tools/builtin/api.go @@ -14,7 +14,6 @@ import ( "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/js" "github.com/docker/cagent/pkg/tools" - "github.com/docker/cagent/pkg/useragent" ) type APITool struct { @@ -66,15 +65,11 @@ func (t *APITool) callTool(ctx context.Context, toolCall tools.ToolCall) (*tools return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("User-Agent", useragent.Header) + setHeaders(req, t.config.Headers) if t.config.Method == http.MethodPost { req.Header.Set("Content-Type", "application/json") } - for key, value := range t.config.Headers { - req.Header.Set(key, value) - } - resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("request failed: %w", err) diff --git a/pkg/tools/builtin/openapi.go b/pkg/tools/builtin/openapi.go index 3b86d9ca2..897660886 100644 --- a/pkg/tools/builtin/openapi.go +++ b/pkg/tools/builtin/openapi.go @@ -15,6 +15,7 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/docker/cagent/pkg/tools" + "github.com/docker/cagent/pkg/upstream" "github.com/docker/cagent/pkg/useragent" ) @@ -349,9 +350,11 @@ func sanitizeToolName(name string) string { } // setHeaders sets the User-Agent and custom headers on an HTTP request. +// Header values may contain ${headers.NAME} placeholders that are resolved +// from upstream headers stored in the request context. func setHeaders(req *http.Request, headers map[string]string) { req.Header.Set("User-Agent", useragent.Header) - for k, v := range headers { + for k, v := range upstream.ResolveHeaders(req.Context(), headers) { req.Header.Set(k, v) } } diff --git a/pkg/tools/mcp/remote.go b/pkg/tools/mcp/remote.go index 4c92aad5e..e99bb8f2b 100644 --- a/pkg/tools/mcp/remote.go +++ b/pkg/tools/mcp/remote.go @@ -11,6 +11,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/docker/cagent/pkg/tools" + "github.com/docker/cagent/pkg/upstream" ) type remoteMCPClient struct { @@ -124,35 +125,11 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeReque return session.InitializeResult(), nil } -// headerTransport is a RoundTripper that adds custom headers to all requests -type headerTransport struct { - base http.RoundTripper - headers map[string]string -} - -func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Clone the request to avoid modifying the original - req = req.Clone(req.Context()) - - // Add custom headers - for key, value := range t.headers { - req.Header.Set(key, value) - } - - return t.base.RoundTrip(req) -} - -// createHTTPClient creates an HTTP client with custom headers and OAuth support +// createHTTPClient creates an HTTP client with custom headers and OAuth support. +// Header values may contain ${headers.NAME} placeholders that are resolved +// at request time from upstream headers stored in the request context. func (c *remoteMCPClient) createHTTPClient() *http.Client { - transport := http.DefaultTransport - - // Add custom headers first - if len(c.headers) > 0 { - transport = &headerTransport{ - base: transport, - headers: c.headers, - } - } + transport := c.headerTransport() // Then wrap with OAuth support transport = &oauthTransport{ @@ -168,6 +145,13 @@ func (c *remoteMCPClient) createHTTPClient() *http.Client { } } +func (c *remoteMCPClient) headerTransport() http.RoundTripper { + if len(c.headers) > 0 { + return upstream.NewHeaderTransport(http.DefaultTransport, c.headers) + } + return http.DefaultTransport +} + func (c *remoteMCPClient) Close(context.Context) error { c.mu.RLock() session := c.session diff --git a/pkg/upstream/headers.go b/pkg/upstream/headers.go new file mode 100644 index 000000000..d570450f2 --- /dev/null +++ b/pkg/upstream/headers.go @@ -0,0 +1,135 @@ +// Package upstream provides utilities for propagating HTTP headers +// from incoming API requests to outbound toolset HTTP calls. +package upstream + +import ( + "context" + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/dop251/goja" +) + +type contextKey struct{} + +// WithHeaders returns a new context carrying the given HTTP headers. +func WithHeaders(ctx context.Context, h http.Header) context.Context { + return context.WithValue(ctx, contextKey{}, h) +} + +// HeadersFromContext retrieves upstream HTTP headers from the context. +// Returns nil if no headers are present. +func HeadersFromContext(ctx context.Context) http.Header { + h, _ := ctx.Value(contextKey{}).(http.Header) + return h +} + +// Handler wraps an http.Handler to store the incoming HTTP request +// headers in the request context for downstream toolset forwarding. +func Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := WithHeaders(r.Context(), r.Header.Clone()) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// NewHeaderTransport wraps an http.RoundTripper to set custom headers on +// every outbound request. Header values may contain ${headers.NAME} +// placeholders that are resolved at request time from upstream headers +// stored in the request context. +func NewHeaderTransport(base http.RoundTripper, headers map[string]string) http.RoundTripper { + return &headerTransport{base: base, headers: headers} +} + +type headerTransport struct { + base http.RoundTripper + headers map[string]string +} + +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + for key, value := range ResolveHeaders(req.Context(), t.headers) { + req.Header.Set(key, value) + } + return t.base.RoundTrip(req) +} + +// ResolveHeaders resolves ${headers.NAME} placeholders in header values +// using upstream headers from the context. Header names in the placeholder +// are case-insensitive, matching HTTP header convention. +// +// For example, given the config header: +// +// Authorization: ${headers.Authorization} +// +// and an upstream request with "Authorization: Bearer token", the resolved +// value will be "Bearer token". +func ResolveHeaders(ctx context.Context, headers map[string]string) map[string]string { + if len(headers) == 0 { + return headers + } + + upstream := HeadersFromContext(ctx) + if upstream == nil { + return headers + } + + vm := goja.New() + _ = vm.Set("headers", vm.NewDynamicObject(headerAccessor(func(name string) goja.Value { + return vm.ToValue(upstream.Get(name)) + }))) + + resolved := make(map[string]string, len(headers)) + for k, v := range headers { + resolved[k] = expandTemplate(vm, v) + } + return resolved +} + +// headerAccessor implements [goja.DynamicObject] for case-insensitive +// HTTP header lookups. +type headerAccessor func(string) goja.Value + +func (h headerAccessor) Get(k string) goja.Value { return h(k) } +func (headerAccessor) Set(string, goja.Value) bool { return false } +func (headerAccessor) Has(string) bool { return true } +func (headerAccessor) Delete(string) bool { return false } +func (headerAccessor) Keys() []string { return nil } + +// headerPlaceholderRe matches ${headers.NAME} and captures the header +// name so we can rewrite it to bracket notation for the JS runtime. +var headerPlaceholderRe = regexp.MustCompile(`\$\{\s*headers\.([^}]+)\}`) + +// expandTemplate evaluates a string as a JavaScript template literal, +// resolving any ${...} expressions via the goja runtime. +// Before evaluation it rewrites ${headers.NAME} to ${headers["NAME"]} +// so that header names containing hyphens (e.g. X-Request-Id) are +// accessed correctly. +func expandTemplate(vm *goja.Runtime, text string) string { + if !strings.Contains(text, "${") { + return text + } + + // Rewrite dotted header access to bracket notation so names with + // hyphens work: ${headers.X-Req-Id} → ${headers["X-Req-Id"]} + text = headerPlaceholderRe.ReplaceAllStringFunc(text, func(m string) string { + parts := headerPlaceholderRe.FindStringSubmatch(m) + name := strings.TrimSpace(parts[1]) + return `${headers["` + name + `"]}` + }) + + escaped := strings.ReplaceAll(text, "\\", "\\\\") + escaped = strings.ReplaceAll(escaped, "`", "\\`") + script := "`" + escaped + "`" + + v, err := vm.RunString(script) + if err != nil { + return text + } + if v == nil || v.Export() == nil { + return "" + } + return fmt.Sprintf("%v", v.Export()) +} diff --git a/pkg/upstream/headers_test.go b/pkg/upstream/headers_test.go new file mode 100644 index 000000000..f70c4d5b6 --- /dev/null +++ b/pkg/upstream/headers_test.go @@ -0,0 +1,138 @@ +package upstream + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHeadersRoundTrip(t *testing.T) { + t.Parallel() + + h := http.Header{} + h.Set("Authorization", "Bearer token123") + h.Set("X-Custom", "value") + + ctx := WithHeaders(t.Context(), h) + got := HeadersFromContext(ctx) + + require.NotNil(t, got) + assert.Equal(t, "Bearer token123", got.Get("Authorization")) + assert.Equal(t, "value", got.Get("X-Custom")) +} + +func TestHeadersFromContext_Empty(t *testing.T) { + t.Parallel() + + got := HeadersFromContext(t.Context()) + assert.Nil(t, got) +} + +func TestHandler_InjectsHeaders(t *testing.T) { + t.Parallel() + + var captured http.Header + inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + captured = HeadersFromContext(r.Context()) + }) + + handler := Handler(inner) + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set("X-Test", "hello") + + handler.ServeHTTP(httptest.NewRecorder(), req) + + require.NotNil(t, captured) + assert.Equal(t, "hello", captured.Get("X-Test")) +} + +func TestResolveHeaders(t *testing.T) { + t.Parallel() + + upstream := http.Header{} + upstream.Set("Authorization", "Bearer secret") + upstream.Set("X-Request-Id", "abc-123") + + ctx := WithHeaders(t.Context(), upstream) + + tests := []struct { + name string + headers map[string]string + expected map[string]string + }{ + { + name: "no placeholders", + headers: map[string]string{"Content-Type": "application/json"}, + expected: map[string]string{"Content-Type": "application/json"}, + }, + { + name: "single placeholder", + headers: map[string]string{"Authorization": "${headers.Authorization}"}, + expected: map[string]string{"Authorization": "Bearer secret"}, + }, + { + name: "case insensitive header name", + headers: map[string]string{"Authorization": "${headers.authorization}"}, + expected: map[string]string{"Authorization": "Bearer secret"}, + }, + { + name: "multiple headers with placeholders", + headers: map[string]string{"Authorization": "${headers.Authorization}", "X-Req": "${headers.X-Request-Id}"}, + expected: map[string]string{"Authorization": "Bearer secret", "X-Req": "abc-123"}, + }, + { + name: "mixed static and placeholder", + headers: map[string]string{"Authorization": "${headers.Authorization}", "Accept": "text/html"}, + expected: map[string]string{"Authorization": "Bearer secret", "Accept": "text/html"}, + }, + { + name: "placeholder with surrounding text", + headers: map[string]string{"X-Info": "id=${headers.X-Request-Id}&ok"}, + expected: map[string]string{"X-Info": "id=abc-123&ok"}, + }, + { + name: "missing upstream header resolves to empty", + headers: map[string]string{"Authorization": "${headers.X-Missing}"}, + expected: map[string]string{"Authorization": ""}, + }, + { + name: "nil headers", + headers: nil, + expected: nil, + }, + { + name: "empty headers", + headers: map[string]string{}, + expected: map[string]string{}, + }, + { + name: "trimmed spaces in name", + headers: map[string]string{"Auth": "${headers. Authorization }"}, + expected: map[string]string{"Auth": "Bearer secret"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := ResolveHeaders(ctx, tt.headers) + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestResolveHeaders_NoUpstreamContext(t *testing.T) { + t.Parallel() + + headers := map[string]string{ + "Authorization": "${headers.Authorization}", + "Accept": "text/html", + } + + // No upstream headers in context — placeholders are left as-is. + got := ResolveHeaders(t.Context(), headers) + assert.Equal(t, headers, got) +}