diff --git a/extension/transport/registry.go b/extension/transport/registry.go new file mode 100644 index 000000000..d034b14b3 --- /dev/null +++ b/extension/transport/registry.go @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package transport + +import "sync" + +var ( + mu sync.Mutex + provider Provider +) + +// Register registers a transport Provider. +// Later registrations override earlier ones. +// Typically called from init() via blank import. +func Register(p Provider) { + mu.Lock() + defer mu.Unlock() + provider = p +} + +// GetProvider returns the currently registered Provider. +// Returns nil if no provider has been registered. +func GetProvider() Provider { + mu.Lock() + defer mu.Unlock() + return provider +} diff --git a/extension/transport/registry_test.go b/extension/transport/registry_test.go new file mode 100644 index 000000000..836cbca14 --- /dev/null +++ b/extension/transport/registry_test.go @@ -0,0 +1,77 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package transport + +import ( + "context" + "net/http" + "testing" +) + +type stubInterceptor struct{} + +func (s *stubInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) { + return nil +} + +type stubProvider struct { + name string +} + +func (s *stubProvider) Name() string { return s.name } +func (s *stubProvider) ResolveInterceptor(context.Context) Interceptor { return &stubInterceptor{} } + +func TestGetProvider_NilByDefault(t *testing.T) { + mu.Lock() + provider = nil + mu.Unlock() + + if got := GetProvider(); got != nil { + t.Fatalf("expected nil, got %v", got) + } +} + +func TestRegisterAndGet(t *testing.T) { + mu.Lock() + provider = nil + mu.Unlock() + + p := &stubProvider{name: "a"} + Register(p) + + got := GetProvider() + if got != p { + t.Fatalf("expected registered provider, got %v", got) + } +} + +func TestLastRegistrationWins(t *testing.T) { + mu.Lock() + provider = nil + mu.Unlock() + + a := &stubProvider{name: "a"} + b := &stubProvider{name: "b"} + Register(a) + Register(b) + + got := GetProvider() + if got != b { + t.Fatalf("expected provider b, got %v", got) + } +} + +func TestResolveInterceptor_ReturnsNonNil(t *testing.T) { + mu.Lock() + provider = nil + mu.Unlock() + + p := &stubProvider{name: "test"} + Register(p) + + ic := GetProvider().ResolveInterceptor(context.Background()) + if ic == nil { + t.Fatal("expected non-nil Interceptor") + } +} diff --git a/extension/transport/types.go b/extension/transport/types.go new file mode 100644 index 000000000..e60c4018d --- /dev/null +++ b/extension/transport/types.go @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package transport + +import ( + "context" + "net/http" +) + +// Provider creates Interceptor instances. +// Follows the same API style as extension/credential.Provider and extension/fileio.Provider. +type Provider interface { + Name() string + ResolveInterceptor(ctx context.Context) Interceptor +} + +// Interceptor defines network-layer customization via a pre/post hook pair. +// The built-in transport chain always executes between PreRoundTrip and the +// returned post function, and cannot be skipped or overridden by the extension. +// +// PreRoundTrip is called before the built-in chain. Use it to add custom +// headers, rewrite the host, or start trace spans. Built-in decorators run +// after this and will override any same-named security headers set here. +// The extension must not replace req.Context() — the middleware restores +// the original context after PreRoundTrip returns. +// +// The returned function (if non-nil) is called after the built-in chain +// completes. Use it for logging, ending trace spans, or recording metrics. +type Interceptor interface { + PreRoundTrip(req *http.Request) func(resp *http.Response, err error) +} diff --git a/internal/cmdutil/factory_default.go b/internal/cmdutil/factory_default.go index 5b08a05cb..8c8ea4f88 100644 --- a/internal/cmdutil/factory_default.go +++ b/internal/cmdutil/factory_default.go @@ -95,8 +95,8 @@ func cachedHttpClientFunc() func() (*http.Client, error) { var transport http.RoundTripper = util.NewBaseTransport() transport = &RetryTransport{Base: transport} transport = &SecurityHeaderTransport{Base: transport} - transport = &auth.SecurityPolicyTransport{Base: transport} // Add our global response interceptor + transport = wrapWithExtension(transport) client := &http.Client{ Transport: transport, Timeout: 30 * time.Second, @@ -133,7 +133,7 @@ func buildSDKTransport() http.RoundTripper { sdkTransport = &RetryTransport{Base: sdkTransport} sdkTransport = &UserAgentTransport{Base: sdkTransport} sdkTransport = &auth.SecurityPolicyTransport{Base: sdkTransport} - return sdkTransport + return wrapWithExtension(sdkTransport) } type credentialDeps struct { diff --git a/internal/cmdutil/factory_default_test.go b/internal/cmdutil/factory_default_test.go index 5f4d60014..9d01e82f4 100644 --- a/internal/cmdutil/factory_default_test.go +++ b/internal/cmdutil/factory_default_test.go @@ -6,9 +6,12 @@ package cmdutil import ( "context" "errors" + "net/http" + "net/http/httptest" "testing" _ "github.com/larksuite/cli/extension/credential/env" + exttransport "github.com/larksuite/cli/extension/transport" internalauth "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/credential" @@ -194,3 +197,170 @@ func TestNewDefault_ConfigUsesRuntimePlaceholderForTokenOnlyEnvAccount(t *testin t.Fatalf("Config().AppSecret = %q, want token-only no-secret marker", cfg.AppSecret) } } + +type stubTransportProvider struct { + interceptor exttransport.Interceptor +} + +func (s *stubTransportProvider) Name() string { return "stub" } +func (s *stubTransportProvider) ResolveInterceptor(context.Context) exttransport.Interceptor { + if s.interceptor != nil { + return s.interceptor + } + return &stubTransportImpl{} +} + +type stubTransportImpl struct{} + +func (s *stubTransportImpl) PreRoundTrip(req *http.Request) func(*http.Response, error) { + return nil +} + +// headerCapturingInterceptor sets custom headers in PreRoundTrip and records +// whether PostRoundTrip was called, to verify execution order. +type headerCapturingInterceptor struct { + preCalled bool + postCalled bool +} + +func (h *headerCapturingInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) { + h.preCalled = true + // Set a custom header that should survive (no built-in override) + req.Header.Set("X-Custom-Trace", "ext-trace-123") + // Try to override a security header — should be overwritten by SecurityHeaderTransport + req.Header.Set(HeaderSource, "ext-tampered") + return func(resp *http.Response, err error) { + h.postCalled = true + } +} + +func TestExtensionInterceptor_ExecutionOrder(t *testing.T) { + var receivedHeaders http.Header + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ic := &headerCapturingInterceptor{} + exttransport.Register(&stubTransportProvider{interceptor: ic}) + t.Cleanup(func() { exttransport.Register(nil) }) + + // Use HTTP transport chain (has SecurityHeaderTransport) + var base http.RoundTripper = http.DefaultTransport + base = &RetryTransport{Base: base} + base = &SecurityHeaderTransport{Base: base} + transport := wrapWithExtension(base) + client := &http.Client{Transport: transport} + + req, _ := http.NewRequest("GET", srv.URL, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + // PreRoundTrip was called + if !ic.preCalled { + t.Fatal("PreRoundTrip was not called") + } + // PostRoundTrip (closure) was called + if !ic.postCalled { + t.Fatal("PostRoundTrip closure was not called") + } + // Custom header set by extension survives (no built-in override) + if got := receivedHeaders.Get("X-Custom-Trace"); got != "ext-trace-123" { + t.Fatalf("X-Custom-Trace = %q, want %q", got, "ext-trace-123") + } + // Security header overridden by extension is restored by SecurityHeaderTransport + if got := receivedHeaders.Get(HeaderSource); got != SourceValue { + t.Fatalf("%s = %q, want %q (built-in should override extension)", HeaderSource, got, SourceValue) + } +} + +func TestExtensionInterceptor_ContextTamperPrevented(t *testing.T) { + type ctxKeyType string + const testKey ctxKeyType = "original" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + var ctxValue any + + // Use a custom transport that captures the context value seen by the built-in chain + capturer := roundTripFunc(func(req *http.Request) (*http.Response, error) { + ctxValue = req.Context().Value(testKey) + return http.DefaultTransport.RoundTrip(req) + }) + + // Interceptor that tries to tamper with context + tamperIC := interceptorFunc(func(req *http.Request) func(*http.Response, error) { + // Try to replace context with a new one + *req = *req.WithContext(context.WithValue(req.Context(), testKey, "tampered")) + return nil + }) + + mid := &extensionMiddleware{Base: capturer, Ext: tamperIC} + + origCtx := context.WithValue(context.Background(), testKey, "original") + req, _ := http.NewRequestWithContext(origCtx, "GET", srv.URL, nil) + resp, err := mid.RoundTrip(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + // Built-in chain should see original context, not tampered + if ctxValue != "original" { + t.Fatalf("built-in chain saw context value %q, want %q", ctxValue, "original") + } +} + +// interceptorFunc adapts a function to exttransport.Interceptor. +type interceptorFunc func(*http.Request) func(*http.Response, error) + +func (f interceptorFunc) PreRoundTrip(req *http.Request) func(*http.Response, error) { return f(req) } + +func TestBuildSDKTransport_WithExtension(t *testing.T) { + exttransport.Register(&stubTransportProvider{}) + t.Cleanup(func() { exttransport.Register(nil) }) + + transport := buildSDKTransport() + + // Chain: extensionMiddleware → SecurityPolicy → UserAgent → Retry → Base + mid, ok := transport.(*extensionMiddleware) + if !ok { + t.Fatalf("outer transport type = %T, want *extensionMiddleware", transport) + } + sec, ok := mid.Base.(*internalauth.SecurityPolicyTransport) + if !ok { + t.Fatalf("transport type = %T, want *auth.SecurityPolicyTransport", mid.Base) + } + ua, ok := sec.Base.(*UserAgentTransport) + if !ok { + t.Fatalf("transport type = %T, want *UserAgentTransport", sec.Base) + } + if _, ok := ua.Base.(*RetryTransport); !ok { + t.Fatalf("innermost transport type = %T, want *RetryTransport", ua.Base) + } +} + +func TestBuildSDKTransport_WithoutExtension(t *testing.T) { + exttransport.Register(nil) + + transport := buildSDKTransport() + + sec, ok := transport.(*internalauth.SecurityPolicyTransport) + if !ok { + t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport) + } + ua, ok := sec.Base.(*UserAgentTransport) + if !ok { + t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base) + } + if _, ok := ua.Base.(*RetryTransport); !ok { + t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base) + } +} diff --git a/internal/cmdutil/transport.go b/internal/cmdutil/transport.go index 366fc7ca3..b922adf11 100644 --- a/internal/cmdutil/transport.go +++ b/internal/cmdutil/transport.go @@ -4,9 +4,11 @@ package cmdutil import ( + "context" "net/http" "time" + exttransport "github.com/larksuite/cli/extension/transport" "github.com/larksuite/cli/internal/util" ) @@ -100,3 +102,40 @@ func (t *SecurityHeaderTransport) RoundTrip(req *http.Request) (*http.Response, } return t.base().RoundTrip(req) } + +// extensionMiddleware wraps the built-in transport chain with pre/post hooks. +// The built-in chain always executes and cannot be skipped or overridden. +// The original request context is restored after PreRoundTrip to prevent +// extensions from tampering with cancellation, deadlines, or built-in values. +type extensionMiddleware struct { + Base http.RoundTripper + Ext exttransport.Interceptor +} + +// RoundTrip calls PreRoundTrip, restores the original context, executes +// the built-in chain, then calls the post hook if non-nil. +func (m *extensionMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { + origCtx := req.Context() + req = req.Clone(origCtx) // isolate caller's request before extension mutations + post := m.Ext.PreRoundTrip(req) + req = req.WithContext(origCtx) // restore original context + resp, err := m.Base.RoundTrip(req) + if post != nil { + post(resp, err) + } + return resp, err +} + +// wrapWithExtension wraps transport with the registered extension middleware. +// If no extension is registered, returns transport unchanged. +func wrapWithExtension(transport http.RoundTripper) http.RoundTripper { + p := exttransport.GetProvider() + if p == nil { + return transport + } + tr := p.ResolveInterceptor(context.Background()) + if tr == nil { + return transport + } + return &extensionMiddleware{Base: transport, Ext: tr} +}