From fe461d95b849fa75b478400fa787c92b6c3964ce Mon Sep 17 00:00:00 2001 From: tuxedomm <273098272+tuxedomm@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:59:53 +0800 Subject: [PATCH 1/7] refactor(cmd): split Execute into Build with IO/Keychain injection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce a public cmd.Build entry point so external consumers (cli-server, MCP server, other embedders) can assemble the full CLI command tree without going through os.Args or the platform keychain. Build takes an InvocationContext plus functional BuildOptions: * WithIO(in, out, errOut) — inject custom streams; terminal detection is derived from the input's underlying *os.File when present. * WithKeychain(kc) — swap the credential store. * HideProfile(bool) — registered later in cmd.HideProfile. The existing Execute() keeps using the internal buildInternal (which still returns the Factory so error handling can attribute exit codes), and SetDefaultFS replaces the global VFS implementation at startup. Hardening applied up front: * cmdutil.NewIOStreams(in, out, errOut) centralizes terminal detection so SystemIO() and WithIO share one path. * cmdutil.NewDefault normalizes partial IOStreams — callers may pass &IOStreams{Out: buf} without tripping nil-writer panics in the RoundTripper warnings, Cobra, or the credential provider. * Build guards against nil functional options. * An API contract test (cmd/build_api_test.go) exercises Build + WithIO + WithKeychain + HideProfile + SetDefaultFS so the public surface is reachable by deadcode analysis. Change-Id: I7c895e6019817401accbde2db3ef800da40ad319 --- cmd/build.go | 111 +++++++++++++++++++++++ cmd/build_api_test.go | 50 ++++++++++ cmd/init.go | 18 ++++ cmd/root.go | 43 +-------- cmd/root_integration_test.go | 6 +- internal/cmdutil/factory_default.go | 25 ++--- internal/cmdutil/factory_default_test.go | 12 +-- internal/cmdutil/factory_http_test.go | 7 +- internal/cmdutil/iostreams.go | 49 +++++++++- internal/credential/default_provider.go | 9 +- internal/credential/integration_test.go | 3 +- 11 files changed, 260 insertions(+), 73 deletions(-) create mode 100644 cmd/build.go create mode 100644 cmd/build_api_test.go create mode 100644 cmd/init.go diff --git a/cmd/build.go b/cmd/build.go new file mode 100644 index 000000000..2a0de48f2 --- /dev/null +++ b/cmd/build.go @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "context" + "io" + + "github.com/larksuite/cli/cmd/api" + "github.com/larksuite/cli/cmd/auth" + "github.com/larksuite/cli/cmd/completion" + cmdconfig "github.com/larksuite/cli/cmd/config" + "github.com/larksuite/cli/cmd/doctor" + "github.com/larksuite/cli/cmd/profile" + "github.com/larksuite/cli/cmd/schema" + "github.com/larksuite/cli/cmd/service" + cmdupdate "github.com/larksuite/cli/cmd/update" + "github.com/larksuite/cli/internal/build" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/keychain" + "github.com/larksuite/cli/shortcuts" + "github.com/spf13/cobra" +) + +// BuildOption configures optional aspects of the command tree construction. +type BuildOption func(*buildConfig) + +type buildConfig struct { + streams *cmdutil.IOStreams + keychain keychain.KeychainAccess +} + +// WithIO sets the IO streams for the CLI by wrapping raw reader/writers. +// Terminal detection is delegated to cmdutil.NewIOStreams. +func WithIO(in io.Reader, out, errOut io.Writer) BuildOption { + return func(c *buildConfig) { + c.streams = cmdutil.NewIOStreams(in, out, errOut) + } +} + +// WithKeychain sets the secret storage backend. If not provided, the platform keychain is used. +func WithKeychain(kc keychain.KeychainAccess) BuildOption { + return func(c *buildConfig) { + c.keychain = kc + } +} + +// Build constructs the full command tree without executing. +// Returns only the cobra.Command; Factory is internal. +// Use Execute for the standard production entry point. +func Build(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOption) *cobra.Command { + _, rootCmd := buildInternal(ctx, inv, opts...) + return rootCmd +} + +// buildInternal is the internal constructor that also returns Factory for error handling. +func buildInternal(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOption) (*cmdutil.Factory, *cobra.Command) { + cfg := &buildConfig{ + streams: cmdutil.SystemIO(), + } + for _, o := range opts { + if o != nil { + o(cfg) + } + } + + f := cmdutil.NewDefault(cfg.streams, inv) + if cfg.keychain != nil { + f.Keychain = cfg.keychain + } + + globals := &GlobalOptions{Profile: inv.Profile} + rootCmd := &cobra.Command{ + Use: "lark-cli", + Short: "Lark/Feishu CLI — OAuth authorization, UAT management, API calls", + Long: rootLong, + Version: build.Version, + } + + rootCmd.SetContext(ctx) + rootCmd.SetIn(cfg.streams.In) + rootCmd.SetOut(cfg.streams.Out) + rootCmd.SetErr(cfg.streams.ErrOut) + + installTipsHelpFunc(rootCmd) + rootCmd.SilenceErrors = true + + RegisterGlobalFlags(rootCmd.PersistentFlags(), globals) + rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + cmd.SilenceUsage = true + } + + rootCmd.AddCommand(cmdconfig.NewCmdConfig(f)) + rootCmd.AddCommand(auth.NewCmdAuth(f)) + rootCmd.AddCommand(profile.NewCmdProfile(f)) + rootCmd.AddCommand(doctor.NewCmdDoctor(f)) + rootCmd.AddCommand(api.NewCmdApi(f, nil)) + rootCmd.AddCommand(schema.NewCmdSchema(f, nil)) + rootCmd.AddCommand(completion.NewCmdCompletion(f)) + rootCmd.AddCommand(cmdupdate.NewCmdUpdate(f)) + service.RegisterServiceCommands(rootCmd, f) + shortcuts.RegisterShortcuts(rootCmd, f) + + // Prune commands incompatible with strict mode. + if mode := f.ResolveStrictMode(ctx); mode.IsActive() { + pruneForStrictMode(rootCmd, mode) + } + + return f, rootCmd +} diff --git a/cmd/build_api_test.go b/cmd/build_api_test.go new file mode 100644 index 000000000..fa490c294 --- /dev/null +++ b/cmd/build_api_test.go @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "bytes" + "context" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/vfs" +) + +// noopKeychain is a zero-side-effect KeychainAccess for exercising +// WithKeychain without touching the platform keychain. +type noopKeychain struct{} + +func (noopKeychain) Get(service, account string) (string, error) { return "", nil } +func (noopKeychain) Set(service, account, value string) error { return nil } +func (noopKeychain) Remove(service, account string) error { return nil } + +// TestBuild_ExternalAPI asserts the library surface that external consumers +// (e.g. cli-server) depend on: Build composes a root command from an +// InvocationContext plus BuildOptions (WithIO, WithKeychain, HideProfile), +// and SetDefaultFS swaps the global VFS. This test is the contract guard. +func TestBuild_ExternalAPI(t *testing.T) { + // Exercise SetDefaultFS both directions. Passing nil restores the OS FS. + SetDefaultFS(vfs.OsFs{}) + SetDefaultFS(nil) + + var in, out, errOut bytes.Buffer + rootCmd := Build( + context.Background(), + cmdutil.InvocationContext{}, + WithIO(&in, &out, &errOut), + WithKeychain(noopKeychain{}), + HideProfile(true), + ) + + if rootCmd == nil { + t.Fatal("Build returned nil root command") + } + if rootCmd.Use != "lark-cli" { + t.Errorf("rootCmd.Use = %q, want %q", rootCmd.Use, "lark-cli") + } + if len(rootCmd.Commands()) == 0 { + t.Error("Build produced a root command with no subcommands") + } +} diff --git a/cmd/init.go b/cmd/init.go new file mode 100644 index 000000000..d9093eabf --- /dev/null +++ b/cmd/init.go @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import "github.com/larksuite/cli/internal/vfs" + +// SetDefaultFS replaces the global filesystem implementation used by internal +// packages. The provided fs must implement the vfs.FS interface. If fs is nil, +// the default OS filesystem is restored. +// +// Call this before Build or Execute to take effect. +func SetDefaultFS(fs vfs.FS) { + if fs == nil { + fs = vfs.OsFs{} + } + vfs.DefaultFS = fs +} diff --git a/cmd/root.go b/cmd/root.go index dca93f7cb..8088346a9 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -14,15 +14,6 @@ import ( "os" "strconv" - "github.com/larksuite/cli/cmd/api" - "github.com/larksuite/cli/cmd/auth" - "github.com/larksuite/cli/cmd/completion" - cmdconfig "github.com/larksuite/cli/cmd/config" - "github.com/larksuite/cli/cmd/doctor" - "github.com/larksuite/cli/cmd/profile" - "github.com/larksuite/cli/cmd/schema" - "github.com/larksuite/cli/cmd/service" - cmdupdate "github.com/larksuite/cli/cmd/update" internalauth "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/build" "github.com/larksuite/cli/internal/cmdutil" @@ -30,7 +21,6 @@ import ( "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/registry" "github.com/larksuite/cli/internal/update" - "github.com/larksuite/cli/shortcuts" "github.com/spf13/cobra" ) @@ -95,38 +85,7 @@ func Execute() int { fmt.Fprintln(os.Stderr, "Error:", err) return 1 } - f := cmdutil.NewDefault(inv) - - globals := &GlobalOptions{Profile: inv.Profile} - rootCmd := &cobra.Command{ - Use: "lark-cli", - Short: "Lark/Feishu CLI — OAuth authorization, UAT management, API calls", - Long: rootLong, - Version: build.Version, - } - installTipsHelpFunc(rootCmd) - rootCmd.SilenceErrors = true - - RegisterGlobalFlags(rootCmd.PersistentFlags(), globals) - rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { - cmd.SilenceUsage = true - } - - rootCmd.AddCommand(cmdconfig.NewCmdConfig(f)) - rootCmd.AddCommand(auth.NewCmdAuth(f)) - rootCmd.AddCommand(profile.NewCmdProfile(f)) - rootCmd.AddCommand(doctor.NewCmdDoctor(f)) - rootCmd.AddCommand(api.NewCmdApi(f, nil)) - rootCmd.AddCommand(schema.NewCmdSchema(f, nil)) - rootCmd.AddCommand(completion.NewCmdCompletion(f)) - rootCmd.AddCommand(cmdupdate.NewCmdUpdate(f)) - service.RegisterServiceCommands(rootCmd, f) - shortcuts.RegisterShortcuts(rootCmd, f) - - // Prune commands incompatible with strict mode. - if mode := f.ResolveStrictMode(context.Background()); mode.IsActive() { - pruneForStrictMode(rootCmd, mode) - } + f, rootCmd := buildInternal(context.Background(), inv) // --- Update check (non-blocking) --- if !isCompletionCommand(os.Args) { diff --git a/cmd/root_integration_test.go b/cmd/root_integration_test.go index 555018738..e2521365d 100644 --- a/cmd/root_integration_test.go +++ b/cmd/root_integration_test.go @@ -135,10 +135,12 @@ func newStrictModeDefaultFactory(t *testing.T, profile string, mode core.StrictM t.Fatalf("SaveMultiAppConfig() error = %v", err) } - f := cmdutil.NewDefault(cmdutil.InvocationContext{Profile: profile}) stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} - f.IOStreams = &cmdutil.IOStreams{In: nil, Out: stdout, ErrOut: stderr} + f := cmdutil.NewDefault( + cmdutil.NewIOStreams(&bytes.Buffer{}, stdout, stderr), + cmdutil.InvocationContext{Profile: profile}, + ) return f, stdout, stderr } diff --git a/internal/cmdutil/factory_default.go b/internal/cmdutil/factory_default.go index c9b4e92cf..eed433454 100644 --- a/internal/cmdutil/factory_default.go +++ b/internal/cmdutil/factory_default.go @@ -8,13 +8,11 @@ import ( "fmt" "io" "net/http" - "os" "sync" "time" lark "github.com/larksuite/oapi-sdk-go/v3" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" - "golang.org/x/term" extcred "github.com/larksuite/cli/extension/credential" "github.com/larksuite/cli/extension/fileio" @@ -34,27 +32,24 @@ import ( // Phase 2: Credential (sole data source for account info) // Phase 3: Config derived from Credential // Phase 4: LarkClient derived from Credential -func NewDefault(inv InvocationContext) *Factory { +func NewDefault(streams *IOStreams, inv InvocationContext) *Factory { + streams = normalizeStreams(streams) f := &Factory{ Keychain: keychain.Default(), Invocation: inv, - } - f.IOStreams = &IOStreams{ - In: os.Stdin, - Out: os.Stdout, - ErrOut: os.Stderr, - IsTerminal: term.IsTerminal(int(os.Stdin.Fd())), + IOStreams: streams, } // Phase 0: FileIO provider (no dependency) f.FileIOProvider = fileio.GetProvider() // Phase 1: HttpClient (no credential dependency) - f.HttpClient = cachedHttpClientFunc() + f.HttpClient = cachedHttpClientFunc(f) // Phase 2: Credential (sole data source) + // Keychain is read via closure so callers can replace f.Keychain after construction. f.Credential = buildCredentialProvider(credentialDeps{ - Keychain: f.Keychain, + Keychain: func() keychain.KeychainAccess { return f.Keychain }, Profile: inv.Profile, HttpClient: f.HttpClient, ErrOut: f.IOStreams.ErrOut, @@ -93,9 +88,9 @@ func safeRedirectPolicy(req *http.Request, via []*http.Request) error { return nil } -func cachedHttpClientFunc() func() (*http.Client, error) { +func cachedHttpClientFunc(f *Factory) func() (*http.Client, error) { return sync.OnceValues(func() (*http.Client, error) { - util.WarnIfProxied(os.Stderr) + util.WarnIfProxied(f.IOStreams.ErrOut) var transport http.RoundTripper = util.NewBaseTransport() transport = &RetryTransport{Base: transport} @@ -122,7 +117,7 @@ func cachedLarkClientFunc(f *Factory) func() (*lark.Client, error) { lark.WithLogLevel(larkcore.LogLevelError), lark.WithHeaders(BaseSecurityHeaders()), } - util.WarnIfProxied(os.Stderr) + util.WarnIfProxied(f.IOStreams.ErrOut) opts = append(opts, lark.WithHttpClient(&http.Client{ Transport: buildSDKTransport(), CheckRedirect: safeRedirectPolicy, @@ -142,7 +137,7 @@ func buildSDKTransport() http.RoundTripper { } type credentialDeps struct { - Keychain keychain.KeychainAccess + Keychain func() keychain.KeychainAccess Profile string HttpClient func() (*http.Client, error) ErrOut io.Writer diff --git a/internal/cmdutil/factory_default_test.go b/internal/cmdutil/factory_default_test.go index fe91ead77..7204e2de9 100644 --- a/internal/cmdutil/factory_default_test.go +++ b/internal/cmdutil/factory_default_test.go @@ -63,7 +63,7 @@ func TestNewDefault_InvocationProfileUsedByStrictModeAndConfig(t *testing.T) { t.Fatalf("SaveMultiAppConfig() error = %v", err) } - f := NewDefault(InvocationContext{Profile: "target"}) + f := NewDefault(nil, InvocationContext{Profile: "target"}) if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeBot { t.Fatalf("ResolveStrictMode() = %q, want %q", got, core.StrictModeBot) } @@ -103,7 +103,7 @@ func TestNewDefault_InvocationProfileMissingSticksAcrossEarlyStrictMode(t *testi t.Fatalf("SaveMultiAppConfig() error = %v", err) } - f := NewDefault(InvocationContext{Profile: "missing"}) + f := NewDefault(nil, InvocationContext{Profile: "missing"}) if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeOff { t.Fatalf("ResolveStrictMode() = %q, want %q", got, core.StrictModeOff) } @@ -144,7 +144,7 @@ func TestNewDefault_ResolveAs_UsesDefaultAsFromEnvAccount(t *testing.T) { t.Setenv(envvars.CliTenantAccessToken, "") t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) - f := NewDefault(InvocationContext{}) + f := NewDefault(nil, InvocationContext{}) cmd := newCmdWithAsFlag("auto", false) got := f.ResolveAs(context.Background(), cmd, "auto") @@ -164,7 +164,7 @@ func TestNewDefault_ConfigReturnsCliConfigCopyOfCredentialAccount(t *testing.T) t.Setenv(envvars.CliTenantAccessToken, "") t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) - f := NewDefault(InvocationContext{}) + f := NewDefault(nil, InvocationContext{}) acct, err := f.Credential.ResolveAccount(context.Background()) if err != nil { @@ -189,7 +189,7 @@ func TestNewDefault_ConfigUsesRuntimePlaceholderForTokenOnlyEnvAccount(t *testin t.Setenv(envvars.CliTenantAccessToken, "") t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) - f := NewDefault(InvocationContext{}) + f := NewDefault(nil, InvocationContext{}) acct, err := f.Credential.ResolveAccount(context.Background()) if err != nil { @@ -217,7 +217,7 @@ func TestNewDefault_FileIOProviderDoesNotResolveDuringInitialization(t *testing. fileio.Register(provider) t.Cleanup(func() { fileio.Register(prev) }) - f := NewDefault(InvocationContext{}) + f := NewDefault(nil, InvocationContext{}) if f.FileIOProvider != provider { t.Fatalf("NewDefault() provider = %T, want %T", f.FileIOProvider, provider) } diff --git a/internal/cmdutil/factory_http_test.go b/internal/cmdutil/factory_http_test.go index c27e9e696..a2b50e823 100644 --- a/internal/cmdutil/factory_http_test.go +++ b/internal/cmdutil/factory_http_test.go @@ -4,11 +4,12 @@ package cmdutil import ( + "io" "testing" ) func TestCachedHttpClientFunc_ReturnsSameInstance(t *testing.T) { - fn := cachedHttpClientFunc() + fn := cachedHttpClientFunc(&Factory{IOStreams: &IOStreams{ErrOut: io.Discard}}) c1, err := fn() if err != nil { @@ -28,7 +29,7 @@ func TestCachedHttpClientFunc_ReturnsSameInstance(t *testing.T) { } func TestCachedHttpClientFunc_HasTimeout(t *testing.T) { - fn := cachedHttpClientFunc() + fn := cachedHttpClientFunc(&Factory{IOStreams: &IOStreams{ErrOut: io.Discard}}) c, _ := fn() if c.Timeout == 0 { t.Error("expected non-zero timeout") @@ -36,7 +37,7 @@ func TestCachedHttpClientFunc_HasTimeout(t *testing.T) { } func TestCachedHttpClientFunc_HasRedirectPolicy(t *testing.T) { - fn := cachedHttpClientFunc() + fn := cachedHttpClientFunc(&Factory{IOStreams: &IOStreams{ErrOut: io.Discard}}) c, _ := fn() if c.CheckRedirect == nil { t.Error("expected CheckRedirect to be set (safeRedirectPolicy)") diff --git a/internal/cmdutil/iostreams.go b/internal/cmdutil/iostreams.go index 76068e057..864300f52 100644 --- a/internal/cmdutil/iostreams.go +++ b/internal/cmdutil/iostreams.go @@ -3,7 +3,12 @@ package cmdutil -import "io" +import ( + "io" + "os" + + "golang.org/x/term" +) // IOStreams provides the standard input/output/error streams. // Commands should use these instead of os.Stdin/Stdout/Stderr @@ -14,3 +19,45 @@ type IOStreams struct { ErrOut io.Writer IsTerminal bool } + +// NewIOStreams builds an IOStreams from arbitrary readers/writers. +// IsTerminal is derived from in's underlying *os.File, if any; non-file +// readers (bytes.Buffer, strings.Reader, …) yield IsTerminal=false. +func NewIOStreams(in io.Reader, out, errOut io.Writer) *IOStreams { + isTerminal := false + if f, ok := in.(*os.File); ok { + isTerminal = term.IsTerminal(int(f.Fd())) + } + return &IOStreams{In: in, Out: out, ErrOut: errOut, IsTerminal: isTerminal} +} + +// SystemIO creates an IOStreams wired to the process's standard file descriptors. +// +//nolint:forbidigo // entry point for real stdio +func SystemIO() *IOStreams { + return NewIOStreams(os.Stdin, os.Stdout, os.Stderr) +} + +// normalizeStreams returns a fresh IOStreams with any nil field filled from +// SystemIO(). Callers constructing a partial struct like &IOStreams{Out: buf} +// get a usable result without nil writers leaking into RoundTripper warnings, +// Cobra I/O, or credential-provider error paths. +func normalizeStreams(s *IOStreams) *IOStreams { + if s == nil { + return SystemIO() + } + out := *s + if out.In == nil || out.Out == nil || out.ErrOut == nil { + sys := SystemIO() + if out.In == nil { + out.In = sys.In + } + if out.Out == nil { + out.Out = sys.Out + } + if out.ErrOut == nil { + out.ErrOut = sys.ErrOut + } + } + return &out +} diff --git a/internal/credential/default_provider.go b/internal/credential/default_provider.go index bedad7b86..a3ebb90c9 100644 --- a/internal/credential/default_provider.go +++ b/internal/credential/default_provider.go @@ -21,11 +21,14 @@ import ( // DefaultAccountProvider resolves account from config.json via keychain. type DefaultAccountProvider struct { - keychain keychain.KeychainAccess + keychain func() keychain.KeychainAccess profile string } -func NewDefaultAccountProvider(kc keychain.KeychainAccess, profile string) *DefaultAccountProvider { +func NewDefaultAccountProvider(kc func() keychain.KeychainAccess, profile string) *DefaultAccountProvider { + if kc == nil { + kc = keychain.Default + } return &DefaultAccountProvider{keychain: kc, profile: profile} } @@ -36,7 +39,7 @@ func (p *DefaultAccountProvider) ResolveAccount(ctx context.Context) (*Account, return nil, &core.ConfigError{Code: 2, Type: "config", Message: "not configured", Hint: "run `lark-cli config init --new` in the background. It blocks and outputs a verification URL — retrieve the URL and open it in a browser to complete setup."} } - cfg, err := core.ResolveConfigFromMulti(multi, p.keychain, p.profile) + cfg, err := core.ResolveConfigFromMulti(multi, p.keychain(), p.profile) if err != nil { return nil, err } diff --git a/internal/credential/integration_test.go b/internal/credential/integration_test.go index 46a3485ff..7daef1d6a 100644 --- a/internal/credential/integration_test.go +++ b/internal/credential/integration_test.go @@ -12,6 +12,7 @@ import ( "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/envvars" + "github.com/larksuite/cli/internal/keychain" ) type noopKC struct{} @@ -99,7 +100,7 @@ func TestFullChain_ConfigStrictMode(t *testing.T) { } ep := &envprovider.Provider{} - defaultAcct := credential.NewDefaultAccountProvider(&noopKC{}, "") + defaultAcct := credential.NewDefaultAccountProvider(func() keychain.KeychainAccess { return &noopKC{} }, "") cp := credential.NewCredentialProvider( []extcred.Provider{ep}, From 719bc630efb9126b172a3f82b8a11481982e96c4 Mon Sep 17 00:00:00 2001 From: tuxedomm <273098272+tuxedomm@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:23:31 +0800 Subject: [PATCH 2/7] feat(schema): filter methods by strict mode in schema output When strict mode is active, schema output now excludes methods that are incompatible with the forced identity. This applies to both pretty and JSON output formats at the resource and method levels. Change-Id: I39647d5578466c3e23dc545bfb917ae075203ad7 --- cmd/schema/schema.go | 97 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 92 insertions(+), 5 deletions(-) diff --git a/cmd/schema/schema.go b/cmd/schema/schema.go index f45d6e816..6dca541f9 100644 --- a/cmd/schema/schema.go +++ b/cmd/schema/schema.go @@ -4,12 +4,14 @@ package schema import ( + "context" "fmt" "io" "sort" "strings" "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/registry" "github.com/larksuite/cli/internal/util" @@ -19,6 +21,7 @@ import ( // SchemaOptions holds all inputs for the schema command. type SchemaOptions struct { Factory *cmdutil.Factory + Ctx context.Context // Positional args Path string @@ -41,7 +44,7 @@ func printServices(w io.Writer) { fmt.Fprintf(w, "\n%sUsage: lark-cli schema ..%s\n", output.Dim, output.Reset) } -func printResourceList(w io.Writer, spec map[string]interface{}) { +func printResourceList(w io.Writer, spec map[string]interface{}, mode core.StrictMode) { name := registry.GetStrFromMap(spec, "name") version := registry.GetStrFromMap(spec, "version") title := registry.GetStrFromMap(spec, "title") @@ -55,9 +58,13 @@ func printResourceList(w io.Writer, spec map[string]interface{}) { resources, _ := spec["resources"].(map[string]interface{}) for _, resName := range sortedKeys(resources) { - fmt.Fprintf(w, " %s%s%s\n", output.Cyan, resName, output.Reset) resMap, _ := resources[resName].(map[string]interface{}) methods, _ := resMap["methods"].(map[string]interface{}) + methods = filterMethodsByStrictMode(methods, mode) + if len(methods) == 0 { + continue + } + fmt.Fprintf(w, " %s%s%s\n", output.Cyan, resName, output.Reset) for _, methodName := range sortedKeys(methods) { m, _ := methods[methodName].(map[string]interface{}) httpMethod := registry.GetStrFromMap(m, "httpMethod") @@ -359,6 +366,7 @@ func NewCmdSchema(f *cmdutil.Factory, runF func(*SchemaOptions) error) *cobra.Co if len(args) > 0 { opts.Path = args[0] } + opts.Ctx = cmd.Context() if runF != nil { return runF(opts) } @@ -451,6 +459,7 @@ func completeSchemaPath(_ *cobra.Command, args []string, toComplete string) ([]s func schemaRun(opts *SchemaOptions) error { out := opts.Factory.IOStreams.Out + mode := opts.Factory.ResolveStrictMode(opts.Ctx) if opts.Path == "" { printServices(out) @@ -469,9 +478,9 @@ func schemaRun(opts *SchemaOptions) error { if len(parts) == 1 { if opts.Format == "pretty" { - printResourceList(out, spec) + printResourceList(out, spec, mode) } else { - output.PrintJson(out, spec) + output.PrintJson(out, filterSpecByStrictMode(spec, mode)) } return nil } @@ -492,6 +501,7 @@ func schemaRun(opts *SchemaOptions) error { if opts.Format == "pretty" { fmt.Fprintf(out, "%s%s.%s%s\n\n", output.Bold, serviceName, resName, output.Reset) methods, _ := resource["methods"].(map[string]interface{}) + methods = filterMethodsByStrictMode(methods, mode) for _, mName := range sortedKeys(methods) { m, _ := methods[mName].(map[string]interface{}) httpMethod := registry.GetStrFromMap(m, "httpMethod") @@ -500,13 +510,26 @@ func schemaRun(opts *SchemaOptions) error { } fmt.Fprintf(out, "\n%sUsage: lark-cli schema %s.%s.%s\n", output.Dim, serviceName, resName, output.Reset) } else { - output.PrintJson(out, resource) + // For JSON output, filter methods in a copy to avoid mutating the registry. + if mode.IsActive() { + filtered := make(map[string]interface{}) + for k, v := range resource { + filtered[k] = v + } + if methods, ok := resource["methods"].(map[string]interface{}); ok { + filtered["methods"] = filterMethodsByStrictMode(methods, mode) + } + output.PrintJson(out, filtered) + } else { + output.PrintJson(out, resource) + } } return nil } methodName := remaining[0] methods, _ := resource["methods"].(map[string]interface{}) + methods = filterMethodsByStrictMode(methods, mode) method, ok := methods[methodName].(map[string]interface{}) if !ok { var mNames []string @@ -525,3 +548,67 @@ func schemaRun(opts *SchemaOptions) error { } return nil } + +// filterSpecByStrictMode returns a shallow copy of spec with each resource's methods +// filtered by strict mode. Returns the original spec when strict mode is off. +func filterSpecByStrictMode(spec map[string]interface{}, mode core.StrictMode) map[string]interface{} { + if !mode.IsActive() { + return spec + } + result := make(map[string]interface{}, len(spec)) + for k, v := range spec { + result[k] = v + } + resources, _ := spec["resources"].(map[string]interface{}) + if resources == nil { + return result + } + filteredRes := make(map[string]interface{}, len(resources)) + for resName, resVal := range resources { + resMap, ok := resVal.(map[string]interface{}) + if !ok { + continue + } + methods, _ := resMap["methods"].(map[string]interface{}) + filtered := filterMethodsByStrictMode(methods, mode) + if len(filtered) == 0 { + continue + } + resCopy := make(map[string]interface{}, len(resMap)) + for k, v := range resMap { + resCopy[k] = v + } + resCopy["methods"] = filtered + filteredRes[resName] = resCopy + } + result["resources"] = filteredRes + return result +} + +// filterMethodsByStrictMode removes methods incompatible with the active strict mode. +// Returns the original map unmodified when strict mode is off. +func filterMethodsByStrictMode(methods map[string]interface{}, mode core.StrictMode) map[string]interface{} { + if !mode.IsActive() || methods == nil { + return methods + } + token := registry.IdentityToAccessToken(string(mode.ForcedIdentity())) + filtered := make(map[string]interface{}, len(methods)) + for name, val := range methods { + m, ok := val.(map[string]interface{}) + if !ok { + continue + } + tokens, _ := m["accessTokens"].([]interface{}) + if tokens == nil { + filtered[name] = val + continue + } + for _, t := range tokens { + if ts, ok := t.(string); ok && ts == token { + filtered[name] = val + break + } + } + } + return filtered +} From 2a1602ce9ed2ab578a0e50789c8077fc78fbae99 Mon Sep 17 00:00:00 2001 From: tuxedomm <273098272+tuxedomm@users.noreply.github.com> Date: Sat, 11 Apr 2026 14:36:30 +0800 Subject: [PATCH 3/7] refactor: centralize strict-mode as flag registration Change-Id: Iec11151c5002c2f58a8aa067d08747db2e4d2d8c --- cmd/api/api.go | 5 +- cmd/api/api_test.go | 18 +++++ cmd/service/service.go | 6 +- cmd/service/service_test.go | 18 +++++ internal/cmdutil/identity_flag.go | 68 +++++++++++++++++++ internal/cmdutil/identity_flag_test.go | 67 ++++++++++++++++++ shortcuts/common/runner.go | 10 +-- shortcuts/common/runner_identity_flag_test.go | 45 ++++++++++++ shortcuts/common/runner_jq_test.go | 10 +-- 9 files changed, 226 insertions(+), 21 deletions(-) create mode 100644 internal/cmdutil/identity_flag.go create mode 100644 internal/cmdutil/identity_flag_test.go create mode 100644 shortcuts/common/runner_identity_flag_test.go diff --git a/cmd/api/api.go b/cmd/api/api.go index 1fe651d03..e4f83b4d2 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -79,7 +79,7 @@ func NewCmdApi(f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command cmd.Flags().StringVar(&opts.Params, "params", "", "query parameters JSON (supports - for stdin)") cmd.Flags().StringVar(&opts.Data, "data", "", "request body JSON (supports - for stdin)") - cmd.Flags().StringVar(&asStr, "as", "auto", "identity type: user | bot | auto (default)") + cmdutil.AddAPIIdentityFlag(cmd, f, &asStr) cmd.Flags().StringVarP(&opts.Output, "output", "o", "", "output file path for binary responses") cmd.Flags().BoolVar(&opts.PageAll, "page-all", false, "automatically paginate through all pages") cmd.Flags().IntVar(&opts.PageSize, "page-size", 0, "page size (0 = use API default)") @@ -96,9 +96,6 @@ func NewCmdApi(f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command } return nil, cobra.ShellCompDirectiveNoFileComp } - _ = cmd.RegisterFlagCompletionFunc("as", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { - return []string{"user", "bot"}, cobra.ShellCompDirectiveNoFileComp - }) _ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { return []string{"json", "ndjson", "table", "csv"}, cobra.ShellCompDirectiveNoFileComp }) diff --git a/cmd/api/api_test.go b/cmd/api/api_test.go index 9f5be7022..e81cc984f 100644 --- a/cmd/api/api_test.go +++ b/cmd/api/api_test.go @@ -180,6 +180,24 @@ func TestApiValidArgsFunction(t *testing.T) { } } +func TestNewCmdApi_StrictModeHidesAsFlag(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, SupportedIdentities: 2, + }) + + cmd := NewCmdApi(f, nil) + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if !flag.Hidden { + t.Fatal("expected --as flag to be hidden in strict mode") + } + if got := flag.DefValue; got != "bot" { + t.Fatalf("default value = %q, want %q", got, "bot") + } +} + func TestApiCmd_PageLimitDefault(t *testing.T) { f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, diff --git a/cmd/service/service.go b/cmd/service/service.go index 63b6fc6b7..f5d9ae1a6 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -159,7 +159,7 @@ func NewCmdServiceMethod(f *cmdutil.Factory, spec, method map[string]interface{} case "POST", "PUT", "PATCH", "DELETE": cmd.Flags().StringVar(&opts.Data, "data", "", "request body JSON (supports - for stdin)") } - cmd.Flags().StringVar(&asStr, "as", "auto", "identity type: user | bot | auto (default)") + cmdutil.AddAPIIdentityFlag(cmd, f, &asStr) cmd.Flags().StringVarP(&opts.Output, "output", "o", "", "output file path for binary responses") cmd.Flags().BoolVar(&opts.PageAll, "page-all", false, "automatically paginate through all pages") cmd.Flags().IntVar(&opts.PageLimit, "page-limit", 10, "max pages to fetch with --page-all (0 = unlimited)") @@ -177,10 +177,6 @@ func NewCmdServiceMethod(f *cmdutil.Factory, spec, method map[string]interface{} cmd.Flags().StringVar(&opts.File, "file", "", "file to upload ([field=]path, supports - for stdin)") } } - - _ = cmd.RegisterFlagCompletionFunc("as", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { - return []string{"user", "bot"}, cobra.ShellCompDirectiveNoFileComp - }) _ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { return []string{"json", "ndjson", "table", "csv"}, cobra.ShellCompDirectiveNoFileComp }) diff --git a/cmd/service/service_test.go b/cmd/service/service_test.go index f9adc5d47..3a2e5a7be 100644 --- a/cmd/service/service_test.go +++ b/cmd/service/service_test.go @@ -121,6 +121,24 @@ func TestRegisterService_MergesExistingCommand(t *testing.T) { } } +func TestNewCmdServiceMethod_StrictModeHidesAsFlag(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, SupportedIdentities: 2, + }) + + cmd := NewCmdServiceMethod(f, driveSpec(), driveMethod("GET", nil), "copy", "files", nil) + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if !flag.Hidden { + t.Fatal("expected --as flag to be hidden in strict mode") + } + if got := flag.DefValue; got != "bot" { + t.Fatalf("default value = %q, want %q", got, "bot") + } +} + // ── NewCmdServiceMethod flags ── func TestNewCmdServiceMethod_GETHasNoDataFlag(t *testing.T) { diff --git a/internal/cmdutil/identity_flag.go b/internal/cmdutil/identity_flag.go new file mode 100644 index 000000000..99b7ed3df --- /dev/null +++ b/internal/cmdutil/identity_flag.go @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdutil + +import ( + "context" + "fmt" + "strings" + + "github.com/spf13/cobra" +) + +// AddAPIIdentityFlag registers the standard --as flag shape used by api/service commands. +func AddAPIIdentityFlag(cmd *cobra.Command, f *Factory, target *string) { + addIdentityFlag(cmd, f, target, identityFlagConfig{ + defaultValue: "auto", + usage: "identity type: user | bot | auto (default)", + completionValues: []string{"user", "bot"}, + }) +} + +// AddShortcutIdentityFlag registers the standard --as flag shape used by shortcuts. +func AddShortcutIdentityFlag(cmd *cobra.Command, f *Factory, authTypes []string) { + if len(authTypes) == 0 { + authTypes = []string{"user"} + } + addIdentityFlag(cmd, f, nil, identityFlagConfig{ + defaultValue: authTypes[0], + usage: "identity type: " + strings.Join(authTypes, " | "), + completionValues: authTypes, + }) +} + +type identityFlagConfig struct { + defaultValue string + usage string + completionValues []string +} + +// addIdentityFlag centralizes --as registration and strict-mode UX. +// When strict mode is active, the flag is still accepted for compatibility +// but hidden from help/completion and locked to the forced identity by default. +func addIdentityFlag(cmd *cobra.Command, f *Factory, target *string, cfg identityFlagConfig) { + if forced := f.ResolveStrictMode(context.Background()).ForcedIdentity(); forced != "" { + // Keep registering --as in strict mode even though it is hidden. + // This preserves parser compatibility for existing invocations that still pass + // --as, and keeps downstream GetString("as") / ResolveAs paths stable. + // The usage text below is effectively placeholder text because the flag is hidden. + registerIdentityFlag(cmd, target, string(forced), + fmt.Sprintf("identity locked to %s by strict mode (admin-managed)", forced)) + _ = cmd.Flags().MarkHidden("as") + return + } + + registerIdentityFlag(cmd, target, cfg.defaultValue, cfg.usage) + _ = cmd.RegisterFlagCompletionFunc("as", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { + return cfg.completionValues, cobra.ShellCompDirectiveNoFileComp + }) +} + +func registerIdentityFlag(cmd *cobra.Command, target *string, defaultValue, usage string) { + if target != nil { + cmd.Flags().StringVar(target, "as", defaultValue, usage) + return + } + cmd.Flags().String("as", defaultValue, usage) +} diff --git a/internal/cmdutil/identity_flag_test.go b/internal/cmdutil/identity_flag_test.go new file mode 100644 index 000000000..2f1350752 --- /dev/null +++ b/internal/cmdutil/identity_flag_test.go @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdutil + +import ( + "testing" + + "github.com/larksuite/cli/internal/core" + "github.com/spf13/cobra" +) + +func TestAddAPIIdentityFlag_NonStrictMode(t *testing.T) { + f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) + cmd := &cobra.Command{Use: "test"} + + AddAPIIdentityFlag(cmd, f, nil) + + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if flag.Hidden { + t.Fatal("expected --as flag to be visible outside strict mode") + } + if got := flag.DefValue; got != "auto" { + t.Fatalf("default value = %q, want %q", got, "auto") + } +} + +func TestAddAPIIdentityFlag_StrictModeHidesFlagAndLocksDefault(t *testing.T) { + f, _, _, _ := TestFactory(t, &core.CliConfig{ + AppID: "a", AppSecret: "s", SupportedIdentities: 2, + }) + cmd := &cobra.Command{Use: "test"} + + AddAPIIdentityFlag(cmd, f, nil) + + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if !flag.Hidden { + t.Fatal("expected --as flag to be hidden in strict mode") + } + if got := flag.DefValue; got != "bot" { + t.Fatalf("default value = %q, want %q", got, "bot") + } +} + +func TestAddShortcutIdentityFlag_UsesAuthTypes(t *testing.T) { + f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) + cmd := &cobra.Command{Use: "test"} + + AddShortcutIdentityFlag(cmd, f, []string{"bot"}) + + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if flag.Hidden { + t.Fatal("expected --as flag to be visible outside strict mode") + } + if got := flag.DefValue; got != "bot" { + t.Fatalf("default value = %q, want %q", got, "bot") + } +} diff --git a/shortcuts/common/runner.go b/shortcuts/common/runner.go index 7d9685a13..46988902b 100644 --- a/shortcuts/common/runner.go +++ b/shortcuts/common/runner.go @@ -592,7 +592,7 @@ func (s Shortcut) mountDeclarative(parent *cobra.Command, f *cmdutil.Factory) { }, } cmdutil.SetSupportedIdentities(cmd, shortcut.AuthTypes) - registerShortcutFlags(cmd, &shortcut) + registerShortcutFlags(cmd, f, &shortcut) cmdutil.SetTips(cmd, shortcut.Tips) parent.AddCommand(cmd) } @@ -823,7 +823,7 @@ func rejectPositionalArgs() cobra.PositionalArgs { } } -func registerShortcutFlags(cmd *cobra.Command, s *Shortcut) { +func registerShortcutFlags(cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut) { for _, fl := range s.Flags { desc := fl.Desc if len(fl.Enum) > 0 { @@ -874,11 +874,7 @@ func registerShortcutFlags(cmd *cobra.Command, s *Shortcut) { cmd.Flags().Bool("yes", false, "confirm high-risk operation") } cmd.Flags().StringP("jq", "q", "", "jq expression to filter JSON output") - cmd.Flags().String("as", s.AuthTypes[0], "identity type: user | bot") - - _ = cmd.RegisterFlagCompletionFunc("as", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { - return s.AuthTypes, cobra.ShellCompDirectiveNoFileComp - }) + cmdutil.AddShortcutIdentityFlag(cmd, f, s.AuthTypes) if s.HasFormat { _ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { return []string{"json", "pretty", "table", "ndjson", "csv"}, cobra.ShellCompDirectiveNoFileComp diff --git a/shortcuts/common/runner_identity_flag_test.go b/shortcuts/common/runner_identity_flag_test.go new file mode 100644 index 000000000..a6ed1020b --- /dev/null +++ b/shortcuts/common/runner_identity_flag_test.go @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package common + +import ( + "context" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/spf13/cobra" +) + +func TestShortcutMount_StrictModeHidesAsFlag(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, SupportedIdentities: 2, + }) + parent := &cobra.Command{Use: "root"} + shortcut := Shortcut{ + Service: "docs", + Command: "+fetch", + Description: "fetch doc", + AuthTypes: []string{"user", "bot"}, + Execute: func(context.Context, *RuntimeContext) error { + return nil + }, + } + + shortcut.Mount(parent, f) + cmd, _, err := parent.Find([]string{"+fetch"}) + if err != nil { + t.Fatalf("Find() error = %v", err) + } + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if !flag.Hidden { + t.Fatal("expected --as flag to be hidden in strict mode") + } + if got := flag.DefValue; got != "bot" { + t.Fatalf("default value = %q, want %q", got, "bot") + } +} diff --git a/shortcuts/common/runner_jq_test.go b/shortcuts/common/runner_jq_test.go index 17af83dc7..949002672 100644 --- a/shortcuts/common/runner_jq_test.go +++ b/shortcuts/common/runner_jq_test.go @@ -145,10 +145,10 @@ func TestRuntimeContext_FileIO_UsesExecutionContext(t *testing.T) { } } -func newTestShortcutCmd(s *Shortcut) *cobra.Command { +func newTestShortcutCmd(s *Shortcut, f *cmdutil.Factory) *cobra.Command { cmd := &cobra.Command{Use: "test-shortcut"} cmd.SetContext(context.Background()) - registerShortcutFlags(cmd, s) + registerShortcutFlags(cmd, f, s) return cmd } @@ -177,7 +177,7 @@ func TestRunShortcut_JqAndFormatConflict(t *testing.T) { return nil }, } - cmd := newTestShortcutCmd(s) + cmd := newTestShortcutCmd(s, newTestFactory()) cmd.Flags().Set("jq", ".data") cmd.Flags().Set("format", "table") cmd.Flags().Set("as", "bot") @@ -200,7 +200,7 @@ func TestRunShortcut_JqInvalidExpression(t *testing.T) { return nil }, } - cmd := newTestShortcutCmd(s) + cmd := newTestShortcutCmd(s, newTestFactory()) cmd.Flags().Set("jq", "invalid[") cmd.Flags().Set("as", "bot") @@ -223,7 +223,7 @@ func TestRunShortcut_JqRuntimeError_PropagatesError(t *testing.T) { return nil }, } - cmd := newTestShortcutCmd(s) + cmd := newTestShortcutCmd(s, newTestFactory()) cmd.Flags().Set("jq", ".foo | invalid_func_xyz") cmd.Flags().Set("as", "bot") From bdd5b55569399794434bed1499c029ba098d0412 Mon Sep 17 00:00:00 2001 From: tuxedomm <273098272+tuxedomm@users.noreply.github.com> Date: Tue, 21 Apr 2026 14:00:18 +0800 Subject: [PATCH 4/7] fix(cmd): align strict-mode completion and build context; drop dead register shims MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Thread a context.Context through RegisterShortcuts, RegisterServiceCommands, and service.registerService/Resource/Method by introducing explicit *WithContext variants. Pass that context into NewCmdServiceMethodWithContext so shortcut and service command construction can honor cancellation and strict-mode pruning consistently. Also drop the context-less registerMethod and registerResource shims — they became unreachable once the WithContext variants took over, and were the source of new deadcode warnings. registerService is retained because service_test.go still calls it directly. Change-Id: I3fe5673aed663c7383bbbc5b0ae94d1f3491f22d --- cmd/api/api.go | 6 +- cmd/build.go | 6 +- cmd/schema/schema.go | 111 +++++++++++++------------ cmd/schema/schema_test.go | 46 ++++++++++ cmd/service/service.go | 26 ++++-- internal/cmdutil/identity_flag.go | 12 +-- internal/cmdutil/identity_flag_test.go | 7 +- shortcuts/common/runner.go | 16 +++- shortcuts/register.go | 8 +- 9 files changed, 161 insertions(+), 77 deletions(-) diff --git a/cmd/api/api.go b/cmd/api/api.go index e4f83b4d2..b6383707f 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -57,6 +57,10 @@ func normalisePath(raw string) string { // NewCmdApi creates the api command. If runF is non-nil it is called instead of apiRun (test hook). func NewCmdApi(f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command { + return NewCmdApiWithContext(context.Background(), f, runF) +} + +func NewCmdApiWithContext(ctx context.Context, f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command { opts := &APIOptions{Factory: f} var asStr string @@ -79,7 +83,7 @@ func NewCmdApi(f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command cmd.Flags().StringVar(&opts.Params, "params", "", "query parameters JSON (supports - for stdin)") cmd.Flags().StringVar(&opts.Data, "data", "", "request body JSON (supports - for stdin)") - cmdutil.AddAPIIdentityFlag(cmd, f, &asStr) + cmdutil.AddAPIIdentityFlag(ctx, cmd, f, &asStr) cmd.Flags().StringVarP(&opts.Output, "output", "o", "", "output file path for binary responses") cmd.Flags().BoolVar(&opts.PageAll, "page-all", false, "automatically paginate through all pages") cmd.Flags().IntVar(&opts.PageSize, "page-size", 0, "page size (0 = use API default)") diff --git a/cmd/build.go b/cmd/build.go index 2a0de48f2..fcba27d62 100644 --- a/cmd/build.go +++ b/cmd/build.go @@ -95,12 +95,12 @@ func buildInternal(ctx context.Context, inv cmdutil.InvocationContext, opts ...B rootCmd.AddCommand(auth.NewCmdAuth(f)) rootCmd.AddCommand(profile.NewCmdProfile(f)) rootCmd.AddCommand(doctor.NewCmdDoctor(f)) - rootCmd.AddCommand(api.NewCmdApi(f, nil)) + rootCmd.AddCommand(api.NewCmdApiWithContext(ctx, f, nil)) rootCmd.AddCommand(schema.NewCmdSchema(f, nil)) rootCmd.AddCommand(completion.NewCmdCompletion(f)) rootCmd.AddCommand(cmdupdate.NewCmdUpdate(f)) - service.RegisterServiceCommands(rootCmd, f) - shortcuts.RegisterShortcuts(rootCmd, f) + service.RegisterServiceCommandsWithContext(ctx, rootCmd, f) + shortcuts.RegisterShortcutsWithContext(ctx, rootCmd, f) // Prune commands incompatible with strict mode. if mode := f.ResolveStrictMode(ctx); mode.IsActive() { diff --git a/cmd/schema/schema.go b/cmd/schema/schema.go index 6dca541f9..152f37d24 100644 --- a/cmd/schema/schema.go +++ b/cmd/schema/schema.go @@ -375,7 +375,7 @@ func NewCmdSchema(f *cmdutil.Factory, runF func(*SchemaOptions) error) *cobra.Co } cmdutil.DisableAuthCheck(cmd) - cmd.ValidArgsFunction = completeSchemaPath + cmd.ValidArgsFunction = completeSchemaPath(f) cmd.Flags().StringVar(&opts.Format, "format", "json", "output format: json (default) | pretty") _ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { return []string{"json", "pretty"}, cobra.ShellCompDirectiveNoFileComp @@ -387,74 +387,81 @@ func NewCmdSchema(f *cmdutil.Factory, runF func(*SchemaOptions) error) *cobra.Co // completeSchemaPath provides tab-completion for the schema path argument. // It handles dotted resource names (e.g. app.table.fields) by iterating all // resources and classifying each as a prefix-match or fully-matched. -func completeSchemaPath(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - if len(args) > 0 { - return nil, cobra.ShellCompDirectiveNoFileComp - } +func completeSchemaPath(f *cmdutil.Factory) func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { + return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + if len(args) > 0 { + return nil, cobra.ShellCompDirectiveNoFileComp + } - parts := strings.Split(toComplete, ".") + parts := strings.Split(toComplete, ".") - // Level 1: complete service names - if len(parts) <= 1 { - var completions []string - for _, s := range registry.ListFromMetaProjects() { - if strings.HasPrefix(s, toComplete) { - completions = append(completions, s+".") + // Level 1: complete service names + if len(parts) <= 1 { + var completions []string + for _, s := range registry.ListFromMetaProjects() { + if strings.HasPrefix(s, toComplete) { + completions = append(completions, s+".") + } } + return completions, cobra.ShellCompDirectiveNoFileComp | cobra.ShellCompDirectiveNoSpace } - return completions, cobra.ShellCompDirectiveNoFileComp | cobra.ShellCompDirectiveNoSpace - } - serviceName := parts[0] - spec := registry.LoadFromMeta(serviceName) - if spec == nil { - return nil, cobra.ShellCompDirectiveNoFileComp - } - resources, _ := spec["resources"].(map[string]interface{}) - if resources == nil { - return nil, cobra.ShellCompDirectiveNoFileComp - } + serviceName := parts[0] + spec := registry.LoadFromMeta(serviceName) + if spec == nil { + return nil, cobra.ShellCompDirectiveNoFileComp + } + mode := f.ResolveStrictMode(cmd.Context()) + spec = filterSpecByStrictMode(spec, mode) + resources, _ := spec["resources"].(map[string]interface{}) + if resources == nil { + return nil, cobra.ShellCompDirectiveNoFileComp + } - // afterService = everything user typed after "serviceName." - afterService := strings.Join(parts[1:], ".") + afterService := strings.Join(parts[1:], ".") + completions := completeSchemaPathForSpec(serviceName, resources, afterService) + allTrailingDot := len(completions) > 0 + for _, c := range completions { + if !strings.HasSuffix(c, ".") { + allTrailingDot = false + break + } + } + directive := cobra.ShellCompDirectiveNoFileComp + if allTrailingDot { + directive |= cobra.ShellCompDirectiveNoSpace + } + return completions, directive + } +} + +func completeSchemaPathForSpec(serviceName string, resources map[string]interface{}, afterService string) []string { var completions []string for resName, resVal := range resources { if strings.HasPrefix(resName, afterService) { - // afterService is a prefix of this resource name → resource candidate completions = append(completions, serviceName+"."+resName+".") - } else if strings.HasPrefix(afterService, resName+".") { - // This resource is fully matched; remainder is method prefix - methodPrefix := afterService[len(resName)+1:] - resMap, _ := resVal.(map[string]interface{}) - if resMap == nil { - continue - } - methods, _ := resMap["methods"].(map[string]interface{}) - for methodName := range methods { - if strings.HasPrefix(methodName, methodPrefix) { - completions = append(completions, serviceName+"."+resName+"."+methodName) - } + continue + } + if !strings.HasPrefix(afterService, resName+".") { + continue + } + methodPrefix := afterService[len(resName)+1:] + resMap, _ := resVal.(map[string]interface{}) + if resMap == nil { + continue + } + methods, _ := resMap["methods"].(map[string]interface{}) + for methodName := range methods { + if strings.HasPrefix(methodName, methodPrefix) { + completions = append(completions, serviceName+"."+resName+"."+methodName) } } } sort.Strings(completions) - - // If all completions end with ".", user is still navigating resources → NoSpace - allTrailingDot := len(completions) > 0 - for _, c := range completions { - if !strings.HasSuffix(c, ".") { - allTrailingDot = false - break - } - } - directive := cobra.ShellCompDirectiveNoFileComp - if allTrailingDot { - directive |= cobra.ShellCompDirectiveNoSpace - } - return completions, directive + return completions } func schemaRun(opts *SchemaOptions) error { diff --git a/cmd/schema/schema_test.go b/cmd/schema/schema_test.go index 639822acc..da4129302 100644 --- a/cmd/schema/schema_test.go +++ b/cmd/schema/schema_test.go @@ -182,3 +182,49 @@ func TestHasFileFields(t *testing.T) { }) } } + +func TestCompleteSchemaPathForSpec(t *testing.T) { + resources := map[string]interface{}{ + "records": map[string]interface{}{ + "methods": map[string]interface{}{ + "create": map[string]interface{}{}, + "list": map[string]interface{}{}, + }, + }, + "record_permissions": map[string]interface{}{ + "methods": map[string]interface{}{ + "get": map[string]interface{}{}, + }, + }, + } + + got := completeSchemaPathForSpec("base", resources, "records.cr") + if len(got) != 1 || got[0] != "base.records.create" { + t.Fatalf("completions = %v, want [base.records.create]", got) + } + + got = completeSchemaPathForSpec("base", resources, "record") + if len(got) != 2 || got[0] != "base.record_permissions." || got[1] != "base.records." { + t.Fatalf("resource completions = %v", got) + } +} + +func TestFilterSpecByStrictMode_RemovesIncompatibleMethodsFromCompletionSource(t *testing.T) { + spec := map[string]interface{}{ + "resources": map[string]interface{}{ + "records": map[string]interface{}{ + "methods": map[string]interface{}{ + "list": map[string]interface{}{"accessTokens": []interface{}{"tenant"}}, + "create": map[string]interface{}{"accessTokens": []interface{}{"user"}}, + }, + }, + }, + } + + filtered := filterSpecByStrictMode(spec, core.StrictModeBot) + resources, _ := filtered["resources"].(map[string]interface{}) + got := completeSchemaPathForSpec("base", resources, "records.") + if len(got) != 1 || got[0] != "base.records.list" { + t.Fatalf("filtered completions = %v, want [base.records.list]", got) + } +} diff --git a/cmd/service/service.go b/cmd/service/service.go index f5d9ae1a6..808c80077 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -24,6 +24,10 @@ import ( // RegisterServiceCommands registers all service commands from from_meta specs. func RegisterServiceCommands(parent *cobra.Command, f *cmdutil.Factory) { + RegisterServiceCommandsWithContext(context.Background(), parent, f) +} + +func RegisterServiceCommandsWithContext(ctx context.Context, parent *cobra.Command, f *cmdutil.Factory) { for _, project := range registry.ListFromMetaProjects() { spec := registry.LoadFromMeta(project) if spec == nil { @@ -38,11 +42,15 @@ func RegisterServiceCommands(parent *cobra.Command, f *cmdutil.Factory) { if resources == nil { continue } - registerService(parent, spec, resources, f) + registerServiceWithContext(ctx, parent, spec, resources, f) } } func registerService(parent *cobra.Command, spec map[string]interface{}, resources map[string]interface{}, f *cmdutil.Factory) { + registerServiceWithContext(context.Background(), parent, spec, resources, f) +} + +func registerServiceWithContext(ctx context.Context, parent *cobra.Command, spec map[string]interface{}, resources map[string]interface{}, f *cmdutil.Factory) { specName := registry.GetStrFromMap(spec, "name") specDesc := registry.GetServiceDescription(specName, "en") if specDesc == "" { @@ -70,11 +78,11 @@ func registerService(parent *cobra.Command, spec map[string]interface{}, resourc if resMap == nil { continue } - registerResource(svc, spec, resName, resMap, f) + registerResourceWithContext(ctx, svc, spec, resName, resMap, f) } } -func registerResource(parent *cobra.Command, spec map[string]interface{}, name string, resource map[string]interface{}, f *cmdutil.Factory) { +func registerResourceWithContext(ctx context.Context, parent *cobra.Command, spec map[string]interface{}, name string, resource map[string]interface{}, f *cmdutil.Factory) { res := &cobra.Command{ Use: name, Short: name + " operations", @@ -87,7 +95,7 @@ func registerResource(parent *cobra.Command, spec map[string]interface{}, name s if methodMap == nil { continue } - registerMethod(res, spec, methodMap, methodName, name, f) + registerMethodWithContext(ctx, res, spec, methodMap, methodName, name, f) } } @@ -120,12 +128,16 @@ func detectFileFields(method map[string]interface{}) []string { return cmdutil.DetectFileFields(method) } -func registerMethod(parent *cobra.Command, spec map[string]interface{}, method map[string]interface{}, name string, resName string, f *cmdutil.Factory) { - parent.AddCommand(NewCmdServiceMethod(f, spec, method, name, resName, nil)) +func registerMethodWithContext(ctx context.Context, parent *cobra.Command, spec map[string]interface{}, method map[string]interface{}, name string, resName string, f *cmdutil.Factory) { + parent.AddCommand(NewCmdServiceMethodWithContext(ctx, f, spec, method, name, resName, nil)) } // NewCmdServiceMethod creates a command for a dynamically registered service method. func NewCmdServiceMethod(f *cmdutil.Factory, spec, method map[string]interface{}, name, resName string, runF func(*ServiceMethodOptions) error) *cobra.Command { + return NewCmdServiceMethodWithContext(context.Background(), f, spec, method, name, resName, runF) +} + +func NewCmdServiceMethodWithContext(ctx context.Context, f *cmdutil.Factory, spec, method map[string]interface{}, name, resName string, runF func(*ServiceMethodOptions) error) *cobra.Command { desc := registry.GetStrFromMap(method, "description") httpMethod := registry.GetStrFromMap(method, "httpMethod") specName := registry.GetStrFromMap(spec, "name") @@ -159,7 +171,7 @@ func NewCmdServiceMethod(f *cmdutil.Factory, spec, method map[string]interface{} case "POST", "PUT", "PATCH", "DELETE": cmd.Flags().StringVar(&opts.Data, "data", "", "request body JSON (supports - for stdin)") } - cmdutil.AddAPIIdentityFlag(cmd, f, &asStr) + cmdutil.AddAPIIdentityFlag(ctx, cmd, f, &asStr) cmd.Flags().StringVarP(&opts.Output, "output", "o", "", "output file path for binary responses") cmd.Flags().BoolVar(&opts.PageAll, "page-all", false, "automatically paginate through all pages") cmd.Flags().IntVar(&opts.PageLimit, "page-limit", 10, "max pages to fetch with --page-all (0 = unlimited)") diff --git a/internal/cmdutil/identity_flag.go b/internal/cmdutil/identity_flag.go index 99b7ed3df..c99d5c628 100644 --- a/internal/cmdutil/identity_flag.go +++ b/internal/cmdutil/identity_flag.go @@ -12,8 +12,8 @@ import ( ) // AddAPIIdentityFlag registers the standard --as flag shape used by api/service commands. -func AddAPIIdentityFlag(cmd *cobra.Command, f *Factory, target *string) { - addIdentityFlag(cmd, f, target, identityFlagConfig{ +func AddAPIIdentityFlag(ctx context.Context, cmd *cobra.Command, f *Factory, target *string) { + addIdentityFlag(ctx, cmd, f, target, identityFlagConfig{ defaultValue: "auto", usage: "identity type: user | bot | auto (default)", completionValues: []string{"user", "bot"}, @@ -21,11 +21,11 @@ func AddAPIIdentityFlag(cmd *cobra.Command, f *Factory, target *string) { } // AddShortcutIdentityFlag registers the standard --as flag shape used by shortcuts. -func AddShortcutIdentityFlag(cmd *cobra.Command, f *Factory, authTypes []string) { +func AddShortcutIdentityFlag(ctx context.Context, cmd *cobra.Command, f *Factory, authTypes []string) { if len(authTypes) == 0 { authTypes = []string{"user"} } - addIdentityFlag(cmd, f, nil, identityFlagConfig{ + addIdentityFlag(ctx, cmd, f, nil, identityFlagConfig{ defaultValue: authTypes[0], usage: "identity type: " + strings.Join(authTypes, " | "), completionValues: authTypes, @@ -41,8 +41,8 @@ type identityFlagConfig struct { // addIdentityFlag centralizes --as registration and strict-mode UX. // When strict mode is active, the flag is still accepted for compatibility // but hidden from help/completion and locked to the forced identity by default. -func addIdentityFlag(cmd *cobra.Command, f *Factory, target *string, cfg identityFlagConfig) { - if forced := f.ResolveStrictMode(context.Background()).ForcedIdentity(); forced != "" { +func addIdentityFlag(ctx context.Context, cmd *cobra.Command, f *Factory, target *string, cfg identityFlagConfig) { + if forced := f.ResolveStrictMode(ctx).ForcedIdentity(); forced != "" { // Keep registering --as in strict mode even though it is hidden. // This preserves parser compatibility for existing invocations that still pass // --as, and keeps downstream GetString("as") / ResolveAs paths stable. diff --git a/internal/cmdutil/identity_flag_test.go b/internal/cmdutil/identity_flag_test.go index 2f1350752..fa93d7263 100644 --- a/internal/cmdutil/identity_flag_test.go +++ b/internal/cmdutil/identity_flag_test.go @@ -4,6 +4,7 @@ package cmdutil import ( + "context" "testing" "github.com/larksuite/cli/internal/core" @@ -14,7 +15,7 @@ func TestAddAPIIdentityFlag_NonStrictMode(t *testing.T) { f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) cmd := &cobra.Command{Use: "test"} - AddAPIIdentityFlag(cmd, f, nil) + AddAPIIdentityFlag(context.Background(), cmd, f, nil) flag := cmd.Flags().Lookup("as") if flag == nil { @@ -34,7 +35,7 @@ func TestAddAPIIdentityFlag_StrictModeHidesFlagAndLocksDefault(t *testing.T) { }) cmd := &cobra.Command{Use: "test"} - AddAPIIdentityFlag(cmd, f, nil) + AddAPIIdentityFlag(context.Background(), cmd, f, nil) flag := cmd.Flags().Lookup("as") if flag == nil { @@ -52,7 +53,7 @@ func TestAddShortcutIdentityFlag_UsesAuthTypes(t *testing.T) { f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) cmd := &cobra.Command{Use: "test"} - AddShortcutIdentityFlag(cmd, f, []string{"bot"}) + AddShortcutIdentityFlag(context.Background(), cmd, f, []string{"bot"}) flag := cmd.Flags().Lookup("as") if flag == nil { diff --git a/shortcuts/common/runner.go b/shortcuts/common/runner.go index 46988902b..6a0009b9c 100644 --- a/shortcuts/common/runner.go +++ b/shortcuts/common/runner.go @@ -571,12 +571,16 @@ func enhancePermissionError(err error, requiredScopes []string) error { // Mount registers the shortcut on a parent command. func (s Shortcut) Mount(parent *cobra.Command, f *cmdutil.Factory) { + s.MountWithContext(context.Background(), parent, f) +} + +func (s Shortcut) MountWithContext(ctx context.Context, parent *cobra.Command, f *cmdutil.Factory) { if s.Execute != nil { - s.mountDeclarative(parent, f) + s.mountDeclarative(ctx, parent, f) } } -func (s Shortcut) mountDeclarative(parent *cobra.Command, f *cmdutil.Factory) { +func (s Shortcut) mountDeclarative(ctx context.Context, parent *cobra.Command, f *cmdutil.Factory) { shortcut := s if len(shortcut.AuthTypes) == 0 { shortcut.AuthTypes = []string{"user"} @@ -592,7 +596,7 @@ func (s Shortcut) mountDeclarative(parent *cobra.Command, f *cmdutil.Factory) { }, } cmdutil.SetSupportedIdentities(cmd, shortcut.AuthTypes) - registerShortcutFlags(cmd, f, &shortcut) + registerShortcutFlagsWithContext(ctx, cmd, f, &shortcut) cmdutil.SetTips(cmd, shortcut.Tips) parent.AddCommand(cmd) } @@ -824,6 +828,10 @@ func rejectPositionalArgs() cobra.PositionalArgs { } func registerShortcutFlags(cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut) { + registerShortcutFlagsWithContext(context.Background(), cmd, f, s) +} + +func registerShortcutFlagsWithContext(ctx context.Context, cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut) { for _, fl := range s.Flags { desc := fl.Desc if len(fl.Enum) > 0 { @@ -874,7 +882,7 @@ func registerShortcutFlags(cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut) cmd.Flags().Bool("yes", false, "confirm high-risk operation") } cmd.Flags().StringP("jq", "q", "", "jq expression to filter JSON output") - cmdutil.AddShortcutIdentityFlag(cmd, f, s.AuthTypes) + cmdutil.AddShortcutIdentityFlag(ctx, cmd, f, s.AuthTypes) if s.HasFormat { _ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { return []string{"json", "pretty", "table", "ndjson", "csv"}, cobra.ShellCompDirectiveNoFileComp diff --git a/shortcuts/register.go b/shortcuts/register.go index f5b085689..534163ec2 100644 --- a/shortcuts/register.go +++ b/shortcuts/register.go @@ -4,6 +4,8 @@ package shortcuts import ( + "context" + "github.com/larksuite/cli/shortcuts/okr" "github.com/spf13/cobra" @@ -58,6 +60,10 @@ func AllShortcuts() []common.Shortcut { // RegisterShortcuts registers all +shortcut commands on the program. func RegisterShortcuts(program *cobra.Command, f *cmdutil.Factory) { + RegisterShortcutsWithContext(context.Background(), program, f) +} + +func RegisterShortcutsWithContext(ctx context.Context, program *cobra.Command, f *cmdutil.Factory) { // Group by service byService := make(map[string][]common.Shortcut) for _, s := range allShortcuts { @@ -86,7 +92,7 @@ func RegisterShortcuts(program *cobra.Command, f *cmdutil.Factory) { } for _, shortcut := range shortcuts { - shortcut.Mount(svc, f) + shortcut.MountWithContext(ctx, svc, f) } } } From a8455dbc03a81eecdd259f3287b3bfe9f7427215 Mon Sep 17 00:00:00 2001 From: tuxedomm <273098272+tuxedomm@users.noreply.github.com> Date: Mon, 13 Apr 2026 22:53:50 +0800 Subject: [PATCH 5/7] refactor(cmd): hide --profile in single-app mode via build option MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - GlobalOptions gains HideProfile; RegisterGlobalFlags stays pure and reads the policy off the struct. No boolean-trap parameter, one call per site. - buildConfig holds GlobalOptions inline so HideProfile(bool) BuildOption mutates it directly. buildInternal stays a pure assembly function and requires callers to supply WithIO — no implicit os.Std* fallback. - Add WithIO BuildOption (wrapping raw io.Reader/Writer with automatic *os.File TTY detection); Execute injects streams explicitly and decides profile visibility via HideProfile(isSingleAppMode()). - installTipsHelpFunc force-shows hidden root flags while rendering the root command's own help, so single-app users still discover --profile via lark-cli --help without it polluting subcommand helps. Change-Id: I7755387e993992ca969e0a4a6f54441cc1993eef --- cmd/build.go | 29 ++++++++--- cmd/global_flags.go | 31 +++++++++-- cmd/global_flags_test.go | 110 +++++++++++++++++++++++++++++++++++++++ cmd/root.go | 17 +++++- 4 files changed, 174 insertions(+), 13 deletions(-) create mode 100644 cmd/global_flags_test.go diff --git a/cmd/build.go b/cmd/build.go index fcba27d62..92443f7a8 100644 --- a/cmd/build.go +++ b/cmd/build.go @@ -29,6 +29,7 @@ type BuildOption func(*buildConfig) type buildConfig struct { streams *cmdutil.IOStreams keychain keychain.KeychainAccess + globals GlobalOptions } // WithIO sets the IO streams for the CLI by wrapping raw reader/writers. @@ -46,6 +47,16 @@ func WithKeychain(kc keychain.KeychainAccess) BuildOption { } } +// HideProfile sets the visibility policy for the root-level --profile flag. +// When hide is true the flag stays registered (so existing invocations still +// parse) but is omitted from help and shell completion. Typically called as +// HideProfile(isSingleAppMode()). +func HideProfile(hide bool) BuildOption { + return func(c *buildConfig) { + c.globals.HideProfile = hide + } +} + // Build constructs the full command tree without executing. // Returns only the cobra.Command; Factory is internal. // Use Execute for the standard production entry point. @@ -54,11 +65,17 @@ func Build(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOpti return rootCmd } -// buildInternal is the internal constructor that also returns Factory for error handling. +// buildInternal is a pure assembly function: it wires the command tree from +// inv and BuildOptions alone. Any state-dependent decision (disk, network, +// env) belongs in the caller and must be threaded in via BuildOption. +// +// Callers must supply WithIO; buildInternal intentionally does not default +// the streams so tests and alternative entry points can't silently inherit +// os.Std*. func buildInternal(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOption) (*cmdutil.Factory, *cobra.Command) { - cfg := &buildConfig{ - streams: cmdutil.SystemIO(), - } + // cfg.globals.Profile is left zero here; it's bound to the --profile + // flag in RegisterGlobalFlags and filled by cobra's parse step. + cfg := &buildConfig{} for _, o := range opts { if o != nil { o(cfg) @@ -69,8 +86,6 @@ func buildInternal(ctx context.Context, inv cmdutil.InvocationContext, opts ...B if cfg.keychain != nil { f.Keychain = cfg.keychain } - - globals := &GlobalOptions{Profile: inv.Profile} rootCmd := &cobra.Command{ Use: "lark-cli", Short: "Lark/Feishu CLI — OAuth authorization, UAT management, API calls", @@ -86,7 +101,7 @@ func buildInternal(ctx context.Context, inv cmdutil.InvocationContext, opts ...B installTipsHelpFunc(rootCmd) rootCmd.SilenceErrors = true - RegisterGlobalFlags(rootCmd.PersistentFlags(), globals) + RegisterGlobalFlags(rootCmd.PersistentFlags(), &cfg.globals) rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { cmd.SilenceUsage = true } diff --git a/cmd/global_flags.go b/cmd/global_flags.go index d634cc4fd..b77e8f189 100644 --- a/cmd/global_flags.go +++ b/cmd/global_flags.go @@ -3,15 +3,38 @@ package cmd -import "github.com/spf13/pflag" +import ( + "github.com/larksuite/cli/internal/core" + "github.com/spf13/pflag" +) // GlobalOptions are the root-level flags shared by bootstrap parsing and the -// actual Cobra command tree. +// actual Cobra command tree. Profile is the parsed --profile value; HideProfile +// is a build-time policy — when true, --profile stays parseable but is marked +// hidden from help and shell completion. type GlobalOptions struct { - Profile string + Profile string + HideProfile bool } -// RegisterGlobalFlags registers the root-level persistent flags. +// RegisterGlobalFlags registers the root-level persistent flags on fs and +// applies any visibility policy encoded in opts. Pure function: no disk, +// network, or environment reads — the caller decides HideProfile. func RegisterGlobalFlags(fs *pflag.FlagSet, opts *GlobalOptions) { fs.StringVar(&opts.Profile, "profile", "", "use a specific profile") + if opts.HideProfile { + _ = fs.MarkHidden("profile") + } +} + +// isSingleAppMode reports whether the on-disk config has at most one app. +// Missing configs are treated as single-app since --profile is meaningless +// until at least two profiles exist. Intended for the Execute entry point — +// buildInternal must not call this directly to stay state-free. +func isSingleAppMode() bool { + raw, err := core.LoadMultiAppConfig() + if err != nil || raw == nil { + return true + } + return len(raw.Apps) <= 1 } diff --git a/cmd/global_flags_test.go b/cmd/global_flags_test.go new file mode 100644 index 000000000..c24d1573a --- /dev/null +++ b/cmd/global_flags_test.go @@ -0,0 +1,110 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "context" + "os" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/spf13/pflag" +) + +func testStreams() BuildOption { return WithIO(os.Stdin, os.Stdout, os.Stderr) } + +func TestRegisterGlobalFlags_PolicyVisible(t *testing.T) { + fs := pflag.NewFlagSet("test", pflag.ContinueOnError) + opts := &GlobalOptions{} + RegisterGlobalFlags(fs, opts) + + flag := fs.Lookup("profile") + if flag == nil { + t.Fatal("profile flag should be registered") + } + if flag.Hidden { + t.Fatal("profile flag should be visible when HideProfile is false") + } +} + +func TestRegisterGlobalFlags_PolicyHidden(t *testing.T) { + fs := pflag.NewFlagSet("test", pflag.ContinueOnError) + opts := &GlobalOptions{HideProfile: true} + RegisterGlobalFlags(fs, opts) + + flag := fs.Lookup("profile") + if flag == nil { + t.Fatal("profile flag should be registered") + } + if !flag.Hidden { + t.Fatal("profile flag should be hidden when HideProfile is true") + } + if err := fs.Parse([]string{"--profile", "x"}); err != nil { + t.Fatalf("Parse() error = %v; hidden flag should still parse", err) + } + if opts.Profile != "x" { + t.Fatalf("opts.Profile = %q, want %q", opts.Profile, "x") + } +} + +func TestIsSingleAppMode_NoConfig(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + if !isSingleAppMode() { + t.Fatal("isSingleAppMode() = false, want true when no config exists") + } +} + +func TestIsSingleAppMode_SingleApp(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + saveAppsForTest(t, []core.AppConfig{ + {Name: "default", AppId: "cli_a", AppSecret: core.PlainSecret("x"), Brand: core.BrandFeishu}, + }) + if !isSingleAppMode() { + t.Fatal("isSingleAppMode() = false, want true for single-app config") + } +} + +func TestIsSingleAppMode_MultiApp(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + saveAppsForTest(t, []core.AppConfig{ + {Name: "a", AppId: "cli_a", AppSecret: core.PlainSecret("x"), Brand: core.BrandFeishu}, + {Name: "b", AppId: "cli_b", AppSecret: core.PlainSecret("y"), Brand: core.BrandFeishu}, + }) + if isSingleAppMode() { + t.Fatal("isSingleAppMode() = true, want false for multi-app config") + } +} + +func TestBuildInternal_HideProfileOption(t *testing.T) { + _, root := buildInternal(context.Background(), cmdutil.InvocationContext{}, testStreams(), HideProfile(true)) + + flag := root.PersistentFlags().Lookup("profile") + if flag == nil { + t.Fatal("profile flag should be registered") + } + if !flag.Hidden { + t.Fatal("profile flag should be hidden when HideProfile(true) is applied") + } +} + +func TestBuildInternal_DefaultShowsProfileFlag(t *testing.T) { + _, root := buildInternal(context.Background(), cmdutil.InvocationContext{}, testStreams()) + + flag := root.PersistentFlags().Lookup("profile") + if flag == nil { + t.Fatal("profile flag should be registered by default") + } + if flag.Hidden { + t.Fatal("profile flag should be visible by default") + } +} + +func saveAppsForTest(t *testing.T, apps []core.AppConfig) { + t.Helper() + multi := &core.MultiAppConfig{CurrentApp: apps[0].Name, Apps: apps} + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } +} diff --git a/cmd/root.go b/cmd/root.go index 8088346a9..57c91d8b1 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -85,7 +85,11 @@ func Execute() int { fmt.Fprintln(os.Stderr, "Error:", err) return 1 } - f, rootCmd := buildInternal(context.Background(), inv) + f, rootCmd := buildInternal( + context.Background(), inv, + WithIO(os.Stdin, os.Stdout, os.Stderr), + HideProfile(isSingleAppMode()), + ) // --- Update check (non-blocking) --- if !isCompletionCommand(os.Args) { @@ -236,10 +240,19 @@ func writeSecurityPolicyError(w io.Writer, spErr *internalauth.SecurityPolicyErr } // installTipsHelpFunc wraps the default help function to append a TIPS section -// when a command has tips set via cmdutil.SetTips. +// when a command has tips set via cmdutil.SetTips. It also force-shows global +// flags that are normally hidden in single-app mode (currently --profile) +// when rendering the root command's own help, so users discovering the CLI +// still see them at `lark-cli --help`. func installTipsHelpFunc(root *cobra.Command) { defaultHelp := root.HelpFunc() root.SetHelpFunc(func(cmd *cobra.Command, args []string) { + if cmd == root { + if f := root.PersistentFlags().Lookup("profile"); f != nil && f.Hidden { + f.Hidden = false + defer func() { f.Hidden = true }() + } + } defaultHelp(cmd, args) tips := cmdutil.GetTips(cmd) if len(tips) == 0 { From fb1bc4d15e196c974e85e504db639d233f27cef5 Mon Sep 17 00:00:00 2001 From: tuxedomm <273098272+tuxedomm@users.noreply.github.com> Date: Tue, 21 Apr 2026 14:01:09 +0800 Subject: [PATCH 6/7] feat(transport): extension abort hook and shared base transport MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two transport-layer changes bundled because both reshape the base round-tripper contract used by the HTTP client, the Lark SDK client, and the in-process updater. 1. Extension abort hook (PreRoundTripE). Extensions implementing exttransport.AbortableInterceptor can now return an error from PreRoundTripE to skip the built-in chain. The post hook still fires with (nil, reason) so extensions can unwind resources. extensionMiddleware captures the provider name so the returned *AbortError carries attribution. 2. Shared base transport to stop RPC leak. util.NewBaseTransport cloned http.DefaultTransport on every call, so each cmdutil.Factory produced a fresh *http.Transport whose persistConn readLoop/writeLoop goroutines lingered until IdleConnTimeout (~90s). Invisible in a single-process CLI, but the fork is consumed by cli-server where each RPC request constructs a new Factory, causing linear memory + goroutine growth under load. Replace NewBaseTransport with SharedTransport — returns http.DefaultTransport (the stdlib-wide singleton) by default, and a cached proxy-disabled clone only when LARK_CLI_NO_PROXY is set. Return type is http.RoundTripper to discourage in-place mutation of the shared instance. FallbackTransport is kept as a thin *http.Transport wrapper so existing callers in internal/auth and internal/cmdutil transport decorators (which were already on the singleton path) do not have to migrate. Leak-site migrations: factory_default.go (HTTP + SDK base) and update.go now call SharedTransport directly. Change-Id: Ia82462134c5c5ee838be878b887860f41446a235 --- extension/transport/errors.go | 51 +++ extension/transport/errors_test.go | 103 ++++++ extension/transport/types.go | 25 ++ internal/cmdutil/factory_default.go | 4 +- internal/cmdutil/factory_default_test.go | 187 ----------- internal/cmdutil/retry_transport_test.go | 81 ----- internal/cmdutil/transport.go | 45 ++- internal/cmdutil/transport_test.go | 408 +++++++++++++++++++++++ internal/update/update.go | 2 +- internal/util/proxy.go | 50 ++- internal/util/proxy_test.go | 94 +++--- 11 files changed, 713 insertions(+), 337 deletions(-) create mode 100644 extension/transport/errors.go create mode 100644 extension/transport/errors_test.go delete mode 100644 internal/cmdutil/retry_transport_test.go create mode 100644 internal/cmdutil/transport_test.go diff --git a/extension/transport/errors.go b/extension/transport/errors.go new file mode 100644 index 000000000..9ebe907e7 --- /dev/null +++ b/extension/transport/errors.go @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package transport + +import ( + "errors" + "fmt" +) + +// ErrAborted is a sentinel matched by errors.Is on any extension-triggered +// round-trip abort. Callers that only need to know whether an error was +// caused by an extension interception should use: +// +// if errors.Is(err, transport.ErrAborted) { ... } +var ErrAborted = errors.New("round trip aborted by extension") + +// AbortError is returned by the built-in middleware when an AbortableInterceptor +// short-circuits a request via PreRoundTripE. It wraps the extension's original +// reason and carries the extension's Provider.Name() for traceability. +// +// Use errors.As to recover the typed error: +// +// var aErr *transport.AbortError +// if errors.As(err, &aErr) { +// log.Printf("blocked by %s: %v", aErr.Extension, aErr.Reason) +// } +// +// errors.Is(err, transport.ErrAborted) also works, and errors.Is against the +// inner reason still works via Unwrap. +type AbortError struct { + // Extension is the name of the Provider whose interceptor aborted the + // request (from Provider.Name()). May be empty if the provider did not + // supply a name. + Extension string + // Reason is the original non-nil error returned by PreRoundTripE. + Reason error +} + +func (e *AbortError) Error() string { + if e.Extension != "" { + return fmt.Sprintf("extension %q aborted round trip: %v", e.Extension, e.Reason) + } + return fmt.Sprintf("extension aborted round trip: %v", e.Reason) +} + +// Unwrap lets errors.Is / errors.As traverse to the underlying Reason. +func (e *AbortError) Unwrap() error { return e.Reason } + +// Is enables errors.Is(err, ErrAborted) at any nesting depth. +func (e *AbortError) Is(target error) bool { return target == ErrAborted } diff --git a/extension/transport/errors_test.go b/extension/transport/errors_test.go new file mode 100644 index 000000000..31932e45f --- /dev/null +++ b/extension/transport/errors_test.go @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package transport + +import ( + "errors" + "fmt" + "testing" +) + +func TestAbortError_Error(t *testing.T) { + tests := []struct { + name string + err *AbortError + want string + }{ + { + name: "with extension name", + err: &AbortError{Extension: "audit", Reason: errors.New("bad")}, + want: `extension "audit" aborted round trip: bad`, + }, + { + name: "without extension name", + err: &AbortError{Reason: errors.New("bad")}, + want: "extension aborted round trip: bad", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.err.Error(); got != tt.want { + t.Fatalf("Error() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestAbortError_Unwrap(t *testing.T) { + reason := errors.New("bad") + e := &AbortError{Reason: reason} + if got := e.Unwrap(); got != reason { + t.Fatalf("Unwrap() = %v, want %v", got, reason) + } +} + +func TestAbortError_IsErrAborted(t *testing.T) { + e := &AbortError{Reason: errors.New("bad")} + if !errors.Is(e, ErrAborted) { + t.Fatal("errors.Is(e, ErrAborted) = false, want true") + } + // Sanity: not matched by unrelated sentinels. + if errors.Is(e, errors.New("other")) { + t.Fatal("errors.Is matched unrelated sentinel") + } +} + +func TestAbortError_UnwrapReachesInnerSentinel(t *testing.T) { + // Extensions often return typed/sentinel errors; callers should still be + // able to errors.Is against those after the middleware wraps them. + innerSentinel := errors.New("policy-deny-42") + e := &AbortError{Reason: fmt.Errorf("wrapped: %w", innerSentinel)} + if !errors.Is(e, innerSentinel) { + t.Fatal("errors.Is(e, innerSentinel) = false, want true (Unwrap chain broken)") + } +} + +func TestAbortError_As(t *testing.T) { + reason := errors.New("bad") + base := &AbortError{Extension: "audit", Reason: reason} + + // Direct As. + var aErr *AbortError + if !errors.As(base, &aErr) { + t.Fatal("errors.As(base, *AbortError) = false") + } + if aErr.Extension != "audit" || aErr.Reason != reason { + t.Fatalf("aErr = %+v, want {audit, bad}", aErr) + } + + // Nested As: even when the *AbortError is wrapped in another error, + // errors.As must still find it via Unwrap chain. + wrapped := fmt.Errorf("outer: %w", base) + var aErr2 *AbortError + if !errors.As(wrapped, &aErr2) { + t.Fatal("errors.As(wrapped, *AbortError) = false") + } + if aErr2 != base { + t.Fatalf("aErr2 = %p, want %p", aErr2, base) + } + + // errors.Is still matches the sentinel through the outer wrapper. + if !errors.Is(wrapped, ErrAborted) { + t.Fatal("errors.Is(wrapped, ErrAborted) = false via nested wrap") + } +} + +func TestErrAborted_IsItselfSentinel(t *testing.T) { + // Guard against accidental re-assignment of ErrAborted: a bare ErrAborted + // value should still satisfy errors.Is(err, ErrAborted) for symmetry. + if !errors.Is(ErrAborted, ErrAborted) { + t.Fatal("errors.Is(ErrAborted, ErrAborted) = false") + } +} diff --git a/extension/transport/types.go b/extension/transport/types.go index e60c4018d..c74e36866 100644 --- a/extension/transport/types.go +++ b/extension/transport/types.go @@ -27,6 +27,31 @@ type Provider interface { // // The returned function (if non-nil) is called after the built-in chain // completes. Use it for logging, ending trace spans, or recording metrics. +// +// Body note: the middleware Clones the caller's request before invoking the +// interceptor, which copies headers/URL/etc. but shares the underlying +// io.ReadCloser. Extensions that read req.Body are responsible for restoring +// a replayable body (e.g. via req.GetBody) before returning, otherwise the +// built-in chain will see an exhausted stream. type Interceptor interface { PreRoundTrip(req *http.Request) func(resp *http.Response, err error) } + +// AbortableInterceptor is an optional extension of Interceptor that lets an +// extension reject a request before the built-in chain runs. Extensions that +// implement this interface are detected by the built-in middleware via a +// type assertion; both methods must be present, but when an extension +// implements PreRoundTripE the middleware will NOT call PreRoundTrip. +// +// Returning a non-nil error from PreRoundTripE aborts the request: the +// built-in chain is not executed and the middleware returns an *AbortError +// wrapping the reason. The returned post function (if non-nil) is still +// invoked with (nil, reason) so that extensions can unwind any state they +// created in the pre hook (spans, metrics, audit records). +// +// Extensions that only care about the abortable variant can provide a no-op +// PreRoundTrip method alongside PreRoundTripE to satisfy Interceptor. +type AbortableInterceptor interface { + Interceptor + PreRoundTripE(req *http.Request) (post func(resp *http.Response, err error), err error) +} diff --git a/internal/cmdutil/factory_default.go b/internal/cmdutil/factory_default.go index eed433454..c1dc10817 100644 --- a/internal/cmdutil/factory_default.go +++ b/internal/cmdutil/factory_default.go @@ -92,7 +92,7 @@ func cachedHttpClientFunc(f *Factory) func() (*http.Client, error) { return sync.OnceValues(func() (*http.Client, error) { util.WarnIfProxied(f.IOStreams.ErrOut) - var transport http.RoundTripper = util.NewBaseTransport() + var transport http.RoundTripper = util.SharedTransport() transport = &RetryTransport{Base: transport} transport = &SecurityHeaderTransport{Base: transport} transport = &auth.SecurityPolicyTransport{Base: transport} // Add our global response interceptor @@ -129,7 +129,7 @@ func cachedLarkClientFunc(f *Factory) func() (*lark.Client, error) { } func buildSDKTransport() http.RoundTripper { - var sdkTransport http.RoundTripper = util.NewBaseTransport() + var sdkTransport http.RoundTripper = util.SharedTransport() sdkTransport = &RetryTransport{Base: sdkTransport} sdkTransport = &UserAgentTransport{Base: sdkTransport} sdkTransport = &auth.SecurityPolicyTransport{Base: sdkTransport} diff --git a/internal/cmdutil/factory_default_test.go b/internal/cmdutil/factory_default_test.go index 7204e2de9..9ec8e7ec3 100644 --- a/internal/cmdutil/factory_default_test.go +++ b/internal/cmdutil/factory_default_test.go @@ -6,14 +6,10 @@ package cmdutil import ( "context" "errors" - "net/http" - "net/http/httptest" "testing" _ "github.com/larksuite/cli/extension/credential/env" "github.com/larksuite/cli/extension/fileio" - exttransport "github.com/larksuite/cli/extension/transport" - internalauth "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/envvars" @@ -120,22 +116,6 @@ func TestNewDefault_InvocationProfileMissingSticksAcrossEarlyStrictMode(t *testi } } -func TestBuildSDKTransport_IncludesRetryTransport(t *testing.T) { - transport := buildSDKTransport() - - sec, ok := transport.(*internalauth.SecurityPolicyTransport) - if !ok { - t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport) - } - ua, ok := sec.Base.(*UserAgentTransport) - if !ok { - t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base) - } - if _, ok := ua.Base.(*RetryTransport); !ok { - t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base) - } -} - func TestNewDefault_ResolveAs_UsesDefaultAsFromEnvAccount(t *testing.T) { t.Setenv(envvars.CliAppID, "env-app") t.Setenv(envvars.CliAppSecret, "env-secret") @@ -232,170 +212,3 @@ func TestNewDefault_FileIOProviderDoesNotResolveDuringInitialization(t *testing. t.Fatalf("ResolveFileIO() calls after explicit resolve = %d, want 1", provider.resolveCalls) } } - -type stubTransportProvider struct { - interceptor exttransport.Interceptor -} - -func (s *stubTransportProvider) Name() string { return "stub" } -func (s *stubTransportProvider) ResolveInterceptor(context.Context) exttransport.Interceptor { - if s.interceptor != nil { - return s.interceptor - } - return &stubTransportImpl{} -} - -type stubTransportImpl struct{} - -func (s *stubTransportImpl) PreRoundTrip(req *http.Request) func(*http.Response, error) { - return nil -} - -// headerCapturingInterceptor sets custom headers in PreRoundTrip and records -// whether PostRoundTrip was called, to verify execution order. -type headerCapturingInterceptor struct { - preCalled bool - postCalled bool -} - -func (h *headerCapturingInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) { - h.preCalled = true - // Set a custom header that should survive (no built-in override) - req.Header.Set("X-Custom-Trace", "ext-trace-123") - // Try to override a security header — should be overwritten by SecurityHeaderTransport - req.Header.Set(HeaderSource, "ext-tampered") - return func(resp *http.Response, err error) { - h.postCalled = true - } -} - -func TestExtensionInterceptor_ExecutionOrder(t *testing.T) { - var receivedHeaders http.Header - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.WriteHeader(http.StatusOK) - })) - defer srv.Close() - - ic := &headerCapturingInterceptor{} - exttransport.Register(&stubTransportProvider{interceptor: ic}) - t.Cleanup(func() { exttransport.Register(nil) }) - - // Use HTTP transport chain (has SecurityHeaderTransport) - var base http.RoundTripper = http.DefaultTransport - base = &RetryTransport{Base: base} - base = &SecurityHeaderTransport{Base: base} - transport := wrapWithExtension(base) - client := &http.Client{Transport: transport} - - req, _ := http.NewRequest("GET", srv.URL, nil) - resp, err := client.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - resp.Body.Close() - - // PreRoundTrip was called - if !ic.preCalled { - t.Fatal("PreRoundTrip was not called") - } - // PostRoundTrip (closure) was called - if !ic.postCalled { - t.Fatal("PostRoundTrip closure was not called") - } - // Custom header set by extension survives (no built-in override) - if got := receivedHeaders.Get("X-Custom-Trace"); got != "ext-trace-123" { - t.Fatalf("X-Custom-Trace = %q, want %q", got, "ext-trace-123") - } - // Security header overridden by extension is restored by SecurityHeaderTransport - if got := receivedHeaders.Get(HeaderSource); got != SourceValue { - t.Fatalf("%s = %q, want %q (built-in should override extension)", HeaderSource, got, SourceValue) - } -} - -func TestExtensionInterceptor_ContextTamperPrevented(t *testing.T) { - type ctxKeyType string - const testKey ctxKeyType = "original" - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer srv.Close() - - var ctxValue any - - // Use a custom transport that captures the context value seen by the built-in chain - capturer := roundTripFunc(func(req *http.Request) (*http.Response, error) { - ctxValue = req.Context().Value(testKey) - return http.DefaultTransport.RoundTrip(req) - }) - - // Interceptor that tries to tamper with context - tamperIC := interceptorFunc(func(req *http.Request) func(*http.Response, error) { - // Try to replace context with a new one - *req = *req.WithContext(context.WithValue(req.Context(), testKey, "tampered")) - return nil - }) - - mid := &extensionMiddleware{Base: capturer, Ext: tamperIC} - - origCtx := context.WithValue(context.Background(), testKey, "original") - req, _ := http.NewRequestWithContext(origCtx, "GET", srv.URL, nil) - resp, err := mid.RoundTrip(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - resp.Body.Close() - - // Built-in chain should see original context, not tampered - if ctxValue != "original" { - t.Fatalf("built-in chain saw context value %q, want %q", ctxValue, "original") - } -} - -// interceptorFunc adapts a function to exttransport.Interceptor. -type interceptorFunc func(*http.Request) func(*http.Response, error) - -func (f interceptorFunc) PreRoundTrip(req *http.Request) func(*http.Response, error) { return f(req) } - -func TestBuildSDKTransport_WithExtension(t *testing.T) { - exttransport.Register(&stubTransportProvider{}) - t.Cleanup(func() { exttransport.Register(nil) }) - - transport := buildSDKTransport() - - // Chain: extensionMiddleware → SecurityPolicy → UserAgent → Retry → Base - mid, ok := transport.(*extensionMiddleware) - if !ok { - t.Fatalf("outer transport type = %T, want *extensionMiddleware", transport) - } - sec, ok := mid.Base.(*internalauth.SecurityPolicyTransport) - if !ok { - t.Fatalf("transport type = %T, want *auth.SecurityPolicyTransport", mid.Base) - } - ua, ok := sec.Base.(*UserAgentTransport) - if !ok { - t.Fatalf("transport type = %T, want *UserAgentTransport", sec.Base) - } - if _, ok := ua.Base.(*RetryTransport); !ok { - t.Fatalf("innermost transport type = %T, want *RetryTransport", ua.Base) - } -} - -func TestBuildSDKTransport_WithoutExtension(t *testing.T) { - exttransport.Register(nil) - - transport := buildSDKTransport() - - sec, ok := transport.(*internalauth.SecurityPolicyTransport) - if !ok { - t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport) - } - ua, ok := sec.Base.(*UserAgentTransport) - if !ok { - t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base) - } - if _, ok := ua.Base.(*RetryTransport); !ok { - t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base) - } -} diff --git a/internal/cmdutil/retry_transport_test.go b/internal/cmdutil/retry_transport_test.go deleted file mode 100644 index a10782acf..000000000 --- a/internal/cmdutil/retry_transport_test.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) 2026 Lark Technologies Pte. Ltd. -// SPDX-License-Identifier: MIT - -package cmdutil - -import ( - "io" - "net/http" - "strings" - "testing" - "time" -) - -type roundTripFunc func(*http.Request) (*http.Response, error) - -func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { - return f(req) -} - -func TestRetryTransport_NoRetry(t *testing.T) { - calls := 0 - base := roundTripFunc(func(req *http.Request) (*http.Response, error) { - calls++ - return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok"))}, nil - }) - rt := &RetryTransport{Base: base, MaxRetries: 0} - req, _ := http.NewRequest("GET", "http://example.com/test", nil) - resp, err := rt.RoundTrip(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.StatusCode != 200 { - t.Errorf("expected 200, got %d", resp.StatusCode) - } - if calls != 1 { - t.Errorf("expected 1 call, got %d", calls) - } -} - -func TestRetryTransport_RetryOn500(t *testing.T) { - calls := 0 - base := roundTripFunc(func(req *http.Request) (*http.Response, error) { - calls++ - if calls < 3 { - return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("error"))}, nil - } - return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok"))}, nil - }) - rt := &RetryTransport{Base: base, MaxRetries: 3, Delay: 1 * time.Millisecond} - req, _ := http.NewRequest("GET", "http://example.com/test", nil) - resp, err := rt.RoundTrip(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.StatusCode != 200 { - t.Errorf("expected 200 after retries, got %d", resp.StatusCode) - } - if calls != 3 { - t.Errorf("expected 3 calls, got %d", calls) - } -} - -func TestRetryTransport_DefaultNoRetry(t *testing.T) { - calls := 0 - base := roundTripFunc(func(req *http.Request) (*http.Response, error) { - calls++ - return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("error"))}, nil - }) - rt := &RetryTransport{Base: base} // default MaxRetries=0 - req, _ := http.NewRequest("GET", "http://example.com/test", nil) - resp, err := rt.RoundTrip(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.StatusCode != 500 { - t.Errorf("expected 500 with no retries, got %d", resp.StatusCode) - } - if calls != 1 { - t.Errorf("expected 1 call with default config, got %d", calls) - } -} diff --git a/internal/cmdutil/transport.go b/internal/cmdutil/transport.go index b922adf11..627755ea6 100644 --- a/internal/cmdutil/transport.go +++ b/internal/cmdutil/transport.go @@ -104,20 +104,47 @@ func (t *SecurityHeaderTransport) RoundTrip(req *http.Request) (*http.Response, } // extensionMiddleware wraps the built-in transport chain with pre/post hooks. -// The built-in chain always executes and cannot be skipped or overridden. -// The original request context is restored after PreRoundTrip to prevent +// The built-in chain always executes unless the extension is an +// exttransport.AbortableInterceptor and its PreRoundTripE returns a non-nil +// error; it cannot otherwise be skipped or overridden. +// +// The original request context is restored after the pre hook to prevent // extensions from tampering with cancellation, deadlines, or built-in values. +// Cloning the request isolates header/URL/etc. mutations from the caller's +// request object; req.Body is intentionally shared — extensions that consume +// it are responsible for rewinding (see Interceptor doc). type extensionMiddleware struct { - Base http.RoundTripper - Ext exttransport.Interceptor + Base http.RoundTripper + Ext exttransport.Interceptor + ExtName string // Provider.Name(), captured at wrap time for *AbortError.Extension } -// RoundTrip calls PreRoundTrip, restores the original context, executes -// the built-in chain, then calls the post hook if non-nil. +// RoundTrip invokes the interceptor pre hook, restores the original context, +// executes the built-in chain (unless aborted), then calls the post hook if +// non-nil. When the extension implements AbortableInterceptor and returns a +// non-nil error from PreRoundTripE, the built-in chain is skipped and an +// *exttransport.AbortError is returned; the post hook is still invoked with +// (nil, reason) so extensions can unwind resources. func (m *extensionMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { origCtx := req.Context() - req = req.Clone(origCtx) // isolate caller's request before extension mutations - post := m.Ext.PreRoundTrip(req) + req = req.Clone(origCtx) + + var ( + post func(*http.Response, error) + abortEr error + ) + if a, ok := m.Ext.(exttransport.AbortableInterceptor); ok { + post, abortEr = a.PreRoundTripE(req) + } else { + post = m.Ext.PreRoundTrip(req) + } + if abortEr != nil { + if post != nil { + post(nil, abortEr) + } + return nil, &exttransport.AbortError{Extension: m.ExtName, Reason: abortEr} + } + req = req.WithContext(origCtx) // restore original context resp, err := m.Base.RoundTrip(req) if post != nil { @@ -137,5 +164,5 @@ func wrapWithExtension(transport http.RoundTripper) http.RoundTripper { if tr == nil { return transport } - return &extensionMiddleware{Base: transport, Ext: tr} + return &extensionMiddleware{Base: transport, Ext: tr, ExtName: p.Name()} } diff --git a/internal/cmdutil/transport_test.go b/internal/cmdutil/transport_test.go new file mode 100644 index 000000000..ced9e1e7e --- /dev/null +++ b/internal/cmdutil/transport_test.go @@ -0,0 +1,408 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdutil + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + exttransport "github.com/larksuite/cli/extension/transport" + internalauth "github.com/larksuite/cli/internal/auth" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// --------------------------------------------------------------------------- +// RetryTransport +// --------------------------------------------------------------------------- + +func TestRetryTransport_NoRetry(t *testing.T) { + calls := 0 + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + calls++ + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok"))}, nil + }) + rt := &RetryTransport{Base: base, MaxRetries: 0} + req, _ := http.NewRequest("GET", "http://example.com/test", nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if calls != 1 { + t.Errorf("expected 1 call, got %d", calls) + } +} + +func TestRetryTransport_RetryOn500(t *testing.T) { + calls := 0 + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + calls++ + if calls < 3 { + return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("error"))}, nil + } + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok"))}, nil + }) + rt := &RetryTransport{Base: base, MaxRetries: 3, Delay: 1 * time.Millisecond} + req, _ := http.NewRequest("GET", "http://example.com/test", nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("expected 200 after retries, got %d", resp.StatusCode) + } + if calls != 3 { + t.Errorf("expected 3 calls, got %d", calls) + } +} + +func TestRetryTransport_DefaultNoRetry(t *testing.T) { + calls := 0 + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + calls++ + return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("error"))}, nil + }) + rt := &RetryTransport{Base: base} // default MaxRetries=0 + req, _ := http.NewRequest("GET", "http://example.com/test", nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 500 { + t.Errorf("expected 500 with no retries, got %d", resp.StatusCode) + } + if calls != 1 { + t.Errorf("expected 1 call with default config, got %d", calls) + } +} + +// --------------------------------------------------------------------------- +// buildSDKTransport chain composition +// --------------------------------------------------------------------------- + +func TestBuildSDKTransport_IncludesRetryTransport(t *testing.T) { + transport := buildSDKTransport() + + sec, ok := transport.(*internalauth.SecurityPolicyTransport) + if !ok { + t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport) + } + ua, ok := sec.Base.(*UserAgentTransport) + if !ok { + t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base) + } + if _, ok := ua.Base.(*RetryTransport); !ok { + t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base) + } +} + +func TestBuildSDKTransport_WithExtension(t *testing.T) { + exttransport.Register(&stubTransportProvider{}) + t.Cleanup(func() { exttransport.Register(nil) }) + + transport := buildSDKTransport() + + // Chain: extensionMiddleware → SecurityPolicy → UserAgent → Retry → Base + mid, ok := transport.(*extensionMiddleware) + if !ok { + t.Fatalf("outer transport type = %T, want *extensionMiddleware", transport) + } + sec, ok := mid.Base.(*internalauth.SecurityPolicyTransport) + if !ok { + t.Fatalf("transport type = %T, want *auth.SecurityPolicyTransport", mid.Base) + } + ua, ok := sec.Base.(*UserAgentTransport) + if !ok { + t.Fatalf("transport type = %T, want *UserAgentTransport", sec.Base) + } + if _, ok := ua.Base.(*RetryTransport); !ok { + t.Fatalf("innermost transport type = %T, want *RetryTransport", ua.Base) + } +} + +func TestBuildSDKTransport_WithoutExtension(t *testing.T) { + exttransport.Register(nil) + + transport := buildSDKTransport() + + sec, ok := transport.(*internalauth.SecurityPolicyTransport) + if !ok { + t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport) + } + ua, ok := sec.Base.(*UserAgentTransport) + if !ok { + t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base) + } + if _, ok := ua.Base.(*RetryTransport); !ok { + t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base) + } +} + +// --------------------------------------------------------------------------- +// extensionMiddleware — legacy Interceptor path +// --------------------------------------------------------------------------- + +type stubTransportProvider struct { + interceptor exttransport.Interceptor +} + +func (s *stubTransportProvider) Name() string { return "stub" } +func (s *stubTransportProvider) ResolveInterceptor(context.Context) exttransport.Interceptor { + if s.interceptor != nil { + return s.interceptor + } + return &stubTransportImpl{} +} + +type stubTransportImpl struct{} + +func (s *stubTransportImpl) PreRoundTrip(req *http.Request) func(*http.Response, error) { + return nil +} + +// headerCapturingInterceptor sets custom headers in PreRoundTrip and records +// whether PostRoundTrip was called, to verify execution order. +type headerCapturingInterceptor struct { + preCalled bool + postCalled bool +} + +func (h *headerCapturingInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) { + h.preCalled = true + // Set a custom header that should survive (no built-in override) + req.Header.Set("X-Custom-Trace", "ext-trace-123") + // Try to override a security header — should be overwritten by SecurityHeaderTransport + req.Header.Set(HeaderSource, "ext-tampered") + return func(resp *http.Response, err error) { + h.postCalled = true + } +} + +func TestExtensionInterceptor_ExecutionOrder(t *testing.T) { + var receivedHeaders http.Header + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ic := &headerCapturingInterceptor{} + exttransport.Register(&stubTransportProvider{interceptor: ic}) + t.Cleanup(func() { exttransport.Register(nil) }) + + // Use HTTP transport chain (has SecurityHeaderTransport) + var base http.RoundTripper = http.DefaultTransport + base = &RetryTransport{Base: base} + base = &SecurityHeaderTransport{Base: base} + transport := wrapWithExtension(base) + client := &http.Client{Transport: transport} + + req, _ := http.NewRequest("GET", srv.URL, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + // PreRoundTrip was called + if !ic.preCalled { + t.Fatal("PreRoundTrip was not called") + } + // PostRoundTrip (closure) was called + if !ic.postCalled { + t.Fatal("PostRoundTrip closure was not called") + } + // Custom header set by extension survives (no built-in override) + if got := receivedHeaders.Get("X-Custom-Trace"); got != "ext-trace-123" { + t.Fatalf("X-Custom-Trace = %q, want %q", got, "ext-trace-123") + } + // Security header overridden by extension is restored by SecurityHeaderTransport + if got := receivedHeaders.Get(HeaderSource); got != SourceValue { + t.Fatalf("%s = %q, want %q (built-in should override extension)", HeaderSource, got, SourceValue) + } +} + +// interceptorFunc adapts a function to exttransport.Interceptor. +type interceptorFunc func(*http.Request) func(*http.Response, error) + +func (f interceptorFunc) PreRoundTrip(req *http.Request) func(*http.Response, error) { return f(req) } + +func TestExtensionInterceptor_ContextTamperPrevented(t *testing.T) { + type ctxKeyType string + const testKey ctxKeyType = "original" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + var ctxValue any + + // Use a custom transport that captures the context value seen by the built-in chain + capturer := roundTripFunc(func(req *http.Request) (*http.Response, error) { + ctxValue = req.Context().Value(testKey) + return http.DefaultTransport.RoundTrip(req) + }) + + // Interceptor that tries to tamper with context + tamperIC := interceptorFunc(func(req *http.Request) func(*http.Response, error) { + // Try to replace context with a new one + *req = *req.WithContext(context.WithValue(req.Context(), testKey, "tampered")) + return nil + }) + + mid := &extensionMiddleware{Base: capturer, Ext: tamperIC} + + origCtx := context.WithValue(context.Background(), testKey, "original") + req, _ := http.NewRequestWithContext(origCtx, "GET", srv.URL, nil) + resp, err := mid.RoundTrip(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + // Built-in chain should see original context, not tampered + if ctxValue != "original" { + t.Fatalf("built-in chain saw context value %q, want %q", ctxValue, "original") + } +} + +// --------------------------------------------------------------------------- +// extensionMiddleware — PreRoundTripE abort path +// --------------------------------------------------------------------------- + +// abortingInterceptor implements exttransport.AbortableInterceptor and +// records invocation of the pre and post hooks. These middleware tests only +// assert middleware-level integration; pure *AbortError behavior +// (Error/Unwrap/Is/As) is covered in extension/transport/errors_test.go. +type abortingInterceptor struct { + reason error // if non-nil, PreRoundTripE returns this to abort + nilPost bool // if true, PreRoundTripE returns a nil post func + preECalled bool + postCalled bool + postResp *http.Response + postErr error +} + +// PreRoundTrip is a no-op that satisfies the legacy Interceptor method; the +// middleware never calls it when PreRoundTripE is present. +func (*abortingInterceptor) PreRoundTrip(*http.Request) func(*http.Response, error) { + return nil +} + +func (a *abortingInterceptor) PreRoundTripE(req *http.Request) (func(*http.Response, error), error) { + a.preECalled = true + if a.nilPost { + return nil, a.reason + } + return func(resp *http.Response, err error) { + a.postCalled = true + a.postResp = resp + a.postErr = err + }, a.reason +} + +func TestExtensionMiddleware_PreRoundTripEAbort(t *testing.T) { + innerErr := errors.New("denied by policy") + + t.Run("skips base and wires AbortError fields", func(t *testing.T) { + ic := &abortingInterceptor{reason: innerErr} + baseCalls := 0 + base := roundTripFunc(func(*http.Request) (*http.Response, error) { + baseCalls++ + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + mid := &extensionMiddleware{Base: base, Ext: ic, ExtName: "stub"} + req, _ := http.NewRequest("GET", "http://example.invalid/", nil) + resp, err := mid.RoundTrip(req) + + if resp != nil { + t.Fatalf("resp = %v, want nil on abort", resp) + } + if baseCalls != 0 { + t.Fatalf("base RoundTrip called %d times on abort, want 0", baseCalls) + } + if !ic.preECalled { + t.Fatal("PreRoundTripE was not called") + } + + var aErr *exttransport.AbortError + if !errors.As(err, &aErr) { + t.Fatalf("errors.As(*AbortError) = false, err = %v (%T)", err, err) + } + if aErr.Extension != "stub" || aErr.Reason != innerErr { + t.Fatalf("AbortError = %+v, want {Extension:stub Reason:%v}", aErr, innerErr) + } + + // Post must see the original inner err, not the *AbortError wrapper. + if !ic.postCalled { + t.Fatal("post hook was not called on abort") + } + if ic.postResp != nil { + t.Fatalf("post resp = %v, want nil", ic.postResp) + } + if ic.postErr != innerErr { + t.Fatalf("post err = %v, want original inner err %v", ic.postErr, innerErr) + } + }) + + t.Run("nil post still returns AbortError", func(t *testing.T) { + ic := &abortingInterceptor{reason: innerErr, nilPost: true} + base := roundTripFunc(func(*http.Request) (*http.Response, error) { + t.Fatal("base must not be called on abort") + return nil, nil + }) + + mid := &extensionMiddleware{Base: base, Ext: ic, ExtName: "stub"} + req, _ := http.NewRequest("GET", "http://example.invalid/", nil) + _, err := mid.RoundTrip(req) + + var aErr *exttransport.AbortError + if !errors.As(err, &aErr) { + t.Fatalf("errors.As(*AbortError) = false, err = %v", err) + } + }) +} + +func TestExtensionMiddleware_PreRoundTripEHappyPath(t *testing.T) { + ic := &abortingInterceptor{} // reason == nil → no abort + baseCalls := 0 + base := roundTripFunc(func(*http.Request) (*http.Response, error) { + baseCalls++ + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + mid := &extensionMiddleware{Base: base, Ext: ic, ExtName: "stub"} + req, _ := http.NewRequest("GET", "http://example.invalid/", nil) + resp, err := mid.RoundTrip(req) + if err != nil { + t.Fatalf("happy path returned err: %v", err) + } + if resp == nil || resp.StatusCode != http.StatusOK { + t.Fatalf("resp = %v, want 200", resp) + } + if baseCalls != 1 { + t.Fatalf("base RoundTrip called %d times, want 1", baseCalls) + } + if !ic.preECalled { + t.Fatal("PreRoundTripE was not called") + } + if !ic.postCalled || ic.postErr != nil { + t.Fatalf("post hook not called or err != nil: called=%v err=%v", ic.postCalled, ic.postErr) + } +} diff --git a/internal/update/update.go b/internal/update/update.go index c87bb59fc..39052d9ef 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -61,7 +61,7 @@ func httpClient() *http.Client { } return &http.Client{ Timeout: fetchTimeout, - Transport: util.NewBaseTransport(), + Transport: util.SharedTransport(), } } diff --git a/internal/util/proxy.go b/internal/util/proxy.go index 64308da85..d9e251859 100644 --- a/internal/util/proxy.go +++ b/internal/util/proxy.go @@ -72,31 +72,47 @@ func WarnIfProxied(w io.Writer) { }) } -// NewBaseTransport creates an *http.Transport cloned from http.DefaultTransport. -// If LARK_CLI_NO_PROXY is set, proxy support is disabled. -// Each call returns a new instance; use FallbackTransport for a shared singleton. -func NewBaseTransport() *http.Transport { +// noProxyTransport is a proxy-disabled clone of http.DefaultTransport, +// lazily built the first time LARK_CLI_NO_PROXY is observed set. +var noProxyTransport = sync.OnceValue(func() *http.Transport { def, ok := http.DefaultTransport.(*http.Transport) if !ok { return &http.Transport{} } t := def.Clone() + t.Proxy = nil + return t +}) + +// SharedTransport returns the base http.RoundTripper for CLI HTTP clients. +// +// By default it returns http.DefaultTransport — the stdlib-provided +// process-wide singleton — so every HTTP client in the process shares one +// TCP connection pool, TLS session cache, and HTTP/2 state. When +// LARK_CLI_NO_PROXY is set it returns a separate proxy-disabled singleton +// clone; LARK_CLI_NO_PROXY is checked on every call, but the clone is built +// at most once. +// +// The returned RoundTripper MUST NOT be mutated. Callers that need a +// customized transport should assert to *http.Transport and Clone() it. +// Using a shared base is required so persistConn readLoop/writeLoop +// goroutines are reused; cloning per call leaks them until IdleConnTimeout +// (~90s) fires. +func SharedTransport() http.RoundTripper { if os.Getenv(EnvNoProxy) != "" { - t.Proxy = nil + return noProxyTransport() } - return t + return http.DefaultTransport } -// fallbackTransport is a lazily-initialized singleton used by transport -// decorators when their Base field is nil, preserving connection pooling. -var fallbackTransport = sync.OnceValue(func() *http.Transport { - return NewBaseTransport() -}) - -// FallbackTransport returns a shared *http.Transport singleton suitable for -// use as a fallback when a transport decorator's Base is nil. -// Unlike NewBaseTransport (which clones per call), this reuses a single -// instance so that TCP connections and TLS sessions are pooled. +// FallbackTransport returns a shared *http.Transport singleton. It is a +// thin wrapper over SharedTransport retained so modules that were already +// on the leak-free singleton path (internal/auth, internal/cmdutil +// transport decorators) do not have to migrate. New code should prefer +// SharedTransport and treat the base as an http.RoundTripper. func FallbackTransport() *http.Transport { - return fallbackTransport() + if t, ok := SharedTransport().(*http.Transport); ok { + return t + } + return noProxyTransport() } diff --git a/internal/util/proxy_test.go b/internal/util/proxy_test.go index daf1b7ab4..f78720963 100644 --- a/internal/util/proxy_test.go +++ b/internal/util/proxy_test.go @@ -28,19 +28,65 @@ func TestDetectProxyEnv(t *testing.T) { } } -func TestNewBaseTransport_Default(t *testing.T) { +func TestSharedTransport_DefaultReturnsStdlibSingleton(t *testing.T) { t.Setenv(EnvNoProxy, "") - tr := NewBaseTransport() - if tr.Proxy == nil { - t.Error("expected proxy func to be set when LARK_CLI_NO_PROXY is not set") + tr := SharedTransport() + if tr != http.DefaultTransport { + t.Error("SharedTransport should return http.DefaultTransport when LARK_CLI_NO_PROXY is unset") } } -func TestNewBaseTransport_NoProxy(t *testing.T) { +func TestSharedTransport_NoProxyReturnsClone(t *testing.T) { t.Setenv(EnvNoProxy, "1") - tr := NewBaseTransport() - if tr.Proxy != nil { - t.Error("expected proxy func to be nil when LARK_CLI_NO_PROXY=1") + tr := SharedTransport() + if tr == http.DefaultTransport { + t.Fatal("SharedTransport should return a clone, not DefaultTransport, when LARK_CLI_NO_PROXY is set") + } + ht, ok := tr.(*http.Transport) + if !ok { + t.Fatalf("expected *http.Transport, got %T", tr) + } + if ht.Proxy != nil { + t.Error("no-proxy transport should have Proxy == nil") + } +} + +func TestSharedTransport_NoProxyIsCachedSingleton(t *testing.T) { + t.Setenv(EnvNoProxy, "1") + a := SharedTransport() + b := SharedTransport() + if a != b { + t.Error("repeated SharedTransport calls with LARK_CLI_NO_PROXY set must return the same instance") + } +} + +func TestSharedTransport_EnvUnsetAfterSetFallsBackToDefault(t *testing.T) { + // Simulate a process that first runs with LARK_CLI_NO_PROXY=1 (populating + // the no-proxy singleton), then unsets it. Subsequent calls must return + // http.DefaultTransport, NOT the cached no-proxy clone. + t.Setenv(EnvNoProxy, "1") + noProxy := SharedTransport() + if noProxy == http.DefaultTransport { + t.Fatal("precondition: first call with env set should not return DefaultTransport") + } + + t.Setenv(EnvNoProxy, "") + after := SharedTransport() + if after != http.DefaultTransport { + t.Errorf("after unsetting LARK_CLI_NO_PROXY, SharedTransport must return http.DefaultTransport, got %T (%p)", after, after) + } +} + +func TestSharedTransport_NoProxyOverridesSystemProxy(t *testing.T) { + t.Setenv("HTTPS_PROXY", "http://should-be-ignored:8888") + t.Setenv(EnvNoProxy, "1") + + ht, ok := SharedTransport().(*http.Transport) + if !ok { + t.Fatalf("expected *http.Transport, got %T", SharedTransport()) + } + if ht.Proxy != nil { + t.Error("LARK_CLI_NO_PROXY should override system proxy settings") } } @@ -156,35 +202,3 @@ func TestWarnIfProxied_RedactsCredentials(t *testing.T) { t.Errorf("warning should contain redacted proxy URL, got: %s", out) } } - -func TestNewBaseTransport_IsHTTPTransport(t *testing.T) { - t.Setenv(EnvNoProxy, "") - tr := NewBaseTransport() - - // Should be a valid *http.Transport that can be used - var rt http.RoundTripper = tr - _ = rt - - // Verify it's not the same pointer as DefaultTransport (should be a clone) - if tr == http.DefaultTransport { - t.Error("NewBaseTransport should return a clone, not DefaultTransport itself") - } -} - -func TestNewBaseTransport_RespectsNoProxyEnv(t *testing.T) { - // Simulate: user sets both system proxy and our disable flag - t.Setenv("HTTPS_PROXY", "http://should-be-ignored:8888") - t.Setenv(EnvNoProxy, "1") - - tr := NewBaseTransport() - if tr.Proxy != nil { - t.Error("LARK_CLI_NO_PROXY should override system proxy settings") - } - - // Clean up and verify proxy is restored - t.Setenv(EnvNoProxy, "") - tr2 := NewBaseTransport() - if tr2.Proxy == nil { - t.Error("proxy should be enabled when LARK_CLI_NO_PROXY is unset") - } -} From 06f552af2c3a39cabb6738f1a6fc7cb227623c78 Mon Sep 17 00:00:00 2001 From: tuxedomm <273098272+tuxedomm@users.noreply.github.com> Date: Tue, 21 Apr 2026 14:40:14 +0800 Subject: [PATCH 7/7] fix: unblock Build() zero-opts path and sidecar demo build MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two regressions surfaced on refactor/build-execute-split: 1. cmd.Build(ctx, inv) without WithIO panicked at rootCmd.SetIn/Out/Err because cfg.streams stayed nil — NewDefault normalized internally but cmd/build.go never saw the normalized value. Default cfg.streams to cmdutil.SystemIO() before the root command wires them, and add a TestBuild_NoOptions regression guard. 2. sidecar/server-demo/main.go still called cmdutil.NewDefault(inv), so `go build -tags authsidecar_demo ./sidecar/server-demo` failed with "not enough arguments". Pass nil for the new streams parameter to preserve the prior behavior (NewDefault substitutes SystemIO). Change-Id: I20227b2355cde7d19e22eba3eb841c6d8611e8a7 --- cmd/build.go | 11 +++++++---- cmd/build_api_test.go | 13 +++++++++++++ sidecar/server-demo/main.go | 2 +- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/cmd/build.go b/cmd/build.go index 92443f7a8..3f7ff05ea 100644 --- a/cmd/build.go +++ b/cmd/build.go @@ -68,10 +68,6 @@ func Build(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOpti // buildInternal is a pure assembly function: it wires the command tree from // inv and BuildOptions alone. Any state-dependent decision (disk, network, // env) belongs in the caller and must be threaded in via BuildOption. -// -// Callers must supply WithIO; buildInternal intentionally does not default -// the streams so tests and alternative entry points can't silently inherit -// os.Std*. func buildInternal(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOption) (*cmdutil.Factory, *cobra.Command) { // cfg.globals.Profile is left zero here; it's bound to the --profile // flag in RegisterGlobalFlags and filled by cobra's parse step. @@ -81,6 +77,13 @@ func buildInternal(ctx context.Context, inv cmdutil.InvocationContext, opts ...B o(cfg) } } + // Default streams when WithIO is not supplied so the root command's + // SetIn/Out/Err calls below don't deref nil. NewDefault also normalizes + // partial streams internally; keep both in sync so cfg.streams reflects + // the same values the Factory ends up using. + if cfg.streams == nil { + cfg.streams = cmdutil.SystemIO() + } f := cmdutil.NewDefault(cfg.streams, inv) if cfg.keychain != nil { diff --git a/cmd/build_api_test.go b/cmd/build_api_test.go index fa490c294..5735f1ea2 100644 --- a/cmd/build_api_test.go +++ b/cmd/build_api_test.go @@ -48,3 +48,16 @@ func TestBuild_ExternalAPI(t *testing.T) { t.Error("Build produced a root command with no subcommands") } } + +// TestBuild_NoOptions guards against regression of the nil-streams panic: +// calling Build without WithIO must fall back to SystemIO rather than +// deref nil at rootCmd.SetIn/Out/Err. +func TestBuild_NoOptions(t *testing.T) { + rootCmd := Build(context.Background(), cmdutil.InvocationContext{}) + if rootCmd == nil { + t.Fatal("Build returned nil root command") + } + if rootCmd.Use != "lark-cli" { + t.Errorf("rootCmd.Use = %q, want %q", rootCmd.Use, "lark-cli") + } +} diff --git a/sidecar/server-demo/main.go b/sidecar/server-demo/main.go index 1197b5145..7292daa26 100644 --- a/sidecar/server-demo/main.go +++ b/sidecar/server-demo/main.go @@ -99,7 +99,7 @@ func run(ctx context.Context, listen, keyFile, logFile, profile string) error { // Reuse the lark-cli credential pipeline. A production implementation // would likely source credentials from a secrets manager instead. - factory := cmdutil.NewDefault(cmdutil.InvocationContext{Profile: profile}) + factory := cmdutil.NewDefault(nil, cmdutil.InvocationContext{Profile: profile}) cfg, err := factory.Config() if err != nil { return fmt.Errorf("failed to load config: %v", err)