diff --git a/cmd/api/api.go b/cmd/api/api.go index 1fe651d03..b6383707f 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -57,6 +57,10 @@ func normalisePath(raw string) string { // NewCmdApi creates the api command. If runF is non-nil it is called instead of apiRun (test hook). func NewCmdApi(f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command { + return NewCmdApiWithContext(context.Background(), f, runF) +} + +func NewCmdApiWithContext(ctx context.Context, f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command { opts := &APIOptions{Factory: f} var asStr string @@ -79,7 +83,7 @@ func NewCmdApi(f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command cmd.Flags().StringVar(&opts.Params, "params", "", "query parameters JSON (supports - for stdin)") cmd.Flags().StringVar(&opts.Data, "data", "", "request body JSON (supports - for stdin)") - cmd.Flags().StringVar(&asStr, "as", "auto", "identity type: user | bot | auto (default)") + cmdutil.AddAPIIdentityFlag(ctx, cmd, f, &asStr) cmd.Flags().StringVarP(&opts.Output, "output", "o", "", "output file path for binary responses") cmd.Flags().BoolVar(&opts.PageAll, "page-all", false, "automatically paginate through all pages") cmd.Flags().IntVar(&opts.PageSize, "page-size", 0, "page size (0 = use API default)") @@ -96,9 +100,6 @@ func NewCmdApi(f *cmdutil.Factory, runF func(*APIOptions) error) *cobra.Command } return nil, cobra.ShellCompDirectiveNoFileComp } - _ = cmd.RegisterFlagCompletionFunc("as", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { - return []string{"user", "bot"}, cobra.ShellCompDirectiveNoFileComp - }) _ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { return []string{"json", "ndjson", "table", "csv"}, cobra.ShellCompDirectiveNoFileComp }) diff --git a/cmd/api/api_test.go b/cmd/api/api_test.go index 9f5be7022..e81cc984f 100644 --- a/cmd/api/api_test.go +++ b/cmd/api/api_test.go @@ -180,6 +180,24 @@ func TestApiValidArgsFunction(t *testing.T) { } } +func TestNewCmdApi_StrictModeHidesAsFlag(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, SupportedIdentities: 2, + }) + + cmd := NewCmdApi(f, nil) + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if !flag.Hidden { + t.Fatal("expected --as flag to be hidden in strict mode") + } + if got := flag.DefValue; got != "bot" { + t.Fatalf("default value = %q, want %q", got, "bot") + } +} + func TestApiCmd_PageLimitDefault(t *testing.T) { f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, diff --git a/cmd/build.go b/cmd/build.go new file mode 100644 index 000000000..3f7ff05ea --- /dev/null +++ b/cmd/build.go @@ -0,0 +1,129 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "context" + "io" + + "github.com/larksuite/cli/cmd/api" + "github.com/larksuite/cli/cmd/auth" + "github.com/larksuite/cli/cmd/completion" + cmdconfig "github.com/larksuite/cli/cmd/config" + "github.com/larksuite/cli/cmd/doctor" + "github.com/larksuite/cli/cmd/profile" + "github.com/larksuite/cli/cmd/schema" + "github.com/larksuite/cli/cmd/service" + cmdupdate "github.com/larksuite/cli/cmd/update" + "github.com/larksuite/cli/internal/build" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/keychain" + "github.com/larksuite/cli/shortcuts" + "github.com/spf13/cobra" +) + +// BuildOption configures optional aspects of the command tree construction. +type BuildOption func(*buildConfig) + +type buildConfig struct { + streams *cmdutil.IOStreams + keychain keychain.KeychainAccess + globals GlobalOptions +} + +// WithIO sets the IO streams for the CLI by wrapping raw reader/writers. +// Terminal detection is delegated to cmdutil.NewIOStreams. +func WithIO(in io.Reader, out, errOut io.Writer) BuildOption { + return func(c *buildConfig) { + c.streams = cmdutil.NewIOStreams(in, out, errOut) + } +} + +// WithKeychain sets the secret storage backend. If not provided, the platform keychain is used. +func WithKeychain(kc keychain.KeychainAccess) BuildOption { + return func(c *buildConfig) { + c.keychain = kc + } +} + +// HideProfile sets the visibility policy for the root-level --profile flag. +// When hide is true the flag stays registered (so existing invocations still +// parse) but is omitted from help and shell completion. Typically called as +// HideProfile(isSingleAppMode()). +func HideProfile(hide bool) BuildOption { + return func(c *buildConfig) { + c.globals.HideProfile = hide + } +} + +// Build constructs the full command tree without executing. +// Returns only the cobra.Command; Factory is internal. +// Use Execute for the standard production entry point. +func Build(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOption) *cobra.Command { + _, rootCmd := buildInternal(ctx, inv, opts...) + return rootCmd +} + +// buildInternal is a pure assembly function: it wires the command tree from +// inv and BuildOptions alone. Any state-dependent decision (disk, network, +// env) belongs in the caller and must be threaded in via BuildOption. +func buildInternal(ctx context.Context, inv cmdutil.InvocationContext, opts ...BuildOption) (*cmdutil.Factory, *cobra.Command) { + // cfg.globals.Profile is left zero here; it's bound to the --profile + // flag in RegisterGlobalFlags and filled by cobra's parse step. + cfg := &buildConfig{} + for _, o := range opts { + if o != nil { + o(cfg) + } + } + // Default streams when WithIO is not supplied so the root command's + // SetIn/Out/Err calls below don't deref nil. NewDefault also normalizes + // partial streams internally; keep both in sync so cfg.streams reflects + // the same values the Factory ends up using. + if cfg.streams == nil { + cfg.streams = cmdutil.SystemIO() + } + + f := cmdutil.NewDefault(cfg.streams, inv) + if cfg.keychain != nil { + f.Keychain = cfg.keychain + } + rootCmd := &cobra.Command{ + Use: "lark-cli", + Short: "Lark/Feishu CLI — OAuth authorization, UAT management, API calls", + Long: rootLong, + Version: build.Version, + } + + rootCmd.SetContext(ctx) + rootCmd.SetIn(cfg.streams.In) + rootCmd.SetOut(cfg.streams.Out) + rootCmd.SetErr(cfg.streams.ErrOut) + + installTipsHelpFunc(rootCmd) + rootCmd.SilenceErrors = true + + RegisterGlobalFlags(rootCmd.PersistentFlags(), &cfg.globals) + rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + cmd.SilenceUsage = true + } + + rootCmd.AddCommand(cmdconfig.NewCmdConfig(f)) + rootCmd.AddCommand(auth.NewCmdAuth(f)) + rootCmd.AddCommand(profile.NewCmdProfile(f)) + rootCmd.AddCommand(doctor.NewCmdDoctor(f)) + rootCmd.AddCommand(api.NewCmdApiWithContext(ctx, f, nil)) + rootCmd.AddCommand(schema.NewCmdSchema(f, nil)) + rootCmd.AddCommand(completion.NewCmdCompletion(f)) + rootCmd.AddCommand(cmdupdate.NewCmdUpdate(f)) + service.RegisterServiceCommandsWithContext(ctx, rootCmd, f) + shortcuts.RegisterShortcutsWithContext(ctx, rootCmd, f) + + // Prune commands incompatible with strict mode. + if mode := f.ResolveStrictMode(ctx); mode.IsActive() { + pruneForStrictMode(rootCmd, mode) + } + + return f, rootCmd +} diff --git a/cmd/build_api_test.go b/cmd/build_api_test.go new file mode 100644 index 000000000..5735f1ea2 --- /dev/null +++ b/cmd/build_api_test.go @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "bytes" + "context" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/vfs" +) + +// noopKeychain is a zero-side-effect KeychainAccess for exercising +// WithKeychain without touching the platform keychain. +type noopKeychain struct{} + +func (noopKeychain) Get(service, account string) (string, error) { return "", nil } +func (noopKeychain) Set(service, account, value string) error { return nil } +func (noopKeychain) Remove(service, account string) error { return nil } + +// TestBuild_ExternalAPI asserts the library surface that external consumers +// (e.g. cli-server) depend on: Build composes a root command from an +// InvocationContext plus BuildOptions (WithIO, WithKeychain, HideProfile), +// and SetDefaultFS swaps the global VFS. This test is the contract guard. +func TestBuild_ExternalAPI(t *testing.T) { + // Exercise SetDefaultFS both directions. Passing nil restores the OS FS. + SetDefaultFS(vfs.OsFs{}) + SetDefaultFS(nil) + + var in, out, errOut bytes.Buffer + rootCmd := Build( + context.Background(), + cmdutil.InvocationContext{}, + WithIO(&in, &out, &errOut), + WithKeychain(noopKeychain{}), + HideProfile(true), + ) + + if rootCmd == nil { + t.Fatal("Build returned nil root command") + } + if rootCmd.Use != "lark-cli" { + t.Errorf("rootCmd.Use = %q, want %q", rootCmd.Use, "lark-cli") + } + if len(rootCmd.Commands()) == 0 { + t.Error("Build produced a root command with no subcommands") + } +} + +// TestBuild_NoOptions guards against regression of the nil-streams panic: +// calling Build without WithIO must fall back to SystemIO rather than +// deref nil at rootCmd.SetIn/Out/Err. +func TestBuild_NoOptions(t *testing.T) { + rootCmd := Build(context.Background(), cmdutil.InvocationContext{}) + if rootCmd == nil { + t.Fatal("Build returned nil root command") + } + if rootCmd.Use != "lark-cli" { + t.Errorf("rootCmd.Use = %q, want %q", rootCmd.Use, "lark-cli") + } +} diff --git a/cmd/global_flags.go b/cmd/global_flags.go index d634cc4fd..b77e8f189 100644 --- a/cmd/global_flags.go +++ b/cmd/global_flags.go @@ -3,15 +3,38 @@ package cmd -import "github.com/spf13/pflag" +import ( + "github.com/larksuite/cli/internal/core" + "github.com/spf13/pflag" +) // GlobalOptions are the root-level flags shared by bootstrap parsing and the -// actual Cobra command tree. +// actual Cobra command tree. Profile is the parsed --profile value; HideProfile +// is a build-time policy — when true, --profile stays parseable but is marked +// hidden from help and shell completion. type GlobalOptions struct { - Profile string + Profile string + HideProfile bool } -// RegisterGlobalFlags registers the root-level persistent flags. +// RegisterGlobalFlags registers the root-level persistent flags on fs and +// applies any visibility policy encoded in opts. Pure function: no disk, +// network, or environment reads — the caller decides HideProfile. func RegisterGlobalFlags(fs *pflag.FlagSet, opts *GlobalOptions) { fs.StringVar(&opts.Profile, "profile", "", "use a specific profile") + if opts.HideProfile { + _ = fs.MarkHidden("profile") + } +} + +// isSingleAppMode reports whether the on-disk config has at most one app. +// Missing configs are treated as single-app since --profile is meaningless +// until at least two profiles exist. Intended for the Execute entry point — +// buildInternal must not call this directly to stay state-free. +func isSingleAppMode() bool { + raw, err := core.LoadMultiAppConfig() + if err != nil || raw == nil { + return true + } + return len(raw.Apps) <= 1 } diff --git a/cmd/global_flags_test.go b/cmd/global_flags_test.go new file mode 100644 index 000000000..c24d1573a --- /dev/null +++ b/cmd/global_flags_test.go @@ -0,0 +1,110 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "context" + "os" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/spf13/pflag" +) + +func testStreams() BuildOption { return WithIO(os.Stdin, os.Stdout, os.Stderr) } + +func TestRegisterGlobalFlags_PolicyVisible(t *testing.T) { + fs := pflag.NewFlagSet("test", pflag.ContinueOnError) + opts := &GlobalOptions{} + RegisterGlobalFlags(fs, opts) + + flag := fs.Lookup("profile") + if flag == nil { + t.Fatal("profile flag should be registered") + } + if flag.Hidden { + t.Fatal("profile flag should be visible when HideProfile is false") + } +} + +func TestRegisterGlobalFlags_PolicyHidden(t *testing.T) { + fs := pflag.NewFlagSet("test", pflag.ContinueOnError) + opts := &GlobalOptions{HideProfile: true} + RegisterGlobalFlags(fs, opts) + + flag := fs.Lookup("profile") + if flag == nil { + t.Fatal("profile flag should be registered") + } + if !flag.Hidden { + t.Fatal("profile flag should be hidden when HideProfile is true") + } + if err := fs.Parse([]string{"--profile", "x"}); err != nil { + t.Fatalf("Parse() error = %v; hidden flag should still parse", err) + } + if opts.Profile != "x" { + t.Fatalf("opts.Profile = %q, want %q", opts.Profile, "x") + } +} + +func TestIsSingleAppMode_NoConfig(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + if !isSingleAppMode() { + t.Fatal("isSingleAppMode() = false, want true when no config exists") + } +} + +func TestIsSingleAppMode_SingleApp(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + saveAppsForTest(t, []core.AppConfig{ + {Name: "default", AppId: "cli_a", AppSecret: core.PlainSecret("x"), Brand: core.BrandFeishu}, + }) + if !isSingleAppMode() { + t.Fatal("isSingleAppMode() = false, want true for single-app config") + } +} + +func TestIsSingleAppMode_MultiApp(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + saveAppsForTest(t, []core.AppConfig{ + {Name: "a", AppId: "cli_a", AppSecret: core.PlainSecret("x"), Brand: core.BrandFeishu}, + {Name: "b", AppId: "cli_b", AppSecret: core.PlainSecret("y"), Brand: core.BrandFeishu}, + }) + if isSingleAppMode() { + t.Fatal("isSingleAppMode() = true, want false for multi-app config") + } +} + +func TestBuildInternal_HideProfileOption(t *testing.T) { + _, root := buildInternal(context.Background(), cmdutil.InvocationContext{}, testStreams(), HideProfile(true)) + + flag := root.PersistentFlags().Lookup("profile") + if flag == nil { + t.Fatal("profile flag should be registered") + } + if !flag.Hidden { + t.Fatal("profile flag should be hidden when HideProfile(true) is applied") + } +} + +func TestBuildInternal_DefaultShowsProfileFlag(t *testing.T) { + _, root := buildInternal(context.Background(), cmdutil.InvocationContext{}, testStreams()) + + flag := root.PersistentFlags().Lookup("profile") + if flag == nil { + t.Fatal("profile flag should be registered by default") + } + if flag.Hidden { + t.Fatal("profile flag should be visible by default") + } +} + +func saveAppsForTest(t *testing.T, apps []core.AppConfig) { + t.Helper() + multi := &core.MultiAppConfig{CurrentApp: apps[0].Name, Apps: apps} + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } +} diff --git a/cmd/init.go b/cmd/init.go new file mode 100644 index 000000000..d9093eabf --- /dev/null +++ b/cmd/init.go @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import "github.com/larksuite/cli/internal/vfs" + +// SetDefaultFS replaces the global filesystem implementation used by internal +// packages. The provided fs must implement the vfs.FS interface. If fs is nil, +// the default OS filesystem is restored. +// +// Call this before Build or Execute to take effect. +func SetDefaultFS(fs vfs.FS) { + if fs == nil { + fs = vfs.OsFs{} + } + vfs.DefaultFS = fs +} diff --git a/cmd/root.go b/cmd/root.go index dca93f7cb..57c91d8b1 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -14,15 +14,6 @@ import ( "os" "strconv" - "github.com/larksuite/cli/cmd/api" - "github.com/larksuite/cli/cmd/auth" - "github.com/larksuite/cli/cmd/completion" - cmdconfig "github.com/larksuite/cli/cmd/config" - "github.com/larksuite/cli/cmd/doctor" - "github.com/larksuite/cli/cmd/profile" - "github.com/larksuite/cli/cmd/schema" - "github.com/larksuite/cli/cmd/service" - cmdupdate "github.com/larksuite/cli/cmd/update" internalauth "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/build" "github.com/larksuite/cli/internal/cmdutil" @@ -30,7 +21,6 @@ import ( "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/registry" "github.com/larksuite/cli/internal/update" - "github.com/larksuite/cli/shortcuts" "github.com/spf13/cobra" ) @@ -95,38 +85,11 @@ func Execute() int { fmt.Fprintln(os.Stderr, "Error:", err) return 1 } - f := cmdutil.NewDefault(inv) - - globals := &GlobalOptions{Profile: inv.Profile} - rootCmd := &cobra.Command{ - Use: "lark-cli", - Short: "Lark/Feishu CLI — OAuth authorization, UAT management, API calls", - Long: rootLong, - Version: build.Version, - } - installTipsHelpFunc(rootCmd) - rootCmd.SilenceErrors = true - - RegisterGlobalFlags(rootCmd.PersistentFlags(), globals) - rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { - cmd.SilenceUsage = true - } - - rootCmd.AddCommand(cmdconfig.NewCmdConfig(f)) - rootCmd.AddCommand(auth.NewCmdAuth(f)) - rootCmd.AddCommand(profile.NewCmdProfile(f)) - rootCmd.AddCommand(doctor.NewCmdDoctor(f)) - rootCmd.AddCommand(api.NewCmdApi(f, nil)) - rootCmd.AddCommand(schema.NewCmdSchema(f, nil)) - rootCmd.AddCommand(completion.NewCmdCompletion(f)) - rootCmd.AddCommand(cmdupdate.NewCmdUpdate(f)) - service.RegisterServiceCommands(rootCmd, f) - shortcuts.RegisterShortcuts(rootCmd, f) - - // Prune commands incompatible with strict mode. - if mode := f.ResolveStrictMode(context.Background()); mode.IsActive() { - pruneForStrictMode(rootCmd, mode) - } + f, rootCmd := buildInternal( + context.Background(), inv, + WithIO(os.Stdin, os.Stdout, os.Stderr), + HideProfile(isSingleAppMode()), + ) // --- Update check (non-blocking) --- if !isCompletionCommand(os.Args) { @@ -277,10 +240,19 @@ func writeSecurityPolicyError(w io.Writer, spErr *internalauth.SecurityPolicyErr } // installTipsHelpFunc wraps the default help function to append a TIPS section -// when a command has tips set via cmdutil.SetTips. +// when a command has tips set via cmdutil.SetTips. It also force-shows global +// flags that are normally hidden in single-app mode (currently --profile) +// when rendering the root command's own help, so users discovering the CLI +// still see them at `lark-cli --help`. func installTipsHelpFunc(root *cobra.Command) { defaultHelp := root.HelpFunc() root.SetHelpFunc(func(cmd *cobra.Command, args []string) { + if cmd == root { + if f := root.PersistentFlags().Lookup("profile"); f != nil && f.Hidden { + f.Hidden = false + defer func() { f.Hidden = true }() + } + } defaultHelp(cmd, args) tips := cmdutil.GetTips(cmd) if len(tips) == 0 { diff --git a/cmd/root_integration_test.go b/cmd/root_integration_test.go index 555018738..e2521365d 100644 --- a/cmd/root_integration_test.go +++ b/cmd/root_integration_test.go @@ -135,10 +135,12 @@ func newStrictModeDefaultFactory(t *testing.T, profile string, mode core.StrictM t.Fatalf("SaveMultiAppConfig() error = %v", err) } - f := cmdutil.NewDefault(cmdutil.InvocationContext{Profile: profile}) stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} - f.IOStreams = &cmdutil.IOStreams{In: nil, Out: stdout, ErrOut: stderr} + f := cmdutil.NewDefault( + cmdutil.NewIOStreams(&bytes.Buffer{}, stdout, stderr), + cmdutil.InvocationContext{Profile: profile}, + ) return f, stdout, stderr } diff --git a/cmd/schema/schema.go b/cmd/schema/schema.go index f45d6e816..152f37d24 100644 --- a/cmd/schema/schema.go +++ b/cmd/schema/schema.go @@ -4,12 +4,14 @@ package schema import ( + "context" "fmt" "io" "sort" "strings" "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/registry" "github.com/larksuite/cli/internal/util" @@ -19,6 +21,7 @@ import ( // SchemaOptions holds all inputs for the schema command. type SchemaOptions struct { Factory *cmdutil.Factory + Ctx context.Context // Positional args Path string @@ -41,7 +44,7 @@ func printServices(w io.Writer) { fmt.Fprintf(w, "\n%sUsage: lark-cli schema ..%s\n", output.Dim, output.Reset) } -func printResourceList(w io.Writer, spec map[string]interface{}) { +func printResourceList(w io.Writer, spec map[string]interface{}, mode core.StrictMode) { name := registry.GetStrFromMap(spec, "name") version := registry.GetStrFromMap(spec, "version") title := registry.GetStrFromMap(spec, "title") @@ -55,9 +58,13 @@ func printResourceList(w io.Writer, spec map[string]interface{}) { resources, _ := spec["resources"].(map[string]interface{}) for _, resName := range sortedKeys(resources) { - fmt.Fprintf(w, " %s%s%s\n", output.Cyan, resName, output.Reset) resMap, _ := resources[resName].(map[string]interface{}) methods, _ := resMap["methods"].(map[string]interface{}) + methods = filterMethodsByStrictMode(methods, mode) + if len(methods) == 0 { + continue + } + fmt.Fprintf(w, " %s%s%s\n", output.Cyan, resName, output.Reset) for _, methodName := range sortedKeys(methods) { m, _ := methods[methodName].(map[string]interface{}) httpMethod := registry.GetStrFromMap(m, "httpMethod") @@ -359,6 +366,7 @@ func NewCmdSchema(f *cmdutil.Factory, runF func(*SchemaOptions) error) *cobra.Co if len(args) > 0 { opts.Path = args[0] } + opts.Ctx = cmd.Context() if runF != nil { return runF(opts) } @@ -367,7 +375,7 @@ func NewCmdSchema(f *cmdutil.Factory, runF func(*SchemaOptions) error) *cobra.Co } cmdutil.DisableAuthCheck(cmd) - cmd.ValidArgsFunction = completeSchemaPath + cmd.ValidArgsFunction = completeSchemaPath(f) cmd.Flags().StringVar(&opts.Format, "format", "json", "output format: json (default) | pretty") _ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { return []string{"json", "pretty"}, cobra.ShellCompDirectiveNoFileComp @@ -379,78 +387,86 @@ func NewCmdSchema(f *cmdutil.Factory, runF func(*SchemaOptions) error) *cobra.Co // completeSchemaPath provides tab-completion for the schema path argument. // It handles dotted resource names (e.g. app.table.fields) by iterating all // resources and classifying each as a prefix-match or fully-matched. -func completeSchemaPath(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - if len(args) > 0 { - return nil, cobra.ShellCompDirectiveNoFileComp - } +func completeSchemaPath(f *cmdutil.Factory) func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { + return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + if len(args) > 0 { + return nil, cobra.ShellCompDirectiveNoFileComp + } - parts := strings.Split(toComplete, ".") + parts := strings.Split(toComplete, ".") - // Level 1: complete service names - if len(parts) <= 1 { - var completions []string - for _, s := range registry.ListFromMetaProjects() { - if strings.HasPrefix(s, toComplete) { - completions = append(completions, s+".") + // Level 1: complete service names + if len(parts) <= 1 { + var completions []string + for _, s := range registry.ListFromMetaProjects() { + if strings.HasPrefix(s, toComplete) { + completions = append(completions, s+".") + } } + return completions, cobra.ShellCompDirectiveNoFileComp | cobra.ShellCompDirectiveNoSpace } - return completions, cobra.ShellCompDirectiveNoFileComp | cobra.ShellCompDirectiveNoSpace - } - serviceName := parts[0] - spec := registry.LoadFromMeta(serviceName) - if spec == nil { - return nil, cobra.ShellCompDirectiveNoFileComp - } - resources, _ := spec["resources"].(map[string]interface{}) - if resources == nil { - return nil, cobra.ShellCompDirectiveNoFileComp - } + serviceName := parts[0] + spec := registry.LoadFromMeta(serviceName) + if spec == nil { + return nil, cobra.ShellCompDirectiveNoFileComp + } + mode := f.ResolveStrictMode(cmd.Context()) + spec = filterSpecByStrictMode(spec, mode) + resources, _ := spec["resources"].(map[string]interface{}) + if resources == nil { + return nil, cobra.ShellCompDirectiveNoFileComp + } + + afterService := strings.Join(parts[1:], ".") + completions := completeSchemaPathForSpec(serviceName, resources, afterService) - // afterService = everything user typed after "serviceName." - afterService := strings.Join(parts[1:], ".") + allTrailingDot := len(completions) > 0 + for _, c := range completions { + if !strings.HasSuffix(c, ".") { + allTrailingDot = false + break + } + } + directive := cobra.ShellCompDirectiveNoFileComp + if allTrailingDot { + directive |= cobra.ShellCompDirectiveNoSpace + } + return completions, directive + } +} +func completeSchemaPathForSpec(serviceName string, resources map[string]interface{}, afterService string) []string { var completions []string for resName, resVal := range resources { if strings.HasPrefix(resName, afterService) { - // afterService is a prefix of this resource name → resource candidate completions = append(completions, serviceName+"."+resName+".") - } else if strings.HasPrefix(afterService, resName+".") { - // This resource is fully matched; remainder is method prefix - methodPrefix := afterService[len(resName)+1:] - resMap, _ := resVal.(map[string]interface{}) - if resMap == nil { - continue - } - methods, _ := resMap["methods"].(map[string]interface{}) - for methodName := range methods { - if strings.HasPrefix(methodName, methodPrefix) { - completions = append(completions, serviceName+"."+resName+"."+methodName) - } + continue + } + if !strings.HasPrefix(afterService, resName+".") { + continue + } + methodPrefix := afterService[len(resName)+1:] + resMap, _ := resVal.(map[string]interface{}) + if resMap == nil { + continue + } + methods, _ := resMap["methods"].(map[string]interface{}) + for methodName := range methods { + if strings.HasPrefix(methodName, methodPrefix) { + completions = append(completions, serviceName+"."+resName+"."+methodName) } } } sort.Strings(completions) - - // If all completions end with ".", user is still navigating resources → NoSpace - allTrailingDot := len(completions) > 0 - for _, c := range completions { - if !strings.HasSuffix(c, ".") { - allTrailingDot = false - break - } - } - directive := cobra.ShellCompDirectiveNoFileComp - if allTrailingDot { - directive |= cobra.ShellCompDirectiveNoSpace - } - return completions, directive + return completions } func schemaRun(opts *SchemaOptions) error { out := opts.Factory.IOStreams.Out + mode := opts.Factory.ResolveStrictMode(opts.Ctx) if opts.Path == "" { printServices(out) @@ -469,9 +485,9 @@ func schemaRun(opts *SchemaOptions) error { if len(parts) == 1 { if opts.Format == "pretty" { - printResourceList(out, spec) + printResourceList(out, spec, mode) } else { - output.PrintJson(out, spec) + output.PrintJson(out, filterSpecByStrictMode(spec, mode)) } return nil } @@ -492,6 +508,7 @@ func schemaRun(opts *SchemaOptions) error { if opts.Format == "pretty" { fmt.Fprintf(out, "%s%s.%s%s\n\n", output.Bold, serviceName, resName, output.Reset) methods, _ := resource["methods"].(map[string]interface{}) + methods = filterMethodsByStrictMode(methods, mode) for _, mName := range sortedKeys(methods) { m, _ := methods[mName].(map[string]interface{}) httpMethod := registry.GetStrFromMap(m, "httpMethod") @@ -500,13 +517,26 @@ func schemaRun(opts *SchemaOptions) error { } fmt.Fprintf(out, "\n%sUsage: lark-cli schema %s.%s.%s\n", output.Dim, serviceName, resName, output.Reset) } else { - output.PrintJson(out, resource) + // For JSON output, filter methods in a copy to avoid mutating the registry. + if mode.IsActive() { + filtered := make(map[string]interface{}) + for k, v := range resource { + filtered[k] = v + } + if methods, ok := resource["methods"].(map[string]interface{}); ok { + filtered["methods"] = filterMethodsByStrictMode(methods, mode) + } + output.PrintJson(out, filtered) + } else { + output.PrintJson(out, resource) + } } return nil } methodName := remaining[0] methods, _ := resource["methods"].(map[string]interface{}) + methods = filterMethodsByStrictMode(methods, mode) method, ok := methods[methodName].(map[string]interface{}) if !ok { var mNames []string @@ -525,3 +555,67 @@ func schemaRun(opts *SchemaOptions) error { } return nil } + +// filterSpecByStrictMode returns a shallow copy of spec with each resource's methods +// filtered by strict mode. Returns the original spec when strict mode is off. +func filterSpecByStrictMode(spec map[string]interface{}, mode core.StrictMode) map[string]interface{} { + if !mode.IsActive() { + return spec + } + result := make(map[string]interface{}, len(spec)) + for k, v := range spec { + result[k] = v + } + resources, _ := spec["resources"].(map[string]interface{}) + if resources == nil { + return result + } + filteredRes := make(map[string]interface{}, len(resources)) + for resName, resVal := range resources { + resMap, ok := resVal.(map[string]interface{}) + if !ok { + continue + } + methods, _ := resMap["methods"].(map[string]interface{}) + filtered := filterMethodsByStrictMode(methods, mode) + if len(filtered) == 0 { + continue + } + resCopy := make(map[string]interface{}, len(resMap)) + for k, v := range resMap { + resCopy[k] = v + } + resCopy["methods"] = filtered + filteredRes[resName] = resCopy + } + result["resources"] = filteredRes + return result +} + +// filterMethodsByStrictMode removes methods incompatible with the active strict mode. +// Returns the original map unmodified when strict mode is off. +func filterMethodsByStrictMode(methods map[string]interface{}, mode core.StrictMode) map[string]interface{} { + if !mode.IsActive() || methods == nil { + return methods + } + token := registry.IdentityToAccessToken(string(mode.ForcedIdentity())) + filtered := make(map[string]interface{}, len(methods)) + for name, val := range methods { + m, ok := val.(map[string]interface{}) + if !ok { + continue + } + tokens, _ := m["accessTokens"].([]interface{}) + if tokens == nil { + filtered[name] = val + continue + } + for _, t := range tokens { + if ts, ok := t.(string); ok && ts == token { + filtered[name] = val + break + } + } + } + return filtered +} diff --git a/cmd/schema/schema_test.go b/cmd/schema/schema_test.go index 639822acc..da4129302 100644 --- a/cmd/schema/schema_test.go +++ b/cmd/schema/schema_test.go @@ -182,3 +182,49 @@ func TestHasFileFields(t *testing.T) { }) } } + +func TestCompleteSchemaPathForSpec(t *testing.T) { + resources := map[string]interface{}{ + "records": map[string]interface{}{ + "methods": map[string]interface{}{ + "create": map[string]interface{}{}, + "list": map[string]interface{}{}, + }, + }, + "record_permissions": map[string]interface{}{ + "methods": map[string]interface{}{ + "get": map[string]interface{}{}, + }, + }, + } + + got := completeSchemaPathForSpec("base", resources, "records.cr") + if len(got) != 1 || got[0] != "base.records.create" { + t.Fatalf("completions = %v, want [base.records.create]", got) + } + + got = completeSchemaPathForSpec("base", resources, "record") + if len(got) != 2 || got[0] != "base.record_permissions." || got[1] != "base.records." { + t.Fatalf("resource completions = %v", got) + } +} + +func TestFilterSpecByStrictMode_RemovesIncompatibleMethodsFromCompletionSource(t *testing.T) { + spec := map[string]interface{}{ + "resources": map[string]interface{}{ + "records": map[string]interface{}{ + "methods": map[string]interface{}{ + "list": map[string]interface{}{"accessTokens": []interface{}{"tenant"}}, + "create": map[string]interface{}{"accessTokens": []interface{}{"user"}}, + }, + }, + }, + } + + filtered := filterSpecByStrictMode(spec, core.StrictModeBot) + resources, _ := filtered["resources"].(map[string]interface{}) + got := completeSchemaPathForSpec("base", resources, "records.") + if len(got) != 1 || got[0] != "base.records.list" { + t.Fatalf("filtered completions = %v, want [base.records.list]", got) + } +} diff --git a/cmd/service/service.go b/cmd/service/service.go index 63b6fc6b7..808c80077 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -24,6 +24,10 @@ import ( // RegisterServiceCommands registers all service commands from from_meta specs. func RegisterServiceCommands(parent *cobra.Command, f *cmdutil.Factory) { + RegisterServiceCommandsWithContext(context.Background(), parent, f) +} + +func RegisterServiceCommandsWithContext(ctx context.Context, parent *cobra.Command, f *cmdutil.Factory) { for _, project := range registry.ListFromMetaProjects() { spec := registry.LoadFromMeta(project) if spec == nil { @@ -38,11 +42,15 @@ func RegisterServiceCommands(parent *cobra.Command, f *cmdutil.Factory) { if resources == nil { continue } - registerService(parent, spec, resources, f) + registerServiceWithContext(ctx, parent, spec, resources, f) } } func registerService(parent *cobra.Command, spec map[string]interface{}, resources map[string]interface{}, f *cmdutil.Factory) { + registerServiceWithContext(context.Background(), parent, spec, resources, f) +} + +func registerServiceWithContext(ctx context.Context, parent *cobra.Command, spec map[string]interface{}, resources map[string]interface{}, f *cmdutil.Factory) { specName := registry.GetStrFromMap(spec, "name") specDesc := registry.GetServiceDescription(specName, "en") if specDesc == "" { @@ -70,11 +78,11 @@ func registerService(parent *cobra.Command, spec map[string]interface{}, resourc if resMap == nil { continue } - registerResource(svc, spec, resName, resMap, f) + registerResourceWithContext(ctx, svc, spec, resName, resMap, f) } } -func registerResource(parent *cobra.Command, spec map[string]interface{}, name string, resource map[string]interface{}, f *cmdutil.Factory) { +func registerResourceWithContext(ctx context.Context, parent *cobra.Command, spec map[string]interface{}, name string, resource map[string]interface{}, f *cmdutil.Factory) { res := &cobra.Command{ Use: name, Short: name + " operations", @@ -87,7 +95,7 @@ func registerResource(parent *cobra.Command, spec map[string]interface{}, name s if methodMap == nil { continue } - registerMethod(res, spec, methodMap, methodName, name, f) + registerMethodWithContext(ctx, res, spec, methodMap, methodName, name, f) } } @@ -120,12 +128,16 @@ func detectFileFields(method map[string]interface{}) []string { return cmdutil.DetectFileFields(method) } -func registerMethod(parent *cobra.Command, spec map[string]interface{}, method map[string]interface{}, name string, resName string, f *cmdutil.Factory) { - parent.AddCommand(NewCmdServiceMethod(f, spec, method, name, resName, nil)) +func registerMethodWithContext(ctx context.Context, parent *cobra.Command, spec map[string]interface{}, method map[string]interface{}, name string, resName string, f *cmdutil.Factory) { + parent.AddCommand(NewCmdServiceMethodWithContext(ctx, f, spec, method, name, resName, nil)) } // NewCmdServiceMethod creates a command for a dynamically registered service method. func NewCmdServiceMethod(f *cmdutil.Factory, spec, method map[string]interface{}, name, resName string, runF func(*ServiceMethodOptions) error) *cobra.Command { + return NewCmdServiceMethodWithContext(context.Background(), f, spec, method, name, resName, runF) +} + +func NewCmdServiceMethodWithContext(ctx context.Context, f *cmdutil.Factory, spec, method map[string]interface{}, name, resName string, runF func(*ServiceMethodOptions) error) *cobra.Command { desc := registry.GetStrFromMap(method, "description") httpMethod := registry.GetStrFromMap(method, "httpMethod") specName := registry.GetStrFromMap(spec, "name") @@ -159,7 +171,7 @@ func NewCmdServiceMethod(f *cmdutil.Factory, spec, method map[string]interface{} case "POST", "PUT", "PATCH", "DELETE": cmd.Flags().StringVar(&opts.Data, "data", "", "request body JSON (supports - for stdin)") } - cmd.Flags().StringVar(&asStr, "as", "auto", "identity type: user | bot | auto (default)") + cmdutil.AddAPIIdentityFlag(ctx, cmd, f, &asStr) cmd.Flags().StringVarP(&opts.Output, "output", "o", "", "output file path for binary responses") cmd.Flags().BoolVar(&opts.PageAll, "page-all", false, "automatically paginate through all pages") cmd.Flags().IntVar(&opts.PageLimit, "page-limit", 10, "max pages to fetch with --page-all (0 = unlimited)") @@ -177,10 +189,6 @@ func NewCmdServiceMethod(f *cmdutil.Factory, spec, method map[string]interface{} cmd.Flags().StringVar(&opts.File, "file", "", "file to upload ([field=]path, supports - for stdin)") } } - - _ = cmd.RegisterFlagCompletionFunc("as", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { - return []string{"user", "bot"}, cobra.ShellCompDirectiveNoFileComp - }) _ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { return []string{"json", "ndjson", "table", "csv"}, cobra.ShellCompDirectiveNoFileComp }) diff --git a/cmd/service/service_test.go b/cmd/service/service_test.go index f9adc5d47..3a2e5a7be 100644 --- a/cmd/service/service_test.go +++ b/cmd/service/service_test.go @@ -121,6 +121,24 @@ func TestRegisterService_MergesExistingCommand(t *testing.T) { } } +func TestNewCmdServiceMethod_StrictModeHidesAsFlag(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, SupportedIdentities: 2, + }) + + cmd := NewCmdServiceMethod(f, driveSpec(), driveMethod("GET", nil), "copy", "files", nil) + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if !flag.Hidden { + t.Fatal("expected --as flag to be hidden in strict mode") + } + if got := flag.DefValue; got != "bot" { + t.Fatalf("default value = %q, want %q", got, "bot") + } +} + // ── NewCmdServiceMethod flags ── func TestNewCmdServiceMethod_GETHasNoDataFlag(t *testing.T) { diff --git a/extension/transport/errors.go b/extension/transport/errors.go new file mode 100644 index 000000000..9ebe907e7 --- /dev/null +++ b/extension/transport/errors.go @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package transport + +import ( + "errors" + "fmt" +) + +// ErrAborted is a sentinel matched by errors.Is on any extension-triggered +// round-trip abort. Callers that only need to know whether an error was +// caused by an extension interception should use: +// +// if errors.Is(err, transport.ErrAborted) { ... } +var ErrAborted = errors.New("round trip aborted by extension") + +// AbortError is returned by the built-in middleware when an AbortableInterceptor +// short-circuits a request via PreRoundTripE. It wraps the extension's original +// reason and carries the extension's Provider.Name() for traceability. +// +// Use errors.As to recover the typed error: +// +// var aErr *transport.AbortError +// if errors.As(err, &aErr) { +// log.Printf("blocked by %s: %v", aErr.Extension, aErr.Reason) +// } +// +// errors.Is(err, transport.ErrAborted) also works, and errors.Is against the +// inner reason still works via Unwrap. +type AbortError struct { + // Extension is the name of the Provider whose interceptor aborted the + // request (from Provider.Name()). May be empty if the provider did not + // supply a name. + Extension string + // Reason is the original non-nil error returned by PreRoundTripE. + Reason error +} + +func (e *AbortError) Error() string { + if e.Extension != "" { + return fmt.Sprintf("extension %q aborted round trip: %v", e.Extension, e.Reason) + } + return fmt.Sprintf("extension aborted round trip: %v", e.Reason) +} + +// Unwrap lets errors.Is / errors.As traverse to the underlying Reason. +func (e *AbortError) Unwrap() error { return e.Reason } + +// Is enables errors.Is(err, ErrAborted) at any nesting depth. +func (e *AbortError) Is(target error) bool { return target == ErrAborted } diff --git a/extension/transport/errors_test.go b/extension/transport/errors_test.go new file mode 100644 index 000000000..31932e45f --- /dev/null +++ b/extension/transport/errors_test.go @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package transport + +import ( + "errors" + "fmt" + "testing" +) + +func TestAbortError_Error(t *testing.T) { + tests := []struct { + name string + err *AbortError + want string + }{ + { + name: "with extension name", + err: &AbortError{Extension: "audit", Reason: errors.New("bad")}, + want: `extension "audit" aborted round trip: bad`, + }, + { + name: "without extension name", + err: &AbortError{Reason: errors.New("bad")}, + want: "extension aborted round trip: bad", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.err.Error(); got != tt.want { + t.Fatalf("Error() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestAbortError_Unwrap(t *testing.T) { + reason := errors.New("bad") + e := &AbortError{Reason: reason} + if got := e.Unwrap(); got != reason { + t.Fatalf("Unwrap() = %v, want %v", got, reason) + } +} + +func TestAbortError_IsErrAborted(t *testing.T) { + e := &AbortError{Reason: errors.New("bad")} + if !errors.Is(e, ErrAborted) { + t.Fatal("errors.Is(e, ErrAborted) = false, want true") + } + // Sanity: not matched by unrelated sentinels. + if errors.Is(e, errors.New("other")) { + t.Fatal("errors.Is matched unrelated sentinel") + } +} + +func TestAbortError_UnwrapReachesInnerSentinel(t *testing.T) { + // Extensions often return typed/sentinel errors; callers should still be + // able to errors.Is against those after the middleware wraps them. + innerSentinel := errors.New("policy-deny-42") + e := &AbortError{Reason: fmt.Errorf("wrapped: %w", innerSentinel)} + if !errors.Is(e, innerSentinel) { + t.Fatal("errors.Is(e, innerSentinel) = false, want true (Unwrap chain broken)") + } +} + +func TestAbortError_As(t *testing.T) { + reason := errors.New("bad") + base := &AbortError{Extension: "audit", Reason: reason} + + // Direct As. + var aErr *AbortError + if !errors.As(base, &aErr) { + t.Fatal("errors.As(base, *AbortError) = false") + } + if aErr.Extension != "audit" || aErr.Reason != reason { + t.Fatalf("aErr = %+v, want {audit, bad}", aErr) + } + + // Nested As: even when the *AbortError is wrapped in another error, + // errors.As must still find it via Unwrap chain. + wrapped := fmt.Errorf("outer: %w", base) + var aErr2 *AbortError + if !errors.As(wrapped, &aErr2) { + t.Fatal("errors.As(wrapped, *AbortError) = false") + } + if aErr2 != base { + t.Fatalf("aErr2 = %p, want %p", aErr2, base) + } + + // errors.Is still matches the sentinel through the outer wrapper. + if !errors.Is(wrapped, ErrAborted) { + t.Fatal("errors.Is(wrapped, ErrAborted) = false via nested wrap") + } +} + +func TestErrAborted_IsItselfSentinel(t *testing.T) { + // Guard against accidental re-assignment of ErrAborted: a bare ErrAborted + // value should still satisfy errors.Is(err, ErrAborted) for symmetry. + if !errors.Is(ErrAborted, ErrAborted) { + t.Fatal("errors.Is(ErrAborted, ErrAborted) = false") + } +} diff --git a/extension/transport/types.go b/extension/transport/types.go index e60c4018d..c74e36866 100644 --- a/extension/transport/types.go +++ b/extension/transport/types.go @@ -27,6 +27,31 @@ type Provider interface { // // The returned function (if non-nil) is called after the built-in chain // completes. Use it for logging, ending trace spans, or recording metrics. +// +// Body note: the middleware Clones the caller's request before invoking the +// interceptor, which copies headers/URL/etc. but shares the underlying +// io.ReadCloser. Extensions that read req.Body are responsible for restoring +// a replayable body (e.g. via req.GetBody) before returning, otherwise the +// built-in chain will see an exhausted stream. type Interceptor interface { PreRoundTrip(req *http.Request) func(resp *http.Response, err error) } + +// AbortableInterceptor is an optional extension of Interceptor that lets an +// extension reject a request before the built-in chain runs. Extensions that +// implement this interface are detected by the built-in middleware via a +// type assertion; both methods must be present, but when an extension +// implements PreRoundTripE the middleware will NOT call PreRoundTrip. +// +// Returning a non-nil error from PreRoundTripE aborts the request: the +// built-in chain is not executed and the middleware returns an *AbortError +// wrapping the reason. The returned post function (if non-nil) is still +// invoked with (nil, reason) so that extensions can unwind any state they +// created in the pre hook (spans, metrics, audit records). +// +// Extensions that only care about the abortable variant can provide a no-op +// PreRoundTrip method alongside PreRoundTripE to satisfy Interceptor. +type AbortableInterceptor interface { + Interceptor + PreRoundTripE(req *http.Request) (post func(resp *http.Response, err error), err error) +} diff --git a/internal/cmdutil/factory_default.go b/internal/cmdutil/factory_default.go index c9b4e92cf..c1dc10817 100644 --- a/internal/cmdutil/factory_default.go +++ b/internal/cmdutil/factory_default.go @@ -8,13 +8,11 @@ import ( "fmt" "io" "net/http" - "os" "sync" "time" lark "github.com/larksuite/oapi-sdk-go/v3" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" - "golang.org/x/term" extcred "github.com/larksuite/cli/extension/credential" "github.com/larksuite/cli/extension/fileio" @@ -34,27 +32,24 @@ import ( // Phase 2: Credential (sole data source for account info) // Phase 3: Config derived from Credential // Phase 4: LarkClient derived from Credential -func NewDefault(inv InvocationContext) *Factory { +func NewDefault(streams *IOStreams, inv InvocationContext) *Factory { + streams = normalizeStreams(streams) f := &Factory{ Keychain: keychain.Default(), Invocation: inv, - } - f.IOStreams = &IOStreams{ - In: os.Stdin, - Out: os.Stdout, - ErrOut: os.Stderr, - IsTerminal: term.IsTerminal(int(os.Stdin.Fd())), + IOStreams: streams, } // Phase 0: FileIO provider (no dependency) f.FileIOProvider = fileio.GetProvider() // Phase 1: HttpClient (no credential dependency) - f.HttpClient = cachedHttpClientFunc() + f.HttpClient = cachedHttpClientFunc(f) // Phase 2: Credential (sole data source) + // Keychain is read via closure so callers can replace f.Keychain after construction. f.Credential = buildCredentialProvider(credentialDeps{ - Keychain: f.Keychain, + Keychain: func() keychain.KeychainAccess { return f.Keychain }, Profile: inv.Profile, HttpClient: f.HttpClient, ErrOut: f.IOStreams.ErrOut, @@ -93,11 +88,11 @@ func safeRedirectPolicy(req *http.Request, via []*http.Request) error { return nil } -func cachedHttpClientFunc() func() (*http.Client, error) { +func cachedHttpClientFunc(f *Factory) func() (*http.Client, error) { return sync.OnceValues(func() (*http.Client, error) { - util.WarnIfProxied(os.Stderr) + util.WarnIfProxied(f.IOStreams.ErrOut) - var transport http.RoundTripper = util.NewBaseTransport() + var transport http.RoundTripper = util.SharedTransport() transport = &RetryTransport{Base: transport} transport = &SecurityHeaderTransport{Base: transport} transport = &auth.SecurityPolicyTransport{Base: transport} // Add our global response interceptor @@ -122,7 +117,7 @@ func cachedLarkClientFunc(f *Factory) func() (*lark.Client, error) { lark.WithLogLevel(larkcore.LogLevelError), lark.WithHeaders(BaseSecurityHeaders()), } - util.WarnIfProxied(os.Stderr) + util.WarnIfProxied(f.IOStreams.ErrOut) opts = append(opts, lark.WithHttpClient(&http.Client{ Transport: buildSDKTransport(), CheckRedirect: safeRedirectPolicy, @@ -134,7 +129,7 @@ func cachedLarkClientFunc(f *Factory) func() (*lark.Client, error) { } func buildSDKTransport() http.RoundTripper { - var sdkTransport http.RoundTripper = util.NewBaseTransport() + var sdkTransport http.RoundTripper = util.SharedTransport() sdkTransport = &RetryTransport{Base: sdkTransport} sdkTransport = &UserAgentTransport{Base: sdkTransport} sdkTransport = &auth.SecurityPolicyTransport{Base: sdkTransport} @@ -142,7 +137,7 @@ func buildSDKTransport() http.RoundTripper { } type credentialDeps struct { - Keychain keychain.KeychainAccess + Keychain func() keychain.KeychainAccess Profile string HttpClient func() (*http.Client, error) ErrOut io.Writer diff --git a/internal/cmdutil/factory_default_test.go b/internal/cmdutil/factory_default_test.go index fe91ead77..9ec8e7ec3 100644 --- a/internal/cmdutil/factory_default_test.go +++ b/internal/cmdutil/factory_default_test.go @@ -6,14 +6,10 @@ package cmdutil import ( "context" "errors" - "net/http" - "net/http/httptest" "testing" _ "github.com/larksuite/cli/extension/credential/env" "github.com/larksuite/cli/extension/fileio" - exttransport "github.com/larksuite/cli/extension/transport" - internalauth "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/envvars" @@ -63,7 +59,7 @@ func TestNewDefault_InvocationProfileUsedByStrictModeAndConfig(t *testing.T) { t.Fatalf("SaveMultiAppConfig() error = %v", err) } - f := NewDefault(InvocationContext{Profile: "target"}) + f := NewDefault(nil, InvocationContext{Profile: "target"}) if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeBot { t.Fatalf("ResolveStrictMode() = %q, want %q", got, core.StrictModeBot) } @@ -103,7 +99,7 @@ func TestNewDefault_InvocationProfileMissingSticksAcrossEarlyStrictMode(t *testi t.Fatalf("SaveMultiAppConfig() error = %v", err) } - f := NewDefault(InvocationContext{Profile: "missing"}) + f := NewDefault(nil, InvocationContext{Profile: "missing"}) if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeOff { t.Fatalf("ResolveStrictMode() = %q, want %q", got, core.StrictModeOff) } @@ -120,22 +116,6 @@ func TestNewDefault_InvocationProfileMissingSticksAcrossEarlyStrictMode(t *testi } } -func TestBuildSDKTransport_IncludesRetryTransport(t *testing.T) { - transport := buildSDKTransport() - - sec, ok := transport.(*internalauth.SecurityPolicyTransport) - if !ok { - t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport) - } - ua, ok := sec.Base.(*UserAgentTransport) - if !ok { - t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base) - } - if _, ok := ua.Base.(*RetryTransport); !ok { - t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base) - } -} - func TestNewDefault_ResolveAs_UsesDefaultAsFromEnvAccount(t *testing.T) { t.Setenv(envvars.CliAppID, "env-app") t.Setenv(envvars.CliAppSecret, "env-secret") @@ -144,7 +124,7 @@ func TestNewDefault_ResolveAs_UsesDefaultAsFromEnvAccount(t *testing.T) { t.Setenv(envvars.CliTenantAccessToken, "") t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) - f := NewDefault(InvocationContext{}) + f := NewDefault(nil, InvocationContext{}) cmd := newCmdWithAsFlag("auto", false) got := f.ResolveAs(context.Background(), cmd, "auto") @@ -164,7 +144,7 @@ func TestNewDefault_ConfigReturnsCliConfigCopyOfCredentialAccount(t *testing.T) t.Setenv(envvars.CliTenantAccessToken, "") t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) - f := NewDefault(InvocationContext{}) + f := NewDefault(nil, InvocationContext{}) acct, err := f.Credential.ResolveAccount(context.Background()) if err != nil { @@ -189,7 +169,7 @@ func TestNewDefault_ConfigUsesRuntimePlaceholderForTokenOnlyEnvAccount(t *testin t.Setenv(envvars.CliTenantAccessToken, "") t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) - f := NewDefault(InvocationContext{}) + f := NewDefault(nil, InvocationContext{}) acct, err := f.Credential.ResolveAccount(context.Background()) if err != nil { @@ -217,7 +197,7 @@ func TestNewDefault_FileIOProviderDoesNotResolveDuringInitialization(t *testing. fileio.Register(provider) t.Cleanup(func() { fileio.Register(prev) }) - f := NewDefault(InvocationContext{}) + f := NewDefault(nil, InvocationContext{}) if f.FileIOProvider != provider { t.Fatalf("NewDefault() provider = %T, want %T", f.FileIOProvider, provider) } @@ -232,170 +212,3 @@ func TestNewDefault_FileIOProviderDoesNotResolveDuringInitialization(t *testing. t.Fatalf("ResolveFileIO() calls after explicit resolve = %d, want 1", provider.resolveCalls) } } - -type stubTransportProvider struct { - interceptor exttransport.Interceptor -} - -func (s *stubTransportProvider) Name() string { return "stub" } -func (s *stubTransportProvider) ResolveInterceptor(context.Context) exttransport.Interceptor { - if s.interceptor != nil { - return s.interceptor - } - return &stubTransportImpl{} -} - -type stubTransportImpl struct{} - -func (s *stubTransportImpl) PreRoundTrip(req *http.Request) func(*http.Response, error) { - return nil -} - -// headerCapturingInterceptor sets custom headers in PreRoundTrip and records -// whether PostRoundTrip was called, to verify execution order. -type headerCapturingInterceptor struct { - preCalled bool - postCalled bool -} - -func (h *headerCapturingInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) { - h.preCalled = true - // Set a custom header that should survive (no built-in override) - req.Header.Set("X-Custom-Trace", "ext-trace-123") - // Try to override a security header — should be overwritten by SecurityHeaderTransport - req.Header.Set(HeaderSource, "ext-tampered") - return func(resp *http.Response, err error) { - h.postCalled = true - } -} - -func TestExtensionInterceptor_ExecutionOrder(t *testing.T) { - var receivedHeaders http.Header - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.WriteHeader(http.StatusOK) - })) - defer srv.Close() - - ic := &headerCapturingInterceptor{} - exttransport.Register(&stubTransportProvider{interceptor: ic}) - t.Cleanup(func() { exttransport.Register(nil) }) - - // Use HTTP transport chain (has SecurityHeaderTransport) - var base http.RoundTripper = http.DefaultTransport - base = &RetryTransport{Base: base} - base = &SecurityHeaderTransport{Base: base} - transport := wrapWithExtension(base) - client := &http.Client{Transport: transport} - - req, _ := http.NewRequest("GET", srv.URL, nil) - resp, err := client.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - resp.Body.Close() - - // PreRoundTrip was called - if !ic.preCalled { - t.Fatal("PreRoundTrip was not called") - } - // PostRoundTrip (closure) was called - if !ic.postCalled { - t.Fatal("PostRoundTrip closure was not called") - } - // Custom header set by extension survives (no built-in override) - if got := receivedHeaders.Get("X-Custom-Trace"); got != "ext-trace-123" { - t.Fatalf("X-Custom-Trace = %q, want %q", got, "ext-trace-123") - } - // Security header overridden by extension is restored by SecurityHeaderTransport - if got := receivedHeaders.Get(HeaderSource); got != SourceValue { - t.Fatalf("%s = %q, want %q (built-in should override extension)", HeaderSource, got, SourceValue) - } -} - -func TestExtensionInterceptor_ContextTamperPrevented(t *testing.T) { - type ctxKeyType string - const testKey ctxKeyType = "original" - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer srv.Close() - - var ctxValue any - - // Use a custom transport that captures the context value seen by the built-in chain - capturer := roundTripFunc(func(req *http.Request) (*http.Response, error) { - ctxValue = req.Context().Value(testKey) - return http.DefaultTransport.RoundTrip(req) - }) - - // Interceptor that tries to tamper with context - tamperIC := interceptorFunc(func(req *http.Request) func(*http.Response, error) { - // Try to replace context with a new one - *req = *req.WithContext(context.WithValue(req.Context(), testKey, "tampered")) - return nil - }) - - mid := &extensionMiddleware{Base: capturer, Ext: tamperIC} - - origCtx := context.WithValue(context.Background(), testKey, "original") - req, _ := http.NewRequestWithContext(origCtx, "GET", srv.URL, nil) - resp, err := mid.RoundTrip(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - resp.Body.Close() - - // Built-in chain should see original context, not tampered - if ctxValue != "original" { - t.Fatalf("built-in chain saw context value %q, want %q", ctxValue, "original") - } -} - -// interceptorFunc adapts a function to exttransport.Interceptor. -type interceptorFunc func(*http.Request) func(*http.Response, error) - -func (f interceptorFunc) PreRoundTrip(req *http.Request) func(*http.Response, error) { return f(req) } - -func TestBuildSDKTransport_WithExtension(t *testing.T) { - exttransport.Register(&stubTransportProvider{}) - t.Cleanup(func() { exttransport.Register(nil) }) - - transport := buildSDKTransport() - - // Chain: extensionMiddleware → SecurityPolicy → UserAgent → Retry → Base - mid, ok := transport.(*extensionMiddleware) - if !ok { - t.Fatalf("outer transport type = %T, want *extensionMiddleware", transport) - } - sec, ok := mid.Base.(*internalauth.SecurityPolicyTransport) - if !ok { - t.Fatalf("transport type = %T, want *auth.SecurityPolicyTransport", mid.Base) - } - ua, ok := sec.Base.(*UserAgentTransport) - if !ok { - t.Fatalf("transport type = %T, want *UserAgentTransport", sec.Base) - } - if _, ok := ua.Base.(*RetryTransport); !ok { - t.Fatalf("innermost transport type = %T, want *RetryTransport", ua.Base) - } -} - -func TestBuildSDKTransport_WithoutExtension(t *testing.T) { - exttransport.Register(nil) - - transport := buildSDKTransport() - - sec, ok := transport.(*internalauth.SecurityPolicyTransport) - if !ok { - t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport) - } - ua, ok := sec.Base.(*UserAgentTransport) - if !ok { - t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base) - } - if _, ok := ua.Base.(*RetryTransport); !ok { - t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base) - } -} diff --git a/internal/cmdutil/factory_http_test.go b/internal/cmdutil/factory_http_test.go index c27e9e696..a2b50e823 100644 --- a/internal/cmdutil/factory_http_test.go +++ b/internal/cmdutil/factory_http_test.go @@ -4,11 +4,12 @@ package cmdutil import ( + "io" "testing" ) func TestCachedHttpClientFunc_ReturnsSameInstance(t *testing.T) { - fn := cachedHttpClientFunc() + fn := cachedHttpClientFunc(&Factory{IOStreams: &IOStreams{ErrOut: io.Discard}}) c1, err := fn() if err != nil { @@ -28,7 +29,7 @@ func TestCachedHttpClientFunc_ReturnsSameInstance(t *testing.T) { } func TestCachedHttpClientFunc_HasTimeout(t *testing.T) { - fn := cachedHttpClientFunc() + fn := cachedHttpClientFunc(&Factory{IOStreams: &IOStreams{ErrOut: io.Discard}}) c, _ := fn() if c.Timeout == 0 { t.Error("expected non-zero timeout") @@ -36,7 +37,7 @@ func TestCachedHttpClientFunc_HasTimeout(t *testing.T) { } func TestCachedHttpClientFunc_HasRedirectPolicy(t *testing.T) { - fn := cachedHttpClientFunc() + fn := cachedHttpClientFunc(&Factory{IOStreams: &IOStreams{ErrOut: io.Discard}}) c, _ := fn() if c.CheckRedirect == nil { t.Error("expected CheckRedirect to be set (safeRedirectPolicy)") diff --git a/internal/cmdutil/identity_flag.go b/internal/cmdutil/identity_flag.go new file mode 100644 index 000000000..c99d5c628 --- /dev/null +++ b/internal/cmdutil/identity_flag.go @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdutil + +import ( + "context" + "fmt" + "strings" + + "github.com/spf13/cobra" +) + +// AddAPIIdentityFlag registers the standard --as flag shape used by api/service commands. +func AddAPIIdentityFlag(ctx context.Context, cmd *cobra.Command, f *Factory, target *string) { + addIdentityFlag(ctx, cmd, f, target, identityFlagConfig{ + defaultValue: "auto", + usage: "identity type: user | bot | auto (default)", + completionValues: []string{"user", "bot"}, + }) +} + +// AddShortcutIdentityFlag registers the standard --as flag shape used by shortcuts. +func AddShortcutIdentityFlag(ctx context.Context, cmd *cobra.Command, f *Factory, authTypes []string) { + if len(authTypes) == 0 { + authTypes = []string{"user"} + } + addIdentityFlag(ctx, cmd, f, nil, identityFlagConfig{ + defaultValue: authTypes[0], + usage: "identity type: " + strings.Join(authTypes, " | "), + completionValues: authTypes, + }) +} + +type identityFlagConfig struct { + defaultValue string + usage string + completionValues []string +} + +// addIdentityFlag centralizes --as registration and strict-mode UX. +// When strict mode is active, the flag is still accepted for compatibility +// but hidden from help/completion and locked to the forced identity by default. +func addIdentityFlag(ctx context.Context, cmd *cobra.Command, f *Factory, target *string, cfg identityFlagConfig) { + if forced := f.ResolveStrictMode(ctx).ForcedIdentity(); forced != "" { + // Keep registering --as in strict mode even though it is hidden. + // This preserves parser compatibility for existing invocations that still pass + // --as, and keeps downstream GetString("as") / ResolveAs paths stable. + // The usage text below is effectively placeholder text because the flag is hidden. + registerIdentityFlag(cmd, target, string(forced), + fmt.Sprintf("identity locked to %s by strict mode (admin-managed)", forced)) + _ = cmd.Flags().MarkHidden("as") + return + } + + registerIdentityFlag(cmd, target, cfg.defaultValue, cfg.usage) + _ = cmd.RegisterFlagCompletionFunc("as", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { + return cfg.completionValues, cobra.ShellCompDirectiveNoFileComp + }) +} + +func registerIdentityFlag(cmd *cobra.Command, target *string, defaultValue, usage string) { + if target != nil { + cmd.Flags().StringVar(target, "as", defaultValue, usage) + return + } + cmd.Flags().String("as", defaultValue, usage) +} diff --git a/internal/cmdutil/identity_flag_test.go b/internal/cmdutil/identity_flag_test.go new file mode 100644 index 000000000..fa93d7263 --- /dev/null +++ b/internal/cmdutil/identity_flag_test.go @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdutil + +import ( + "context" + "testing" + + "github.com/larksuite/cli/internal/core" + "github.com/spf13/cobra" +) + +func TestAddAPIIdentityFlag_NonStrictMode(t *testing.T) { + f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) + cmd := &cobra.Command{Use: "test"} + + AddAPIIdentityFlag(context.Background(), cmd, f, nil) + + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if flag.Hidden { + t.Fatal("expected --as flag to be visible outside strict mode") + } + if got := flag.DefValue; got != "auto" { + t.Fatalf("default value = %q, want %q", got, "auto") + } +} + +func TestAddAPIIdentityFlag_StrictModeHidesFlagAndLocksDefault(t *testing.T) { + f, _, _, _ := TestFactory(t, &core.CliConfig{ + AppID: "a", AppSecret: "s", SupportedIdentities: 2, + }) + cmd := &cobra.Command{Use: "test"} + + AddAPIIdentityFlag(context.Background(), cmd, f, nil) + + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if !flag.Hidden { + t.Fatal("expected --as flag to be hidden in strict mode") + } + if got := flag.DefValue; got != "bot" { + t.Fatalf("default value = %q, want %q", got, "bot") + } +} + +func TestAddShortcutIdentityFlag_UsesAuthTypes(t *testing.T) { + f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) + cmd := &cobra.Command{Use: "test"} + + AddShortcutIdentityFlag(context.Background(), cmd, f, []string{"bot"}) + + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if flag.Hidden { + t.Fatal("expected --as flag to be visible outside strict mode") + } + if got := flag.DefValue; got != "bot" { + t.Fatalf("default value = %q, want %q", got, "bot") + } +} diff --git a/internal/cmdutil/iostreams.go b/internal/cmdutil/iostreams.go index 76068e057..864300f52 100644 --- a/internal/cmdutil/iostreams.go +++ b/internal/cmdutil/iostreams.go @@ -3,7 +3,12 @@ package cmdutil -import "io" +import ( + "io" + "os" + + "golang.org/x/term" +) // IOStreams provides the standard input/output/error streams. // Commands should use these instead of os.Stdin/Stdout/Stderr @@ -14,3 +19,45 @@ type IOStreams struct { ErrOut io.Writer IsTerminal bool } + +// NewIOStreams builds an IOStreams from arbitrary readers/writers. +// IsTerminal is derived from in's underlying *os.File, if any; non-file +// readers (bytes.Buffer, strings.Reader, …) yield IsTerminal=false. +func NewIOStreams(in io.Reader, out, errOut io.Writer) *IOStreams { + isTerminal := false + if f, ok := in.(*os.File); ok { + isTerminal = term.IsTerminal(int(f.Fd())) + } + return &IOStreams{In: in, Out: out, ErrOut: errOut, IsTerminal: isTerminal} +} + +// SystemIO creates an IOStreams wired to the process's standard file descriptors. +// +//nolint:forbidigo // entry point for real stdio +func SystemIO() *IOStreams { + return NewIOStreams(os.Stdin, os.Stdout, os.Stderr) +} + +// normalizeStreams returns a fresh IOStreams with any nil field filled from +// SystemIO(). Callers constructing a partial struct like &IOStreams{Out: buf} +// get a usable result without nil writers leaking into RoundTripper warnings, +// Cobra I/O, or credential-provider error paths. +func normalizeStreams(s *IOStreams) *IOStreams { + if s == nil { + return SystemIO() + } + out := *s + if out.In == nil || out.Out == nil || out.ErrOut == nil { + sys := SystemIO() + if out.In == nil { + out.In = sys.In + } + if out.Out == nil { + out.Out = sys.Out + } + if out.ErrOut == nil { + out.ErrOut = sys.ErrOut + } + } + return &out +} diff --git a/internal/cmdutil/retry_transport_test.go b/internal/cmdutil/retry_transport_test.go deleted file mode 100644 index a10782acf..000000000 --- a/internal/cmdutil/retry_transport_test.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) 2026 Lark Technologies Pte. Ltd. -// SPDX-License-Identifier: MIT - -package cmdutil - -import ( - "io" - "net/http" - "strings" - "testing" - "time" -) - -type roundTripFunc func(*http.Request) (*http.Response, error) - -func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { - return f(req) -} - -func TestRetryTransport_NoRetry(t *testing.T) { - calls := 0 - base := roundTripFunc(func(req *http.Request) (*http.Response, error) { - calls++ - return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok"))}, nil - }) - rt := &RetryTransport{Base: base, MaxRetries: 0} - req, _ := http.NewRequest("GET", "http://example.com/test", nil) - resp, err := rt.RoundTrip(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.StatusCode != 200 { - t.Errorf("expected 200, got %d", resp.StatusCode) - } - if calls != 1 { - t.Errorf("expected 1 call, got %d", calls) - } -} - -func TestRetryTransport_RetryOn500(t *testing.T) { - calls := 0 - base := roundTripFunc(func(req *http.Request) (*http.Response, error) { - calls++ - if calls < 3 { - return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("error"))}, nil - } - return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok"))}, nil - }) - rt := &RetryTransport{Base: base, MaxRetries: 3, Delay: 1 * time.Millisecond} - req, _ := http.NewRequest("GET", "http://example.com/test", nil) - resp, err := rt.RoundTrip(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.StatusCode != 200 { - t.Errorf("expected 200 after retries, got %d", resp.StatusCode) - } - if calls != 3 { - t.Errorf("expected 3 calls, got %d", calls) - } -} - -func TestRetryTransport_DefaultNoRetry(t *testing.T) { - calls := 0 - base := roundTripFunc(func(req *http.Request) (*http.Response, error) { - calls++ - return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("error"))}, nil - }) - rt := &RetryTransport{Base: base} // default MaxRetries=0 - req, _ := http.NewRequest("GET", "http://example.com/test", nil) - resp, err := rt.RoundTrip(req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.StatusCode != 500 { - t.Errorf("expected 500 with no retries, got %d", resp.StatusCode) - } - if calls != 1 { - t.Errorf("expected 1 call with default config, got %d", calls) - } -} diff --git a/internal/cmdutil/transport.go b/internal/cmdutil/transport.go index b922adf11..627755ea6 100644 --- a/internal/cmdutil/transport.go +++ b/internal/cmdutil/transport.go @@ -104,20 +104,47 @@ func (t *SecurityHeaderTransport) RoundTrip(req *http.Request) (*http.Response, } // extensionMiddleware wraps the built-in transport chain with pre/post hooks. -// The built-in chain always executes and cannot be skipped or overridden. -// The original request context is restored after PreRoundTrip to prevent +// The built-in chain always executes unless the extension is an +// exttransport.AbortableInterceptor and its PreRoundTripE returns a non-nil +// error; it cannot otherwise be skipped or overridden. +// +// The original request context is restored after the pre hook to prevent // extensions from tampering with cancellation, deadlines, or built-in values. +// Cloning the request isolates header/URL/etc. mutations from the caller's +// request object; req.Body is intentionally shared — extensions that consume +// it are responsible for rewinding (see Interceptor doc). type extensionMiddleware struct { - Base http.RoundTripper - Ext exttransport.Interceptor + Base http.RoundTripper + Ext exttransport.Interceptor + ExtName string // Provider.Name(), captured at wrap time for *AbortError.Extension } -// RoundTrip calls PreRoundTrip, restores the original context, executes -// the built-in chain, then calls the post hook if non-nil. +// RoundTrip invokes the interceptor pre hook, restores the original context, +// executes the built-in chain (unless aborted), then calls the post hook if +// non-nil. When the extension implements AbortableInterceptor and returns a +// non-nil error from PreRoundTripE, the built-in chain is skipped and an +// *exttransport.AbortError is returned; the post hook is still invoked with +// (nil, reason) so extensions can unwind resources. func (m *extensionMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { origCtx := req.Context() - req = req.Clone(origCtx) // isolate caller's request before extension mutations - post := m.Ext.PreRoundTrip(req) + req = req.Clone(origCtx) + + var ( + post func(*http.Response, error) + abortEr error + ) + if a, ok := m.Ext.(exttransport.AbortableInterceptor); ok { + post, abortEr = a.PreRoundTripE(req) + } else { + post = m.Ext.PreRoundTrip(req) + } + if abortEr != nil { + if post != nil { + post(nil, abortEr) + } + return nil, &exttransport.AbortError{Extension: m.ExtName, Reason: abortEr} + } + req = req.WithContext(origCtx) // restore original context resp, err := m.Base.RoundTrip(req) if post != nil { @@ -137,5 +164,5 @@ func wrapWithExtension(transport http.RoundTripper) http.RoundTripper { if tr == nil { return transport } - return &extensionMiddleware{Base: transport, Ext: tr} + return &extensionMiddleware{Base: transport, Ext: tr, ExtName: p.Name()} } diff --git a/internal/cmdutil/transport_test.go b/internal/cmdutil/transport_test.go new file mode 100644 index 000000000..ced9e1e7e --- /dev/null +++ b/internal/cmdutil/transport_test.go @@ -0,0 +1,408 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdutil + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + exttransport "github.com/larksuite/cli/extension/transport" + internalauth "github.com/larksuite/cli/internal/auth" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// --------------------------------------------------------------------------- +// RetryTransport +// --------------------------------------------------------------------------- + +func TestRetryTransport_NoRetry(t *testing.T) { + calls := 0 + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + calls++ + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok"))}, nil + }) + rt := &RetryTransport{Base: base, MaxRetries: 0} + req, _ := http.NewRequest("GET", "http://example.com/test", nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if calls != 1 { + t.Errorf("expected 1 call, got %d", calls) + } +} + +func TestRetryTransport_RetryOn500(t *testing.T) { + calls := 0 + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + calls++ + if calls < 3 { + return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("error"))}, nil + } + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("ok"))}, nil + }) + rt := &RetryTransport{Base: base, MaxRetries: 3, Delay: 1 * time.Millisecond} + req, _ := http.NewRequest("GET", "http://example.com/test", nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("expected 200 after retries, got %d", resp.StatusCode) + } + if calls != 3 { + t.Errorf("expected 3 calls, got %d", calls) + } +} + +func TestRetryTransport_DefaultNoRetry(t *testing.T) { + calls := 0 + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + calls++ + return &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("error"))}, nil + }) + rt := &RetryTransport{Base: base} // default MaxRetries=0 + req, _ := http.NewRequest("GET", "http://example.com/test", nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 500 { + t.Errorf("expected 500 with no retries, got %d", resp.StatusCode) + } + if calls != 1 { + t.Errorf("expected 1 call with default config, got %d", calls) + } +} + +// --------------------------------------------------------------------------- +// buildSDKTransport chain composition +// --------------------------------------------------------------------------- + +func TestBuildSDKTransport_IncludesRetryTransport(t *testing.T) { + transport := buildSDKTransport() + + sec, ok := transport.(*internalauth.SecurityPolicyTransport) + if !ok { + t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport) + } + ua, ok := sec.Base.(*UserAgentTransport) + if !ok { + t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base) + } + if _, ok := ua.Base.(*RetryTransport); !ok { + t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base) + } +} + +func TestBuildSDKTransport_WithExtension(t *testing.T) { + exttransport.Register(&stubTransportProvider{}) + t.Cleanup(func() { exttransport.Register(nil) }) + + transport := buildSDKTransport() + + // Chain: extensionMiddleware → SecurityPolicy → UserAgent → Retry → Base + mid, ok := transport.(*extensionMiddleware) + if !ok { + t.Fatalf("outer transport type = %T, want *extensionMiddleware", transport) + } + sec, ok := mid.Base.(*internalauth.SecurityPolicyTransport) + if !ok { + t.Fatalf("transport type = %T, want *auth.SecurityPolicyTransport", mid.Base) + } + ua, ok := sec.Base.(*UserAgentTransport) + if !ok { + t.Fatalf("transport type = %T, want *UserAgentTransport", sec.Base) + } + if _, ok := ua.Base.(*RetryTransport); !ok { + t.Fatalf("innermost transport type = %T, want *RetryTransport", ua.Base) + } +} + +func TestBuildSDKTransport_WithoutExtension(t *testing.T) { + exttransport.Register(nil) + + transport := buildSDKTransport() + + sec, ok := transport.(*internalauth.SecurityPolicyTransport) + if !ok { + t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport) + } + ua, ok := sec.Base.(*UserAgentTransport) + if !ok { + t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base) + } + if _, ok := ua.Base.(*RetryTransport); !ok { + t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base) + } +} + +// --------------------------------------------------------------------------- +// extensionMiddleware — legacy Interceptor path +// --------------------------------------------------------------------------- + +type stubTransportProvider struct { + interceptor exttransport.Interceptor +} + +func (s *stubTransportProvider) Name() string { return "stub" } +func (s *stubTransportProvider) ResolveInterceptor(context.Context) exttransport.Interceptor { + if s.interceptor != nil { + return s.interceptor + } + return &stubTransportImpl{} +} + +type stubTransportImpl struct{} + +func (s *stubTransportImpl) PreRoundTrip(req *http.Request) func(*http.Response, error) { + return nil +} + +// headerCapturingInterceptor sets custom headers in PreRoundTrip and records +// whether PostRoundTrip was called, to verify execution order. +type headerCapturingInterceptor struct { + preCalled bool + postCalled bool +} + +func (h *headerCapturingInterceptor) PreRoundTrip(req *http.Request) func(*http.Response, error) { + h.preCalled = true + // Set a custom header that should survive (no built-in override) + req.Header.Set("X-Custom-Trace", "ext-trace-123") + // Try to override a security header — should be overwritten by SecurityHeaderTransport + req.Header.Set(HeaderSource, "ext-tampered") + return func(resp *http.Response, err error) { + h.postCalled = true + } +} + +func TestExtensionInterceptor_ExecutionOrder(t *testing.T) { + var receivedHeaders http.Header + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ic := &headerCapturingInterceptor{} + exttransport.Register(&stubTransportProvider{interceptor: ic}) + t.Cleanup(func() { exttransport.Register(nil) }) + + // Use HTTP transport chain (has SecurityHeaderTransport) + var base http.RoundTripper = http.DefaultTransport + base = &RetryTransport{Base: base} + base = &SecurityHeaderTransport{Base: base} + transport := wrapWithExtension(base) + client := &http.Client{Transport: transport} + + req, _ := http.NewRequest("GET", srv.URL, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + // PreRoundTrip was called + if !ic.preCalled { + t.Fatal("PreRoundTrip was not called") + } + // PostRoundTrip (closure) was called + if !ic.postCalled { + t.Fatal("PostRoundTrip closure was not called") + } + // Custom header set by extension survives (no built-in override) + if got := receivedHeaders.Get("X-Custom-Trace"); got != "ext-trace-123" { + t.Fatalf("X-Custom-Trace = %q, want %q", got, "ext-trace-123") + } + // Security header overridden by extension is restored by SecurityHeaderTransport + if got := receivedHeaders.Get(HeaderSource); got != SourceValue { + t.Fatalf("%s = %q, want %q (built-in should override extension)", HeaderSource, got, SourceValue) + } +} + +// interceptorFunc adapts a function to exttransport.Interceptor. +type interceptorFunc func(*http.Request) func(*http.Response, error) + +func (f interceptorFunc) PreRoundTrip(req *http.Request) func(*http.Response, error) { return f(req) } + +func TestExtensionInterceptor_ContextTamperPrevented(t *testing.T) { + type ctxKeyType string + const testKey ctxKeyType = "original" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + var ctxValue any + + // Use a custom transport that captures the context value seen by the built-in chain + capturer := roundTripFunc(func(req *http.Request) (*http.Response, error) { + ctxValue = req.Context().Value(testKey) + return http.DefaultTransport.RoundTrip(req) + }) + + // Interceptor that tries to tamper with context + tamperIC := interceptorFunc(func(req *http.Request) func(*http.Response, error) { + // Try to replace context with a new one + *req = *req.WithContext(context.WithValue(req.Context(), testKey, "tampered")) + return nil + }) + + mid := &extensionMiddleware{Base: capturer, Ext: tamperIC} + + origCtx := context.WithValue(context.Background(), testKey, "original") + req, _ := http.NewRequestWithContext(origCtx, "GET", srv.URL, nil) + resp, err := mid.RoundTrip(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + // Built-in chain should see original context, not tampered + if ctxValue != "original" { + t.Fatalf("built-in chain saw context value %q, want %q", ctxValue, "original") + } +} + +// --------------------------------------------------------------------------- +// extensionMiddleware — PreRoundTripE abort path +// --------------------------------------------------------------------------- + +// abortingInterceptor implements exttransport.AbortableInterceptor and +// records invocation of the pre and post hooks. These middleware tests only +// assert middleware-level integration; pure *AbortError behavior +// (Error/Unwrap/Is/As) is covered in extension/transport/errors_test.go. +type abortingInterceptor struct { + reason error // if non-nil, PreRoundTripE returns this to abort + nilPost bool // if true, PreRoundTripE returns a nil post func + preECalled bool + postCalled bool + postResp *http.Response + postErr error +} + +// PreRoundTrip is a no-op that satisfies the legacy Interceptor method; the +// middleware never calls it when PreRoundTripE is present. +func (*abortingInterceptor) PreRoundTrip(*http.Request) func(*http.Response, error) { + return nil +} + +func (a *abortingInterceptor) PreRoundTripE(req *http.Request) (func(*http.Response, error), error) { + a.preECalled = true + if a.nilPost { + return nil, a.reason + } + return func(resp *http.Response, err error) { + a.postCalled = true + a.postResp = resp + a.postErr = err + }, a.reason +} + +func TestExtensionMiddleware_PreRoundTripEAbort(t *testing.T) { + innerErr := errors.New("denied by policy") + + t.Run("skips base and wires AbortError fields", func(t *testing.T) { + ic := &abortingInterceptor{reason: innerErr} + baseCalls := 0 + base := roundTripFunc(func(*http.Request) (*http.Response, error) { + baseCalls++ + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + mid := &extensionMiddleware{Base: base, Ext: ic, ExtName: "stub"} + req, _ := http.NewRequest("GET", "http://example.invalid/", nil) + resp, err := mid.RoundTrip(req) + + if resp != nil { + t.Fatalf("resp = %v, want nil on abort", resp) + } + if baseCalls != 0 { + t.Fatalf("base RoundTrip called %d times on abort, want 0", baseCalls) + } + if !ic.preECalled { + t.Fatal("PreRoundTripE was not called") + } + + var aErr *exttransport.AbortError + if !errors.As(err, &aErr) { + t.Fatalf("errors.As(*AbortError) = false, err = %v (%T)", err, err) + } + if aErr.Extension != "stub" || aErr.Reason != innerErr { + t.Fatalf("AbortError = %+v, want {Extension:stub Reason:%v}", aErr, innerErr) + } + + // Post must see the original inner err, not the *AbortError wrapper. + if !ic.postCalled { + t.Fatal("post hook was not called on abort") + } + if ic.postResp != nil { + t.Fatalf("post resp = %v, want nil", ic.postResp) + } + if ic.postErr != innerErr { + t.Fatalf("post err = %v, want original inner err %v", ic.postErr, innerErr) + } + }) + + t.Run("nil post still returns AbortError", func(t *testing.T) { + ic := &abortingInterceptor{reason: innerErr, nilPost: true} + base := roundTripFunc(func(*http.Request) (*http.Response, error) { + t.Fatal("base must not be called on abort") + return nil, nil + }) + + mid := &extensionMiddleware{Base: base, Ext: ic, ExtName: "stub"} + req, _ := http.NewRequest("GET", "http://example.invalid/", nil) + _, err := mid.RoundTrip(req) + + var aErr *exttransport.AbortError + if !errors.As(err, &aErr) { + t.Fatalf("errors.As(*AbortError) = false, err = %v", err) + } + }) +} + +func TestExtensionMiddleware_PreRoundTripEHappyPath(t *testing.T) { + ic := &abortingInterceptor{} // reason == nil → no abort + baseCalls := 0 + base := roundTripFunc(func(*http.Request) (*http.Response, error) { + baseCalls++ + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + mid := &extensionMiddleware{Base: base, Ext: ic, ExtName: "stub"} + req, _ := http.NewRequest("GET", "http://example.invalid/", nil) + resp, err := mid.RoundTrip(req) + if err != nil { + t.Fatalf("happy path returned err: %v", err) + } + if resp == nil || resp.StatusCode != http.StatusOK { + t.Fatalf("resp = %v, want 200", resp) + } + if baseCalls != 1 { + t.Fatalf("base RoundTrip called %d times, want 1", baseCalls) + } + if !ic.preECalled { + t.Fatal("PreRoundTripE was not called") + } + if !ic.postCalled || ic.postErr != nil { + t.Fatalf("post hook not called or err != nil: called=%v err=%v", ic.postCalled, ic.postErr) + } +} diff --git a/internal/credential/default_provider.go b/internal/credential/default_provider.go index bedad7b86..a3ebb90c9 100644 --- a/internal/credential/default_provider.go +++ b/internal/credential/default_provider.go @@ -21,11 +21,14 @@ import ( // DefaultAccountProvider resolves account from config.json via keychain. type DefaultAccountProvider struct { - keychain keychain.KeychainAccess + keychain func() keychain.KeychainAccess profile string } -func NewDefaultAccountProvider(kc keychain.KeychainAccess, profile string) *DefaultAccountProvider { +func NewDefaultAccountProvider(kc func() keychain.KeychainAccess, profile string) *DefaultAccountProvider { + if kc == nil { + kc = keychain.Default + } return &DefaultAccountProvider{keychain: kc, profile: profile} } @@ -36,7 +39,7 @@ func (p *DefaultAccountProvider) ResolveAccount(ctx context.Context) (*Account, return nil, &core.ConfigError{Code: 2, Type: "config", Message: "not configured", Hint: "run `lark-cli config init --new` in the background. It blocks and outputs a verification URL — retrieve the URL and open it in a browser to complete setup."} } - cfg, err := core.ResolveConfigFromMulti(multi, p.keychain, p.profile) + cfg, err := core.ResolveConfigFromMulti(multi, p.keychain(), p.profile) if err != nil { return nil, err } diff --git a/internal/credential/integration_test.go b/internal/credential/integration_test.go index 46a3485ff..7daef1d6a 100644 --- a/internal/credential/integration_test.go +++ b/internal/credential/integration_test.go @@ -12,6 +12,7 @@ import ( "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/envvars" + "github.com/larksuite/cli/internal/keychain" ) type noopKC struct{} @@ -99,7 +100,7 @@ func TestFullChain_ConfigStrictMode(t *testing.T) { } ep := &envprovider.Provider{} - defaultAcct := credential.NewDefaultAccountProvider(&noopKC{}, "") + defaultAcct := credential.NewDefaultAccountProvider(func() keychain.KeychainAccess { return &noopKC{} }, "") cp := credential.NewCredentialProvider( []extcred.Provider{ep}, diff --git a/internal/update/update.go b/internal/update/update.go index c87bb59fc..39052d9ef 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -61,7 +61,7 @@ func httpClient() *http.Client { } return &http.Client{ Timeout: fetchTimeout, - Transport: util.NewBaseTransport(), + Transport: util.SharedTransport(), } } diff --git a/internal/util/proxy.go b/internal/util/proxy.go index 64308da85..d9e251859 100644 --- a/internal/util/proxy.go +++ b/internal/util/proxy.go @@ -72,31 +72,47 @@ func WarnIfProxied(w io.Writer) { }) } -// NewBaseTransport creates an *http.Transport cloned from http.DefaultTransport. -// If LARK_CLI_NO_PROXY is set, proxy support is disabled. -// Each call returns a new instance; use FallbackTransport for a shared singleton. -func NewBaseTransport() *http.Transport { +// noProxyTransport is a proxy-disabled clone of http.DefaultTransport, +// lazily built the first time LARK_CLI_NO_PROXY is observed set. +var noProxyTransport = sync.OnceValue(func() *http.Transport { def, ok := http.DefaultTransport.(*http.Transport) if !ok { return &http.Transport{} } t := def.Clone() + t.Proxy = nil + return t +}) + +// SharedTransport returns the base http.RoundTripper for CLI HTTP clients. +// +// By default it returns http.DefaultTransport — the stdlib-provided +// process-wide singleton — so every HTTP client in the process shares one +// TCP connection pool, TLS session cache, and HTTP/2 state. When +// LARK_CLI_NO_PROXY is set it returns a separate proxy-disabled singleton +// clone; LARK_CLI_NO_PROXY is checked on every call, but the clone is built +// at most once. +// +// The returned RoundTripper MUST NOT be mutated. Callers that need a +// customized transport should assert to *http.Transport and Clone() it. +// Using a shared base is required so persistConn readLoop/writeLoop +// goroutines are reused; cloning per call leaks them until IdleConnTimeout +// (~90s) fires. +func SharedTransport() http.RoundTripper { if os.Getenv(EnvNoProxy) != "" { - t.Proxy = nil + return noProxyTransport() } - return t + return http.DefaultTransport } -// fallbackTransport is a lazily-initialized singleton used by transport -// decorators when their Base field is nil, preserving connection pooling. -var fallbackTransport = sync.OnceValue(func() *http.Transport { - return NewBaseTransport() -}) - -// FallbackTransport returns a shared *http.Transport singleton suitable for -// use as a fallback when a transport decorator's Base is nil. -// Unlike NewBaseTransport (which clones per call), this reuses a single -// instance so that TCP connections and TLS sessions are pooled. +// FallbackTransport returns a shared *http.Transport singleton. It is a +// thin wrapper over SharedTransport retained so modules that were already +// on the leak-free singleton path (internal/auth, internal/cmdutil +// transport decorators) do not have to migrate. New code should prefer +// SharedTransport and treat the base as an http.RoundTripper. func FallbackTransport() *http.Transport { - return fallbackTransport() + if t, ok := SharedTransport().(*http.Transport); ok { + return t + } + return noProxyTransport() } diff --git a/internal/util/proxy_test.go b/internal/util/proxy_test.go index daf1b7ab4..f78720963 100644 --- a/internal/util/proxy_test.go +++ b/internal/util/proxy_test.go @@ -28,19 +28,65 @@ func TestDetectProxyEnv(t *testing.T) { } } -func TestNewBaseTransport_Default(t *testing.T) { +func TestSharedTransport_DefaultReturnsStdlibSingleton(t *testing.T) { t.Setenv(EnvNoProxy, "") - tr := NewBaseTransport() - if tr.Proxy == nil { - t.Error("expected proxy func to be set when LARK_CLI_NO_PROXY is not set") + tr := SharedTransport() + if tr != http.DefaultTransport { + t.Error("SharedTransport should return http.DefaultTransport when LARK_CLI_NO_PROXY is unset") } } -func TestNewBaseTransport_NoProxy(t *testing.T) { +func TestSharedTransport_NoProxyReturnsClone(t *testing.T) { t.Setenv(EnvNoProxy, "1") - tr := NewBaseTransport() - if tr.Proxy != nil { - t.Error("expected proxy func to be nil when LARK_CLI_NO_PROXY=1") + tr := SharedTransport() + if tr == http.DefaultTransport { + t.Fatal("SharedTransport should return a clone, not DefaultTransport, when LARK_CLI_NO_PROXY is set") + } + ht, ok := tr.(*http.Transport) + if !ok { + t.Fatalf("expected *http.Transport, got %T", tr) + } + if ht.Proxy != nil { + t.Error("no-proxy transport should have Proxy == nil") + } +} + +func TestSharedTransport_NoProxyIsCachedSingleton(t *testing.T) { + t.Setenv(EnvNoProxy, "1") + a := SharedTransport() + b := SharedTransport() + if a != b { + t.Error("repeated SharedTransport calls with LARK_CLI_NO_PROXY set must return the same instance") + } +} + +func TestSharedTransport_EnvUnsetAfterSetFallsBackToDefault(t *testing.T) { + // Simulate a process that first runs with LARK_CLI_NO_PROXY=1 (populating + // the no-proxy singleton), then unsets it. Subsequent calls must return + // http.DefaultTransport, NOT the cached no-proxy clone. + t.Setenv(EnvNoProxy, "1") + noProxy := SharedTransport() + if noProxy == http.DefaultTransport { + t.Fatal("precondition: first call with env set should not return DefaultTransport") + } + + t.Setenv(EnvNoProxy, "") + after := SharedTransport() + if after != http.DefaultTransport { + t.Errorf("after unsetting LARK_CLI_NO_PROXY, SharedTransport must return http.DefaultTransport, got %T (%p)", after, after) + } +} + +func TestSharedTransport_NoProxyOverridesSystemProxy(t *testing.T) { + t.Setenv("HTTPS_PROXY", "http://should-be-ignored:8888") + t.Setenv(EnvNoProxy, "1") + + ht, ok := SharedTransport().(*http.Transport) + if !ok { + t.Fatalf("expected *http.Transport, got %T", SharedTransport()) + } + if ht.Proxy != nil { + t.Error("LARK_CLI_NO_PROXY should override system proxy settings") } } @@ -156,35 +202,3 @@ func TestWarnIfProxied_RedactsCredentials(t *testing.T) { t.Errorf("warning should contain redacted proxy URL, got: %s", out) } } - -func TestNewBaseTransport_IsHTTPTransport(t *testing.T) { - t.Setenv(EnvNoProxy, "") - tr := NewBaseTransport() - - // Should be a valid *http.Transport that can be used - var rt http.RoundTripper = tr - _ = rt - - // Verify it's not the same pointer as DefaultTransport (should be a clone) - if tr == http.DefaultTransport { - t.Error("NewBaseTransport should return a clone, not DefaultTransport itself") - } -} - -func TestNewBaseTransport_RespectsNoProxyEnv(t *testing.T) { - // Simulate: user sets both system proxy and our disable flag - t.Setenv("HTTPS_PROXY", "http://should-be-ignored:8888") - t.Setenv(EnvNoProxy, "1") - - tr := NewBaseTransport() - if tr.Proxy != nil { - t.Error("LARK_CLI_NO_PROXY should override system proxy settings") - } - - // Clean up and verify proxy is restored - t.Setenv(EnvNoProxy, "") - tr2 := NewBaseTransport() - if tr2.Proxy == nil { - t.Error("proxy should be enabled when LARK_CLI_NO_PROXY is unset") - } -} diff --git a/shortcuts/common/runner.go b/shortcuts/common/runner.go index 7d9685a13..6a0009b9c 100644 --- a/shortcuts/common/runner.go +++ b/shortcuts/common/runner.go @@ -571,12 +571,16 @@ func enhancePermissionError(err error, requiredScopes []string) error { // Mount registers the shortcut on a parent command. func (s Shortcut) Mount(parent *cobra.Command, f *cmdutil.Factory) { + s.MountWithContext(context.Background(), parent, f) +} + +func (s Shortcut) MountWithContext(ctx context.Context, parent *cobra.Command, f *cmdutil.Factory) { if s.Execute != nil { - s.mountDeclarative(parent, f) + s.mountDeclarative(ctx, parent, f) } } -func (s Shortcut) mountDeclarative(parent *cobra.Command, f *cmdutil.Factory) { +func (s Shortcut) mountDeclarative(ctx context.Context, parent *cobra.Command, f *cmdutil.Factory) { shortcut := s if len(shortcut.AuthTypes) == 0 { shortcut.AuthTypes = []string{"user"} @@ -592,7 +596,7 @@ func (s Shortcut) mountDeclarative(parent *cobra.Command, f *cmdutil.Factory) { }, } cmdutil.SetSupportedIdentities(cmd, shortcut.AuthTypes) - registerShortcutFlags(cmd, &shortcut) + registerShortcutFlagsWithContext(ctx, cmd, f, &shortcut) cmdutil.SetTips(cmd, shortcut.Tips) parent.AddCommand(cmd) } @@ -823,7 +827,11 @@ func rejectPositionalArgs() cobra.PositionalArgs { } } -func registerShortcutFlags(cmd *cobra.Command, s *Shortcut) { +func registerShortcutFlags(cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut) { + registerShortcutFlagsWithContext(context.Background(), cmd, f, s) +} + +func registerShortcutFlagsWithContext(ctx context.Context, cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut) { for _, fl := range s.Flags { desc := fl.Desc if len(fl.Enum) > 0 { @@ -874,11 +882,7 @@ func registerShortcutFlags(cmd *cobra.Command, s *Shortcut) { cmd.Flags().Bool("yes", false, "confirm high-risk operation") } cmd.Flags().StringP("jq", "q", "", "jq expression to filter JSON output") - cmd.Flags().String("as", s.AuthTypes[0], "identity type: user | bot") - - _ = cmd.RegisterFlagCompletionFunc("as", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { - return s.AuthTypes, cobra.ShellCompDirectiveNoFileComp - }) + cmdutil.AddShortcutIdentityFlag(ctx, cmd, f, s.AuthTypes) if s.HasFormat { _ = cmd.RegisterFlagCompletionFunc("format", func(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { return []string{"json", "pretty", "table", "ndjson", "csv"}, cobra.ShellCompDirectiveNoFileComp diff --git a/shortcuts/common/runner_identity_flag_test.go b/shortcuts/common/runner_identity_flag_test.go new file mode 100644 index 000000000..a6ed1020b --- /dev/null +++ b/shortcuts/common/runner_identity_flag_test.go @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package common + +import ( + "context" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/spf13/cobra" +) + +func TestShortcutMount_StrictModeHidesAsFlag(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, SupportedIdentities: 2, + }) + parent := &cobra.Command{Use: "root"} + shortcut := Shortcut{ + Service: "docs", + Command: "+fetch", + Description: "fetch doc", + AuthTypes: []string{"user", "bot"}, + Execute: func(context.Context, *RuntimeContext) error { + return nil + }, + } + + shortcut.Mount(parent, f) + cmd, _, err := parent.Find([]string{"+fetch"}) + if err != nil { + t.Fatalf("Find() error = %v", err) + } + flag := cmd.Flags().Lookup("as") + if flag == nil { + t.Fatal("expected --as flag to be registered") + } + if !flag.Hidden { + t.Fatal("expected --as flag to be hidden in strict mode") + } + if got := flag.DefValue; got != "bot" { + t.Fatalf("default value = %q, want %q", got, "bot") + } +} diff --git a/shortcuts/common/runner_jq_test.go b/shortcuts/common/runner_jq_test.go index 17af83dc7..949002672 100644 --- a/shortcuts/common/runner_jq_test.go +++ b/shortcuts/common/runner_jq_test.go @@ -145,10 +145,10 @@ func TestRuntimeContext_FileIO_UsesExecutionContext(t *testing.T) { } } -func newTestShortcutCmd(s *Shortcut) *cobra.Command { +func newTestShortcutCmd(s *Shortcut, f *cmdutil.Factory) *cobra.Command { cmd := &cobra.Command{Use: "test-shortcut"} cmd.SetContext(context.Background()) - registerShortcutFlags(cmd, s) + registerShortcutFlags(cmd, f, s) return cmd } @@ -177,7 +177,7 @@ func TestRunShortcut_JqAndFormatConflict(t *testing.T) { return nil }, } - cmd := newTestShortcutCmd(s) + cmd := newTestShortcutCmd(s, newTestFactory()) cmd.Flags().Set("jq", ".data") cmd.Flags().Set("format", "table") cmd.Flags().Set("as", "bot") @@ -200,7 +200,7 @@ func TestRunShortcut_JqInvalidExpression(t *testing.T) { return nil }, } - cmd := newTestShortcutCmd(s) + cmd := newTestShortcutCmd(s, newTestFactory()) cmd.Flags().Set("jq", "invalid[") cmd.Flags().Set("as", "bot") @@ -223,7 +223,7 @@ func TestRunShortcut_JqRuntimeError_PropagatesError(t *testing.T) { return nil }, } - cmd := newTestShortcutCmd(s) + cmd := newTestShortcutCmd(s, newTestFactory()) cmd.Flags().Set("jq", ".foo | invalid_func_xyz") cmd.Flags().Set("as", "bot") diff --git a/shortcuts/register.go b/shortcuts/register.go index f5b085689..534163ec2 100644 --- a/shortcuts/register.go +++ b/shortcuts/register.go @@ -4,6 +4,8 @@ package shortcuts import ( + "context" + "github.com/larksuite/cli/shortcuts/okr" "github.com/spf13/cobra" @@ -58,6 +60,10 @@ func AllShortcuts() []common.Shortcut { // RegisterShortcuts registers all +shortcut commands on the program. func RegisterShortcuts(program *cobra.Command, f *cmdutil.Factory) { + RegisterShortcutsWithContext(context.Background(), program, f) +} + +func RegisterShortcutsWithContext(ctx context.Context, program *cobra.Command, f *cmdutil.Factory) { // Group by service byService := make(map[string][]common.Shortcut) for _, s := range allShortcuts { @@ -86,7 +92,7 @@ func RegisterShortcuts(program *cobra.Command, f *cmdutil.Factory) { } for _, shortcut := range shortcuts { - shortcut.Mount(svc, f) + shortcut.MountWithContext(ctx, svc, f) } } } diff --git a/sidecar/server-demo/main.go b/sidecar/server-demo/main.go index 1197b5145..7292daa26 100644 --- a/sidecar/server-demo/main.go +++ b/sidecar/server-demo/main.go @@ -99,7 +99,7 @@ func run(ctx context.Context, listen, keyFile, logFile, profile string) error { // Reuse the lark-cli credential pipeline. A production implementation // would likely source credentials from a secrets manager instead. - factory := cmdutil.NewDefault(cmdutil.InvocationContext{Profile: profile}) + factory := cmdutil.NewDefault(nil, cmdutil.InvocationContext{Profile: profile}) cfg, err := factory.Config() if err != nil { return fmt.Errorf("failed to load config: %v", err)