Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions extension/transport/registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
// SPDX-License-Identifier: MIT

package transport

import "sync"

var (
mu sync.Mutex
provider Provider
)

// Register registers a transport Provider.
// Later registrations override earlier ones.
// Typically called from init() via blank import.
func Register(p Provider) {
mu.Lock()
defer mu.Unlock()
provider = p
}

// GetProvider returns the currently registered Provider.
// Returns nil if no provider has been registered.
func GetProvider() Provider {
mu.Lock()
defer mu.Unlock()
return provider
}
77 changes: 77 additions & 0 deletions extension/transport/registry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
// SPDX-License-Identifier: MIT

package transport

import (
"context"
"net/http"
"testing"
)

type stubInterceptor struct{}

func (s *stubInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) {
return nil
}

type stubProvider struct {
name string
}

func (s *stubProvider) Name() string { return s.name }
func (s *stubProvider) ResolveInterceptor(context.Context) Interceptor { return &stubInterceptor{} }

func TestGetProvider_NilByDefault(t *testing.T) {
mu.Lock()
provider = nil
mu.Unlock()

if got := GetProvider(); got != nil {
t.Fatalf("expected nil, got %v", got)
}
}

func TestRegisterAndGet(t *testing.T) {
mu.Lock()
provider = nil
mu.Unlock()

p := &stubProvider{name: "a"}
Register(p)

got := GetProvider()
if got != p {
t.Fatalf("expected registered provider, got %v", got)
}
}

func TestLastRegistrationWins(t *testing.T) {
mu.Lock()
provider = nil
mu.Unlock()

a := &stubProvider{name: "a"}
b := &stubProvider{name: "b"}
Register(a)
Register(b)

got := GetProvider()
if got != b {
t.Fatalf("expected provider b, got %v", got)
}
}

func TestResolveInterceptor_ReturnsNonNil(t *testing.T) {
mu.Lock()
provider = nil
mu.Unlock()

p := &stubProvider{name: "test"}
Register(p)

ic := GetProvider().ResolveInterceptor(context.Background())
if ic == nil {
t.Fatal("expected non-nil Interceptor")
}
}
32 changes: 32 additions & 0 deletions extension/transport/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2026 Lark Technologies Pte. Ltd.
// SPDX-License-Identifier: MIT

package transport

import (
"context"
"net/http"
)

// Provider creates Interceptor instances.
// Follows the same API style as extension/credential.Provider and extension/fileio.Provider.
type Provider interface {
Name() string
ResolveInterceptor(ctx context.Context) Interceptor
}

