diff --git a/cmd/root/push.go b/cmd/root/push.go index 99bd85571..43439d138 100644 --- a/cmd/root/push.go +++ b/cmd/root/push.go @@ -51,7 +51,7 @@ func runPushCommand(cmd *cobra.Command, args []string) error { out.Printf("Pushing agent %s to %s\n", agentFilename, tag) - err = remote.Push(tag) + err = remote.Push(ctx, tag) if err != nil { return fmt.Errorf("failed to push artifact: %w", err) } diff --git a/go.mod b/go.mod index 483429fcc..3ce99c95a 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/junegunn/fzf v0.70.0 github.com/k3a/html2text v1.4.0 + github.com/kofalt/go-memoize v0.0.0-20240506050413-9e5eb99a0f2a github.com/labstack/echo/v4 v4.15.1 github.com/mattn/go-isatty v0.0.20 github.com/mattn/go-runewidth v0.0.21 @@ -183,6 +184,7 @@ require ( github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pjbgf/sha1cd v0.3.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 1be63a8d8..a21d36bf8 100644 --- a/go.sum +++ b/go.sum @@ -311,6 +311,8 @@ github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4 github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/kofalt/go-memoize v0.0.0-20240506050413-9e5eb99a0f2a h1:yyeZ0oZLWgSakB9QzPuL/Kyx9kcXYblDOswXaOEx0tg= +github.com/kofalt/go-memoize v0.0.0-20240506050413-9e5eb99a0f2a/go.mod h1:EUxMohcCc4AiiO1SImzCQo3EdrEYj9Xkyrxbepg02nQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -385,6 +387,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4= @@ -431,10 +435,13 @@ github.com/skeema/knownhosts v1.3.1 h1:X2osQ+RAjK76shCbvhHHHVl3ZlgDm8apHEHFqRjnB github.com/skeema/knownhosts v1.3.1/go.mod h1:r7KTdC8l4uxWRyK2TpQZ/1o5HaSzh06ePQNxPwTcfiY= github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 h1:6C8qej6f1bStuePVkLSFxoU22XBS165D3klxlzRg8F4= github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82/go.mod h1:xe4pgH49k4SsmkQq5OT8abwhWmnzkhpgnXeekbx2efw= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/assertions v1.2.0 h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N3yZFZkDFs= +github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/smartystreets/gunit v1.4.2 h1:tyWYZffdPhQPfK5VsMQXfauwnJkqg7Tv5DLuQVYxq3Q= +github.com/smartystreets/gunit v1.4.2/go.mod h1:ZjM1ozSIMJlAz/ay4SG8PeKF00ckUp+zMHZXV9/bvak= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= diff --git a/pkg/config/sources.go b/pkg/config/sources.go index e8ccf52fe..f94a2336a 100644 --- a/pkg/config/sources.go +++ b/pkg/config/sources.go @@ -229,7 +229,7 @@ func (a urlSource) Read(ctx context.Context) ([]byte, error) { // Add GitHub token authorization for GitHub URLs a.addGitHubAuth(ctx, req) - resp, err := httpclient.NewHTTPClient().Do(req) + resp, err := httpclient.NewHTTPClient(ctx).Do(req) if err != nil { // Network error - try to use cached version if cachedData, cacheErr := os.ReadFile(cachePath); cacheErr == nil { diff --git a/pkg/desktop/paths.go b/pkg/desktop/paths.go index b431ff00c..a916d02dc 100644 --- a/pkg/desktop/paths.go +++ b/pkg/desktop/paths.go @@ -4,6 +4,7 @@ import "sync" type DockerDesktopPaths struct { BackendSocket string + ProxySocket string } var Paths = sync.OnceValue(func() DockerDesktopPaths { diff --git a/pkg/desktop/running.go b/pkg/desktop/running.go index f3f450802..2343f1f02 100644 --- a/pkg/desktop/running.go +++ b/pkg/desktop/running.go @@ -2,9 +2,12 @@ package desktop import ( "context" + "time" ) func IsDockerDesktopRunning(ctx context.Context) bool { + ctx, cancel := context.WithTimeout(ctx, time.Second*3) + defer cancel() err := ClientBackend.Get(ctx, "/ping", nil) return err == nil } diff --git a/pkg/desktop/socket/dial.go b/pkg/desktop/socket/dial.go new file mode 100644 index 000000000..f089e9dfb --- /dev/null +++ b/pkg/desktop/socket/dial.go @@ -0,0 +1,9 @@ +package socket + +import ( + "strings" +) + +func stripUnixScheme(path string) string { + return strings.TrimPrefix(path, "unix://") +} diff --git a/pkg/desktop/socket/dial_unix.go b/pkg/desktop/socket/dial_unix.go new file mode 100644 index 000000000..5579135e1 --- /dev/null +++ b/pkg/desktop/socket/dial_unix.go @@ -0,0 +1,14 @@ +//go:build !windows + +package socket + +import ( + "context" + "net" +) + +// DialUnix is a simple wrapper for `net.Dial("unix")`. +func DialUnix(ctx context.Context, path string) (net.Conn, error) { + dialer := &net.Dialer{} + return dialer.DialContext(ctx, "unix", stripUnixScheme(path)) +} diff --git a/pkg/desktop/socket/dial_windows.go b/pkg/desktop/socket/dial_windows.go new file mode 100644 index 000000000..7425d7893 --- /dev/null +++ b/pkg/desktop/socket/dial_windows.go @@ -0,0 +1,24 @@ +package socket + +import ( + "context" + "net" + "strings" + "time" + + "github.com/Microsoft/go-winio" +) + +// DialUnix is a simple wrapper for `winio.DialPipe(path, 10s)`. +// It provides API compatibility for named pipes with the Unix domain socket API. +func DialUnix(ctx context.Context, path string) (net.Conn, error) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + if strings.HasPrefix(path, "unix://") { + // windows supports AF_UNIX + d := &net.Dialer{} + return d.DialContext(ctx, "unix", stripUnixScheme(path)) + } + return winio.DialPipeContext(ctx, path) +} diff --git a/pkg/desktop/sockets_darwin.go b/pkg/desktop/sockets_darwin.go index 7555fc912..49a6d1fb7 100644 --- a/pkg/desktop/sockets_darwin.go +++ b/pkg/desktop/sockets_darwin.go @@ -15,5 +15,6 @@ func getDockerDesktopPaths() (DockerDesktopPaths, error) { return DockerDesktopPaths{ BackendSocket: filepath.Join(data, "backend.sock"), + ProxySocket: filepath.Join(data, "httpproxy.sock"), }, nil } diff --git a/pkg/desktop/sockets_linux.go b/pkg/desktop/sockets_linux.go index 0c05906b3..f98d0d0dc 100644 --- a/pkg/desktop/sockets_linux.go +++ b/pkg/desktop/sockets_linux.go @@ -12,6 +12,7 @@ func getDockerDesktopPaths() (DockerDesktopPaths, error) { // Inside LinuxKit return DockerDesktopPaths{ BackendSocket: "/run/host-services/backend.sock", + ProxySocket: "/run/host-services/httpproxy.sock", }, nil } @@ -23,6 +24,7 @@ func getDockerDesktopPaths() (DockerDesktopPaths, error) { // Inside WSL2 return DockerDesktopPaths{ BackendSocket: "/mnt/wsl/docker-desktop/shared-sockets/host-services/backend.sock", + ProxySocket: "/mnt/wsl/docker-desktop/shared-sockets/host-services/httpproxy.sock", }, nil } @@ -38,5 +40,6 @@ func getDockerDesktopPaths() (DockerDesktopPaths, error) { // On Linux return DockerDesktopPaths{ BackendSocket: filepath.Join(home, ".docker", "desktop", "backend.sock"), + ProxySocket: filepath.Join(home, ".docker", "desktop", "httpproxy.sock"), }, nil } diff --git a/pkg/desktop/sockets_windows.go b/pkg/desktop/sockets_windows.go index 891225e33..542a49839 100644 --- a/pkg/desktop/sockets_windows.go +++ b/pkg/desktop/sockets_windows.go @@ -13,5 +13,6 @@ func getDockerDesktopPaths() (DockerDesktopPaths, error) { return DockerDesktopPaths{ BackendSocket: `\\.\pipe\dockerBackendApiServer`, + ProxySocket: `\\.\pipe\dockerHTTPProxy`, }, nil } diff --git a/pkg/gateway/catalog.go b/pkg/gateway/catalog.go index ed7299a77..200eda39d 100644 --- a/pkg/gateway/catalog.go +++ b/pkg/gateway/catalog.go @@ -13,6 +13,7 @@ import ( "time" "github.com/docker/docker-agent/pkg/paths" + "github.com/docker/docker-agent/pkg/remote" ) const ( @@ -166,10 +167,6 @@ func saveToDisk(path string, catalog Catalog, etag string) { } } -// catalogClient is a dedicated HTTP client for catalog fetches, isolated from -// http.DefaultClient so that other parts of the process cannot interfere. -var catalogClient = &http.Client{} - // fetchFromNetwork fetches the catalog, using the ETag for conditional requests. // It returns (nil, "", nil) when the server responds with 304 Not Modified. func fetchFromNetwork(ctx context.Context, etag string) (Catalog, string, error) { @@ -185,6 +182,7 @@ func fetchFromNetwork(ctx context.Context, etag string) (Catalog, string, error) req.Header.Set("If-None-Match", etag) } + catalogClient := &http.Client{Transport: remote.NewTransport(ctx)} resp, err := catalogClient.Do(req) if err != nil { return nil, "", err diff --git a/pkg/httpclient/client.go b/pkg/httpclient/client.go index 5b199c98e..8cfa4fe11 100644 --- a/pkg/httpclient/client.go +++ b/pkg/httpclient/client.go @@ -1,12 +1,14 @@ package httpclient import ( + "context" "fmt" "maps" "net/http" "net/url" "runtime" + "github.com/docker/docker-agent/pkg/remote" "github.com/docker/docker-agent/pkg/version" ) @@ -17,7 +19,7 @@ type HTTPOptions struct { type Opt func(*HTTPOptions) -func NewHTTPClient(opts ...Opt) *http.Client { +func NewHTTPClient(ctx context.Context, opts ...Opt) *http.Client { httpOptions := HTTPOptions{ Header: make(http.Header), } @@ -32,7 +34,7 @@ func NewHTTPClient(opts ...Opt) *http.Client { // Disable automatic gzip: Go's default transport transparently compresses // and decompresses responses, which is incompatible with SSE streaming. // See https://github.com/docker/docker-agent/issues/1956 - rt := newTransport() + rt := newTransport(ctx) return &http.Client{ Transport: &userAgentTransport{ @@ -95,15 +97,18 @@ func WithQuery(query url.Values) Opt { } } -// newTransport returns an HTTP transport with automatic gzip compression disabled. -func newTransport() http.RoundTripper { - t, ok := http.DefaultTransport.(*http.Transport) - if !ok { - return http.DefaultTransport +// newTransport returns an HTTP transport with automatic gzip compression disabled and using Docker Desktop proxy if available. +func newTransport(ctx context.Context) http.RoundTripper { + // Get the base transport with Desktop proxy support from remote package + rt := remote.NewTransport(ctx) + + // If it's an http.Transport, disable compression for SSE streaming compatibility + if transport, ok := rt.(*http.Transport); ok { + transport.DisableCompression = true + return transport } - transport := t.Clone() - transport.DisableCompression = true - return transport + + return rt } type userAgentTransport struct { diff --git a/pkg/httpclient/client_test.go b/pkg/httpclient/client_test.go index dfe5c6f1f..e24b1d028 100644 --- a/pkg/httpclient/client_test.go +++ b/pkg/httpclient/client_test.go @@ -75,7 +75,7 @@ func doRequest(t *testing.T, opts ...Opt) http.Header { })) defer srv.Close() - client := NewHTTPClient(opts...) + client := NewHTTPClient(t.Context(), opts...) req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) require.NoError(t, err) diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index b171e9563..8e3f7b326 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -163,7 +163,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro slog.Debug("Anthropic API key found, creating client") requestOptions := []option.RequestOption{ option.WithAPIKey(authToken), - option.WithHTTPClient(httpclient.NewHTTPClient()), + option.WithHTTPClient(httpclient.NewHTTPClient(ctx)), } if cfg.BaseURL != "" { requestOptions = append(requestOptions, option.WithBaseURL(cfg.BaseURL)) @@ -210,7 +210,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro option.WithAuthToken(authToken), option.WithAPIKey(authToken), option.WithBaseURL(baseURL), - option.WithHTTPClient(httpclient.NewHTTPClient(httpOptions...)), + option.WithHTTPClient(httpclient.NewHTTPClient(ctx, httpOptions...)), ) return client, nil diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 0178e5d74..c7d523a96 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -97,7 +97,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro } backend = genai.BackendGeminiAPI - httpClient = httpclient.NewHTTPClient() + httpClient = httpclient.NewHTTPClient(ctx) } client, err := genai.NewClient(ctx, &genai.ClientConfig{ @@ -152,7 +152,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro return genai.NewClient(ctx, &genai.ClientConfig{ APIKey: authToken, Backend: genai.BackendGeminiAPI, - HTTPClient: httpclient.NewHTTPClient(httpOptions...), + HTTPClient: httpclient.NewHTTPClient(ctx, httpOptions...), HTTPOptions: genai.HTTPOptions{ BaseURL: baseURL, Headers: http.Header{ diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 8042a295b..48e7c32a1 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -92,7 +92,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro clientOptions = append(clientOptions, option.WithBaseURL(cfg.BaseURL)) } - httpClient := httpclient.NewHTTPClient() + httpClient := httpclient.NewHTTPClient(ctx) clientOptions = append(clientOptions, option.WithHTTPClient(httpClient)) client := openai.NewClient(clientOptions...) @@ -135,7 +135,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro client := openai.NewClient( option.WithAPIKey(authToken), option.WithBaseURL(baseURL), - option.WithHTTPClient(httpclient.NewHTTPClient(httpOptions...)), + option.WithHTTPClient(httpclient.NewHTTPClient(ctx, httpOptions...)), option.WithMiddleware(oaistream.ErrorBodyMiddleware()), ) diff --git a/pkg/modelsdev/store.go b/pkg/modelsdev/store.go index e179ef9d8..c08ecceb8 100644 --- a/pkg/modelsdev/store.go +++ b/pkg/modelsdev/store.go @@ -13,6 +13,8 @@ import ( "strings" "sync" "time" + + "github.com/docker/docker-agent/pkg/remote" ) const ( @@ -183,7 +185,7 @@ func fetchFromAPI(ctx context.Context, etag string) (*Database, string, error) { req.Header.Set("If-None-Match", etag) } - resp, err := (&http.Client{Timeout: 30 * time.Second}).Do(req) + resp, err := (&http.Client{Timeout: 30 * time.Second, Transport: remote.NewTransport(ctx)}).Do(req) if err != nil { return nil, "", fmt.Errorf("failed to fetch from API: %w", err) } diff --git a/pkg/remote/pull.go b/pkg/remote/pull.go index 1b34bc9d3..156129a3a 100644 --- a/pkg/remote/pull.go +++ b/pkg/remote/pull.go @@ -12,7 +12,7 @@ import ( // Pull pulls an artifact from a registry and stores it in the content store func Pull(ctx context.Context, registryRef string, force bool, opts ...crane.Option) (string, error) { - opts = append(opts, crane.WithContext(ctx)) + opts = append(opts, crane.WithContext(ctx), crane.WithTransport(NewTransport(ctx))) ref, err := name.ParseReference(registryRef) if err != nil { diff --git a/pkg/remote/pull_test.go b/pkg/remote/pull_test.go index 1e9103aac..798b2ae8c 100644 --- a/pkg/remote/pull_test.go +++ b/pkg/remote/pull_test.go @@ -68,7 +68,7 @@ func TestPullIntegration(t *testing.T) { require.NoError(t, err) assert.NotNil(t, retrievedImg) - err = Push("invalid:reference:with:too:many:colons") + err = Push(t.Context(), "invalid:reference:with:too:many:colons") require.Error(t, err) } diff --git a/pkg/remote/push.go b/pkg/remote/push.go index faab866c6..0f7db9555 100644 --- a/pkg/remote/push.go +++ b/pkg/remote/push.go @@ -1,6 +1,7 @@ package remote import ( + "context" "fmt" "github.com/google/go-containerregistry/pkg/crane" @@ -13,7 +14,7 @@ import ( ) // Push pushes an artifact from the content store to an OCI registry -func Push(reference string) error { +func Push(ctx context.Context, reference string) error { store, err := content.NewStore() if err != nil { return fmt.Errorf("creating content store: %w", err) @@ -45,7 +46,7 @@ func Push(reference string) error { return fmt.Errorf("parsing registry reference %s: %w", reference, err) } - if err := crane.Push(img, ref.String()); err != nil { + if err := crane.Push(img, ref.String(), crane.WithContext(ctx), crane.WithTransport(NewTransport(ctx))); err != nil { return fmt.Errorf("pushing image to registry %s: %w", reference, err) } diff --git a/pkg/remote/push_test.go b/pkg/remote/push_test.go index 6b877947c..4e3b46ac1 100644 --- a/pkg/remote/push_test.go +++ b/pkg/remote/push_test.go @@ -38,18 +38,18 @@ func TestPush(t *testing.T) { require.NoError(t, err) assert.NotNil(t, loadedImg) - err = Push("invalid:reference:with:too:many:colons") + err = Push(t.Context(), "invalid:reference:with:too:many:colons") require.Error(t, err) - err = Push("invalid:reference:with:too:many:colons") + err = Push(t.Context(), "invalid:reference:with:too:many:colons") require.Error(t, err) } func TestPushNonExistentArtifact(t *testing.T) { - err := Push("registry.example.com/test:latest") + err := Push(t.Context(), "registry.example.com/test:latest") require.Error(t, err) - err = Push("registry.example.com/test:latest") + err = Push(t.Context(), "registry.example.com/test:latest") require.Error(t, err) } diff --git a/pkg/remote/transport.go b/pkg/remote/transport.go new file mode 100644 index 000000000..e7dfe33cb --- /dev/null +++ b/pkg/remote/transport.go @@ -0,0 +1,43 @@ +package remote + +import ( + "context" + "net" + "net/http" + "net/url" + "time" + + "github.com/kofalt/go-memoize" + + "github.com/docker/docker-agent/pkg/desktop" + socket "github.com/docker/docker-agent/pkg/desktop/socket" +) + +var memoizer = memoize.NewMemoizer(1*time.Minute, 1*time.Minute) + +// NewTransport returns an HTTP transport that uses Docker Desktop proxy if available. +func NewTransport(ctx context.Context) http.RoundTripper { + t, ok := http.DefaultTransport.(*http.Transport) + if !ok { + return http.DefaultTransport + } + transport := t.Clone() + + desktopRunning, err, _ := memoizer.Memoize("desktopRunning", func() (any, error) { + return desktop.IsDockerDesktopRunning(context.Background()), nil + }) + if err != nil { + return transport + } + if running, ok := desktopRunning.(bool); ok && running { + transport.Proxy = http.ProxyURL(&url.URL{ + Scheme: "http", + }) + // Override the dialer to connect to the Unix socket for the proxy + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return socket.DialUnix(ctx, desktop.Paths().ProxySocket) + } + } + + return transport +} diff --git a/pkg/remote/transport_test.go b/pkg/remote/transport_test.go new file mode 100644 index 000000000..b67da73ac --- /dev/null +++ b/pkg/remote/transport_test.go @@ -0,0 +1,56 @@ +package remote + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/desktop" +) + +func TestNewTransport_UsesDesktopProxyWhenAvailable(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + // Create a transport + transport := NewTransport(ctx) + require.NotNil(t, transport) + + // Verify that it's an http.Transport + httpTransport, ok := transport.(*http.Transport) + require.True(t, ok, "transport should be *http.Transport") + + // If Docker Desktop is running, verify proxy is configured + if desktop.IsDockerDesktopRunning(ctx) { + assert.NotNil(t, httpTransport.Proxy, "proxy should be configured when Docker Desktop is running") + assert.NotNil(t, httpTransport.DialContext, "custom DialContext should be set when Docker Desktop is running") + } +} + +func TestNewTransport_WorksWithoutDesktopProxy(t *testing.T) { + t.Parallel() + + // Create a test server to simulate a registry + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + ctx := t.Context() + + // Create a transport (should work whether Desktop is running or not) + transport := NewTransport(ctx) + require.NotNil(t, transport) + + // Make a simple HTTP request to verify the transport works + client := &http.Client{Transport: transport} + resp, err := client.Get(server.URL) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/pkg/skills/cache.go b/pkg/skills/cache.go index 7672871eb..c0ba466a0 100644 --- a/pkg/skills/cache.go +++ b/pkg/skills/cache.go @@ -1,6 +1,7 @@ package skills import ( + "context" "crypto/sha256" "encoding/hex" "encoding/json" @@ -13,11 +14,12 @@ import ( "strconv" "strings" "time" + + "github.com/docker/docker-agent/pkg/remote" ) type diskCache struct { - baseDir string - httpClient *http.Client + baseDir string } type cacheMetadata struct { @@ -29,9 +31,6 @@ type cacheMetadata struct { func newDiskCache(baseDir string) *diskCache { return &diskCache{ baseDir: baseDir, - httpClient: &http.Client{ - Timeout: 30 * time.Second, - }, } } @@ -68,10 +67,14 @@ func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) { // FetchAndStore downloads a file from the given URL and stores it in the cache. // It respects Cache-Control headers to determine expiry. -func (c *diskCache) FetchAndStore(baseURL, skillName, filePath, fileURL string) (string, error) { +func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, filePath, fileURL string) (string, error) { slog.Debug("Fetching remote skill file", "url", fileURL) - resp, err := c.httpClient.Get(fileURL) + httpClient := &http.Client{ + Timeout: 30 * time.Second, + Transport: remote.NewTransport(ctx), + } + resp, err := httpClient.Get(fileURL) if err != nil { return "", fmt.Errorf("fetching %s: %w", fileURL, err) } diff --git a/pkg/skills/cache_test.go b/pkg/skills/cache_test.go index b3cfff518..5ec753ed6 100644 --- a/pkg/skills/cache_test.go +++ b/pkg/skills/cache_test.go @@ -22,7 +22,7 @@ func TestDiskCache_FetchAndStore(t *testing.T) { cache := newDiskCache(t.TempDir()) - content, err := cache.FetchAndStore("https://example.com", "my-skill", "SKILL.md", srv.URL+"/SKILL.md") + content, err := cache.FetchAndStore(t.Context(), "https://example.com", "my-skill", "SKILL.md", srv.URL+"/SKILL.md") require.NoError(t, err) assert.Equal(t, "file content", content) @@ -54,7 +54,7 @@ func TestDiskCache_Get_Cached(t *testing.T) { cache := newDiskCache(t.TempDir()) - _, err := cache.FetchAndStore("https://example.com", "skill", "SKILL.md", srv.URL+"/SKILL.md") + _, err := cache.FetchAndStore(t.Context(), "https://example.com", "skill", "SKILL.md", srv.URL+"/SKILL.md") require.NoError(t, err) content, ok := cache.Get("https://example.com", "skill", "SKILL.md") @@ -71,7 +71,7 @@ func TestDiskCache_Get_Expired(t *testing.T) { cache := newDiskCache(t.TempDir()) - _, err := cache.FetchAndStore("https://example.com", "skill", "SKILL.md", srv.URL+"/SKILL.md") + _, err := cache.FetchAndStore(t.Context(), "https://example.com", "skill", "SKILL.md", srv.URL+"/SKILL.md") require.NoError(t, err) // The max-age=0 should make it immediately expired @@ -87,7 +87,7 @@ func TestDiskCache_NestedFiles(t *testing.T) { cache := newDiskCache(t.TempDir()) - content, err := cache.FetchAndStore("https://example.com", "my-skill", "references/FORMS.md", srv.URL+"/file") + content, err := cache.FetchAndStore(t.Context(), "https://example.com", "my-skill", "references/FORMS.md", srv.URL+"/file") require.NoError(t, err) assert.Equal(t, "nested file content", content) @@ -152,7 +152,7 @@ func TestDiskCache_HTTPError(t *testing.T) { cache := newDiskCache(t.TempDir()) - _, err := cache.FetchAndStore("https://example.com", "skill", "SKILL.md", srv.URL+"/notfound") + _, err := cache.FetchAndStore(t.Context(), "https://example.com", "skill", "SKILL.md", srv.URL+"/notfound") require.Error(t, err) assert.Contains(t, err.Error(), "HTTP 404") } diff --git a/pkg/skills/remote.go b/pkg/skills/remote.go index c14029cd1..763dc9c9e 100644 --- a/pkg/skills/remote.go +++ b/pkg/skills/remote.go @@ -1,6 +1,7 @@ package skills import ( + "context" "encoding/json" "fmt" "io" @@ -11,6 +12,7 @@ import ( "time" "github.com/docker/docker-agent/pkg/paths" + "github.com/docker/docker-agent/pkg/remote" ) // remoteIndex represents the index.json served at /.well-known/skills/index.json @@ -24,10 +26,6 @@ type remoteSkillEntry struct { Files []string `json:"files"` } -var defaultHTTPClient = &http.Client{ - Timeout: 30 * time.Second, -} - func defaultCache() *diskCache { return newDiskCache(filepath.Join(paths.GetCacheDir(), "skills")) } @@ -37,16 +35,20 @@ func defaultCache() *diskCache { // into a disk cache so the agent can read them without network requests during // task execution. func loadRemoteSkills(baseURL string) []Skill { - return loadRemoteSkillsWithCache(baseURL, defaultCache()) + return loadRemoteSkillsWithCache(context.Background(), baseURL, defaultCache()) } -func loadRemoteSkillsWithCache(baseURL string, cache *diskCache) []Skill { +func loadRemoteSkillsWithCache(ctx context.Context, baseURL string, cache *diskCache) []Skill { baseURL = strings.TrimRight(baseURL, "/") indexURL := baseURL + "/.well-known/skills/index.json" slog.Debug("Fetching remote skills index", "url", indexURL) - resp, err := defaultHTTPClient.Get(indexURL) + httpClient := &http.Client{ + Timeout: 30 * time.Second, + Transport: remote.NewTransport(ctx), + } + resp, err := httpClient.Get(indexURL) if err != nil { slog.Warn("Failed to fetch remote skills index", "url", indexURL, "error", err) return nil @@ -77,7 +79,7 @@ func loadRemoteSkillsWithCache(baseURL string, cache *diskCache) []Skill { } cacheDir := cache.cacheDir(baseURL, entry.Name) - prefetchFiles(cache, baseURL, entry.Name, entry.Files) + prefetchFiles(ctx, cache, baseURL, entry.Name, entry.Files) skill := Skill{ Name: entry.Name, @@ -96,7 +98,7 @@ func loadRemoteSkillsWithCache(baseURL string, cache *diskCache) []Skill { // prefetchFiles downloads all files listed in the index for a skill, // storing them in the disk cache. Files already in cache (and not expired) // are skipped. -func prefetchFiles(cache *diskCache, baseURL, skillName string, files []string) { +func prefetchFiles(ctx context.Context, cache *diskCache, baseURL, skillName string, files []string) { for _, file := range files { if !isValidFilePath(file) { slog.Debug("Skipping invalid file path in skill", "skill", skillName, "file", file) @@ -108,7 +110,7 @@ func prefetchFiles(cache *diskCache, baseURL, skillName string, files []string) } fileURL := fmt.Sprintf("%s/.well-known/skills/%s/%s", baseURL, skillName, file) - if _, err := cache.FetchAndStore(baseURL, skillName, file, fileURL); err != nil { + if _, err := cache.FetchAndStore(ctx, baseURL, skillName, file, fileURL); err != nil { slog.Warn("Failed to prefetch skill file", "skill", skillName, "file", file, "error", err) } } diff --git a/pkg/skills/remote_test.go b/pkg/skills/remote_test.go index 0c6c66574..06433f659 100644 --- a/pkg/skills/remote_test.go +++ b/pkg/skills/remote_test.go @@ -47,7 +47,7 @@ func TestLoadRemoteSkills(t *testing.T) { cacheDir := t.TempDir() cache := newDiskCache(cacheDir) - skills := loadRemoteSkillsWithCache(srv.URL, cache) + skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, cache) require.Len(t, skills, 2) @@ -84,7 +84,7 @@ func TestLoadRemoteSkills(t *testing.T) { defer srv.Close() cache := newDiskCache(t.TempDir()) - skills := loadRemoteSkillsWithCache(srv.URL+"/", cache) + skills := loadRemoteSkillsWithCache(t.Context(), srv.URL+"/", cache) require.Len(t, skills, 1) content, err := os.ReadFile(skills[0].FilePath) @@ -99,7 +99,7 @@ func TestLoadRemoteSkills(t *testing.T) { })) defer srv.Close() - skills := loadRemoteSkillsWithCache(srv.URL, newDiskCache(t.TempDir())) + skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, newDiskCache(t.TempDir())) assert.Empty(t, skills) }) @@ -110,7 +110,7 @@ func TestLoadRemoteSkills(t *testing.T) { })) defer srv.Close() - skills := loadRemoteSkillsWithCache(srv.URL, newDiskCache(t.TempDir())) + skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, newDiskCache(t.TempDir())) assert.Empty(t, skills) }) @@ -121,7 +121,7 @@ func TestLoadRemoteSkills(t *testing.T) { })) defer srv.Close() - skills := loadRemoteSkillsWithCache(srv.URL, newDiskCache(t.TempDir())) + skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, newDiskCache(t.TempDir())) assert.Empty(t, skills) }) @@ -129,7 +129,7 @@ func TestLoadRemoteSkills(t *testing.T) { srv := httptest.NewServer(http.NotFoundHandler()) defer srv.Close() - skills := loadRemoteSkillsWithCache(srv.URL, newDiskCache(t.TempDir())) + skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, newDiskCache(t.TempDir())) assert.Empty(t, skills) }) @@ -139,12 +139,12 @@ func TestLoadRemoteSkills(t *testing.T) { })) defer srv.Close() - skills := loadRemoteSkillsWithCache(srv.URL, newDiskCache(t.TempDir())) + skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, newDiskCache(t.TempDir())) assert.Empty(t, skills) }) t.Run("unreachable server", func(t *testing.T) { - skills := loadRemoteSkillsWithCache("http://127.0.0.1:1", newDiskCache(t.TempDir())) + skills := loadRemoteSkillsWithCache(t.Context(), "http://127.0.0.1:1", newDiskCache(t.TempDir())) assert.Empty(t, skills) }) @@ -168,12 +168,12 @@ func TestLoadRemoteSkills(t *testing.T) { cache := newDiskCache(t.TempDir()) // First load - skills1 := loadRemoteSkillsWithCache(srv.URL, cache) + skills1 := loadRemoteSkillsWithCache(t.Context(), srv.URL, cache) require.Len(t, skills1, 1) assert.Equal(t, 2, fetchCount) // index.json + SKILL.md // Second load — SKILL.md should be cached - skills2 := loadRemoteSkillsWithCache(srv.URL, cache) + skills2 := loadRemoteSkillsWithCache(t.Context(), srv.URL, cache) require.Len(t, skills2, 1) assert.Equal(t, 3, fetchCount) // only index.json re-fetched, SKILL.md from cache }) @@ -193,7 +193,7 @@ func TestLoadRemoteSkills(t *testing.T) { defer srv.Close() cache := newDiskCache(t.TempDir()) - skills := loadRemoteSkillsWithCache(srv.URL, cache) + skills := loadRemoteSkillsWithCache(t.Context(), srv.URL, cache) require.Len(t, skills, 1) // Only SKILL.md should have been fetched, not the malicious paths }) diff --git a/pkg/tools/a2a/a2a.go b/pkg/tools/a2a/a2a.go index f61eff2d2..30d0576a7 100644 --- a/pkg/tools/a2a/a2a.go +++ b/pkg/tools/a2a/a2a.go @@ -123,7 +123,7 @@ func (t *Toolset) Start(ctx context.Context) error { // Use a longer timeout for the HTTP client since LLM responses can take a while. // The default a2a-go HTTP client has only a 5-second timeout which is too short. - httpClient := httpclient.NewHTTPClient() + httpClient := httpclient.NewHTTPClient(ctx) httpClient.Transport = upstream.NewHeaderTransport(httpClient.Transport, t.headers) client, err := a2aclient.NewFromCard(ctx, card, a2aclient.WithJSONRPCTransport(httpClient)) diff --git a/pkg/tools/builtin/api.go b/pkg/tools/builtin/api.go index 283293b29..3b6dba0a5 100644 --- a/pkg/tools/builtin/api.go +++ b/pkg/tools/builtin/api.go @@ -14,6 +14,7 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/js" + "github.com/docker/docker-agent/pkg/remote" "github.com/docker/docker-agent/pkg/tools" ) @@ -30,7 +31,8 @@ var ( func (t *APITool) callTool(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { client := &http.Client{ - Timeout: 30 * time.Second, + Timeout: 30 * time.Second, + Transport: remote.NewTransport(ctx), } endpoint := t.config.Endpoint diff --git a/pkg/tools/builtin/fetch.go b/pkg/tools/builtin/fetch.go index 770458f0a..7f3837e70 100644 --- a/pkg/tools/builtin/fetch.go +++ b/pkg/tools/builtin/fetch.go @@ -14,6 +14,7 @@ import ( "github.com/k3a/html2text" "github.com/temoto/robotstxt" + "github.com/docker/docker-agent/pkg/remote" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/useragent" ) @@ -49,7 +50,8 @@ func (h *fetchHandler) CallTool(ctx context.Context, params FetchToolArgs) (*too // Set timeout if specified client := &http.Client{ - Timeout: h.timeout, + Timeout: h.timeout, + Transport: remote.NewTransport(ctx), } if params.Timeout > 0 { client.Timeout = time.Duration(params.Timeout) * time.Second diff --git a/pkg/tools/builtin/openapi.go b/pkg/tools/builtin/openapi.go index fb53da9cf..784627973 100644 --- a/pkg/tools/builtin/openapi.go +++ b/pkg/tools/builtin/openapi.go @@ -15,6 +15,7 @@ import ( "github.com/getkin/kin-openapi/openapi3" + "github.com/docker/docker-agent/pkg/remote" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/upstream" "github.com/docker/docker-agent/pkg/useragent" @@ -70,7 +71,7 @@ func (t *OpenAPITool) fetchSpec(ctx context.Context) (*openapi3.T, error) { req.Header.Set("Accept", "application/json") setHeaders(req, t.headers) - resp, err := (&http.Client{Timeout: httpTimeout}).Do(req) + resp, err := (&http.Client{Timeout: httpTimeout, Transport: remote.NewTransport(ctx)}).Do(req) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } @@ -398,7 +399,7 @@ func (h *openAPIHandler) callTool(ctx context.Context, params openAPICallArgs) ( req.Header.Set("Accept", "application/json") setHeaders(req, h.headers) - resp, err := (&http.Client{Timeout: httpTimeout}).Do(req) + resp, err := (&http.Client{Timeout: httpTimeout, Transport: remote.NewTransport(ctx)}).Do(req) if err != nil { return nil, fmt.Errorf("request failed: %w", err) }