From 65075a9652f2ccddb4ced370358901f69a207e54 Mon Sep 17 00:00:00 2001 From: liushiyao Date: Tue, 7 Apr 2026 17:42:40 +0800 Subject: [PATCH 1/5] feat: add transport extension with interceptor pre/post hooks Add extension/transport package following the same Provider pattern as credential and fileio extensions. The Interceptor interface uses a PreRoundTrip/post-closure design that guarantees built-in transport decorators (SecurityHeader, SecurityPolicy, Retry) cannot be skipped, overridden, or tampered with by extensions. The original request context is restored after PreRoundTrip to prevent context tampering. Change-Id: I2e51ff67a0e2d8d32944a0565c2a6781110f281f Co-Authored-By: Claude Opus 4.6 --- extension/transport/registry.go | 28 +++++++++ extension/transport/registry_test.go | 77 ++++++++++++++++++++++++ extension/transport/types.go | 32 ++++++++++ internal/cmdutil/factory_default.go | 4 +- internal/cmdutil/factory_default_test.go | 58 ++++++++++++++++++ internal/cmdutil/transport.go | 38 ++++++++++++ 6 files changed, 235 insertions(+), 2 deletions(-) create mode 100644 extension/transport/registry.go create mode 100644 extension/transport/registry_test.go create mode 100644 extension/transport/types.go diff --git a/extension/transport/registry.go b/extension/transport/registry.go new file mode 100644 index 000000000..d034b14b3 --- /dev/null +++ b/extension/transport/registry.go @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package transport + +import "sync" + +var ( + mu sync.Mutex + provider Provider +) + +// Register registers a transport Provider. +// Later registrations override earlier ones. +// Typically called from init() via blank import. +func Register(p Provider) { + mu.Lock() + defer mu.Unlock() + provider = p +} + +// GetProvider returns the currently registered Provider. +// Returns nil if no provider has been registered. +func GetProvider() Provider { + mu.Lock() + defer mu.Unlock() + return provider +} diff --git a/extension/transport/registry_test.go b/extension/transport/registry_test.go new file mode 100644 index 000000000..8c4020834 --- /dev/null +++ b/extension/transport/registry_test.go @@ -0,0 +1,77 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package transport + +import ( + "context" + "net/http" + "testing" +) + +type stubInterceptor struct{} + +func (s *stubInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) { + return nil +} + +type stubProvider struct { + name string +} + +func (s *stubProvider) Name() string { return s.name } +func (s *stubProvider) ResolveInterceptor(context.Context) Interceptor { return &stubInterceptor{} } + +func TestGetProvider_NilByDefault(t *testing.T) { + mu.Lock() + provider = nil + mu.Unlock() + + if got := GetProvider(); got != nil { + t.Fatalf("expected nil, got %v", got) + } +} + +func TestRegisterAndGet(t *testing.T) { + mu.Lock() + provider = nil + mu.Unlock() + + p := &stubProvider{name: "a"} + Register(p) + + got := GetProvider() + if got != p { + t.Fatalf("expected registered provider, got %v", got) + } +} + +func TestLastRegistrationWins(t *testing.T) { + mu.Lock() + provider = nil + mu.Unlock() + + a := &stubProvider{name: "a"} + b := &stubProvider{name: "b"} + Register(a) + Register(b) + + got := GetProvider() + if got != b { + t.Fatalf("expected provider b, got %v", got) + } +} + +func TestResolveInterceptor_ReturnsNonNil(t *testing.T) { + mu.Lock() + provider = nil + mu.Unlock() + + p := &stubProvider{name: "test"} + Register(p) + + ic := GetProvider().ResolveInterceptor(context.Background()) + if ic == nil { + t.Fatal("expected non-nil Interceptor") + } +} diff --git a/extension/transport/types.go b/extension/transport/types.go new file mode 100644 index 000000000..e60c4018d --- /dev/null +++ b/extension/transport/types.go @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package transport + +import ( + "context" + "net/http" +) + +// Provider creates Interceptor instances. +// Follows the same API style as extension/credential.Provider and extension/fileio.Provider. +type Provider interface { + Name() string + ResolveInterceptor(ctx context.Context) Interceptor +} + +// Interceptor defines network-layer customization via a pre/post hook pair. +// The built-in transport chain always executes between PreRoundTrip and the +// returned post function, and cannot be skipped or overridden by the extension. +// +// PreRoundTrip is called before the built-in chain. Use it to add custom +// headers, rewrite the host, or start trace spans. Built-in decorators run +// after this and will override any same-named security headers set here. +// The extension must not replace req.Context() — the middleware restores +// the original context after PreRoundTrip returns. +// +// The returned function (if non-nil) is called after the built-in chain +// completes. Use it for logging, ending trace spans, or recording metrics. +type Interceptor interface { + PreRoundTrip(req *http.Request) func(resp *http.Response, err error) +} diff --git a/internal/cmdutil/factory_default.go b/internal/cmdutil/factory_default.go index 5b08a05cb..8c8ea4f88 100644 --- a/internal/cmdutil/factory_default.go +++ b/internal/cmdutil/factory_default.go @@ -95,8 +95,8 @@ func cachedHttpClientFunc() func() (*http.Client, error) { var transport http.RoundTripper = util.NewBaseTransport() transport = &RetryTransport{Base: transport} transport = &SecurityHeaderTransport{Base: transport} - transport = &auth.SecurityPolicyTransport{Base: transport} // Add our global response interceptor + transport = wrapWithExtension(transport) client := &http.Client{ Transport: transport, Timeout: 30 * time.Second, @@ -133,7 +133,7 @@ func buildSDKTransport() http.RoundTripper { sdkTransport = &RetryTransport{Base: sdkTransport} sdkTransport = &UserAgentTransport{Base: sdkTransport} sdkTransport = &auth.SecurityPolicyTransport{Base: sdkTransport} - return sdkTransport + return wrapWithExtension(sdkTransport) } type credentialDeps struct { diff --git a/internal/cmdutil/factory_default_test.go b/internal/cmdutil/factory_default_test.go index 5f4d60014..0e9f53131 100644 --- a/internal/cmdutil/factory_default_test.go +++ b/internal/cmdutil/factory_default_test.go @@ -8,7 +8,10 @@ import ( "errors" "testing" + "net/http" + _ "github.com/larksuite/cli/extension/credential/env" + exttransport "github.com/larksuite/cli/extension/transport" internalauth "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/credential" @@ -194,3 +197,58 @@ func TestNewDefault_ConfigUsesRuntimePlaceholderForTokenOnlyEnvAccount(t *testin t.Fatalf("Config().AppSecret = %q, want token-only no-secret marker", cfg.AppSecret) } } + +type stubTransportProvider struct{} + +func (s *stubTransportProvider) Name() string { return "stub" } +func (s *stubTransportProvider) ResolveInterceptor(context.Context) exttransport.Interceptor { + return &stubTransportImpl{} +} + +type stubTransportImpl struct{} + +func (s *stubTransportImpl) PreRoundTrip(req *http.Request) func(*http.Response, error) { + return nil +} + +func TestBuildSDKTransport_WithExtension(t *testing.T) { + exttransport.Register(&stubTransportProvider{}) + t.Cleanup(func() { exttransport.Register(nil) }) + + transport := buildSDKTransport() + + // Chain: extensionMiddleware → SecurityPolicy → UserAgent → Retry → Base + mid, ok := transport.(*extensionMiddleware) + if !ok { + t.Fatalf("outer transport type = %T, want *extensionMiddleware", transport) + } + sec, ok := mid.Base.(*internalauth.SecurityPolicyTransport) + if !ok { + t.Fatalf("transport type = %T, want *auth.SecurityPolicyTransport", mid.Base) + } + ua, ok := sec.Base.(*UserAgentTransport) + if !ok { + t.Fatalf("transport type = %T, want *UserAgentTransport", sec.Base) + } + if _, ok := ua.Base.(*RetryTransport); !ok { + t.Fatalf("innermost transport type = %T, want *RetryTransport", ua.Base) + } +} + +func TestBuildSDKTransport_WithoutExtension(t *testing.T) { + exttransport.Register(nil) + + transport := buildSDKTransport() + + sec, ok := transport.(*internalauth.SecurityPolicyTransport) + if !ok { + t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport) + } + ua, ok := sec.Base.(*UserAgentTransport) + if !ok { + t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base) + } + if _, ok := ua.Base.(*RetryTransport); !ok { + t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base) + } +} diff --git a/internal/cmdutil/transport.go b/internal/cmdutil/transport.go index 366fc7ca3..8d53f7edf 100644 --- a/internal/cmdutil/transport.go +++ b/internal/cmdutil/transport.go @@ -4,9 +4,11 @@ package cmdutil import ( + "context" "net/http" "time" + exttransport "github.com/larksuite/cli/extension/transport" "github.com/larksuite/cli/internal/util" ) @@ -100,3 +102,39 @@ func (t *SecurityHeaderTransport) RoundTrip(req *http.Request) (*http.Response, } return t.base().RoundTrip(req) } + +// extensionMiddleware wraps the built-in transport chain with pre/post hooks. +// The built-in chain always executes and cannot be skipped or overridden. +// The original request context is restored after PreRoundTrip to prevent +// extensions from tampering with cancellation, deadlines, or built-in values. +type extensionMiddleware struct { + Base http.RoundTripper + Ext exttransport.Interceptor +} + +// RoundTrip calls PreRoundTrip, restores the original context, executes +// the built-in chain, then calls the post hook if non-nil. +func (m *extensionMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { + origCtx := req.Context() + post := m.Ext.PreRoundTrip(req) + req = req.WithContext(origCtx) // restore original context + resp, err := m.Base.RoundTrip(req) + if post != nil { + post(resp, err) + } + return resp, err +} + +// wrapWithExtension wraps transport with the registered extension middleware. +// If no extension is registered, returns transport unchanged. +func wrapWithExtension(transport http.RoundTripper) http.RoundTripper { + p := exttransport.GetProvider() + if p == nil { + return transport + } + tr := p.ResolveInterceptor(context.Background()) + if tr == nil { + return transport + } + return &extensionMiddleware{Base: transport, Ext: tr} +} From 9f694ea24e3ce41a7521f750ce08ee8159e87bcf Mon Sep 17 00:00:00 2001 From: liushiyao Date: Tue, 7 Apr 2026 17:49:40 +0800 Subject: [PATCH 2/5] test: add behavior tests for transport extension interceptor Verify execution order (Pre/Post hooks called), security header override protection, custom header passthrough, and context tamper prevention. Change-Id: I8d126d777903e967bb5b9a1f3a07ba125f072ec6 Co-Authored-By: Claude Opus 4.6 --- internal/cmdutil/factory_default_test.go | 118 ++++++++++++++++++++++- 1 file changed, 115 insertions(+), 3 deletions(-) diff --git a/internal/cmdutil/factory_default_test.go b/internal/cmdutil/factory_default_test.go index 0e9f53131..2f4951789 100644 --- a/internal/cmdutil/factory_default_test.go +++ b/internal/cmdutil/factory_default_test.go @@ -6,9 +6,9 @@ package cmdutil import ( "context" "errors" - "testing" - "net/http" + "net/http/httptest" + "testing" _ "github.com/larksuite/cli/extension/credential/env" exttransport "github.com/larksuite/cli/extension/transport" @@ -198,10 +198,15 @@ func TestNewDefault_ConfigUsesRuntimePlaceholderForTokenOnlyEnvAccount(t *testin } } -type stubTransportProvider struct{} +type stubTransportProvider struct { + interceptor exttransport.Interceptor +} func (s *stubTransportProvider) Name() string { return "stub" } func (s *stubTransportProvider) ResolveInterceptor(context.Context) exttransport.Interceptor { + if s.interceptor != nil { + return s.interceptor + } return &stubTransportImpl{} } @@ -211,6 +216,113 @@ func (s *stubTransportImpl) PreRoundTrip(req *http.Request) func(*http.Response, return nil } +// headerCapturingInterceptor sets custom headers in PreRoundTrip and records +// whether PostRoundTrip was called, to verify execution order. +type headerCapturingInterceptor struct { + preCalled bool + postCalled bool +} + +func (h *headerCapturingInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) { + h.preCalled = true + // Set a custom header that should survive (no built-in override) + req.Header.Set("X-Custom-Trace", "ext-trace-123") + // Try to override a security header — should be overwritten by SecurityHeaderTransport + req.Header.Set(HeaderSource, "ext-tampered") + return func(resp *http.Response, err error) { + h.postCalled = true + } +} + +func TestExtensionInterceptor_ExecutionOrder(t *testing.T) { + var receivedHeaders http.Header + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ic := &headerCapturingInterceptor{} + exttransport.Register(&stubTransportProvider{interceptor: ic}) + t.Cleanup(func() { exttransport.Register(nil) }) + + // Use HTTP transport chain (has SecurityHeaderTransport) + var base http.RoundTripper = http.DefaultTransport + base = &RetryTransport{Base: base} + base = &SecurityHeaderTransport{Base: base} + transport := wrapWithExtension(base) + client := &http.Client{Transport: transport} + + req, _ := http.NewRequest("GET", srv.URL, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + // PreRoundTrip was called + if !ic.preCalled { + t.Fatal("PreRoundTrip was not called") + } + // PostRoundTrip (closure) was called + if !ic.postCalled { + t.Fatal("PostRoundTrip closure was not called") + } + // Custom header set by extension survives (no built-in override) + if got := receivedHeaders.Get("X-Custom-Trace"); got != "ext-trace-123" { + t.Fatalf("X-Custom-Trace = %q, want %q", got, "ext-trace-123") + } + // Security header overridden by extension is restored by SecurityHeaderTransport + if got := receivedHeaders.Get(HeaderSource); got != SourceValue { + t.Fatalf("%s = %q, want %q (built-in should override extension)", HeaderSource, got, SourceValue) + } +} + +func TestExtensionInterceptor_ContextTamperPrevented(t *testing.T) { + type ctxKeyType string + const testKey ctxKeyType = "original" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + var ctxSeenByBuiltIn context.Context + + // Use a custom transport that captures the context seen by the built-in chain + capturer := roundTripFunc(func(req *http.Request) (*http.Response, error) { + ctxSeenByBuiltIn = req.Context() + return http.DefaultTransport.RoundTrip(req) + }) + + // Interceptor that tries to tamper with context + tamperIC := interceptorFunc(func(req *http.Request) func(*http.Response, error) { + // Try to replace context with a new one + *req = *req.WithContext(context.WithValue(req.Context(), testKey, "tampered")) + return nil + }) + + mid := &extensionMiddleware{Base: capturer, Ext: tamperIC} + + origCtx := context.WithValue(context.Background(), testKey, "original") + req, _ := http.NewRequestWithContext(origCtx, "GET", srv.URL, nil) + resp, err := mid.RoundTrip(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + // Built-in chain should see original context, not tampered + if got := ctxSeenByBuiltIn.Value(testKey); got != "original" { + t.Fatalf("built-in chain saw context value %q, want %q", got, "original") + } +} + +// interceptorFunc adapts a function to exttransport.Interceptor. +type interceptorFunc func(*http.Request) func(*http.Response, error) + +func (f interceptorFunc) PreRoundTrip(req *http.Request) func(*http.Response, error) { return f(req) } + func TestBuildSDKTransport_WithExtension(t *testing.T) { exttransport.Register(&stubTransportProvider{}) t.Cleanup(func() { exttransport.Register(nil) }) From dce80a6824360934c831e6e1adfe9d4b1d51c5c9 Mon Sep 17 00:00:00 2001 From: liushiyao Date: Tue, 7 Apr 2026 17:52:01 +0800 Subject: [PATCH 3/5] fix: resolve gofmt and fatcontext lint issues Change-Id: I8c2265ef6bfbf6a7149df2b92db9fae2e1700c1c Co-Authored-By: Claude Opus 4.6 --- extension/transport/registry_test.go | 2 +- internal/cmdutil/factory_default_test.go | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/extension/transport/registry_test.go b/extension/transport/registry_test.go index 8c4020834..836cbca14 100644 --- a/extension/transport/registry_test.go +++ b/extension/transport/registry_test.go @@ -19,7 +19,7 @@ type stubProvider struct { name string } -func (s *stubProvider) Name() string { return s.name } +func (s *stubProvider) Name() string { return s.name } func (s *stubProvider) ResolveInterceptor(context.Context) Interceptor { return &stubInterceptor{} } func TestGetProvider_NilByDefault(t *testing.T) { diff --git a/internal/cmdutil/factory_default_test.go b/internal/cmdutil/factory_default_test.go index 2f4951789..9d01e82f4 100644 --- a/internal/cmdutil/factory_default_test.go +++ b/internal/cmdutil/factory_default_test.go @@ -287,11 +287,11 @@ func TestExtensionInterceptor_ContextTamperPrevented(t *testing.T) { })) defer srv.Close() - var ctxSeenByBuiltIn context.Context + var ctxValue any - // Use a custom transport that captures the context seen by the built-in chain + // Use a custom transport that captures the context value seen by the built-in chain capturer := roundTripFunc(func(req *http.Request) (*http.Response, error) { - ctxSeenByBuiltIn = req.Context() + ctxValue = req.Context().Value(testKey) return http.DefaultTransport.RoundTrip(req) }) @@ -313,8 +313,8 @@ func TestExtensionInterceptor_ContextTamperPrevented(t *testing.T) { resp.Body.Close() // Built-in chain should see original context, not tampered - if got := ctxSeenByBuiltIn.Value(testKey); got != "original" { - t.Fatalf("built-in chain saw context value %q, want %q", got, "original") + if ctxValue != "original" { + t.Fatalf("built-in chain saw context value %q, want %q", ctxValue, "original") } } From b0b3a6d9d5ee3844ffb1e89f81c7248a33e05afc Mon Sep 17 00:00:00 2001 From: liushiyao Date: Tue, 7 Apr 2026 18:00:15 +0800 Subject: [PATCH 4/5] fix: clone request before PreRoundTrip to isolate caller's headers Change-Id: Iafa17d03f0ffad0830b92882e69147eba8249dd3 Co-Authored-By: Claude Opus 4.6 --- internal/cmdutil/transport.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/cmdutil/transport.go b/internal/cmdutil/transport.go index 8d53f7edf..4be4afc53 100644 --- a/internal/cmdutil/transport.go +++ b/internal/cmdutil/transport.go @@ -116,6 +116,7 @@ type extensionMiddleware struct { // the built-in chain, then calls the post hook if non-nil. func (m *extensionMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { origCtx := req.Context() + req = req.Clone(origCtx) // isolate caller's request before extension mutations post := m.Ext.PreRoundTrip(req) req = req.WithContext(origCtx) // restore original context resp, err := m.Base.RoundTrip(req) From f6ca1bbff41e7c724b9242e98d59b5094dc1f110 Mon Sep 17 00:00:00 2001 From: liushiyao Date: Tue, 7 Apr 2026 18:04:46 +0800 Subject: [PATCH 5/5] style: fix gofmt formatting in transport.go Change-Id: If060b8d27fee9563a9c62ddbc67ed60021e55364 Co-Authored-By: Claude Opus 4.6 --- internal/cmdutil/transport.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/cmdutil/transport.go b/internal/cmdutil/transport.go index 4be4afc53..b922adf11 100644 --- a/internal/cmdutil/transport.go +++ b/internal/cmdutil/transport.go @@ -116,7 +116,7 @@ type extensionMiddleware struct { // the built-in chain, then calls the post hook if non-nil. func (m *extensionMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { origCtx := req.Context() - req = req.Clone(origCtx) // isolate caller's request before extension mutations + req = req.Clone(origCtx) // isolate caller's request before extension mutations post := m.Ext.PreRoundTrip(req) req = req.WithContext(origCtx) // restore original context resp, err := m.Base.RoundTrip(req)