// Interceptor defines network-layer customization via a pre/post hook pair.
// The built-in transport chain always executes between PreRoundTrip and the
// returned post function, and cannot be skipped or overridden by the extension.
//
// PreRoundTrip is called before the built-in chain. Use it to add custom
// headers, rewrite the host, or start trace spans. Built-in decorators run
// after this and will override any same-named security headers set here.
// The extension must not replace req.Context() — the middleware restores
// the original context after PreRoundTrip returns.
//
// The returned function (if non-nil) is called after the built-in chain
// completes. Use it for logging, ending trace spans, or recording metrics.
type Interceptor interface {
PreRoundTrip(req *http.Request) func(resp *http.Response, err error)
}
4 changes: 2 additions & 2 deletions internal/cmdutil/factory_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ func cachedHttpClientFunc() func() (*http.Client, error) {
var transport http.RoundTripper = util.NewBaseTransport()
transport = &RetryTransport{Base: transport}
transport = &SecurityHeaderTransport{Base: transport}

transport = &auth.SecurityPolicyTransport{Base: transport} // Add our global response interceptor
transport = wrapWithExtension(transport)
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
Expand Down Expand Up @@ -133,7 +133,7 @@ func buildSDKTransport() http.RoundTripper {
sdkTransport = &RetryTransport{Base: sdkTransport}
sdkTransport = &UserAgentTransport{Base: sdkTransport}
sdkTransport = &auth.SecurityPolicyTransport{Base: sdkTransport}
return sdkTransport
return wrapWithExtension(sdkTransport)
}

type credentialDeps struct {
Expand Down
170 changes: 170 additions & 0 deletions internal/cmdutil/factory_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ package cmdutil
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"

_ "github.com/larksuite/cli/extension/credential/env"
exttransport "github.com/larksuite/cli/extension/transport"
internalauth "github.com/larksuite/cli/internal/auth"
"github.com/larksuite/cli/internal/core"
"github.com/larksuite/cli/internal/credential"
Expand Down Expand Up @@ -194,3 +197,170 @@ func TestNewDefault_ConfigUsesRuntimePlaceholderForTokenOnlyEnvAccount(t *testin
t.Fatalf("Config().AppSecret = %q, want token-only no-secret marker", cfg.AppSecret)
}
}

type stubTransportProvider struct {
interceptor exttransport.Interceptor
}

func (s *stubTransportProvider) Name() string { return "stub" }
func (s *stubTransportProvider) ResolveInterceptor(context.Context) exttransport.Interceptor {
if s.interceptor != nil {
return s.interceptor
}
return &stubTransportImpl{}
}

type stubTransportImpl struct{}

func (s *stubTransportImpl) PreRoundTrip(req *http.Request) func(*http.Response, error) {
return nil
}

// headerCapturingInterceptor sets custom headers in PreRoundTrip and records
// whether PostRoundTrip was called, to verify execution order.
type headerCapturingInterceptor struct {
preCalled bool
postCalled bool
}

func (h *headerCapturingInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) {
h.preCalled = true
// Set a custom header that should survive (no built-in override)
req.Header.Set("X-Custom-Trace", "ext-trace-123")
// Try to override a security header — should be overwritten by SecurityHeaderTransport
req.Header.Set(HeaderSource, "ext-tampered")
return func(resp *http.Response, err error) {
h.postCalled = true
}
}

func TestExtensionInterceptor_ExecutionOrder(t *testing.T) {
var receivedHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

ic := &headerCapturingInterceptor{}
exttransport.Register(&stubTransportProvider{interceptor: ic})
t.Cleanup(func() { exttransport.Register(nil) })

Comment thread
tuxedomm marked this conversation as resolved.
// Use HTTP transport chain (has SecurityHeaderTransport)
var base http.RoundTripper = http.DefaultTransport
base = &RetryTransport{Base: base}
base = &SecurityHeaderTransport{Base: base}
transport := wrapWithExtension(base)
client := &http.Client{Transport: transport}

req, _ := http.NewRequest("GET", srv.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
resp.Body.Close()

// PreRoundTrip was called
if !ic.preCalled {
t.Fatal("PreRoundTrip was not called")
}
// PostRoundTrip (closure) was called
if !ic.postCalled {
t.Fatal("PostRoundTrip closure was not called")
}
// Custom header set by extension survives (no built-in override)
if got := receivedHeaders.Get("X-Custom-Trace"); got != "ext-trace-123" {
t.Fatalf("X-Custom-Trace = %q, want %q", got, "ext-trace-123")
}
// Security header overridden by extension is restored by SecurityHeaderTransport
if got := receivedHeaders.Get(HeaderSource); got != SourceValue {
t.Fatalf("%s = %q, want %q (built-in should override extension)", HeaderSource, got, SourceValue)
}
}

func TestExtensionInterceptor_ContextTamperPrevented(t *testing.T) {
type ctxKeyType string
const testKey ctxKeyType = "original"

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

var ctxValue any

// Use a custom transport that captures the context value seen by the built-in chain
capturer := roundTripFunc(func(req *http.Request) (*http.Response, error) {
ctxValue = req.Context().Value(testKey)
return http.DefaultTransport.RoundTrip(req)
})

// Interceptor that tries to tamper with context
tamperIC := interceptorFunc(func(req *http.Request) func(*http.Response, error) {
// Try to replace context with a new one
*req = *req.WithContext(context.WithValue(req.Context(), testKey, "tampered"))
return nil
})

mid := &extensionMiddleware{Base: capturer, Ext: tamperIC}

origCtx := context.WithValue(context.Background(), testKey, "original")
req, _ := http.NewRequestWithContext(origCtx, "GET", srv.URL, nil)
resp, err := mid.RoundTrip(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
resp.Body.Close()

// Built-in chain should see original context, not tampered
if ctxValue != "original" {
t.Fatalf("built-in chain saw context value %q, want %q", ctxValue, "original")
}
}

// interceptorFunc adapts a function to exttransport.Interceptor.
type interceptorFunc func(*http.Request) func(*http.Response, error)

func (f interceptorFunc) PreRoundTrip(req *http.Request) func(*http.Response, error) { return f(req) }

func TestBuildSDKTransport_WithExtension(t *testing.T) {
exttransport.Register(&stubTransportProvider{})
t.Cleanup(func() { exttransport.Register(nil) })

transport := buildSDKTransport()

// Chain: extensionMiddleware → SecurityPolicy → UserAgent → Retry → Base
mid, ok := transport.(*extensionMiddleware)
if !ok {
t.Fatalf("outer transport type = %T, want *extensionMiddleware", transport)
}
sec, ok := mid.Base.(*internalauth.SecurityPolicyTransport)
if !ok {
t.Fatalf("transport type = %T, want *auth.SecurityPolicyTransport", mid.Base)
}
ua, ok := sec.Base.(*UserAgentTransport)
if !ok {
t.Fatalf("transport type = %T, want *UserAgentTransport", sec.Base)
}
if _, ok := ua.Base.(*RetryTransport); !ok {
t.Fatalf("innermost transport type = %T, want *RetryTransport", ua.Base)
}
}

func TestBuildSDKTransport_WithoutExtension(t *testing.T) {
exttransport.Register(nil)

transport := buildSDKTransport()

sec, ok := transport.(*internalauth.SecurityPolicyTransport)
if !ok {
t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport)
}
ua, ok := sec.Base.(*UserAgentTransport)
if !ok {
t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base)
}
if _, ok := ua.Base.(*RetryTransport); !ok {
t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base)
}
}
39 changes: 39 additions & 0 deletions internal/cmdutil/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
package cmdutil

