Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pkg/connectrpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/docker/cagent/pkg/server"
"github.com/docker/cagent/pkg/session"
"github.com/docker/cagent/pkg/tools"
"github.com/docker/cagent/pkg/upstream"
)

// Server implements the Connect-RPC AgentService.
Expand All @@ -44,7 +45,8 @@ func (s *Server) Handler() http.Handler {

path, handler := cagentv1connect.NewAgentServiceHandler(s)
mux.Handle(path, handler)
return h2c.NewHandler(mux, &http2.Server{})

return upstream.Handler(h2c.NewHandler(mux, &http2.Server{}))
}

// Serve starts the Connect-RPC server on the given listener.
Expand Down
2 changes: 2 additions & 0 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/docker/cagent/pkg/api"
"github.com/docker/cagent/pkg/config"
"github.com/docker/cagent/pkg/session"
"github.com/docker/cagent/pkg/upstream"
)

type Server struct {
Expand All @@ -27,6 +28,7 @@ type Server struct {
func New(ctx context.Context, sessionStore session.Store, runConfig *config.RuntimeConfig, refreshInterval time.Duration, agentSources config.Sources) (*Server, error) {
e := echo.New()
e.Use(middleware.RequestLogger())
e.Use(echo.WrapMiddleware(upstream.Handler))

s := &Server{
e: e,
Expand Down
4 changes: 3 additions & 1 deletion pkg/tools/a2a/a2a.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/docker/cagent/pkg/httpclient"
"github.com/docker/cagent/pkg/tools"
"github.com/docker/cagent/pkg/upstream"
)

// Toolset implements tools.ToolSet for A2A remote agents.
Expand Down Expand Up @@ -121,7 +122,8 @@ func (t *Toolset) Start(ctx context.Context) error {

// Use a longer timeout for the HTTP client since LLM responses can take a while.
// The default a2a-go HTTP client has only a 5-second timeout which is too short.
httpClient := httpclient.NewHTTPClient(httpclient.WithHeaders(t.headers))
httpClient := httpclient.NewHTTPClient()
httpClient.Transport = upstream.NewHeaderTransport(httpClient.Transport, t.headers)

client, err := a2aclient.NewFromCard(ctx, card, a2aclient.WithJSONRPCTransport(httpClient))
if err != nil {
Expand Down
7 changes: 1 addition & 6 deletions pkg/tools/builtin/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/docker/cagent/pkg/config/latest"
"github.com/docker/cagent/pkg/js"
"github.com/docker/cagent/pkg/tools"
"github.com/docker/cagent/pkg/useragent"
)

type APITool struct {
Expand Down Expand Up @@ -66,15 +65,11 @@ func (t *APITool) callTool(ctx context.Context, toolCall tools.ToolCall) (*tools
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("User-Agent", useragent.Header)
setHeaders(req, t.config.Headers)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context may not contain upstream headers for placeholder resolution

The setHeaders function is called to resolve ${headers.NAME} placeholders from req.Context(). However, the context comes from the ctx parameter passed to callTool. If this tool is invoked outside the normal HTTP handler middleware chain (or if the upstream.Handler middleware wasn't applied), the context won't contain upstream headers and placeholders won't resolve.

Current flow:

  1. HTTP request → upstream.Handler middleware → stores headers in context
  2. Handler invokes tool via some chain → callTool(ctx, ...)
  3. setHeaders tries to resolve placeholders from ctx

The implementation assumes the ctx parameter contains upstream headers, but this isn't validated or guaranteed.

Evidence: The test cases in api_test.go all use t.Context() (a plain testing context with no upstream headers) and don't test placeholder resolution.

Suggestion: Either:

  • Add documentation that this tool requires contexts derived from HTTP requests processed by upstream.Handler
  • Add integration tests showing placeholder resolution works through the full handler chain
  • Add runtime validation that logs a warning when placeholders are used but upstream headers aren't available in context

if t.config.Method == http.MethodPost {
req.Header.Set("Content-Type", "application/json")
}

for key, value := range t.config.Headers {
req.Header.Set(key, value)
}

resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
Expand Down
5 changes: 4 additions & 1 deletion pkg/tools/builtin/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/getkin/kin-openapi/openapi3"

"github.com/docker/cagent/pkg/tools"
"github.com/docker/cagent/pkg/upstream"
"github.com/docker/cagent/pkg/useragent"
)

Expand Down Expand Up @@ -349,9 +350,11 @@ func sanitizeToolName(name string) string {
}

// setHeaders sets the User-Agent and custom headers on an HTTP request.
// Header values may contain ${headers.NAME} placeholders that are resolved
// from upstream headers stored in the request context.
func setHeaders(req *http.Request, headers map[string]string) {
req.Header.Set("User-Agent", useragent.Header)
for k, v := range headers {
for k, v := range upstream.ResolveHeaders(req.Context(), headers) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context may not contain upstream headers for placeholder resolution

Similar to api.go, this code calls setHeaders(req, h.headers) which attempts to resolve ${headers.NAME} placeholders from req.Context(). The context is derived from the ctx parameter passed to callTool.

If the caller didn't propagate the context from an HTTP request that went through the upstream.Handler middleware, upstream headers won't be available in the context, and placeholders won't resolve.

Evidence: The test openapi_test.go:TestOpenAPITool_CustomHeaders shows static headers work, but there's no test validating ${headers.X} placeholder resolution works when invoked through an HTTP handler with the middleware chain.

Suggestion:

  • Add integration tests that verify placeholder resolution through the full HTTP handler → tool invocation chain
  • Document the requirement that tools using header placeholders must be invoked with contexts from HTTP requests
  • Consider adding a helper that validates upstream headers are present when placeholders are detected in config

req.Header.Set(k, v)
}
}
Expand Down
40 changes: 12 additions & 28 deletions pkg/tools/mcp/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"

"github.com/docker/cagent/pkg/tools"
"github.com/docker/cagent/pkg/upstream"
)

type remoteMCPClient struct {
Expand Down Expand Up @@ -124,35 +125,11 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeReque
return session.InitializeResult(), nil
}

// headerTransport is a RoundTripper that adds custom headers to all requests
type headerTransport struct {
base http.RoundTripper
headers map[string]string
}

func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Clone the request to avoid modifying the original
req = req.Clone(req.Context())

// Add custom headers
for key, value := range t.headers {
req.Header.Set(key, value)
}

return t.base.RoundTrip(req)
}

// createHTTPClient creates an HTTP client with custom headers and OAuth support
// createHTTPClient creates an HTTP client with custom headers and OAuth support.
// Header values may contain ${headers.NAME} placeholders that are resolved
// at request time from upstream headers stored in the request context.
func (c *remoteMCPClient) createHTTPClient() *http.Client {
transport := http.DefaultTransport

// Add custom headers first
if len(c.headers) > 0 {
transport = &headerTransport{
base: transport,
headers: c.headers,
}
}
transport := c.headerTransport()

// Then wrap with OAuth support
transport = &oauthTransport{
Expand All @@ -168,6 +145,13 @@ func (c *remoteMCPClient) createHTTPClient() *http.Client {
}
}

func (c *remoteMCPClient) headerTransport() http.RoundTripper {
if len(c.headers) > 0 {
return upstream.NewHeaderTransport(http.DefaultTransport, c.headers)
}
return http.DefaultTransport
}

func (c *remoteMCPClient) Close(context.Context) error {
c.mu.RLock()
session := c.session
Expand Down
135 changes: 135 additions & 0 deletions pkg/upstream/headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Package upstream provides utilities for propagating HTTP headers
// from incoming API requests to outbound toolset HTTP calls.
package upstream

import (
"context"
"fmt"
"net/http"
"regexp"
"strings"

"github.com/dop251/goja"
)

type contextKey struct{}

// WithHeaders returns a new context carrying the given HTTP headers.
func WithHeaders(ctx context.Context, h http.Header) context.Context {
return context.WithValue(ctx, contextKey{}, h)
}

// HeadersFromContext retrieves upstream HTTP headers from the context.
// Returns nil if no headers are present.
func HeadersFromContext(ctx context.Context) http.Header {
h, _ := ctx.Value(contextKey{}).(http.Header)
return h
}

// Handler wraps an http.Handler to store the incoming HTTP request
// headers in the request context for downstream toolset forwarding.
func Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := WithHeaders(r.Context(), r.Header.Clone())
next.ServeHTTP(w, r.WithContext(ctx))
})
}

// NewHeaderTransport wraps an http.RoundTripper to set custom headers on
// every outbound request. Header values may contain ${headers.NAME}
// placeholders that are resolved at request time from upstream headers
// stored in the request context.
func NewHeaderTransport(base http.RoundTripper, headers map[string]string) http.RoundTripper {
return &headerTransport{base: base, headers: headers}
}

type headerTransport struct {
base http.RoundTripper
headers map[string]string
}

func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
for key, value := range ResolveHeaders(req.Context(), t.headers) {
req.Header.Set(key, value)
}
return t.base.RoundTrip(req)
}

// ResolveHeaders resolves ${headers.NAME} placeholders in header values
// using upstream headers from the context. Header names in the placeholder
// are case-insensitive, matching HTTP header convention.
//
// For example, given the config header:
//
// Authorization: ${headers.Authorization}
//
// and an upstream request with "Authorization: Bearer token", the resolved
// value will be "Bearer token".
func ResolveHeaders(ctx context.Context, headers map[string]string) map[string]string {
if len(headers) == 0 {
return headers
}

upstream := HeadersFromContext(ctx)
if upstream == nil {
return headers
}

vm := goja.New()
_ = vm.Set("headers", vm.NewDynamicObject(headerAccessor(func(name string) goja.Value {
return vm.ToValue(upstream.Get(name))
})))

resolved := make(map[string]string, len(headers))
for k, v := range headers {
resolved[k] = expandTemplate(vm, v)
}
return resolved
}

// headerAccessor implements [goja.DynamicObject] for case-insensitive
// HTTP header lookups.
type headerAccessor func(string) goja.Value

func (h headerAccessor) Get(k string) goja.Value { return h(k) }
func (headerAccessor) Set(string, goja.Value) bool { return false }
func (headerAccessor) Has(string) bool { return true }
func (headerAccessor) Delete(string) bool { return false }
func (headerAccessor) Keys() []string { return nil }

// headerPlaceholderRe matches ${headers.NAME} and captures the header
// name so we can rewrite it to bracket notation for the JS runtime.
var headerPlaceholderRe = regexp.MustCompile(`\$\{\s*headers\.([^}]+)\}`)

// expandTemplate evaluates a string as a JavaScript template literal,
// resolving any ${...} expressions via the goja runtime.
// Before evaluation it rewrites ${headers.NAME} to ${headers["NAME"]}
// so that header names containing hyphens (e.g. X-Request-Id) are
// accessed correctly.
func expandTemplate(vm *goja.Runtime, text string) string {
if !strings.Contains(text, "${") {
return text
}

// Rewrite dotted header access to bracket notation so names with
// hyphens work: ${headers.X-Req-Id} → ${headers["X-Req-Id"]}
text = headerPlaceholderRe.ReplaceAllStringFunc(text, func(m string) string {
parts := headerPlaceholderRe.FindStringSubmatch(m)
name := strings.TrimSpace(parts[1])
return `${headers["` + name + `"]}`
})

escaped := strings.ReplaceAll(text, "\\", "\\\\")
escaped = strings.ReplaceAll(escaped, "`", "\\`")
script := "`" + escaped + "`"

v, err := vm.RunString(script)
if err != nil {
return text
}
if v == nil || v.Export() == nil {
return ""
}
return fmt.Sprintf("%v", v.Export())
}
Loading