From 07343ef25a80e321ed060ae2b58a84652d53a567 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Fri, 13 Feb 2026 19:08:50 +0100 Subject: [PATCH] Support ${headers.NAME} syntax to forward upstream API headers to toolsets Add a new pkg/upstream package that allows toolset header values to reference incoming API request headers using ${headers.NAME} placeholders. For example, a toolset config like: headers: Authorization: ${headers.Authorization} will resolve the Authorization value at request time from the upstream HTTP request that triggered the agent. The middleware (Echo and ConnectRPC) stores the incoming request headers in the context. At tool-call time, header values containing ${headers.X} are resolved from that context. Static header values without placeholders are unaffected. Assisted-By: cagent --- pkg/connectrpc/server.go | 4 +- pkg/server/server.go | 2 + pkg/tools/a2a/a2a.go | 4 +- pkg/tools/builtin/api.go | 7 +- pkg/tools/builtin/openapi.go | 5 +- pkg/tools/mcp/remote.go | 40 +++------- pkg/upstream/headers.go | 135 ++++++++++++++++++++++++++++++++++ pkg/upstream/headers_test.go | 138 +++++++++++++++++++++++++++++++++++ 8 files changed, 298 insertions(+), 37 deletions(-) create mode 100644 pkg/upstream/headers.go create mode 100644 pkg/upstream/headers_test.go diff --git a/pkg/connectrpc/server.go b/pkg/connectrpc/server.go index 9d234e698..c6f75bb59 100644 --- a/pkg/connectrpc/server.go +++ b/pkg/connectrpc/server.go @@ -24,6 +24,7 @@ import ( "github.com/docker/cagent/pkg/server" "github.com/docker/cagent/pkg/session" "github.com/docker/cagent/pkg/tools" + "github.com/docker/cagent/pkg/upstream" ) // Server implements the Connect-RPC AgentService. @@ -44,7 +45,8 @@ func (s *Server) Handler() http.Handler { path, handler := cagentv1connect.NewAgentServiceHandler(s) mux.Handle(path, handler) - return h2c.NewHandler(mux, &http2.Server{}) + + return upstream.Handler(h2c.NewHandler(mux, &http2.Server{})) } // Serve starts the Connect-RPC server on the given listener. diff --git a/pkg/server/server.go b/pkg/server/server.go index 2bfbb4c7e..c8a47a9f3 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -17,6 +17,7 @@ import ( "github.com/docker/cagent/pkg/api" "github.com/docker/cagent/pkg/config" "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/upstream" ) type Server struct { @@ -27,6 +28,7 @@ type Server struct { func New(ctx context.Context, sessionStore session.Store, runConfig *config.RuntimeConfig, refreshInterval time.Duration, agentSources config.Sources) (*Server, error) { e := echo.New() e.Use(middleware.RequestLogger()) + e.Use(echo.WrapMiddleware(upstream.Handler)) s := &Server{ e: e, diff --git a/pkg/tools/a2a/a2a.go b/pkg/tools/a2a/a2a.go index 2cd1df2cd..0f180e5b8 100644 --- a/pkg/tools/a2a/a2a.go +++ b/pkg/tools/a2a/a2a.go @@ -16,6 +16,7 @@ import ( "github.com/docker/cagent/pkg/httpclient" "github.com/docker/cagent/pkg/tools" + "github.com/docker/cagent/pkg/upstream" ) // Toolset implements tools.ToolSet for A2A remote agents. @@ -121,7 +122,8 @@ func (t *Toolset) Start(ctx context.Context) error { // Use a longer timeout for the HTTP client since LLM responses can take a while. // The default a2a-go HTTP client has only a 5-second timeout which is too short. - httpClient := httpclient.NewHTTPClient(httpclient.WithHeaders(t.headers)) + httpClient := httpclient.NewHTTPClient() + httpClient.Transport = upstream.NewHeaderTransport(httpClient.Transport, t.headers) client, err := a2aclient.NewFromCard(ctx, card, a2aclient.WithJSONRPCTransport(httpClient)) if err != nil { diff --git a/pkg/tools/builtin/api.go b/pkg/tools/builtin/api.go index b8ee47711..386833007 100644 --- a/pkg/tools/builtin/api.go +++ b/pkg/tools/builtin/api.go @@ -14,7 +14,6 @@ import ( "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/js" "github.com/docker/cagent/pkg/tools" - "github.com/docker/cagent/pkg/useragent" ) type APITool struct { @@ -66,15 +65,11 @@ func (t *APITool) callTool(ctx context.Context, toolCall tools.ToolCall) (*tools return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("User-Agent", useragent.Header) + setHeaders(req, t.config.Headers) if t.config.Method == http.MethodPost { req.Header.Set("Content-Type", "application/json") } - for key, value := range t.config.Headers { - req.Header.Set(key, value) - } - resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("request failed: %w", err) diff --git a/pkg/tools/builtin/openapi.go b/pkg/tools/builtin/openapi.go index 3b86d9ca2..897660886 100644 --- a/pkg/tools/builtin/openapi.go +++ b/pkg/tools/builtin/openapi.go @@ -15,6 +15,7 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/docker/cagent/pkg/tools" + "github.com/docker/cagent/pkg/upstream" "github.com/docker/cagent/pkg/useragent" ) @@ -349,9 +350,11 @@ func sanitizeToolName(name string) string { } // setHeaders sets the User-Agent and custom headers on an HTTP request. +// Header values may contain ${headers.NAME} placeholders that are resolved +// from upstream headers stored in the request context. func setHeaders(req *http.Request, headers map[string]string) { req.Header.Set("User-Agent", useragent.Header) - for k, v := range headers { + for k, v := range upstream.ResolveHeaders(req.Context(), headers) { req.Header.Set(k, v) } } diff --git a/pkg/tools/mcp/remote.go b/pkg/tools/mcp/remote.go index 4c92aad5e..e99bb8f2b 100644 --- a/pkg/tools/mcp/remote.go +++ b/pkg/tools/mcp/remote.go @@ -11,6 +11,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/docker/cagent/pkg/tools" + "github.com/docker/cagent/pkg/upstream" ) type remoteMCPClient struct { @@ -124,35 +125,11 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeReque return session.InitializeResult(), nil } -// headerTransport is a RoundTripper that adds custom headers to all requests -type headerTransport struct { - base http.RoundTripper - headers map[string]string -} - -func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Clone the request to avoid modifying the original - req = req.Clone(req.Context()) - - // Add custom headers - for key, value := range t.headers { - req.Header.Set(key, value) - } - - return t.base.RoundTrip(req) -} - -// createHTTPClient creates an HTTP client with custom headers and OAuth support +// createHTTPClient creates an HTTP client with custom headers and OAuth support. +// Header values may contain ${headers.NAME} placeholders that are resolved +// at request time from upstream headers stored in the request context. func (c *remoteMCPClient) createHTTPClient() *http.Client { - transport := http.DefaultTransport - - // Add custom headers first - if len(c.headers) > 0 { - transport = &headerTransport{ - base: transport, - headers: c.headers, - } - } + transport := c.headerTransport() // Then wrap with OAuth support transport = &oauthTransport{ @@ -168,6 +145,13 @@ func (c *remoteMCPClient) createHTTPClient() *http.Client { } } +func (c *remoteMCPClient) headerTransport() http.RoundTripper { + if len(c.headers) > 0 { + return upstream.NewHeaderTransport(http.DefaultTransport, c.headers) + } + return http.DefaultTransport +} + func (c *remoteMCPClient) Close(context.Context) error { c.mu.RLock() session := c.session diff --git a/pkg/upstream/headers.go b/pkg/upstream/headers.go new file mode 100644 index 000000000..d570450f2 --- /dev/null +++ b/pkg/upstream/headers.go @@ -0,0 +1,135 @@ +// Package upstream provides utilities for propagating HTTP headers +// from incoming API requests to outbound toolset HTTP calls. +package upstream + +import ( + "context" + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/dop251/goja" +) + +type contextKey struct{} + +// WithHeaders returns a new context carrying the given HTTP headers. +func WithHeaders(ctx context.Context, h http.Header) context.Context { + return context.WithValue(ctx, contextKey{}, h) +} + +// HeadersFromContext retrieves upstream HTTP headers from the context. +// Returns nil if no headers are present. +func HeadersFromContext(ctx context.Context) http.Header { + h, _ := ctx.Value(contextKey{}).(http.Header) + return h +} + +// Handler wraps an http.Handler to store the incoming HTTP request +// headers in the request context for downstream toolset forwarding. +func Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := WithHeaders(r.Context(), r.Header.Clone()) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// NewHeaderTransport wraps an http.RoundTripper to set custom headers on +// every outbound request. Header values may contain ${headers.NAME} +// placeholders that are resolved at request time from upstream headers +// stored in the request context. +func NewHeaderTransport(base http.RoundTripper, headers map[string]string) http.RoundTripper { + return &headerTransport{base: base, headers: headers} +} + +type headerTransport struct { + base http.RoundTripper + headers map[string]string +} + +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + for key, value := range ResolveHeaders(req.Context(), t.headers) { + req.Header.Set(key, value) + } + return t.base.RoundTrip(req) +} + +// ResolveHeaders resolves ${headers.NAME} placeholders in header values +// using upstream headers from the context. Header names in the placeholder +// are case-insensitive, matching HTTP header convention. +// +// For example, given the config header: +// +// Authorization: ${headers.Authorization} +// +// and an upstream request with "Authorization: Bearer token", the resolved +// value will be "Bearer token". +func ResolveHeaders(ctx context.Context, headers map[string]string) map[string]string { + if len(headers) == 0 { + return headers + } + + upstream := HeadersFromContext(ctx) + if upstream == nil { + return headers + } + + vm := goja.New() + _ = vm.Set("headers", vm.NewDynamicObject(headerAccessor(func(name string) goja.Value { + return vm.ToValue(upstream.Get(name)) + }))) + + resolved := make(map[string]string, len(headers)) + for k, v := range headers { + resolved[k] = expandTemplate(vm, v) + } + return resolved +} + +// headerAccessor implements [goja.DynamicObject] for case-insensitive +// HTTP header lookups. +type headerAccessor func(string) goja.Value + +func (h headerAccessor) Get(k string) goja.Value { return h(k) } +func (headerAccessor) Set(string, goja.Value) bool { return false } +func (headerAccessor) Has(string) bool { return true } +func (headerAccessor) Delete(string) bool { return false } +func (headerAccessor) Keys() []string { return nil } + +// headerPlaceholderRe matches ${headers.NAME} and captures the header +// name so we can rewrite it to bracket notation for the JS runtime. +var headerPlaceholderRe = regexp.MustCompile(`\$\{\s*headers\.([^}]+)\}`) + +// expandTemplate evaluates a string as a JavaScript template literal, +// resolving any ${...} expressions via the goja runtime. +// Before evaluation it rewrites ${headers.NAME} to ${headers["NAME"]} +// so that header names containing hyphens (e.g. X-Request-Id) are +// accessed correctly. +func expandTemplate(vm *goja.Runtime, text string) string { + if !strings.Contains(text, "${") { + return text + } + + // Rewrite dotted header access to bracket notation so names with + // hyphens work: ${headers.X-Req-Id} → ${headers["X-Req-Id"]} + text = headerPlaceholderRe.ReplaceAllStringFunc(text, func(m string) string { + parts := headerPlaceholderRe.FindStringSubmatch(m) + name := strings.TrimSpace(parts[1]) + return `${headers["` + name + `"]}` + }) + + escaped := strings.ReplaceAll(text, "\\", "\\\\") + escaped = strings.ReplaceAll(escaped, "`", "\\`") + script := "`" + escaped + "`" + + v, err := vm.RunString(script) + if err != nil { + return text + } + if v == nil || v.Export() == nil { + return "" + } + return fmt.Sprintf("%v", v.Export()) +} diff --git a/pkg/upstream/headers_test.go b/pkg/upstream/headers_test.go new file mode 100644 index 000000000..f70c4d5b6 --- /dev/null +++ b/pkg/upstream/headers_test.go @@ -0,0 +1,138 @@ +package upstream + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHeadersRoundTrip(t *testing.T) { + t.Parallel() + + h := http.Header{} + h.Set("Authorization", "Bearer token123") + h.Set("X-Custom", "value") + + ctx := WithHeaders(t.Context(), h) + got := HeadersFromContext(ctx) + + require.NotNil(t, got) + assert.Equal(t, "Bearer token123", got.Get("Authorization")) + assert.Equal(t, "value", got.Get("X-Custom")) +} + +func TestHeadersFromContext_Empty(t *testing.T) { + t.Parallel() + + got := HeadersFromContext(t.Context()) + assert.Nil(t, got) +} + +func TestHandler_InjectsHeaders(t *testing.T) { + t.Parallel() + + var captured http.Header + inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + captured = HeadersFromContext(r.Context()) + }) + + handler := Handler(inner) + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req.Header.Set("X-Test", "hello") + + handler.ServeHTTP(httptest.NewRecorder(), req) + + require.NotNil(t, captured) + assert.Equal(t, "hello", captured.Get("X-Test")) +} + +func TestResolveHeaders(t *testing.T) { + t.Parallel() + + upstream := http.Header{} + upstream.Set("Authorization", "Bearer secret") + upstream.Set("X-Request-Id", "abc-123") + + ctx := WithHeaders(t.Context(), upstream) + + tests := []struct { + name string + headers map[string]string + expected map[string]string + }{ + { + name: "no placeholders", + headers: map[string]string{"Content-Type": "application/json"}, + expected: map[string]string{"Content-Type": "application/json"}, + }, + { + name: "single placeholder", + headers: map[string]string{"Authorization": "${headers.Authorization}"}, + expected: map[string]string{"Authorization": "Bearer secret"}, + }, + { + name: "case insensitive header name", + headers: map[string]string{"Authorization": "${headers.authorization}"}, + expected: map[string]string{"Authorization": "Bearer secret"}, + }, + { + name: "multiple headers with placeholders", + headers: map[string]string{"Authorization": "${headers.Authorization}", "X-Req": "${headers.X-Request-Id}"}, + expected: map[string]string{"Authorization": "Bearer secret", "X-Req": "abc-123"}, + }, + { + name: "mixed static and placeholder", + headers: map[string]string{"Authorization": "${headers.Authorization}", "Accept": "text/html"}, + expected: map[string]string{"Authorization": "Bearer secret", "Accept": "text/html"}, + }, + { + name: "placeholder with surrounding text", + headers: map[string]string{"X-Info": "id=${headers.X-Request-Id}&ok"}, + expected: map[string]string{"X-Info": "id=abc-123&ok"}, + }, + { + name: "missing upstream header resolves to empty", + headers: map[string]string{"Authorization": "${headers.X-Missing}"}, + expected: map[string]string{"Authorization": ""}, + }, + { + name: "nil headers", + headers: nil, + expected: nil, + }, + { + name: "empty headers", + headers: map[string]string{}, + expected: map[string]string{}, + }, + { + name: "trimmed spaces in name", + headers: map[string]string{"Auth": "${headers. Authorization }"}, + expected: map[string]string{"Auth": "Bearer secret"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := ResolveHeaders(ctx, tt.headers) + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestResolveHeaders_NoUpstreamContext(t *testing.T) { + t.Parallel() + + headers := map[string]string{ + "Authorization": "${headers.Authorization}", + "Accept": "text/html", + } + + // No upstream headers in context — placeholders are left as-is. + got := ResolveHeaders(t.Context(), headers) + assert.Equal(t, headers, got) +}