import (
"context"
"net/http"
"time"

exttransport "github.com/larksuite/cli/extension/transport"
"github.com/larksuite/cli/internal/util"
)

Expand Down Expand Up @@ -100,3 +102,40 @@ func (t *SecurityHeaderTransport) RoundTrip(req *http.Request) (*http.Response,
}
return t.base().RoundTrip(req)
}

// extensionMiddleware wraps the built-in transport chain with pre/post hooks.
// The built-in chain always executes and cannot be skipped or overridden.
// The original request context is restored after PreRoundTrip to prevent
// extensions from tampering with cancellation, deadlines, or built-in values.
type extensionMiddleware struct {
Base http.RoundTripper
Ext exttransport.Interceptor
}

// RoundTrip calls PreRoundTrip, restores the original context, executes
// the built-in chain, then calls the post hook if non-nil.
func (m *extensionMiddleware) RoundTrip(req *http.Request) (*http.Response, error) {
origCtx := req.Context()
req = req.Clone(origCtx) // isolate caller's request before extension mutations
post := m.Ext.PreRoundTrip(req)
req = req.WithContext(origCtx) // restore original context
resp, err := m.Base.RoundTrip(req)
Comment thread
greptile-apps[bot] marked this conversation as resolved.
if post != nil {
post(resp, err)
}
return resp, err
}

// wrapWithExtension wraps transport with the registered extension middleware.
// If no extension is registered, returns transport unchanged.
func wrapWithExtension(transport http.RoundTripper) http.RoundTripper {
p := exttransport.GetProvider()
if p == nil {
return transport
}
tr := p.ResolveInterceptor(context.Background())
if tr == nil {
return transport
}
return &extensionMiddleware{Base: transport, Ext: tr}
}
Loading