diff --git a/.golangci.yml b/.golangci.yml index 4690fe93c..da2d5320b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -27,6 +27,7 @@ linters: - reassign # checks that package variables are not reassigned - unconvert # removes unnecessary type conversions - unused # checks for unused constants, variables, functions and types + - forbidigo # forbids specific function calls # To enable later after fixing existing issues: # - errcheck # checks for unchecked errors @@ -44,8 +45,89 @@ linters: linters: - bodyclose - gocritic + - forbidigo + - path-except: (shortcuts/|internal/) + linters: + - forbidigo + - path: internal/vfs/ + linters: + - forbidigo settings: + forbidigo: + forbid: + # ── Filesystem operations: use internal/vfs instead ── + - pattern: os\.Stat\b + msg: "use vfs.Stat() from internal/vfs" + - pattern: os\.Lstat\b + msg: "use vfs.Lstat() from internal/vfs" + - pattern: os\.Open\b + msg: "use vfs.Open() from internal/vfs" + - pattern: os\.OpenFile\b + msg: "use vfs.OpenFile() from internal/vfs" + - pattern: os\.Create\b + msg: "use vfs.OpenFile() from internal/vfs" + - pattern: os\.CreateTemp\b + msg: >- + internal/: use vfs.CreateTemp() from internal/vfs. + shortcuts/: avoid temp files entirely — use io.Reader streaming or in-memory buffers instead. + - pattern: os\.Mkdir\b + msg: "use vfs.MkdirAll() from internal/vfs" + - pattern: os\.MkdirAll\b + msg: "use vfs.MkdirAll() from internal/vfs" + - pattern: os\.Remove\b + msg: >- + internal/: use vfs.Remove() from internal/vfs. + shortcuts/: avoid temp files entirely — use io.Reader streaming or in-memory buffers instead. + - pattern: os\.RemoveAll\b + msg: >- + internal/: add RemoveAll to internal/vfs/fs.go first, then use vfs.RemoveAll(). + shortcuts/: avoid temp files entirely — use io.Reader streaming or in-memory buffers instead. + - pattern: os\.Rename\b + msg: "use vfs.Rename() from internal/vfs" + - pattern: os\.ReadFile\b + msg: "use vfs.ReadFile() from internal/vfs" + - pattern: os\.WriteFile\b + msg: "use vfs.WriteFile() from internal/vfs" + - pattern: os\.ReadDir\b + msg: "add ReadDir to internal/vfs/fs.go first, then use vfs.ReadDir()" + - pattern: os\.Getwd\b + msg: "use vfs.Getwd() from internal/vfs" + - pattern: os\.Chdir\b + msg: "add Chdir to internal/vfs/fs.go first, then use vfs.Chdir()" + - pattern: os\.UserHomeDir\b + msg: "use vfs.UserHomeDir() from internal/vfs" + - pattern: os\.Chmod\b + msg: "add Chmod to internal/vfs/fs.go first, then use vfs.Chmod()" + - pattern: os\.Chown\b + msg: "add Chown to internal/vfs/fs.go first, then use vfs.Chown()" + - pattern: os\.Lchown\b + msg: "add Lchown to internal/vfs/fs.go first, then use vfs.Lchown()" + - pattern: os\.Link\b + msg: "add Link to internal/vfs/fs.go first, then use vfs.Link()" + - pattern: os\.Symlink\b + msg: "add Symlink to internal/vfs/fs.go first, then use vfs.Symlink()" + - pattern: os\.Readlink\b + msg: "add Readlink to internal/vfs/fs.go first, then use vfs.Readlink()" + - pattern: os\.Truncate\b + msg: "add Truncate to internal/vfs/fs.go first, then use vfs.Truncate()" + - pattern: os\.DirFS\b + msg: "add DirFS to internal/vfs/fs.go first, then use vfs.DirFS()" + - pattern: os\.SameFile\b + msg: "add SameFile to internal/vfs/fs.go first, then use vfs.SameFile()" + # ── IO streams: use IOStreams from cmdutil instead ── + - pattern: os\.Stdin\b + msg: "use IOStreams.In instead of os.Stdin" + - pattern: os\.Stdout\b + msg: "use IOStreams.Out instead of os.Stdout" + - pattern: os\.Stderr\b + msg: "use IOStreams.ErrOut instead of os.Stderr" + # ── Process-level rules ── + - pattern: os\.Exit\b + msg: >- + Do not use os.Exit in shortcuts/. Return an error instead and let + the caller (cmd layer) decide how to terminate. + analyze-types: true gocritic: disabled-checks: - appendAssign diff --git a/AGENTS.md b/AGENTS.md index e594a81c4..e6ed5f872 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,33 +1,78 @@ # AGENTS.md -Concise maintainer/developer guide for building, testing, and opening high-quality PRs in this repo. ## Goal (pick one per PR) + - Make CLI better: improve UX, error messages, help text, flags, and output clarity. - Improve reliability: fix bugs, edge cases, and regressions with tests. - Improve developer velocity: simplify code paths, reduce complexity, keep behavior explicit. - Improve quality gates: strengthen tests/lint/checks without adding heavy process. -## Fast Dev Loop -1. `make build` (runs `python3 scripts/fetch_meta.py` first) -2. `make unit-test` (required before PR) -3. Run changed command(s) manually via `./lark-cli ...` +## Build & Test + +```bash +make build # Build (runs fetch_meta first) +make unit-test # Required before PR (runs with -race) +make test # Full: vet + unit + integration +``` ## Pre-PR Checks (match CI gates) + 1. `make unit-test` -2. `go mod tidy` (must not change `go.mod`/`go.sum`) +2. `go mod tidy` — must not change `go.mod`/`go.sum` 3. `go run github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.1.6 run --new-from-rev=origin/main` 4. If dependencies changed: `go run github.com/google/go-licenses/v2@v2.0.1 check ./... --disallowed_types=forbidden,restricted,reciprocal,unknown` -5. Optional full local suite: `make test` (vet + unit + integration) - -## Test/Check Commands -- Unit: `make unit-test` -- Integration: `make integration-test` -- Full: `make test` -- Vet only: `make vet` -- Coverage (local): `go test -race -coverprofile=coverage.txt -covermode=atomic ./...` - -## Commit/PR Rules -- Use Conventional Commits in English: `feat: ...`, `fix: ...`, `docs: ...`, `ci: ...`, `test: ...`, `chore: ...`, `refactor: ...` -- Keep PR title in the same Conventional Commit format (squash merge keeps it). -- Before opening a real PR, draft/fill description from `.github/pull_request_template.md` and ensure Summary/Changes/Test Plan are complete. -- Never commit secrets/tokens/internal sensitive data. + +## Commit & PR + +- Conventional Commits in English: `feat:`, `fix:`, `docs:`, `test:`, `refactor:`, `chore:`, `ci:` +- PR title in the same format. Fill `.github/pull_request_template.md` completely. +- Never commit secrets, tokens, or internal sensitive data. + +## Source Layout + +| Path | What it does | +|------|-------------| +| `cmd/root.go` | Entry point, command registration, strict mode pruning | +| `cmd/profile/` | Multi-profile management (add/list/use/rename/remove) | +| `cmd/config/` | Config init, show, strict-mode | +| `cmd/service/` | Auto-registered API commands from embedded metadata | +| `shortcuts/common/runner.go` | Shortcut execution pipeline, Flag.Input (@file/stdin) resolution | +| `shortcuts/` | Domain-specific shortcut implementations | +| `internal/cmdutil/factory.go` | Factory pattern — identity resolution, credential, config | +| `internal/cmdutil/factory_default.go` | Production factory wiring | +| `internal/credential/` | Credential provider chain (extension → default) | +| `extension/credential/` | Plugin-facing credential interfaces and env provider | +| `internal/client/client.go` | APIClient: DoSDKRequest, DoStream | +| `internal/core/config.go` | Multi-profile config loading/saving | +| `internal/vfs/` | Filesystem abstraction (use `vfs.*` instead of `os.*`) | +| `internal/validate/path.go` | Path safety validation | + +## Who Uses This CLI + +This CLI's primary consumers include AI agents (Claude Code, Cursor, Gemini CLI). Your code is read by machines — error messages, output format, and flag design all directly affect agent success rates. + +The one rule to internalize: **every error message you write will be parsed by an AI to decide its next action.** Make errors structured, actionable, and specific. + +## Code Conventions + +### Structured errors in commands + +`RunE` functions must return `output.Errorf` / `output.ErrWithHint` — never bare `fmt.Errorf`. AI agents parse stderr as JSON; bare errors break this contract. + +### stdout is data, stderr is everything else + +Program output (JSON envelopes) goes to stdout. Progress, warnings, hints go to stderr. Mixing them corrupts pipe chains. + +### Use `vfs.*` instead of `os.*` + +All filesystem access goes through `internal/vfs`. This enables test mocking. + +### Validate paths before reading + +CLI arguments are untrusted (they come from AI agents). Call `validate.SafeInputPath` before any file I/O. + +### Tests + +- Every behavior change needs a test alongside the change. +- `cmdutil.TestFactory(t, config)` for test factories. +- `t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir())` to isolate config state. diff --git a/cmd/api/api.go b/cmd/api/api.go index 89661b365..084cb059b 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -152,7 +152,11 @@ func buildAPIRequest(opts *APIOptions) (client.RawApiRequest, error) { func apiRun(opts *APIOptions) error { f := opts.Factory - opts.As = f.ResolveAs(opts.Cmd, opts.As) + opts.As = f.ResolveAs(opts.Ctx, opts.Cmd, opts.As) + + if err := f.CheckStrictMode(opts.Ctx, opts.As); err != nil { + return err + } if opts.PageAll && opts.Output != "" { return output.ErrValidation("--output and --page-all are mutually exclusive") @@ -166,7 +170,7 @@ func apiRun(opts *APIOptions) error { return err } - config, err := f.ResolveConfig(opts.As) + config, err := f.Config() if err != nil { return err } diff --git a/cmd/api/api_test.go b/cmd/api/api_test.go index 0b01d250e..b55e82257 100644 --- a/cmd/api/api_test.go +++ b/cmd/api/api_test.go @@ -70,16 +70,6 @@ func TestApiCmd_BotMode(t *testing.T) { AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, }) - // Register tenant_access_token stub - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, - "msg": "ok", - "tenant_access_token": "t-test-token", - "expire": 7200, - }, - }) // Register API endpoint stub reg.Register(&httpmock.Stub{ URL: "/open-apis/test", @@ -234,13 +224,6 @@ func TestApiCmd_BinaryResponse_AutoSave(t *testing.T) { AppID: "test-app-bin", AppSecret: "test-secret-bin", Brand: core.BrandFeishu, }) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token-bin", "expire": 7200, - }, - }) reg.Register(&httpmock.Stub{ URL: "/open-apis/drive/v1/files/xxx/download", RawBody: []byte("fake-binary-content"), @@ -266,14 +249,6 @@ func TestApiCmd_PageAll_NonBatchAPI_FallbackToJSON(t *testing.T) { AppID: "test-app-pageall1", AppSecret: "test-secret-pageall1", Brand: core.BrandFeishu, }) - // Register tenant_access_token stub - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token-pa1", "expire": 7200, - }, - }) // Register a non-batch API that returns scalar data (no array field) reg.Register(&httpmock.Stub{ URL: "/open-apis/contact/v3/users/u123", @@ -310,13 +285,6 @@ func TestApiCmd_PageAll_NonBatchAPI_ErrorStillOutputsJSON(t *testing.T) { AppID: "test-app-pageall-err", AppSecret: "test-secret-pageall-err", Brand: core.BrandFeishu, }) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token-err", "expire": 7200, - }, - }) // Non-batch API that returns a business error (code != 0) reg.Register(&httpmock.Stub{ URL: "/open-apis/im/v1/chats/oc_xxx/announcement", @@ -346,14 +314,6 @@ func TestApiCmd_PageAll_BatchAPI_StreamsItems(t *testing.T) { AppID: "test-app-pageall2", AppSecret: "test-secret-pageall2", Brand: core.BrandFeishu, }) - // Register tenant_access_token stub (unique app credentials => new token request) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token-pa2", "expire": 7200, - }, - }) // Register a batch API that returns an array field reg.Register(&httpmock.Stub{ URL: "/open-apis/contact/v3/users", @@ -409,13 +369,6 @@ func TestApiCmd_APIError_IsRaw(t *testing.T) { AppID: "test-app-raw", AppSecret: "test-secret-raw", Brand: core.BrandFeishu, }) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token-raw", "expire": 7200, - }, - }) // Return a permission error from the API reg.Register(&httpmock.Stub{ URL: "/open-apis/test/perm", @@ -456,13 +409,6 @@ func TestApiCmd_APIError_PreservesOriginalMessage(t *testing.T) { AppID: "test-app-origmsg", AppSecret: "test-secret-origmsg", Brand: core.BrandFeishu, }) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token-origmsg", "expire": 7200, - }, - }) reg.Register(&httpmock.Stub{ URL: "/open-apis/test/origmsg", Body: map[string]interface{}{ @@ -505,13 +451,6 @@ func TestApiCmd_PageAll_APIError_IsRaw(t *testing.T) { AppID: "test-app-rawpage", AppSecret: "test-secret-rawpage", Brand: core.BrandFeishu, }) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token-rawpage", "expire": 7200, - }, - }) reg.Register(&httpmock.Stub{ URL: "/open-apis/test/rawpage", Body: map[string]interface{}{ @@ -599,13 +538,6 @@ func TestApiCmd_JqFilter_AppliesExpression(t *testing.T) { AppID: "test-app-jq", AppSecret: "test-secret-jq", Brand: core.BrandFeishu, }) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token-jq", "expire": 7200, - }, - }) reg.Register(&httpmock.Stub{ URL: "/open-apis/test/jq", Body: map[string]interface{}{ @@ -676,13 +608,6 @@ func TestApiCmd_PageAll_WithJq(t *testing.T) { AppID: "test-app-pjq", AppSecret: "test-secret-pjq", Brand: core.BrandFeishu, }) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token-pjq", "expire": 7200, - }, - }) reg.Register(&httpmock.Stub{ URL: "/open-apis/contact/v3/users", Body: map[string]interface{}{ diff --git a/cmd/auth/list.go b/cmd/auth/list.go index 240869594..f20c97fe8 100644 --- a/cmd/auth/list.go +++ b/cmd/auth/list.go @@ -46,8 +46,8 @@ func authListRun(opts *ListOptions) error { return nil } - app := multi.Apps[0] - if len(app.Users) == 0 { + app := multi.CurrentAppConfig(f.Invocation.Profile) + if app == nil || len(app.Users) == 0 { fmt.Fprintln(f.IOStreams.ErrOut, "No logged-in users. Run `lark-cli auth login` to log in.") return nil } diff --git a/cmd/auth/login.go b/cmd/auth/login.go index ebd744f28..467755bf1 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -46,6 +46,12 @@ func NewCmdAuthLogin(f *cmdutil.Factory, runF func(*LoginOptions) error) *cobra. For AI agents: this command blocks until the user completes authorization in the browser. Run it in the background and retrieve the verification URL from its output.`, RunE: func(cmd *cobra.Command, args []string) error { + if mode := f.ResolveStrictMode(cmd.Context()); mode == core.StrictModeBot { + return output.Errorf(output.ExitValidation, "strict_mode", + "strict mode is %q, user login is not allowed. "+ + "This setting is managed by the administrator and must not be modified by AI agents.", + mode) + } opts.Ctx = cmd.Context() if runF != nil { return runF(opts) @@ -53,6 +59,7 @@ browser. Run it in the background and retrieve the verification URL from its out return authLoginRun(opts) }, } + cmdutil.SetSupportedIdentities(cmd, []string{"user"}) cmd.Flags().StringVar(&opts.Scope, "scope", "", "scopes to request (space-separated)") cmd.Flags().BoolVar(&opts.Recommend, "recommend", false, "request only recommended (auto-approve) scopes") @@ -101,8 +108,10 @@ func authLoginRun(opts *LoginOptions) error { // Determine UI language from saved config lang := "zh" - if multi, _ := core.LoadMultiAppConfig(); multi != nil && len(multi.Apps) > 0 { - lang = multi.Apps[0].Lang + if multi, _ := core.LoadMultiAppConfig(); multi != nil { + if app := multi.FindApp(config.ProfileName); app != nil { + lang = app.Lang + } } msg := getLoginMsg(lang) @@ -304,18 +313,9 @@ func authLoginRun(opts *LoginOptions) error { } // Step 8: Update config — overwrite Users to single user, clean old tokens - multi, _ := core.LoadMultiAppConfig() - if multi != nil && len(multi.Apps) > 0 { - app := &multi.Apps[0] - for _, oldUser := range app.Users { - if oldUser.UserOpenId != openId { - larkauth.RemoveStoredToken(config.AppID, oldUser.UserOpenId) - } - } - app.Users = []core.AppUser{{UserOpenId: openId, UserName: userName}} - if err := core.SaveMultiAppConfig(multi); err != nil { - return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) - } + if err := syncLoginUserToProfile(config.ProfileName, config.AppID, openId, userName); err != nil { + _ = larkauth.RemoveStoredToken(config.AppID, openId) + return output.Errorf(output.ExitInternal, "internal", "failed to update login profile: %v", err) } if opts.JSON { @@ -384,24 +384,49 @@ func authLoginPollDeviceCode(opts *LoginOptions, config *core.CliConfig, msg *lo } // Update config — overwrite Users to single user, clean old tokens - multi, _ := core.LoadMultiAppConfig() - if multi != nil && len(multi.Apps) > 0 { - app := &multi.Apps[0] - for _, oldUser := range app.Users { - if oldUser.UserOpenId != openId { - larkauth.RemoveStoredToken(config.AppID, oldUser.UserOpenId) - } - } - app.Users = []core.AppUser{{UserOpenId: openId, UserName: userName}} - if err := core.SaveMultiAppConfig(multi); err != nil { - return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) - } + if err := syncLoginUserToProfile(config.ProfileName, config.AppID, openId, userName); err != nil { + _ = larkauth.RemoveStoredToken(config.AppID, openId) + return output.Errorf(output.ExitInternal, "internal", "failed to update login profile: %v", err) } output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf(msg.LoginSuccess, userName, openId)) return nil } +func syncLoginUserToProfile(profileName, appID, openID, userName string) error { + multi, err := core.LoadMultiAppConfig() + if err != nil { + return fmt.Errorf("load config: %w", err) + } + + app := findProfileByName(multi, profileName) + if app == nil { + return fmt.Errorf("profile %q not found in config", profileName) + } + + oldUsers := append([]core.AppUser(nil), app.Users...) + app.Users = []core.AppUser{{UserOpenId: openID, UserName: userName}} + if err := core.SaveMultiAppConfig(multi); err != nil { + return fmt.Errorf("save config: %w", err) + } + + for _, oldUser := range oldUsers { + if oldUser.UserOpenId != openID { + _ = larkauth.RemoveStoredToken(appID, oldUser.UserOpenId) + } + } + return nil +} + +func findProfileByName(multi *core.MultiAppConfig, profileName string) *core.AppConfig { + for i := range multi.Apps { + if multi.Apps[i].ProfileName() == profileName { + return &multi.Apps[i] + } + } + return nil +} + // collectScopesForDomains collects API scopes (from from_meta projects) and // shortcut scopes for the given domain names. func collectScopesForDomains(domains []string, identity string) []string { diff --git a/cmd/auth/login_config_test.go b/cmd/auth/login_config_test.go new file mode 100644 index 000000000..63f0095da --- /dev/null +++ b/cmd/auth/login_config_test.go @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "strings" + "testing" + + "github.com/larksuite/cli/internal/core" +) + +func setupLoginConfigDir(t *testing.T) { + t.Helper() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) +} + +func TestSyncLoginUserToProfile_UpdatesOnlyTargetProfile(t *testing.T) { + setupLoginConfigDir(t) + multi := &core.MultiAppConfig{ + CurrentApp: "target", + Apps: []core.AppConfig{ + { + Name: "target", + AppId: "app-target", + Users: []core.AppUser{{UserOpenId: "ou_old", UserName: "old"}}, + }, + { + Name: "other", + AppId: "app-other", + Users: []core.AppUser{{UserOpenId: "ou_other", UserName: "other"}}, + }, + }, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + if err := syncLoginUserToProfile("target", "app-target", "ou_new", "new-user"); err != nil { + t.Fatalf("syncLoginUserToProfile() error = %v", err) + } + + saved, err := core.LoadMultiAppConfig() + if err != nil { + t.Fatalf("LoadMultiAppConfig() error = %v", err) + } + if got := saved.Apps[0].Users; len(got) != 1 || got[0].UserOpenId != "ou_new" || got[0].UserName != "new-user" { + t.Fatalf("target users = %#v, want replaced login user", got) + } + if got := saved.Apps[1].Users; len(got) != 1 || got[0].UserOpenId != "ou_other" { + t.Fatalf("other users = %#v, want unchanged", got) + } +} + +func TestSyncLoginUserToProfile_ProfileNotFoundReturnsError(t *testing.T) { + setupLoginConfigDir(t) + multi := &core.MultiAppConfig{ + Apps: []core.AppConfig{{ + Name: "default", + AppId: "app-default", + }}, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + err := syncLoginUserToProfile("missing", "app-default", "ou_new", "new-user") + if err == nil { + t.Fatal("expected error for missing profile") + } + if !strings.Contains(err.Error(), `profile "missing" not found`) { + t.Fatalf("error = %v, want missing profile", err) + } +} diff --git a/cmd/auth/login_strict_test.go b/cmd/auth/login_strict_test.go new file mode 100644 index 000000000..82621d556 --- /dev/null +++ b/cmd/auth/login_strict_test.go @@ -0,0 +1,78 @@ +package auth + +import ( + "strings" + "testing" + + extcred "github.com/larksuite/cli/extension/credential" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" +) + +func TestAuthLogin_StrictModeBot_Blocked(t *testing.T) { + cfg := &core.CliConfig{ + AppID: "a", AppSecret: "s", + SupportedIdentities: uint8(extcred.SupportsBot), + } + f, _, _, _ := cmdutil.TestFactory(t, cfg) + + var called bool + cmd := NewCmdAuthLogin(f, func(opts *LoginOptions) error { + called = true + return nil + }) + cmd.SetArgs([]string{"--scope", "contact:user.base:readonly"}) + + err := cmd.Execute() + if called { + t.Error("runF should not be called in bot strict mode") + } + if err == nil { + t.Fatal("expected error in bot strict mode") + } + if !strings.Contains(err.Error(), "strict mode") { + t.Errorf("error should mention strict mode, got: %v", err) + } +} + +func TestAuthLogin_StrictModeUser_Allowed(t *testing.T) { + cfg := &core.CliConfig{ + AppID: "a", AppSecret: "s", + SupportedIdentities: uint8(extcred.SupportsUser), + } + f, _, _, _ := cmdutil.TestFactory(t, cfg) + + var called bool + cmd := NewCmdAuthLogin(f, func(opts *LoginOptions) error { + called = true + return nil + }) + cmd.SetArgs([]string{"--scope", "contact:user.base:readonly"}) + + err := cmd.Execute() + if !called { + t.Error("runF should be called in user strict mode") + } + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestAuthLogin_StrictModeOff_Allowed(t *testing.T) { + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) + + var called bool + cmd := NewCmdAuthLogin(f, func(opts *LoginOptions) error { + called = true + return nil + }) + cmd.SetArgs([]string{"--scope", "contact:user.base:readonly"}) + + err := cmd.Execute() + if !called { + t.Error("runF should be called when strict mode is off") + } + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} diff --git a/cmd/auth/logout.go b/cmd/auth/logout.go index 4914120c5..ac14d7e63 100644 --- a/cmd/auth/logout.go +++ b/cmd/auth/logout.go @@ -46,8 +46,8 @@ func authLogoutRun(opts *LogoutOptions) error { return nil } - app := &multi.Apps[0] - if len(app.Users) == 0 { + app := multi.CurrentAppConfig(f.Invocation.Profile) + if app == nil || len(app.Users) == 0 { fmt.Fprintln(f.IOStreams.ErrOut, "Not logged in.") return nil } diff --git a/cmd/bootstrap.go b/cmd/bootstrap.go new file mode 100644 index 000000000..841a88409 --- /dev/null +++ b/cmd/bootstrap.go @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "errors" + "io" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/spf13/pflag" +) + +// BootstrapInvocationContext extracts global invocation options before +// the real command tree is built, so provider-backed config resolution sees +// the correct profile from the start. +func BootstrapInvocationContext(args []string) (cmdutil.InvocationContext, error) { + var globals GlobalOptions + + fs := pflag.NewFlagSet("bootstrap", pflag.ContinueOnError) + fs.ParseErrorsAllowlist.UnknownFlags = true + fs.SetInterspersed(true) + fs.SetOutput(io.Discard) + RegisterGlobalFlags(fs, &globals) + + if err := fs.Parse(args); err != nil && !errors.Is(err, pflag.ErrHelp) { + return cmdutil.InvocationContext{}, err + } + return cmdutil.InvocationContext{Profile: globals.Profile}, nil +} diff --git a/cmd/bootstrap_test.go b/cmd/bootstrap_test.go new file mode 100644 index 000000000..aa5fd3de7 --- /dev/null +++ b/cmd/bootstrap_test.go @@ -0,0 +1,72 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import "testing" + +func TestBootstrapInvocationContext_ProfileFlag(t *testing.T) { + inv, err := BootstrapInvocationContext([]string{"--profile", "target", "auth", "status"}) + if err != nil { + t.Fatalf("BootstrapInvocationContext() error = %v", err) + } + if inv.Profile != "target" { + t.Fatalf("BootstrapInvocationContext() profile = %q, want %q", inv.Profile, "target") + } +} + +func TestBootstrapInvocationContext_ProfileEquals(t *testing.T) { + inv, err := BootstrapInvocationContext([]string{"auth", "status", "--profile=target"}) + if err != nil { + t.Fatalf("BootstrapInvocationContext() error = %v", err) + } + if inv.Profile != "target" { + t.Fatalf("BootstrapInvocationContext() profile = %q, want %q", inv.Profile, "target") + } +} + +func TestBootstrapInvocationContext_IgnoresUnknownFlags(t *testing.T) { + inv, err := BootstrapInvocationContext([]string{"auth", "status", "--verify", "--profile", "target"}) + if err != nil { + t.Fatalf("BootstrapInvocationContext() error = %v", err) + } + if inv.Profile != "target" { + t.Fatalf("BootstrapInvocationContext() profile = %q, want %q", inv.Profile, "target") + } +} + +func TestBootstrapInvocationContext_MissingProfileValue(t *testing.T) { + if _, err := BootstrapInvocationContext([]string{"auth", "status", "--profile"}); err == nil { + t.Fatal("BootstrapInvocationContext() error = nil, want non-nil") + } +} + +func TestBootstrapInvocationContext_HelpFlag(t *testing.T) { + inv, err := BootstrapInvocationContext([]string{"--help"}) + if err != nil { + t.Fatalf("--help should not error, got: %v", err) + } + if inv.Profile != "" { + t.Fatalf("profile = %q, want empty", inv.Profile) + } +} + +func TestBootstrapInvocationContext_ShortHelp(t *testing.T) { + inv, err := BootstrapInvocationContext([]string{"-h"}) + if err != nil { + t.Fatalf("-h should not error, got: %v", err) + } + if inv.Profile != "" { + t.Fatalf("profile = %q, want empty", inv.Profile) + } +} + +func TestBootstrapInvocationContext_HelpWithProfile(t *testing.T) { + inv, err := BootstrapInvocationContext([]string{"--profile", "target", "--help"}) + if err != nil { + t.Fatalf("--profile + --help should not error, got: %v", err) + } + if inv.Profile != "target" { + t.Fatalf("profile = %q, want %q", inv.Profile, "target") + } +} diff --git a/cmd/config/config.go b/cmd/config/config.go index 055ecfda6..275309609 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -21,12 +21,10 @@ func NewCmdConfig(f *cmdutil.Factory) *cobra.Command { cmd.AddCommand(NewCmdConfigRemove(f, nil)) cmd.AddCommand(NewCmdConfigShow(f, nil)) cmd.AddCommand(NewCmdConfigDefaultAs(f)) + cmd.AddCommand(NewCmdConfigStrictMode(f)) return cmd } func parseBrand(value string) core.LarkBrand { - if value == "lark" { - return core.BrandLark - } - return core.BrandFeishu + return core.ParseBrand(value) } diff --git a/cmd/config/config_test.go b/cmd/config/config_test.go index 65642781f..beb58c6c7 100644 --- a/cmd/config/config_test.go +++ b/cmd/config/config_test.go @@ -5,13 +5,22 @@ package config import ( "context" + "errors" "strings" "testing" "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/keychain" + "github.com/larksuite/cli/internal/output" ) +type noopConfigKeychain struct{} + +func (n *noopConfigKeychain) Get(service, account string) (string, error) { return "", nil } +func (n *noopConfigKeychain) Set(service, account, value string) error { return nil } +func (n *noopConfigKeychain) Remove(service, account string) error { return nil } + func TestConfigInitCmd_FlagParsing(t *testing.T) { f, _, _, _ := cmdutil.TestFactory(t, nil) f.IOStreams.In = strings.NewReader("secret123\n") @@ -56,6 +65,60 @@ func TestConfigShowCmd_FlagParsing(t *testing.T) { } } +func TestConfigShowRun_NotConfiguredReturnsStructuredError(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + f, _, _, _ := cmdutil.TestFactory(t, nil) + err := configShowRun(&ConfigShowOptions{Factory: f}) + if err == nil { + t.Fatal("expected error") + } + + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("error type = %T, want *output.ExitError", err) + } + if exitErr.Code != output.ExitValidation { + t.Fatalf("exit code = %d, want %d", exitErr.Code, output.ExitValidation) + } + if exitErr.Detail == nil || exitErr.Detail.Type != "config" || exitErr.Detail.Message != "not configured" { + t.Fatalf("detail = %#v, want config/not configured", exitErr.Detail) + } +} + +func TestConfigShowRun_NoActiveProfileReturnsStructuredError(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + multi := &core.MultiAppConfig{ + CurrentApp: "missing", + Apps: []core.AppConfig{{ + Name: "default", + AppId: "app-default", + AppSecret: core.PlainSecret("secret-default"), + Brand: core.BrandFeishu, + }}, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + f, _, _, _ := cmdutil.TestFactory(t, nil) + err := configShowRun(&ConfigShowOptions{Factory: f}) + if err == nil { + t.Fatal("expected error") + } + + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("error type = %T, want *output.ExitError", err) + } + if exitErr.Code != output.ExitValidation { + t.Fatalf("exit code = %d, want %d", exitErr.Code, output.ExitValidation) + } + if exitErr.Detail == nil || exitErr.Detail.Type != "config" || exitErr.Detail.Message != "no active profile" { + t.Fatalf("detail = %#v, want config/no active profile", exitErr.Detail) + } +} + func TestConfigInitCmd_LangFlag(t *testing.T) { f, _, _, _ := cmdutil.TestFactory(t, nil) @@ -157,3 +220,50 @@ func TestConfigRemoveCmd_FlagParsing(t *testing.T) { t.Fatal("expected factory to be preserved in options") } } + +func TestSaveAsProfile_RejectsProfileNameCollisionWithExistingAppID(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + existing := &core.MultiAppConfig{ + Apps: []core.AppConfig{ + { + Name: "prod", + AppId: "cli_prod", + AppSecret: core.PlainSecret("secret"), + Brand: core.BrandFeishu, + }, + }, + } + + err := saveAsProfile(existing, keychain.KeychainAccess(&noopConfigKeychain{}), "cli_prod", "app-new", core.PlainSecret("new-secret"), core.BrandLark, "en") + if err == nil { + t.Fatal("expected conflict error") + } + if !strings.Contains(err.Error(), "conflicts with existing appId") { + t.Fatalf("error = %v, want conflict with existing appId", err) + } +} + +func TestUpdateExistingProfileWithoutSecret_RejectsAppIDChange(t *testing.T) { + multi := &core.MultiAppConfig{ + CurrentApp: "prod", + Apps: []core.AppConfig{ + { + Name: "prod", + AppId: "app-old", + AppSecret: core.SecretInput{Ref: &core.SecretRef{Source: "keychain", ID: "appsecret:app-old"}}, + Brand: core.BrandFeishu, + Lang: "zh", + Users: []core.AppUser{{UserOpenId: "ou_1", UserName: "User"}}, + }, + }, + } + + err := updateExistingProfileWithoutSecret(multi, "", "app-new", core.BrandLark, "en") + if err == nil { + t.Fatal("expected error when changing app ID without a new secret") + } + if !strings.Contains(err.Error(), "App Secret") { + t.Fatalf("error = %v, want mention of App Secret", err) + } +} diff --git a/cmd/config/default_as.go b/cmd/config/default_as.go index 0600de5d1..25bf824f1 100644 --- a/cmd/config/default_as.go +++ b/cmd/config/default_as.go @@ -25,8 +25,13 @@ func NewCmdConfigDefaultAs(f *cmdutil.Factory) *cobra.Command { return output.ErrWithHint(output.ExitValidation, "config", "not configured", "run: lark-cli config init") } + app := multi.CurrentAppConfig(f.Invocation.Profile) + if app == nil { + return output.ErrWithHint(output.ExitValidation, "config", "no active profile", "run: lark-cli config init") + } + if len(args) == 0 { - current := multi.Apps[0].DefaultAs + current := app.DefaultAs if current == "" { current = "auto" } @@ -39,9 +44,9 @@ func NewCmdConfigDefaultAs(f *cmdutil.Factory) *cobra.Command { return output.ErrValidation("invalid identity type %q, valid values: user | bot | auto", value) } - multi.Apps[0].DefaultAs = value + app.DefaultAs = core.Identity(value) if err := core.SaveMultiAppConfig(multi); err != nil { - return fmt.Errorf("failed to save config: %w", err) + return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) } fmt.Fprintf(f.IOStreams.ErrOut, "Default identity set to: %s\n", value) return nil diff --git a/cmd/config/init.go b/cmd/config/init.go index 8ddff7613..3a3a7a84a 100644 --- a/cmd/config/init.go +++ b/cmd/config/init.go @@ -6,6 +6,7 @@ package config import ( "bufio" "context" + "errors" "fmt" "io" "strings" @@ -16,6 +17,7 @@ import ( "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/keychain" "github.com/larksuite/cli/internal/output" ) @@ -29,7 +31,8 @@ type ConfigInitOptions struct { Brand string New bool Lang string - langExplicit bool // true when --lang was explicitly passed + langExplicit bool // true when --lang was explicitly passed + ProfileName string // when set, create/update a named profile instead of replacing Apps[0] } // NewCmdConfigInit creates the config init subcommand. @@ -59,6 +62,7 @@ verification URL from its output.`, cmd.Flags().BoolVar(&opts.AppSecretStdin, "app-secret-stdin", false, "Read App Secret from stdin to avoid process list exposure") cmd.Flags().StringVar(&opts.Brand, "brand", "feishu", "feishu or lark (non-interactive, default feishu)") cmd.Flags().StringVar(&opts.Lang, "lang", "zh", "language for interactive prompts (zh or en)") + cmd.Flags().StringVar(&opts.ProfileName, "name", "", "create or update a named profile (append instead of replace)") return cmd } @@ -94,6 +98,110 @@ func saveAsOnlyApp(appId string, secret core.SecretInput, brand core.LarkBrand, return core.SaveMultiAppConfig(config) } +// saveInitConfig saves a new/updated app config, respecting --profile mode. +// With profileName: appends or updates the named profile (preserves other profiles). +// Without profileName: cleans up old config and saves as the only app. +func saveInitConfig(profileName string, existing *core.MultiAppConfig, f *cmdutil.Factory, appId string, secret core.SecretInput, brand core.LarkBrand, lang string) error { + if profileName != "" { + return saveAsProfile(existing, f.Keychain, profileName, appId, secret, brand, lang) + } + cleanupOldConfig(existing, f, appId) + return saveAsOnlyApp(appId, secret, brand, lang) +} + +// saveAsProfile appends or updates a named profile in the config. +// If a profile with the same name exists, it updates it; otherwise appends. +// When updating, cleans up old keychain secrets if AppId changed. +func saveAsProfile(existing *core.MultiAppConfig, kc keychain.KeychainAccess, profileName, appId string, secret core.SecretInput, brand core.LarkBrand, lang string) error { + multi := existing + if multi == nil { + multi = &core.MultiAppConfig{} + } + + if idx := findProfileIndexByName(multi, profileName); idx >= 0 { + // Clean up old keychain secret and user tokens if AppId changed + if multi.Apps[idx].AppId != appId { + core.RemoveSecretStore(multi.Apps[idx].AppSecret, kc) + for _, user := range multi.Apps[idx].Users { + auth.RemoveStoredToken(multi.Apps[idx].AppId, user.UserOpenId) + } + multi.Apps[idx].Users = []core.AppUser{} + } + // Update existing profile + multi.Apps[idx].AppId = appId + multi.Apps[idx].AppSecret = secret + multi.Apps[idx].Brand = brand + multi.Apps[idx].Lang = lang + } else { + if findAppIndexByAppID(multi, profileName) >= 0 { + return fmt.Errorf("profile name %q conflicts with existing appId", profileName) + } + // Append new profile + multi.Apps = append(multi.Apps, core.AppConfig{ + Name: profileName, + AppId: appId, + AppSecret: secret, + Brand: brand, + Lang: lang, + Users: []core.AppUser{}, + }) + } + return core.SaveMultiAppConfig(multi) +} + +func findProfileIndexByName(multi *core.MultiAppConfig, profileName string) int { + if multi == nil { + return -1 + } + for i := range multi.Apps { + if multi.Apps[i].Name == profileName { + return i + } + } + return -1 +} + +func findAppIndexByAppID(multi *core.MultiAppConfig, appID string) int { + if multi == nil { + return -1 + } + for i := range multi.Apps { + if multi.Apps[i].AppId == appID { + return i + } + } + return -1 +} + +func updateExistingProfileWithoutSecret(existing *core.MultiAppConfig, profileName, appID string, brand core.LarkBrand, lang string) error { + if existing == nil { + return output.ErrValidation("App Secret cannot be empty for new configuration") + } + + var app *core.AppConfig + if profileName != "" { + if idx := findProfileIndexByName(existing, profileName); idx >= 0 { + app = &existing.Apps[idx] + } else { + return output.ErrValidation("App Secret cannot be empty for new profile") + } + } else { + app = existing.CurrentAppConfig("") + if app == nil { + return output.ErrValidation("App Secret cannot be empty for new configuration") + } + } + + if app.AppId != appID { + return output.ErrValidation("App Secret cannot be empty when changing App ID") + } + + app.AppId = appID + app.Brand = brand + app.Lang = lang + return core.SaveMultiAppConfig(existing) +} + func configInitRun(opts *ConfigInitOptions) error { f := opts.Factory @@ -117,6 +225,13 @@ func configInitRun(opts *ConfigInitOptions) error { existing = nil // treat as empty } + // Validate --profile name if set + if opts.ProfileName != "" { + if err := core.ValidateProfileName(opts.ProfileName); err != nil { + return output.ErrValidation("%v", err) + } + } + // Mode 1: Non-interactive if opts.AppID != "" && opts.appSecret != "" { brand := parseBrand(opts.Brand) @@ -124,8 +239,7 @@ func configInitRun(opts *ConfigInitOptions) error { if err != nil { return output.Errorf(output.ExitInternal, "internal", "%v", err) } - cleanupOldConfig(existing, f, opts.AppID) - if err := saveAsOnlyApp(opts.AppID, secret, brand, opts.Lang); err != nil { + if err := saveInitConfig(opts.ProfileName, existing, f, opts.AppID, secret, brand, opts.Lang); err != nil { return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) } output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Configuration saved to %s", core.GetConfigPath())) @@ -136,8 +250,10 @@ func configInitRun(opts *ConfigInitOptions) error { // For interactive modes, prompt language selection if --lang was not explicitly set if f.IOStreams.IsTerminal && !opts.langExplicit && !opts.hasAnyNonInteractiveFlag() { savedLang := "" - if existing != nil && len(existing.Apps) > 0 { - savedLang = existing.Apps[0].Lang + if existing != nil { + if app := existing.CurrentAppConfig(""); app != nil { + savedLang = app.Lang + } } lang, err := promptLangSelection(savedLang) if err != nil { @@ -165,8 +281,7 @@ func configInitRun(opts *ConfigInitOptions) error { if err != nil { return output.Errorf(output.ExitInternal, "internal", "%v", err) } - cleanupOldConfig(existing, f, result.AppID) - if err := saveAsOnlyApp(result.AppID, secret, result.Brand, opts.Lang); err != nil { + if err := saveInitConfig(opts.ProfileName, existing, f, result.AppID, secret, result.Brand, opts.Lang); err != nil { return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) } output.PrintJson(f.IOStreams.Out, map[string]interface{}{"appId": result.AppID, "appSecret": "****", "brand": result.Brand}) @@ -191,21 +306,17 @@ func configInitRun(opts *ConfigInitOptions) error { if err != nil { return output.Errorf(output.ExitInternal, "internal", "%v", err) } - cleanupOldConfig(existing, f, result.AppID) - if err := saveAsOnlyApp(result.AppID, secret, result.Brand, opts.Lang); err != nil { + if err := saveInitConfig(opts.ProfileName, existing, f, result.AppID, secret, result.Brand, opts.Lang); err != nil { return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) } } else if result.Mode == "existing" && result.AppID != "" { // Existing app with unchanged secret — update app ID and brand only - if existing != nil && len(existing.Apps) > 0 { - existing.Apps[0].AppId = result.AppID - existing.Apps[0].Brand = result.Brand - existing.Apps[0].Lang = opts.Lang - if err := core.SaveMultiAppConfig(existing); err != nil { - return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) + if err := updateExistingProfileWithoutSecret(existing, opts.ProfileName, result.AppID, result.Brand, opts.Lang); err != nil { + var exitErr *output.ExitError + if errors.As(err, &exitErr) { + return err } - } else { - return output.ErrValidation("App Secret cannot be empty for new configuration") + return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) } } else { return output.ErrValidation("App ID and App Secret cannot be empty") @@ -224,8 +335,8 @@ func configInitRun(opts *ConfigInitOptions) error { // Mode 5: Legacy interactive (readline fallback) firstApp := (*core.AppConfig)(nil) - if existing != nil && len(existing.Apps) > 0 { - firstApp = &existing.Apps[0] + if existing != nil { + firstApp = existing.CurrentAppConfig("") } reader := bufio.NewReader(f.IOStreams.In) @@ -296,8 +407,7 @@ func configInitRun(opts *ConfigInitOptions) error { if err != nil { return output.Errorf(output.ExitInternal, "internal", "%v", err) } - cleanupOldConfig(existing, f, resolvedAppId) - if err := saveAsOnlyApp(resolvedAppId, storedSecret, parseBrand(resolvedBrand), opts.Lang); err != nil { + if err := saveInitConfig(opts.ProfileName, existing, f, resolvedAppId, storedSecret, parseBrand(resolvedBrand), opts.Lang); err != nil { return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) } output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Configuration saved to %s", core.GetConfigPath())) diff --git a/cmd/config/init_interactive.go b/cmd/config/init_interactive.go index 0172079d4..f138a215c 100644 --- a/cmd/config/init_interactive.go +++ b/cmd/config/init_interactive.go @@ -61,8 +61,8 @@ func runExistingAppForm(f *cmdutil.Factory, msg *initMsg) (*configInitResult, er // Load existing config for defaults existing, _ := core.LoadMultiAppConfig() var firstApp *core.AppConfig - if existing != nil && len(existing.Apps) > 0 { - firstApp = &existing.Apps[0] + if existing != nil { + firstApp = existing.CurrentAppConfig("") } var appID, appSecret, brand string diff --git a/cmd/config/show.go b/cmd/config/show.go index cdee9e347..6e7abb9fe 100644 --- a/cmd/config/show.go +++ b/cmd/config/show.go @@ -4,7 +4,9 @@ package config import ( + "errors" "fmt" + "os" "strings" "github.com/larksuite/cli/internal/cmdutil" @@ -40,12 +42,19 @@ func configShowRun(opts *ConfigShowOptions) error { f := opts.Factory config, err := core.LoadMultiAppConfig() - if err != nil || config == nil || len(config.Apps) == 0 { - fmt.Fprintf(f.IOStreams.ErrOut, "Not configured yet. Config file path: %s\n", core.GetConfigPath()) - fmt.Fprintln(f.IOStreams.ErrOut, "Run `lark-cli config init` to initialize.") - return nil + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return output.ErrWithHint(output.ExitValidation, "config", "not configured", "run: lark-cli config init") + } + return output.Errorf(output.ExitValidation, "config", "failed to load config: %v", err) + } + if config == nil || len(config.Apps) == 0 { + return output.ErrWithHint(output.ExitValidation, "config", "not configured", "run: lark-cli config init") + } + app := config.CurrentAppConfig(f.Invocation.Profile) + if app == nil { + return output.ErrWithHint(output.ExitValidation, "config", "no active profile", "run: lark-cli profile list") } - app := config.Apps[0] users := "(no logged-in users)" if len(app.Users) > 0 { var userStrs []string @@ -55,6 +64,7 @@ func configShowRun(opts *ConfigShowOptions) error { users = strings.Join(userStrs, ", ") } output.PrintJson(f.IOStreams.Out, map[string]interface{}{ + "profile": app.ProfileName(), "appId": app.AppId, "appSecret": "****", "brand": app.Brand, diff --git a/cmd/config/strict_mode.go b/cmd/config/strict_mode.go new file mode 100644 index 000000000..09e81cde7 --- /dev/null +++ b/cmd/config/strict_mode.go @@ -0,0 +1,146 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package config + +import ( + "context" + "fmt" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/output" + "github.com/spf13/cobra" +) + +// NewCmdConfigStrictMode creates the "config strict-mode" subcommand. +func NewCmdConfigStrictMode(f *cmdutil.Factory) *cobra.Command { + var global bool + var reset bool + + cmd := &cobra.Command{ + Use: "strict-mode [bot|user|off]", + Short: "View or set strict mode (identity restriction policy)", + Long: `View or set strict mode (identity restriction policy). + +Without arguments, shows the current strict mode status and its source. +Pass "bot", "user", or "off" to set strict mode. +Use --global to set at the global level. +Use --reset to clear the profile-level setting (inherit global). + +Modes: + bot — only bot identity is allowed, user commands are hidden + user — only user identity is allowed, bot commands are hidden + off — no restriction (default) + +WARNING: Strict mode is a security policy set by the administrator. +AI agents are strictly prohibited from modifying this setting.`, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + multi, err := core.LoadMultiAppConfig() + if err != nil { + return output.ErrWithHint(output.ExitValidation, "config", "not configured", "run: lark-cli config init") + } + + if reset { + app := multi.CurrentAppConfig(f.Invocation.Profile) + if app == nil { + return output.ErrWithHint(output.ExitValidation, "config", "no active profile", "run: lark-cli config init") + } + return resetStrictMode(f, multi, app, global, args) + } + if len(args) == 0 { + app := multi.CurrentAppConfig(f.Invocation.Profile) + if app == nil { + return output.ErrWithHint(output.ExitValidation, "config", "no active profile", "run: lark-cli config init") + } + return showStrictMode(cmd.Context(), f, multi, app) + } + app := multi.CurrentAppConfig(f.Invocation.Profile) + if !global && app == nil { + return output.ErrWithHint(output.ExitValidation, "config", "no active profile", "run: lark-cli config init") + } + return setStrictMode(f, multi, app, args[0], global) + }, + } + + cmd.Flags().BoolVar(&global, "global", false, "set at global level (applies to all profiles)") + cmd.Flags().BoolVar(&reset, "reset", false, "reset profile setting to inherit global") + + return cmd +} + +func resetStrictMode(f *cmdutil.Factory, multi *core.MultiAppConfig, app *core.AppConfig, global bool, args []string) error { + if global { + return output.ErrValidation("--reset cannot be used with --global") + } + if len(args) > 0 { + return output.ErrValidation("--reset cannot be used with a value argument") + } + app.StrictMode = nil + if err := core.SaveMultiAppConfig(multi); err != nil { + return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) + } + fmt.Fprintln(f.IOStreams.ErrOut, "Profile strict-mode reset (inherits global)") + return nil +} + +func showStrictMode(ctx context.Context, f *cmdutil.Factory, multi *core.MultiAppConfig, app *core.AppConfig) error { + // Runtime effective mode from credential provider chain is the source of truth. + runtime := f.ResolveStrictMode(ctx) + configMode, configSource := resolveStrictModeStatus(multi, app) + + if runtime != configMode { + fmt.Fprintf(f.IOStreams.Out, "strict-mode: %s (source: credential provider)\n", runtime) + return nil + } + fmt.Fprintf(f.IOStreams.Out, "strict-mode: %s (source: %s)\n", configMode, configSource) + return nil +} + +func setStrictMode(f *cmdutil.Factory, multi *core.MultiAppConfig, app *core.AppConfig, value string, global bool) error { + mode := core.StrictMode(value) + switch mode { + case core.StrictModeBot, core.StrictModeUser, core.StrictModeOff: + default: + return output.ErrValidation("invalid value %q, valid values: bot | user | off", value) + } + + if global { + multi.StrictMode = mode + for _, a := range multi.Apps { + if a.StrictMode != nil && *a.StrictMode != mode { + fmt.Fprintf(f.IOStreams.ErrOut, + "Warning: profile %q has strict-mode explicitly set to %q, "+ + "which overrides the global setting. "+ + "Use --reset in that profile to inherit global.\n", + a.ProfileName(), *a.StrictMode) + } + } + } else { + if app == nil { + return output.ErrWithHint(output.ExitValidation, "config", "no active profile", "run: lark-cli config init") + } + app.StrictMode = &mode + } + + if err := core.SaveMultiAppConfig(multi); err != nil { + return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) + } + scope := "profile" + if global { + scope = "global" + } + fmt.Fprintf(f.IOStreams.ErrOut, "Strict mode set to %s (%s)\n", mode, scope) + return nil +} + +func resolveStrictModeStatus(multi *core.MultiAppConfig, app *core.AppConfig) (core.StrictMode, string) { + if app != nil && app.StrictMode != nil { + return *app.StrictMode, fmt.Sprintf("profile %q", app.ProfileName()) + } + if multi.StrictMode.IsActive() { + return multi.StrictMode, "global" + } + return core.StrictModeOff, "global (default)" +} diff --git a/cmd/config/strict_mode_test.go b/cmd/config/strict_mode_test.go new file mode 100644 index 000000000..7b930415e --- /dev/null +++ b/cmd/config/strict_mode_test.go @@ -0,0 +1,164 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package config + +import ( + "strings" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" +) + +func setupStrictModeTestConfig(t *testing.T) { + t.Helper() + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + multi := &core.MultiAppConfig{ + Apps: []core.AppConfig{{ + AppId: "test-app", + AppSecret: core.PlainSecret("secret"), + Brand: core.BrandFeishu, + }}, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatal(err) + } +} + +func TestStrictMode_Show_Default(t *testing.T) { + setupStrictModeTestConfig(t) + f, stdout, _, _ := cmdutil.TestFactory(t, &core.CliConfig{AppID: "test-app", AppSecret: "secret"}) + cmd := NewCmdConfigStrictMode(f) + cmd.SetArgs([]string{}) + if err := cmd.Execute(); err != nil { + t.Fatal(err) + } + if !strings.Contains(stdout.String(), "off") { + t.Errorf("expected 'off' in output, got: %s", stdout.String()) + } +} + +func TestStrictMode_SetBot_Profile(t *testing.T) { + setupStrictModeTestConfig(t) + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{AppID: "test-app", AppSecret: "secret"}) + cmd := NewCmdConfigStrictMode(f) + cmd.SetArgs([]string{"bot"}) + if err := cmd.Execute(); err != nil { + t.Fatal(err) + } + multi, _ := core.LoadMultiAppConfig() + app := multi.CurrentAppConfig("") + if app.StrictMode == nil || *app.StrictMode != core.StrictModeBot { + t.Error("expected StrictMode=bot on profile") + } +} + +func TestStrictMode_SetUser_Profile(t *testing.T) { + setupStrictModeTestConfig(t) + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{AppID: "test-app", AppSecret: "secret"}) + cmd := NewCmdConfigStrictMode(f) + cmd.SetArgs([]string{"user"}) + if err := cmd.Execute(); err != nil { + t.Fatal(err) + } + multi, _ := core.LoadMultiAppConfig() + app := multi.CurrentAppConfig("") + if app.StrictMode == nil || *app.StrictMode != core.StrictModeUser { + t.Error("expected StrictMode=user on profile") + } +} + +func TestStrictMode_SetOff_Profile(t *testing.T) { + setupStrictModeTestConfig(t) + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{AppID: "test-app", AppSecret: "secret"}) + cmd := NewCmdConfigStrictMode(f) + cmd.SetArgs([]string{"bot"}) + cmd.Execute() + cmd = NewCmdConfigStrictMode(f) + cmd.SetArgs([]string{"off"}) + if err := cmd.Execute(); err != nil { + t.Fatal(err) + } + multi, _ := core.LoadMultiAppConfig() + app := multi.CurrentAppConfig("") + if app.StrictMode == nil || *app.StrictMode != core.StrictModeOff { + t.Error("expected StrictMode=off on profile") + } +} + +func TestStrictMode_SetBot_Global(t *testing.T) { + setupStrictModeTestConfig(t) + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{AppID: "test-app", AppSecret: "secret"}) + cmd := NewCmdConfigStrictMode(f) + cmd.SetArgs([]string{"bot", "--global"}) + if err := cmd.Execute(); err != nil { + t.Fatal(err) + } + multi, _ := core.LoadMultiAppConfig() + if multi.StrictMode != core.StrictModeBot { + t.Error("expected global StrictMode=bot") + } +} + +func TestStrictMode_SetGlobal_DoesNotRequireActiveProfile(t *testing.T) { + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + multi := &core.MultiAppConfig{ + CurrentApp: "missing-profile", + Apps: []core.AppConfig{{ + Name: "default", + AppId: "test-app", + AppSecret: core.PlainSecret("secret"), + Brand: core.BrandFeishu, + }}, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatal(err) + } + + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{AppID: "test-app", AppSecret: "secret"}) + cmd := NewCmdConfigStrictMode(f) + cmd.SetArgs([]string{"bot", "--global"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("Execute() error = %v", err) + } + + saved, err := core.LoadMultiAppConfig() + if err != nil { + t.Fatalf("LoadMultiAppConfig() error = %v", err) + } + if saved.StrictMode != core.StrictModeBot { + t.Fatalf("StrictMode = %q, want %q", saved.StrictMode, core.StrictModeBot) + } +} + +func TestStrictMode_Reset(t *testing.T) { + setupStrictModeTestConfig(t) + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{AppID: "test-app", AppSecret: "secret"}) + cmd := NewCmdConfigStrictMode(f) + cmd.SetArgs([]string{"bot"}) + cmd.Execute() + cmd = NewCmdConfigStrictMode(f) + cmd.SetArgs([]string{"--reset"}) + if err := cmd.Execute(); err != nil { + t.Fatal(err) + } + multi, _ := core.LoadMultiAppConfig() + app := multi.CurrentAppConfig("") + if app.StrictMode != nil { + t.Errorf("expected nil StrictMode after reset, got %v", *app.StrictMode) + } +} + +func TestStrictMode_InvalidValue(t *testing.T) { + setupStrictModeTestConfig(t) + f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{AppID: "test-app", AppSecret: "secret"}) + cmd := NewCmdConfigStrictMode(f) + cmd.SetArgs([]string{"on"}) + err := cmd.Execute() + if err == nil { + t.Error("expected error for invalid value 'on'") + } +} diff --git a/cmd/global_flags.go b/cmd/global_flags.go new file mode 100644 index 000000000..d634cc4fd --- /dev/null +++ b/cmd/global_flags.go @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import "github.com/spf13/pflag" + +// GlobalOptions are the root-level flags shared by bootstrap parsing and the +// actual Cobra command tree. +type GlobalOptions struct { + Profile string +} + +// RegisterGlobalFlags registers the root-level persistent flags. +func RegisterGlobalFlags(fs *pflag.FlagSet, opts *GlobalOptions) { + fs.StringVar(&opts.Profile, "profile", "", "use a specific profile") +} diff --git a/cmd/profile/add.go b/cmd/profile/add.go new file mode 100644 index 000000000..d84e1f504 --- /dev/null +++ b/cmd/profile/add.go @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package profile + +import ( + "bufio" + "errors" + "fmt" + "os" + "strings" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/output" +) + +// NewCmdProfileAdd creates the profile add subcommand. +func NewCmdProfileAdd(f *cmdutil.Factory) *cobra.Command { + var ( + name string + appID string + appSecretStdin bool + brand string + lang string + use bool + ) + + cmd := &cobra.Command{ + Use: "add", + Short: "Add a new profile", + RunE: func(cmd *cobra.Command, args []string) error { + return profileAddRun(f, name, appID, appSecretStdin, brand, lang, use) + }, + } + + cmd.Flags().StringVar(&name, "name", "", "profile name (required)") + cmd.Flags().StringVar(&appID, "app-id", "", "App ID (required)") + cmd.Flags().BoolVar(&appSecretStdin, "app-secret-stdin", false, "read App Secret from stdin") + cmd.Flags().StringVar(&brand, "brand", "feishu", "feishu or lark") + cmd.Flags().StringVar(&lang, "lang", "zh", "language for interactive prompts (zh or en)") + cmd.Flags().BoolVar(&use, "use", false, "switch to this profile after adding") + + _ = cmd.MarkFlagRequired("name") + _ = cmd.MarkFlagRequired("app-id") + + return cmd +} + +func profileAddRun(f *cmdutil.Factory, name, appID string, appSecretStdin bool, brand, lang string, useAfter bool) error { + if err := core.ValidateProfileName(name); err != nil { + return output.ErrValidation("%v", err) + } + + // Read secret from stdin + if !appSecretStdin { + return output.ErrValidation("app secret must be provided via stdin: use --app-secret-stdin and pipe the secret") + } + scanner := bufio.NewScanner(f.IOStreams.In) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return output.ErrValidation("failed to read secret from stdin: %v", err) + } + return output.ErrValidation("stdin is empty, expected app secret") + } + appSecret := strings.TrimSpace(scanner.Text()) + if appSecret == "" { + return output.ErrValidation("app secret read from stdin is empty") + } + + // Load or create config + multi, err := core.LoadMultiAppConfig() + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return output.Errorf(output.ExitInternal, "internal", "failed to load config: %v", err) + } + multi = &core.MultiAppConfig{} + } + + // Check name uniqueness + if multi.FindApp(name) != nil { + return output.ErrValidation("profile %q already exists", name) + } + + // Check app-id uniqueness — keychain stores secrets by appId, so + // multiple profiles sharing the same appId would collide on credentials. + for _, a := range multi.Apps { + if a.AppId == appID { + return output.ErrValidation("app-id %q is already used by profile %q; each profile must have a unique app-id", appID, a.ProfileName()) + } + } + + // Store secret securely + secret, err := core.ForStorage(appID, core.PlainSecret(appSecret), f.Keychain) + if err != nil { + return output.Errorf(output.ExitInternal, "internal", "%v", err) + } + + parsedBrand := core.ParseBrand(brand) + + // Capture current profile before appending (avoid setting PreviousApp to self) + var previousName string + if useAfter { + if currentApp := multi.CurrentAppConfig(""); currentApp != nil { + previousName = currentApp.ProfileName() + } + } + + // Append profile + multi.Apps = append(multi.Apps, core.AppConfig{ + Name: name, + AppId: appID, + AppSecret: secret, + Brand: parsedBrand, + Lang: lang, + Users: []core.AppUser{}, + }) + + if useAfter { + if previousName != "" { + multi.PreviousApp = previousName + } + multi.CurrentApp = name + } + + if err := core.SaveMultiAppConfig(multi); err != nil { + return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) + } + + output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Profile %q added (%s, %s)", name, appID, parsedBrand)) + if useAfter { + output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Switched to profile %q", name)) + } + return nil +} diff --git a/cmd/profile/list.go b/cmd/profile/list.go new file mode 100644 index 000000000..dbe98c1e7 --- /dev/null +++ b/cmd/profile/list.go @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package profile + +import ( + "errors" + "os" + + "github.com/spf13/cobra" + + larkauth "github.com/larksuite/cli/internal/auth" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/output" +) + +// profileListItem is the JSON output for a single profile entry. +type profileListItem struct { + Name string `json:"name"` + AppID string `json:"appId"` + Brand core.LarkBrand `json:"brand"` + Active bool `json:"active"` + User string `json:"user,omitempty"` + TokenStatus string `json:"tokenStatus,omitempty"` +} + +// NewCmdProfileList creates the profile list subcommand. +func NewCmdProfileList(f *cmdutil.Factory) *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "List all profiles", + RunE: func(cmd *cobra.Command, args []string) error { + return profileListRun(f) + }, + } + return cmd +} + +func profileListRun(f *cmdutil.Factory) error { + multi, err := core.LoadMultiAppConfig() + if err != nil { + if errors.Is(err, os.ErrNotExist) { + output.PrintJson(f.IOStreams.Out, []profileListItem{}) + return nil + } + return output.Errorf(output.ExitValidation, "config", "failed to load config: %v", err) + } + if multi == nil || len(multi.Apps) == 0 { + output.PrintJson(f.IOStreams.Out, []profileListItem{}) + return nil + } + + // Intentionally uses "" to show the persistent active profile, not the ephemeral --profile override. + currentApp := multi.CurrentAppConfig("") + currentName := "" + if currentApp != nil { + currentName = currentApp.ProfileName() + } + + items := make([]profileListItem, 0, len(multi.Apps)) + for i := range multi.Apps { + app := &multi.Apps[i] + name := app.ProfileName() + + item := profileListItem{ + Name: name, + AppID: app.AppId, + Brand: app.Brand, + Active: name == currentName, + } + + if len(app.Users) > 0 { + item.User = app.Users[0].UserName + stored := larkauth.GetStoredToken(app.AppId, app.Users[0].UserOpenId) + if stored != nil { + item.TokenStatus = larkauth.TokenStatus(stored) + } + } + + items = append(items, item) + } + output.PrintJson(f.IOStreams.Out, items) + return nil +} diff --git a/cmd/profile/profile.go b/cmd/profile/profile.go new file mode 100644 index 000000000..2216a4f39 --- /dev/null +++ b/cmd/profile/profile.go @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package profile + +import ( + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdutil" +) + +// NewCmdProfile creates the profile command with subcommands. +func NewCmdProfile(f *cmdutil.Factory) *cobra.Command { + cmd := &cobra.Command{ + Use: "profile", + Short: "Manage configuration profiles", + } + cmdutil.DisableAuthCheck(cmd) + cmdutil.SetTips(cmd, []string{ + "AI agents: Do NOT switch or remove profiles unless the user explicitly asks.", + }) + + cmd.AddCommand(NewCmdProfileList(f)) + cmd.AddCommand(NewCmdProfileUse(f)) + cmd.AddCommand(NewCmdProfileAdd(f)) + cmd.AddCommand(NewCmdProfileRemove(f)) + cmd.AddCommand(NewCmdProfileRename(f)) + return cmd +} diff --git a/cmd/profile/profile_test.go b/cmd/profile/profile_test.go new file mode 100644 index 000000000..83667d554 --- /dev/null +++ b/cmd/profile/profile_test.go @@ -0,0 +1,371 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package profile + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/vfs" +) + +type failRenameFS struct { + vfs.OsFs + err error +} + +func (fs *failRenameFS) Rename(oldpath, newpath string) error { + return fs.err +} + +func setupProfileConfigDir(t *testing.T) string { + t.Helper() + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + return dir +} + +func TestProfileAddRun_InvalidExistingConfigReturnsError(t *testing.T) { + dir := setupProfileConfigDir(t) + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte("{invalid json"), 0600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + f, _, _, _ := cmdutil.TestFactory(t, nil) + f.IOStreams.In = strings.NewReader("secret\n") + + err := profileAddRun(f, "test", "app-test", true, "feishu", "zh", false) + if err == nil { + t.Fatal("expected error for invalid existing config") + } + if !strings.Contains(err.Error(), "failed to load config") { + t.Fatalf("error = %v, want failed to load config", err) + } +} + +func TestProfileAddRun_UseAfterUpdatesCurrentAndPrevious(t *testing.T) { + setupProfileConfigDir(t) + multi := &core.MultiAppConfig{ + CurrentApp: "default", + Apps: []core.AppConfig{ + {Name: "default", AppId: "app-default", AppSecret: core.PlainSecret("secret-default"), Brand: core.BrandFeishu}, + }, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + f, _, _, _ := cmdutil.TestFactory(t, nil) + f.IOStreams.In = strings.NewReader("secret-new\n") + + if err := profileAddRun(f, "target", "app-target", true, "lark", "en", true); err != nil { + t.Fatalf("profileAddRun() error = %v", err) + } + + saved, err := core.LoadMultiAppConfig() + if err != nil { + t.Fatalf("LoadMultiAppConfig() error = %v", err) + } + if saved.CurrentApp != "target" { + t.Fatalf("CurrentApp = %q, want %q", saved.CurrentApp, "target") + } + if saved.PreviousApp != "default" { + t.Fatalf("PreviousApp = %q, want %q", saved.PreviousApp, "default") + } + if len(saved.Apps) != 2 { + t.Fatalf("len(Apps) = %d, want 2", len(saved.Apps)) + } +} + +func TestProfileRemoveRun_RemovesCurrentProfileAndSwitchesToFirstRemaining(t *testing.T) { + setupProfileConfigDir(t) + multi := &core.MultiAppConfig{ + CurrentApp: "target", + PreviousApp: "default", + Apps: []core.AppConfig{ + {Name: "default", AppId: "app-default", AppSecret: core.PlainSecret("secret-default"), Brand: core.BrandFeishu}, + {Name: "target", AppId: "app-target", AppSecret: core.PlainSecret("secret-target"), Brand: core.BrandLark}, + }, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + f, _, _, _ := cmdutil.TestFactory(t, nil) + if err := profileRemoveRun(f, "target"); err != nil { + t.Fatalf("profileRemoveRun() error = %v", err) + } + + saved, err := core.LoadMultiAppConfig() + if err != nil { + t.Fatalf("LoadMultiAppConfig() error = %v", err) + } + if saved.CurrentApp != "default" { + t.Fatalf("CurrentApp = %q, want %q", saved.CurrentApp, "default") + } + if saved.PreviousApp != "default" { + t.Fatalf("PreviousApp = %q, want %q", saved.PreviousApp, "default") + } + if len(saved.Apps) != 1 || saved.Apps[0].ProfileName() != "default" { + t.Fatalf("remaining apps = %#v, want only default", saved.Apps) + } +} + +func TestProfileRenameRun_UpdatesCurrentAndPreviousReferences(t *testing.T) { + setupProfileConfigDir(t) + multi := &core.MultiAppConfig{ + CurrentApp: "old", + PreviousApp: "old", + Apps: []core.AppConfig{{ + Name: "old", + AppId: "app-old", + AppSecret: core.PlainSecret("secret-old"), + Brand: core.BrandFeishu, + }}, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + f, _, _, _ := cmdutil.TestFactory(t, nil) + if err := profileRenameRun(f, "old", "new"); err != nil { + t.Fatalf("profileRenameRun() error = %v", err) + } + + saved, err := core.LoadMultiAppConfig() + if err != nil { + t.Fatalf("LoadMultiAppConfig() error = %v", err) + } + if saved.CurrentApp != "new" { + t.Fatalf("CurrentApp = %q, want %q", saved.CurrentApp, "new") + } + if saved.PreviousApp != "new" { + t.Fatalf("PreviousApp = %q, want %q", saved.PreviousApp, "new") + } + if saved.Apps[0].ProfileName() != "new" { + t.Fatalf("ProfileName() = %q, want %q", saved.Apps[0].ProfileName(), "new") + } +} + +func TestProfileRenameRun_AllowsRenameToOwnAppID(t *testing.T) { + setupProfileConfigDir(t) + multi := &core.MultiAppConfig{ + CurrentApp: "old", + PreviousApp: "old", + Apps: []core.AppConfig{{ + Name: "old", + AppId: "app-old", + AppSecret: core.PlainSecret("secret-old"), + Brand: core.BrandFeishu, + }}, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + f, _, _, _ := cmdutil.TestFactory(t, nil) + if err := profileRenameRun(f, "old", "app-old"); err != nil { + t.Fatalf("profileRenameRun() error = %v", err) + } + + saved, err := core.LoadMultiAppConfig() + if err != nil { + t.Fatalf("LoadMultiAppConfig() error = %v", err) + } + if saved.CurrentApp != "app-old" { + t.Fatalf("CurrentApp = %q, want %q", saved.CurrentApp, "app-old") + } + if saved.PreviousApp != "app-old" { + t.Fatalf("PreviousApp = %q, want %q", saved.PreviousApp, "app-old") + } + if saved.Apps[0].Name != "app-old" { + t.Fatalf("Name = %q, want %q", saved.Apps[0].Name, "app-old") + } +} + +func TestProfileUseRun_ToggleBackUsesPreviousProfile(t *testing.T) { + setupProfileConfigDir(t) + multi := &core.MultiAppConfig{ + CurrentApp: "default", + PreviousApp: "target", + Apps: []core.AppConfig{ + {Name: "default", AppId: "app-default", AppSecret: core.PlainSecret("secret-default"), Brand: core.BrandFeishu}, + {Name: "target", AppId: "app-target", AppSecret: core.PlainSecret("secret-target"), Brand: core.BrandLark}, + }, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + f, _, _, _ := cmdutil.TestFactory(t, nil) + if err := profileUseRun(f, "-"); err != nil { + t.Fatalf("profileUseRun() error = %v", err) + } + + saved, err := core.LoadMultiAppConfig() + if err != nil { + t.Fatalf("LoadMultiAppConfig() error = %v", err) + } + if saved.CurrentApp != "target" { + t.Fatalf("CurrentApp = %q, want %q", saved.CurrentApp, "target") + } + if saved.PreviousApp != "default" { + t.Fatalf("PreviousApp = %q, want %q", saved.PreviousApp, "default") + } +} + +func TestProfileListRun_OutputsProfiles(t *testing.T) { + setupProfileConfigDir(t) + multi := &core.MultiAppConfig{ + CurrentApp: "default", + Apps: []core.AppConfig{ + {Name: "default", AppId: "app-default", AppSecret: core.PlainSecret("secret-default"), Brand: core.BrandFeishu}, + {Name: "target", AppId: "app-target", AppSecret: core.PlainSecret("secret-target"), Brand: core.BrandLark}, + }, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + f, stdout, _, _ := cmdutil.TestFactory(t, nil) + if err := profileListRun(f); err != nil { + t.Fatalf("profileListRun() error = %v", err) + } + + var got []profileListItem + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatalf("Unmarshal() error = %v; output=%s", err, stdout.String()) + } + if len(got) != 2 { + t.Fatalf("len(got) = %d, want 2", len(got)) + } + if got[0].Name != "default" || !got[0].Active { + t.Fatalf("got[0] = %#v, want active default profile", got[0]) + } + if got[1].Name != "target" || got[1].Active { + t.Fatalf("got[1] = %#v, want inactive target profile", got[1]) + } +} + +func TestProfileListRun_NotConfiguredReturnsEmptyList(t *testing.T) { + setupProfileConfigDir(t) + + f, stdout, stderr, _ := cmdutil.TestFactory(t, nil) + if err := profileListRun(f); err != nil { + t.Fatalf("profileListRun() error = %v", err) + } + + var got []profileListItem + if err := json.Unmarshal(stdout.Bytes(), &got); err != nil { + t.Fatalf("Unmarshal() error = %v; output=%s", err, stdout.String()) + } + if len(got) != 0 { + t.Fatalf("len(got) = %d, want 0", len(got)) + } + if stderr.Len() != 0 { + t.Fatalf("stderr = %q, want empty", stderr.String()) + } +} + +func TestProfileRemoveRun_SaveFailureReturnsStructuredError(t *testing.T) { + setupProfileConfigDir(t) + multi := &core.MultiAppConfig{ + CurrentApp: "target", + Apps: []core.AppConfig{ + {Name: "default", AppId: "app-default", AppSecret: core.PlainSecret("secret-default"), Brand: core.BrandFeishu}, + {Name: "target", AppId: "app-target", AppSecret: core.PlainSecret("secret-target"), Brand: core.BrandLark}, + }, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + restoreFS := vfs.DefaultFS + vfs.DefaultFS = &failRenameFS{err: errors.New("rename boom")} + t.Cleanup(func() { vfs.DefaultFS = restoreFS }) + + f, _, _, _ := cmdutil.TestFactory(t, nil) + err := profileRemoveRun(f, "target") + if err == nil { + t.Fatal("expected save error") + } + assertInternalExitError(t, err, "failed to save config") +} + +func TestProfileRenameRun_SaveFailureReturnsStructuredError(t *testing.T) { + setupProfileConfigDir(t) + multi := &core.MultiAppConfig{ + CurrentApp: "old", + Apps: []core.AppConfig{{ + Name: "old", + AppId: "app-old", + AppSecret: core.PlainSecret("secret-old"), + Brand: core.BrandFeishu, + }}, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + restoreFS := vfs.DefaultFS + vfs.DefaultFS = &failRenameFS{err: errors.New("rename boom")} + t.Cleanup(func() { vfs.DefaultFS = restoreFS }) + + f, _, _, _ := cmdutil.TestFactory(t, nil) + err := profileRenameRun(f, "old", "new") + if err == nil { + t.Fatal("expected save error") + } + assertInternalExitError(t, err, "failed to save config") +} + +func TestProfileUseRun_SaveFailureReturnsStructuredError(t *testing.T) { + setupProfileConfigDir(t) + multi := &core.MultiAppConfig{ + CurrentApp: "default", + Apps: []core.AppConfig{ + {Name: "default", AppId: "app-default", AppSecret: core.PlainSecret("secret-default"), Brand: core.BrandFeishu}, + {Name: "target", AppId: "app-target", AppSecret: core.PlainSecret("secret-target"), Brand: core.BrandLark}, + }, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + restoreFS := vfs.DefaultFS + vfs.DefaultFS = &failRenameFS{err: errors.New("rename boom")} + t.Cleanup(func() { vfs.DefaultFS = restoreFS }) + + f, _, _, _ := cmdutil.TestFactory(t, nil) + err := profileUseRun(f, "target") + if err == nil { + t.Fatal("expected save error") + } + assertInternalExitError(t, err, "failed to save config") +} + +func assertInternalExitError(t *testing.T, err error, wantMsg string) { + t.Helper() + + var exitErr *output.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("error type = %T, want *output.ExitError; err=%v", err, err) + } + if exitErr.Code != output.ExitInternal { + t.Fatalf("exit code = %d, want %d", exitErr.Code, output.ExitInternal) + } + if exitErr.Detail == nil || exitErr.Detail.Type != "internal" { + t.Fatalf("detail = %#v, want internal detail", exitErr.Detail) + } + if !strings.Contains(exitErr.Detail.Message, wantMsg) { + t.Fatalf("message = %q, want contains %q", exitErr.Detail.Message, wantMsg) + } +} diff --git a/cmd/profile/remove.go b/cmd/profile/remove.go new file mode 100644 index 000000000..00599c0da --- /dev/null +++ b/cmd/profile/remove.go @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package profile + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + + larkauth "github.com/larksuite/cli/internal/auth" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/output" +) + +// NewCmdProfileRemove creates the profile remove subcommand. +func NewCmdProfileRemove(f *cmdutil.Factory) *cobra.Command { + cmd := &cobra.Command{ + Use: "remove ", + Short: "Remove a profile", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return profileRemoveRun(f, args[0]) + }, + } + cmdutil.SetTips(cmd, []string{ + "AI agents: Do NOT remove profiles unless the user explicitly asks. This is destructive and clears all associated credentials.", + }) + return cmd +} + +func profileRemoveRun(f *cmdutil.Factory, name string) error { + multi, err := core.LoadMultiAppConfig() + if err != nil { + return output.ErrWithHint(output.ExitValidation, "config", "not configured", "run: lark-cli config init") + } + + idx := multi.FindAppIndex(name) + if idx < 0 { + return output.ErrValidation("profile %q not found, available profiles: %s", name, strings.Join(multi.ProfileNames(), ", ")) + } + + if len(multi.Apps) == 1 { + return output.ErrValidation("cannot remove the only profile") + } + + app := &multi.Apps[idx] + removedName := app.ProfileName() + appId := app.AppId + appSecret := app.AppSecret + users := app.Users + + // Remove from slice + multi.Apps = append(multi.Apps[:idx], multi.Apps[idx+1:]...) + + // Fix currentApp / previousApp references + if multi.CurrentApp == removedName { + multi.CurrentApp = multi.Apps[0].ProfileName() + } + if multi.PreviousApp == removedName { + multi.PreviousApp = "" + } + + if err := core.SaveMultiAppConfig(multi); err != nil { + return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) + } + + // Best-effort credential cleanup after config commit + core.RemoveSecretStore(appSecret, f.Keychain) + for _, user := range users { + larkauth.RemoveStoredToken(appId, user.UserOpenId) + } + + output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Profile %q removed", removedName)) + return nil +} diff --git a/cmd/profile/rename.go b/cmd/profile/rename.go new file mode 100644 index 000000000..e86b569c5 --- /dev/null +++ b/cmd/profile/rename.go @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package profile + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/output" +) + +// NewCmdProfileRename creates the profile rename subcommand. +func NewCmdProfileRename(f *cmdutil.Factory) *cobra.Command { + cmd := &cobra.Command{ + Use: "rename ", + Short: "Rename a profile", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + return profileRenameRun(f, args[0], args[1]) + }, + } + return cmd +} + +func profileRenameRun(f *cmdutil.Factory, oldName, newName string) error { + if err := core.ValidateProfileName(newName); err != nil { + return output.ErrValidation("%v", err) + } + + multi, err := core.LoadMultiAppConfig() + if err != nil { + return output.ErrWithHint(output.ExitValidation, "config", "not configured", "run: lark-cli config init") + } + + idx := multi.FindAppIndex(oldName) + if idx < 0 { + return output.ErrValidation("profile %q not found, available profiles: %s", oldName, strings.Join(multi.ProfileNames(), ", ")) + } + + // Check new name uniqueness across other profiles, allowing renames to this + // profile's own appId or current name. + for i := range multi.Apps { + if i == idx { + continue + } + if multi.Apps[i].Name == newName || multi.Apps[i].AppId == newName { + return output.ErrValidation("profile %q already exists", newName) + } + } + + oldProfileName := multi.Apps[idx].ProfileName() + multi.Apps[idx].Name = newName + + // Update currentApp / previousApp references + if multi.CurrentApp == oldProfileName { + multi.CurrentApp = newName + } + if multi.PreviousApp == oldProfileName { + multi.PreviousApp = newName + } + + if err := core.SaveMultiAppConfig(multi); err != nil { + return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) + } + + output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Profile renamed: %q -> %q", oldProfileName, newName)) + return nil +} diff --git a/cmd/profile/use.go b/cmd/profile/use.go new file mode 100644 index 000000000..f73a47be4 --- /dev/null +++ b/cmd/profile/use.go @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package profile + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/output" +) + +// NewCmdProfileUse creates the profile use subcommand. +func NewCmdProfileUse(f *cmdutil.Factory) *cobra.Command { + cmd := &cobra.Command{ + Use: "use ", + Short: "Switch to a profile (use '-' to toggle back)", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return profileUseRun(f, args[0]) + }, + } + cmdutil.SetTips(cmd, []string{ + "AI agents: Do NOT switch profiles unless the user explicitly asks.", + }) + return cmd +} + +func profileUseRun(f *cmdutil.Factory, name string) error { + multi, err := core.LoadMultiAppConfig() + if err != nil { + return output.ErrWithHint(output.ExitValidation, "config", "not configured", "run: lark-cli config init") + } + + // Handle "-" for toggle-back + if name == "-" { + if multi.PreviousApp == "" { + return output.ErrValidation("no previous profile to switch back to") + } + name = multi.PreviousApp + } + + app := multi.FindApp(name) + if app == nil { + return output.ErrValidation("profile %q not found, available profiles: %s", name, strings.Join(multi.ProfileNames(), ", ")) + } + + targetName := app.ProfileName() + + // Short-circuit if already on the target profile + currentApp := multi.CurrentAppConfig("") + if currentApp != nil && currentApp.ProfileName() == targetName { + fmt.Fprintf(f.IOStreams.ErrOut, "Already on profile %q\n", targetName) + return nil + } + + // Update previous and current + if currentApp != nil { + multi.PreviousApp = currentApp.ProfileName() + } + multi.CurrentApp = targetName + + if err := core.SaveMultiAppConfig(multi); err != nil { + return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) + } + + output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Switched to profile %q (%s, %s)", targetName, app.AppId, app.Brand)) + return nil +} diff --git a/cmd/prune.go b/cmd/prune.go new file mode 100644 index 000000000..6ae18a709 --- /dev/null +++ b/cmd/prune.go @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "slices" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/output" + "github.com/spf13/cobra" +) + +// pruneForStrictMode removes commands incompatible with the active strict mode. +func pruneForStrictMode(root *cobra.Command, mode core.StrictMode) { + pruneIncompatible(root, mode) + pruneEmpty(root) +} + +// pruneIncompatible recursively replaces commands whose annotation declares +// identities incompatible with the forced identity. Commands without annotation are kept. +// Hidden stubs preserve direct execution so users get a strict-mode error instead +// of Cobra's generic "unknown flag" fallback from the parent command. +func pruneIncompatible(parent *cobra.Command, mode core.StrictMode) { + forced := string(mode.ForcedIdentity()) + var toRemove []*cobra.Command + var toAdd []*cobra.Command + for _, child := range parent.Commands() { + ids := cmdutil.GetSupportedIdentities(child) + if ids != nil && !slices.Contains(ids, forced) { + toRemove = append(toRemove, child) + toAdd = append(toAdd, strictModeStubFrom(child, mode)) + continue + } + pruneIncompatible(child, mode) + } + if len(toRemove) > 0 { + parent.RemoveCommand(toRemove...) + parent.AddCommand(toAdd...) + } +} + +func strictModeStubFrom(child *cobra.Command, mode core.StrictMode) *cobra.Command { + return &cobra.Command{ + Use: child.Use, + Aliases: append([]string(nil), child.Aliases...), + Hidden: true, + DisableFlagParsing: true, + RunE: func(cmd *cobra.Command, args []string) error { + return output.Errorf(output.ExitValidation, "strict_mode", + "strict mode is %q, only %s identity is allowed. "+ + "This setting is managed by the administrator and must not be modified by AI agents.", + mode, mode.ForcedIdentity()) + }, + } +} + +// pruneEmpty recursively removes group commands (no Run/RunE) that have +// no remaining subcommands after pruning. If only hidden stubs remain, keep +// the group hidden so direct execution still resolves to the stub path. +func pruneEmpty(parent *cobra.Command) { + var toRemove []*cobra.Command + for _, child := range parent.Commands() { + pruneEmpty(child) + if child.Run != nil || child.RunE != nil { + continue + } + switch { + case child.HasAvailableSubCommands(): + case len(child.Commands()) > 0: + child.Hidden = true + default: + toRemove = append(toRemove, child) + } + } + if len(toRemove) > 0 { + parent.RemoveCommand(toRemove...) + } +} diff --git a/cmd/prune_test.go b/cmd/prune_test.go new file mode 100644 index 000000000..8d0594737 --- /dev/null +++ b/cmd/prune_test.go @@ -0,0 +1,200 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "strings" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/spf13/cobra" +) + +func newTestTree() *cobra.Command { + root := &cobra.Command{Use: "root"} + + svc := &cobra.Command{Use: "im"} + root.AddCommand(svc) + + noop := func(*cobra.Command, []string) error { return nil } + + userOnly := &cobra.Command{Use: "+search", Short: "user only", RunE: noop} + cmdutil.SetSupportedIdentities(userOnly, []string{"user"}) + svc.AddCommand(userOnly) + + botOnly := &cobra.Command{Use: "+subscribe", Short: "bot only", RunE: noop} + cmdutil.SetSupportedIdentities(botOnly, []string{"bot"}) + svc.AddCommand(botOnly) + + dual := &cobra.Command{Use: "+send", Short: "dual", RunE: noop} + cmdutil.SetSupportedIdentities(dual, []string{"user", "bot"}) + svc.AddCommand(dual) + + noAnnotation := &cobra.Command{Use: "+legacy", Short: "no annotation", RunE: noop} + svc.AddCommand(noAnnotation) + + res := &cobra.Command{Use: "messages"} + svc.AddCommand(res) + userMethod := &cobra.Command{Use: "search", RunE: func(*cobra.Command, []string) error { return nil }} + cmdutil.SetSupportedIdentities(userMethod, []string{"user"}) + res.AddCommand(userMethod) + + auth := &cobra.Command{Use: "auth"} + root.AddCommand(auth) + login := &cobra.Command{Use: "login", RunE: noop} + cmdutil.SetSupportedIdentities(login, []string{"user"}) + auth.AddCommand(login) + + return root +} + +func findCmd(root *cobra.Command, names ...string) *cobra.Command { + cmd := root + for _, name := range names { + found := false + for _, c := range cmd.Commands() { + if c.Name() == name { + cmd = c + found = true + break + } + } + if !found { + return nil + } + } + return cmd +} + +func TestPruneForStrictMode_Bot(t *testing.T) { + root := newTestTree() + pruneForStrictMode(root, core.StrictModeBot) + + if cmd := findCmd(root, "im", "+search"); cmd == nil || !cmd.Hidden { + t.Error("+search (user-only) should be replaced by a hidden stub in bot mode") + } + if findCmd(root, "im", "+subscribe") == nil { + t.Error("+subscribe (bot-only) should be kept in bot mode") + } + if findCmd(root, "im", "+send") == nil { + t.Error("+send (dual) should be kept in bot mode") + } + if findCmd(root, "im", "+legacy") == nil { + t.Error("+legacy (no annotation) should be kept") + } + if cmd := findCmd(root, "im", "messages", "search"); cmd == nil || !cmd.Hidden { + t.Error("search (user-only method) should be replaced by a hidden stub in bot mode") + } + if cmd := findCmd(root, "auth", "login"); cmd == nil || !cmd.Hidden { + t.Error("auth login should be replaced by a hidden stub in bot mode") + } +} + +func TestPruneForStrictMode_User(t *testing.T) { + root := newTestTree() + pruneForStrictMode(root, core.StrictModeUser) + + if findCmd(root, "im", "+search") == nil { + t.Error("+search (user-only) should be kept in user mode") + } + if cmd := findCmd(root, "im", "+subscribe"); cmd == nil || !cmd.Hidden { + t.Error("+subscribe (bot-only) should be replaced by a hidden stub in user mode") + } + if findCmd(root, "im", "+send") == nil { + t.Error("+send (dual) should be kept in user mode") + } + if cmd := findCmd(root, "auth", "login"); cmd == nil || cmd.Hidden { + t.Error("auth login should be kept in user mode") + } +} + +func TestPruneEmpty(t *testing.T) { + root := newTestTree() + pruneForStrictMode(root, core.StrictModeBot) + + if cmd := findCmd(root, "im", "messages"); cmd == nil || !cmd.Hidden { + t.Error("resource 'messages' should be kept hidden when only hidden stubs remain") + } +} + +func TestPruneEmpty_PreservesOriginallyHiddenGroup(t *testing.T) { + root := &cobra.Command{Use: "root"} + hidden := &cobra.Command{Use: "hidden", Hidden: true} + root.AddCommand(hidden) + hidden.AddCommand(&cobra.Command{ + Use: "visible", + RunE: func(*cobra.Command, []string) error { return nil }, + }) + + pruneEmpty(root) + + if !hidden.Hidden { + t.Fatal("expected originally hidden group to remain hidden") + } +} + +func TestPruneForStrictMode_Bot_DirectUserShortcutReturnsStrictMode(t *testing.T) { + root := newTestTree() + root.SilenceErrors = true + root.SilenceUsage = true + pruneForStrictMode(root, core.StrictModeBot) + root.SetArgs([]string{"im", "+search", "--query", "hello"}) + + err := root.Execute() + if err == nil { + t.Fatal("expected strict-mode error") + } + if !strings.Contains(err.Error(), `strict mode is "bot"`) { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestPruneForStrictMode_Bot_DirectNestedUserMethodReturnsStrictMode(t *testing.T) { + root := newTestTree() + root.SilenceErrors = true + root.SilenceUsage = true + pruneForStrictMode(root, core.StrictModeBot) + root.SetArgs([]string{"im", "messages", "search", "--query", "hello"}) + + err := root.Execute() + if err == nil { + t.Fatal("expected strict-mode error") + } + if !strings.Contains(err.Error(), `strict mode is "bot"`) { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestPruneForStrictMode_Bot_DirectAuthLoginReturnsStrictMode(t *testing.T) { + root := newTestTree() + root.SilenceErrors = true + root.SilenceUsage = true + pruneForStrictMode(root, core.StrictModeBot) + root.SetArgs([]string{"auth", "login", "--json", "--scope", "im:message.send_as_user"}) + + err := root.Execute() + if err == nil { + t.Fatal("expected strict-mode error") + } + if !strings.Contains(err.Error(), `strict mode is "bot"`) { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestPruneForStrictMode_User_DirectBotShortcutReturnsStrictMode(t *testing.T) { + root := newTestTree() + root.SilenceErrors = true + root.SilenceUsage = true + pruneForStrictMode(root, core.StrictModeUser) + root.SetArgs([]string{"im", "+subscribe", "--topic", "x"}) + + err := root.Execute() + if err == nil { + t.Fatal("expected strict-mode error") + } + if !strings.Contains(err.Error(), `strict mode is "user"`) { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/cmd/root.go b/cmd/root.go index 3440edb16..0740d134d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,6 +5,7 @@ package cmd import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -18,6 +19,7 @@ import ( "github.com/larksuite/cli/cmd/completion" cmdconfig "github.com/larksuite/cli/cmd/config" "github.com/larksuite/cli/cmd/doctor" + "github.com/larksuite/cli/cmd/profile" "github.com/larksuite/cli/cmd/schema" "github.com/larksuite/cli/cmd/service" internalauth "github.com/larksuite/cli/internal/auth" @@ -87,8 +89,14 @@ More help: lark-cli --help` // Execute runs the root command and returns the process exit code. func Execute() int { - f := cmdutil.NewDefault() + inv, err := BootstrapInvocationContext(os.Args[1:]) + if err != nil { + fmt.Fprintln(os.Stderr, "Error:", err) + return 1 + } + f := cmdutil.NewDefault(inv) + globals := &GlobalOptions{Profile: inv.Profile} rootCmd := &cobra.Command{ Use: "lark-cli", Short: "Lark/Feishu CLI — OAuth authorization, UAT management, API calls", @@ -97,12 +105,15 @@ func Execute() int { } installTipsHelpFunc(rootCmd) rootCmd.SilenceErrors = true + + RegisterGlobalFlags(rootCmd.PersistentFlags(), globals) rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { cmd.SilenceUsage = true } rootCmd.AddCommand(cmdconfig.NewCmdConfig(f)) rootCmd.AddCommand(auth.NewCmdAuth(f)) + rootCmd.AddCommand(profile.NewCmdProfile(f)) rootCmd.AddCommand(doctor.NewCmdDoctor(f)) rootCmd.AddCommand(api.NewCmdApi(f, nil)) rootCmd.AddCommand(schema.NewCmdSchema(f, nil)) @@ -110,6 +121,11 @@ func Execute() int { service.RegisterServiceCommands(rootCmd, f) shortcuts.RegisterShortcuts(rootCmd, f) + // Prune commands incompatible with strict mode. + if mode := f.ResolveStrictMode(context.Background()); mode.IsActive() { + pruneForStrictMode(rootCmd, mode) + } + // --- Update check (non-blocking) --- if !isCompletionCommand(os.Args) { setupUpdateNotice() diff --git a/cmd/root_e2e_test.go b/cmd/root_e2e_test.go deleted file mode 100644 index afdae1b41..000000000 --- a/cmd/root_e2e_test.go +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright (c) 2026 Lark Technologies Pte. Ltd. -// SPDX-License-Identifier: MIT - -package cmd - -import ( - "bytes" - "encoding/json" - "reflect" - "testing" - - "github.com/larksuite/cli/cmd/api" - "github.com/larksuite/cli/cmd/service" - "github.com/larksuite/cli/internal/cmdutil" - "github.com/larksuite/cli/internal/core" - "github.com/larksuite/cli/internal/httpmock" - "github.com/larksuite/cli/internal/output" - "github.com/larksuite/cli/shortcuts" - "github.com/spf13/cobra" -) - -// buildTestRootCmd creates a root command with api, service, and shortcut -// subcommands wired to a test factory, simulating the real CLI command tree. -func buildTestRootCmd(t *testing.T, f *cmdutil.Factory) *cobra.Command { - t.Helper() - rootCmd := &cobra.Command{Use: "lark-cli"} - rootCmd.SilenceErrors = true - rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { - cmd.SilenceUsage = true - } - rootCmd.AddCommand(api.NewCmdApi(f, nil)) - service.RegisterServiceCommands(rootCmd, f) - shortcuts.RegisterShortcuts(rootCmd, f) - return rootCmd -} - -// executeE2E runs a command through the full command tree and handleRootError, -// returning exit code — matching real CLI behavior. -func executeE2E(t *testing.T, f *cmdutil.Factory, rootCmd *cobra.Command, args []string) int { - t.Helper() - rootCmd.SetArgs(args) - if err := rootCmd.Execute(); err != nil { - return handleRootError(f, err) - } - return 0 -} - -// registerTokenStub registers a tenant_access_token stub so bot auth succeeds. -func registerTokenStub(reg *httpmock.Registry) { - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-e2e-token", "expire": 7200, - }, - }) -} - -// parseEnvelope parses stderr bytes into an ErrorEnvelope. -func parseEnvelope(t *testing.T, stderr *bytes.Buffer) output.ErrorEnvelope { - t.Helper() - if stderr.Len() == 0 { - t.Fatal("expected non-empty stderr, got empty") - } - var env output.ErrorEnvelope - if err := json.Unmarshal(stderr.Bytes(), &env); err != nil { - t.Fatalf("failed to parse stderr as ErrorEnvelope: %v\nstderr: %s", err, stderr.String()) - } - return env -} - -// assertEnvelope verifies exit code, stdout is empty, and stderr matches the -// expected ErrorEnvelope exactly via reflect.DeepEqual. -func assertEnvelope(t *testing.T, code int, wantCode int, stdout *bytes.Buffer, stderr *bytes.Buffer, want output.ErrorEnvelope) { - t.Helper() - if code != wantCode { - t.Errorf("exit code: got %d, want %d", code, wantCode) - } - if stdout.Len() != 0 { - t.Errorf("expected empty stdout, got:\n%s", stdout.String()) - } - got := parseEnvelope(t, stderr) - if !reflect.DeepEqual(got, want) { - gotJSON, _ := json.MarshalIndent(got, "", " ") - wantJSON, _ := json.MarshalIndent(want, "", " ") - t.Errorf("stderr envelope mismatch:\ngot:\n%s\nwant:\n%s", gotJSON, wantJSON) - } -} - -// --- api command --- - -func TestE2E_Api_BusinessError_OutputsEnvelope(t *testing.T) { - f, stdout, stderr, reg := cmdutil.TestFactory(t, &core.CliConfig{ - AppID: "e2e-api-err", AppSecret: "secret", Brand: core.BrandFeishu, - }) - registerTokenStub(reg) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/im/v1/messages", - Body: map[string]interface{}{ - "code": 230002, - "msg": "Bot/User can NOT be out of the chat.", - "error": map[string]interface{}{ - "log_id": "test-log-id-001", - }, - }, - }) - - rootCmd := buildTestRootCmd(t, f) - code := executeE2E(t, f, rootCmd, []string{ - "api", "--as", "bot", "POST", "/open-apis/im/v1/messages", - "--params", `{"receive_id_type":"chat_id"}`, - "--data", `{"receive_id":"oc_xxx","msg_type":"text","content":"{\"text\":\"test\"}"}`, - }) - - // api uses MarkRaw: detail preserved, no enrichment - assertEnvelope(t, code, output.ExitAPI, stdout, stderr, output.ErrorEnvelope{ - OK: false, - Identity: "bot", - Error: &output.ErrDetail{ - Type: "api_error", - Code: 230002, - Message: "API error: [230002] Bot/User can NOT be out of the chat.", - Detail: map[string]interface{}{ - "log_id": "test-log-id-001", - }, - }, - }) -} - -func TestE2E_Api_PermissionError_NotEnriched(t *testing.T) { - f, stdout, stderr, reg := cmdutil.TestFactory(t, &core.CliConfig{ - AppID: "e2e-api-perm", AppSecret: "secret", Brand: core.BrandFeishu, - }) - registerTokenStub(reg) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/test/perm", - Body: map[string]interface{}{ - "code": 99991672, - "msg": "scope not enabled for this app", - "error": map[string]interface{}{ - "permission_violations": []interface{}{ - map[string]interface{}{"subject": "calendar:calendar:readonly"}, - }, - "log_id": "test-log-id-perm", - }, - }, - }) - - rootCmd := buildTestRootCmd(t, f) - code := executeE2E(t, f, rootCmd, []string{ - "api", "--as", "bot", "GET", "/open-apis/test/perm", - }) - - // api uses MarkRaw: enrichment skipped, detail preserved, no console_url - assertEnvelope(t, code, output.ExitAPI, stdout, stderr, output.ErrorEnvelope{ - OK: false, - Identity: "bot", - Error: &output.ErrDetail{ - Type: "permission", - Code: 99991672, - Message: "Permission denied [99991672]", - Hint: "check app permissions or re-authorize: lark-cli auth login", - Detail: map[string]interface{}{ - "permission_violations": []interface{}{ - map[string]interface{}{"subject": "calendar:calendar:readonly"}, - }, - "log_id": "test-log-id-perm", - }, - }, - }) -} - -// --- service command --- - -func TestE2E_Service_BusinessError_OutputsEnvelope(t *testing.T) { - f, stdout, stderr, reg := cmdutil.TestFactory(t, &core.CliConfig{ - AppID: "e2e-svc-err", AppSecret: "secret", Brand: core.BrandFeishu, - }) - registerTokenStub(reg) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/im/v1/chats/oc_fake", - Body: map[string]interface{}{ - "code": 99992356, - "msg": "id not exist", - "error": map[string]interface{}{ - "log_id": "test-log-id-svc", - }, - }, - }) - - rootCmd := buildTestRootCmd(t, f) - code := executeE2E(t, f, rootCmd, []string{ - "im", "chats", "get", "--params", `{"chat_id":"oc_fake"}`, "--as", "bot", - }) - - // service: no MarkRaw, non-permission error — detail preserved - assertEnvelope(t, code, output.ExitAPI, stdout, stderr, output.ErrorEnvelope{ - OK: false, - Identity: "bot", - Error: &output.ErrDetail{ - Type: "api_error", - Code: 99992356, - Message: "API error: [99992356] id not exist", - Detail: map[string]interface{}{ - "log_id": "test-log-id-svc", - }, - }, - }) -} - -func TestE2E_Service_PermissionError_Enriched(t *testing.T) { - f, stdout, stderr, reg := cmdutil.TestFactory(t, &core.CliConfig{ - AppID: "e2e-svc-perm", AppSecret: "secret", Brand: core.BrandFeishu, - }) - registerTokenStub(reg) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/im/v1/chats/oc_test", - Body: map[string]interface{}{ - "code": 99991672, - "msg": "scope not enabled", - "error": map[string]interface{}{ - "permission_violations": []interface{}{ - map[string]interface{}{"subject": "im:chat:readonly"}, - }, - }, - }, - }) - - rootCmd := buildTestRootCmd(t, f) - code := executeE2E(t, f, rootCmd, []string{ - "im", "chats", "get", "--params", `{"chat_id":"oc_test"}`, "--as", "bot", - }) - - // service: no MarkRaw — enrichment applied, detail cleared, console_url set - assertEnvelope(t, code, output.ExitAPI, stdout, stderr, output.ErrorEnvelope{ - OK: false, - Identity: "bot", - Error: &output.ErrDetail{ - Type: "permission", - Code: 99991672, - Message: "App scope not enabled: required scope im:chat:readonly [99991672]", - Hint: "enable the scope in developer console (see console_url)", - ConsoleURL: "https://open.feishu.cn/page/scope-apply?clientID=e2e-svc-perm&scopes=im%3Achat%3Areadonly", - }, - }) -} - -// --- shortcut command --- - -func TestE2E_Shortcut_BusinessError_OutputsEnvelope(t *testing.T) { - f, stdout, stderr, reg := cmdutil.TestFactory(t, &core.CliConfig{ - AppID: "e2e-sc-err", AppSecret: "secret", Brand: core.BrandFeishu, - }) - registerTokenStub(reg) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/im/v1/messages", - Status: 400, - Body: map[string]interface{}{ - "code": 230002, - "msg": "Bot/User can NOT be out of the chat.", - }, - }) - - rootCmd := buildTestRootCmd(t, f) - code := executeE2E(t, f, rootCmd, []string{ - "im", "+messages-send", "--as", "bot", "--chat-id", "oc_xxx", "--text", "test", - }) - - // shortcut: no MarkRaw, no HandleResponse — error via DoAPIJSON path - assertEnvelope(t, code, output.ExitAPI, stdout, stderr, output.ErrorEnvelope{ - OK: false, - Identity: "bot", - Error: &output.ErrDetail{ - Type: "api_error", - Code: 230002, - Message: "HTTP 400: Bot/User can NOT be out of the chat.", - }, - }) -} diff --git a/cmd/root_integration_test.go b/cmd/root_integration_test.go new file mode 100644 index 000000000..555018738 --- /dev/null +++ b/cmd/root_integration_test.go @@ -0,0 +1,490 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "bytes" + "context" + "encoding/json" + "reflect" + "strings" + "testing" + + "github.com/larksuite/cli/cmd/api" + "github.com/larksuite/cli/cmd/auth" + "github.com/larksuite/cli/cmd/service" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/envvars" + "github.com/larksuite/cli/internal/httpmock" + "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/shortcuts" + "github.com/spf13/cobra" +) + +// buildIntegrationRootCmd creates a root command with api, service, and shortcut +// subcommands wired to a test factory, simulating the real CLI command tree. +func buildIntegrationRootCmd(t *testing.T, f *cmdutil.Factory) *cobra.Command { + t.Helper() + rootCmd := &cobra.Command{Use: "lark-cli"} + rootCmd.SilenceErrors = true + rootCmd.SetOut(f.IOStreams.Out) + rootCmd.SetErr(f.IOStreams.ErrOut) + rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + cmd.SilenceUsage = true + } + rootCmd.AddCommand(api.NewCmdApi(f, nil)) + service.RegisterServiceCommands(rootCmd, f) + shortcuts.RegisterShortcuts(rootCmd, f) + return rootCmd +} + +// executeRootIntegration runs a command through the full command tree and +// handleRootError, returning the exit code matching real CLI behavior. +func executeRootIntegration(t *testing.T, f *cmdutil.Factory, rootCmd *cobra.Command, args []string) int { + t.Helper() + rootCmd.SetArgs(args) + if err := rootCmd.Execute(); err != nil { + return handleRootError(f, err) + } + return 0 +} + +// parseEnvelope parses stderr bytes into an ErrorEnvelope. +func parseEnvelope(t *testing.T, stderr *bytes.Buffer) output.ErrorEnvelope { + t.Helper() + if stderr.Len() == 0 { + t.Fatal("expected non-empty stderr, got empty") + } + var env output.ErrorEnvelope + if err := json.Unmarshal(stderr.Bytes(), &env); err != nil { + t.Fatalf("failed to parse stderr as ErrorEnvelope: %v\nstderr: %s", err, stderr.String()) + } + return env +} + +// assertEnvelope verifies exit code, stdout is empty, and stderr matches the +// expected ErrorEnvelope exactly via reflect.DeepEqual. +func assertEnvelope(t *testing.T, code int, wantCode int, stdout *bytes.Buffer, stderr *bytes.Buffer, want output.ErrorEnvelope) { + t.Helper() + if code != wantCode { + t.Errorf("exit code: got %d, want %d", code, wantCode) + } + if stdout.Len() != 0 { + t.Errorf("expected empty stdout, got:\n%s", stdout.String()) + } + got := parseEnvelope(t, stderr) + if !reflect.DeepEqual(got, want) { + gotJSON, _ := json.MarshalIndent(got, "", " ") + wantJSON, _ := json.MarshalIndent(want, "", " ") + t.Errorf("stderr envelope mismatch:\ngot:\n%s\nwant:\n%s", gotJSON, wantJSON) + } +} + +func buildStrictModeIntegrationRootCmd(t *testing.T, f *cmdutil.Factory) *cobra.Command { + t.Helper() + rootCmd := &cobra.Command{Use: "lark-cli"} + rootCmd.SilenceErrors = true + rootCmd.SetOut(f.IOStreams.Out) + rootCmd.SetErr(f.IOStreams.ErrOut) + rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + cmd.SilenceUsage = true + } + rootCmd.AddCommand(auth.NewCmdAuth(f)) + rootCmd.AddCommand(api.NewCmdApi(f, nil)) + service.RegisterServiceCommands(rootCmd, f) + shortcuts.RegisterShortcuts(rootCmd, f) + if mode := f.ResolveStrictMode(context.Background()); mode.IsActive() { + pruneForStrictMode(rootCmd, mode) + } + return rootCmd +} + +func newStrictModeDefaultFactory(t *testing.T, profile string, mode core.StrictMode) (*cmdutil.Factory, *bytes.Buffer, *bytes.Buffer) { + t.Helper() + t.Setenv(envvars.CliAppID, "") + t.Setenv(envvars.CliAppSecret, "") + t.Setenv(envvars.CliUserAccessToken, "") + t.Setenv(envvars.CliTenantAccessToken, "") + t.Setenv(envvars.CliDefaultAs, "") + + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + + targetMode := mode + multi := &core.MultiAppConfig{ + CurrentApp: "default", + Apps: []core.AppConfig{ + { + Name: "default", + AppId: "app-default", + AppSecret: core.PlainSecret("secret-default"), + Brand: core.BrandFeishu, + }, + { + Name: "target", + AppId: "app-target", + AppSecret: core.PlainSecret("secret-target"), + Brand: core.BrandFeishu, + StrictMode: &targetMode, + }, + }, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + f := cmdutil.NewDefault(cmdutil.InvocationContext{Profile: profile}) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + f.IOStreams = &cmdutil.IOStreams{In: nil, Out: stdout, ErrOut: stderr} + return f, stdout, stderr +} + +func resetBuffers(stdout *bytes.Buffer, stderr *bytes.Buffer) { + stdout.Reset() + stderr.Reset() +} + +func parseDryRunJSON(t *testing.T, stdout *bytes.Buffer) map[string]interface{} { + t.Helper() + out := stdout.String() + const prefix = "=== Dry Run ===\n" + if !strings.HasPrefix(out, prefix) { + t.Fatalf("expected dry-run prefix, got:\n%s", out) + } + var payload map[string]interface{} + if err := json.Unmarshal([]byte(strings.TrimPrefix(out, prefix)), &payload); err != nil { + t.Fatalf("failed to parse dry-run payload: %v\nstdout: %s", err, out) + } + return payload +} + +// --- api command --- + +func TestIntegration_Api_BusinessError_OutputsEnvelope(t *testing.T) { + f, stdout, stderr, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "e2e-api-err", AppSecret: "secret", Brand: core.BrandFeishu, + }) + reg.Register(&httpmock.Stub{ + URL: "/open-apis/im/v1/messages", + Body: map[string]interface{}{ + "code": 230002, + "msg": "Bot/User can NOT be out of the chat.", + "error": map[string]interface{}{ + "log_id": "test-log-id-001", + }, + }, + }) + + rootCmd := buildIntegrationRootCmd(t, f) + code := executeRootIntegration(t, f, rootCmd, []string{ + "api", "--as", "bot", "POST", "/open-apis/im/v1/messages", + "--params", `{"receive_id_type":"chat_id"}`, + "--data", `{"receive_id":"oc_xxx","msg_type":"text","content":"{\"text\":\"test\"}"}`, + }) + + // api uses MarkRaw: detail preserved, no enrichment + assertEnvelope(t, code, output.ExitAPI, stdout, stderr, output.ErrorEnvelope{ + OK: false, + Identity: "bot", + Error: &output.ErrDetail{ + Type: "api_error", + Code: 230002, + Message: "API error: [230002] Bot/User can NOT be out of the chat.", + Detail: map[string]interface{}{ + "log_id": "test-log-id-001", + }, + }, + }) +} + +func TestIntegration_Api_PermissionError_NotEnriched(t *testing.T) { + f, stdout, stderr, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "e2e-api-perm", AppSecret: "secret", Brand: core.BrandFeishu, + }) + reg.Register(&httpmock.Stub{ + URL: "/open-apis/test/perm", + Body: map[string]interface{}{ + "code": 99991672, + "msg": "scope not enabled for this app", + "error": map[string]interface{}{ + "permission_violations": []interface{}{ + map[string]interface{}{"subject": "calendar:calendar:readonly"}, + }, + "log_id": "test-log-id-perm", + }, + }, + }) + + rootCmd := buildIntegrationRootCmd(t, f) + code := executeRootIntegration(t, f, rootCmd, []string{ + "api", "--as", "bot", "GET", "/open-apis/test/perm", + }) + + // api uses MarkRaw: enrichment skipped, detail preserved, no console_url + assertEnvelope(t, code, output.ExitAPI, stdout, stderr, output.ErrorEnvelope{ + OK: false, + Identity: "bot", + Error: &output.ErrDetail{ + Type: "permission", + Code: 99991672, + Message: "Permission denied [99991672]", + Hint: "check app permissions or re-authorize: lark-cli auth login", + Detail: map[string]interface{}{ + "permission_violations": []interface{}{ + map[string]interface{}{"subject": "calendar:calendar:readonly"}, + }, + "log_id": "test-log-id-perm", + }, + }, + }) +} + +// --- service command --- + +func TestIntegration_Service_BusinessError_OutputsEnvelope(t *testing.T) { + f, stdout, stderr, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "e2e-svc-err", AppSecret: "secret", Brand: core.BrandFeishu, + }) + reg.Register(&httpmock.Stub{ + URL: "/open-apis/im/v1/chats/oc_fake", + Body: map[string]interface{}{ + "code": 99992356, + "msg": "id not exist", + "error": map[string]interface{}{ + "log_id": "test-log-id-svc", + }, + }, + }) + + rootCmd := buildIntegrationRootCmd(t, f) + code := executeRootIntegration(t, f, rootCmd, []string{ + "im", "chats", "get", "--params", `{"chat_id":"oc_fake"}`, "--as", "bot", + }) + + // service: no MarkRaw, non-permission error — detail preserved + assertEnvelope(t, code, output.ExitAPI, stdout, stderr, output.ErrorEnvelope{ + OK: false, + Identity: "bot", + Error: &output.ErrDetail{ + Type: "api_error", + Code: 99992356, + Message: "API error: [99992356] id not exist", + Detail: map[string]interface{}{ + "log_id": "test-log-id-svc", + }, + }, + }) +} + +func TestIntegration_Service_PermissionError_Enriched(t *testing.T) { + f, stdout, stderr, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "e2e-svc-perm", AppSecret: "secret", Brand: core.BrandFeishu, + }) + reg.Register(&httpmock.Stub{ + URL: "/open-apis/im/v1/chats/oc_test", + Body: map[string]interface{}{ + "code": 99991672, + "msg": "scope not enabled", + "error": map[string]interface{}{ + "permission_violations": []interface{}{ + map[string]interface{}{"subject": "im:chat:readonly"}, + }, + }, + }, + }) + + rootCmd := buildIntegrationRootCmd(t, f) + code := executeRootIntegration(t, f, rootCmd, []string{ + "im", "chats", "get", "--params", `{"chat_id":"oc_test"}`, "--as", "bot", + }) + + // service: no MarkRaw — enrichment applied, detail cleared, console_url set + assertEnvelope(t, code, output.ExitAPI, stdout, stderr, output.ErrorEnvelope{ + OK: false, + Identity: "bot", + Error: &output.ErrDetail{ + Type: "permission", + Code: 99991672, + Message: "App scope not enabled: required scope im:chat:readonly [99991672]", + Hint: "enable the scope in developer console (see console_url)", + ConsoleURL: "https://open.feishu.cn/page/scope-apply?clientID=e2e-svc-perm&scopes=im%3Achat%3Areadonly", + }, + }) +} + +func TestIntegration_StrictModeBot_ProfileOverride_HidesCommandsInHelp(t *testing.T) { + f, stdout, stderr := newStrictModeDefaultFactory(t, "target", core.StrictModeBot) + rootCmd := buildStrictModeIntegrationRootCmd(t, f) + + code := executeRootIntegration(t, f, rootCmd, []string{"auth", "--help"}) + if code != 0 { + t.Fatalf("auth --help exit code = %d, want 0", code) + } + if stderr.Len() != 0 { + t.Fatalf("expected empty stderr, got: %s", stderr.String()) + } + if strings.Contains(stdout.String(), "login") { + t.Fatalf("auth --help should hide login in bot mode, got:\n%s", stdout.String()) + } + + resetBuffers(stdout, stderr) + rootCmd = buildStrictModeIntegrationRootCmd(t, f) + code = executeRootIntegration(t, f, rootCmd, []string{"im", "--help"}) + if code != 0 { + t.Fatalf("im --help exit code = %d, want 0", code) + } + if stderr.Len() != 0 { + t.Fatalf("expected empty stderr, got: %s", stderr.String()) + } + if strings.Contains(stdout.String(), "+messages-search") { + t.Fatalf("im --help should hide +messages-search in bot mode, got:\n%s", stdout.String()) + } + if !strings.Contains(stdout.String(), "+chat-create") { + t.Fatalf("im --help should keep +chat-create in bot mode, got:\n%s", stdout.String()) + } +} + +func TestIntegration_StrictModeBot_ProfileOverride_DirectAuthLoginReturnsEnvelope(t *testing.T) { + f, stdout, stderr := newStrictModeDefaultFactory(t, "target", core.StrictModeBot) + rootCmd := buildStrictModeIntegrationRootCmd(t, f) + + code := executeRootIntegration(t, f, rootCmd, []string{ + "auth", "login", "--json", "--scope", "im:message.send_as_user", + }) + + assertEnvelope(t, code, output.ExitValidation, stdout, stderr, output.ErrorEnvelope{ + OK: false, + Error: &output.ErrDetail{ + Type: "strict_mode", + Message: `strict mode is "bot", only bot identity is allowed. This setting is managed by the administrator and must not be modified by AI agents.`, + }, + }) +} + +func TestIntegration_StrictModeBot_ProfileOverride_DirectUserShortcutReturnsEnvelope(t *testing.T) { + f, stdout, stderr := newStrictModeDefaultFactory(t, "target", core.StrictModeBot) + rootCmd := buildStrictModeIntegrationRootCmd(t, f) + + code := executeRootIntegration(t, f, rootCmd, []string{ + "im", "+messages-search", "--chat-id", "oc_xxx", "--query", "hello", + }) + + assertEnvelope(t, code, output.ExitValidation, stdout, stderr, output.ErrorEnvelope{ + OK: false, + Error: &output.ErrDetail{ + Type: "strict_mode", + Message: `strict mode is "bot", only bot identity is allowed. This setting is managed by the administrator and must not be modified by AI agents.`, + }, + }) +} + +func TestIntegration_StrictModeUser_ProfileOverride_ChatCreateDryRunSucceeds(t *testing.T) { + // +chat-create supports both user and bot identities, so strict mode user + // should allow it and force user identity. + f, stdout, stderr := newStrictModeDefaultFactory(t, "target", core.StrictModeUser) + rootCmd := buildStrictModeIntegrationRootCmd(t, f) + + code := executeRootIntegration(t, f, rootCmd, []string{ + "im", "+chat-create", "--name", "probe", "--dry-run", + }) + + if code != 0 { + t.Fatalf("exit code = %d, want 0; stderr: %s", code, stderr.String()) + } + out := stdout.String() + if out == "" { + t.Fatal("expected non-empty stdout for dry-run") + } +} + +func TestIntegration_StrictModeBot_ProfileOverride_ServiceDryRunForcesBotIdentity(t *testing.T) { + f, stdout, stderr := newStrictModeDefaultFactory(t, "target", core.StrictModeBot) + rootCmd := buildStrictModeIntegrationRootCmd(t, f) + + code := executeRootIntegration(t, f, rootCmd, []string{ + "im", "chats", "get", "--params", `{"chat_id":"oc_test"}`, "--as", "user", "--dry-run", + }) + + if code != 0 { + t.Fatalf("exit code = %d, want 0; stderr: %s", code, stderr.String()) + } + if stderr.Len() != 0 { + t.Fatalf("expected empty stderr, got: %s", stderr.String()) + } + payload := parseDryRunJSON(t, stdout) + if got := payload["as"]; got != "bot" { + t.Fatalf("dry-run as = %v, want bot", got) + } +} + +func TestIntegration_StrictModeUser_ProfileOverride_ServiceBotOnlyMethodReturnsEnvelope(t *testing.T) { + f, stdout, stderr := newStrictModeDefaultFactory(t, "target", core.StrictModeUser) + rootCmd := buildStrictModeIntegrationRootCmd(t, f) + + code := executeRootIntegration(t, f, rootCmd, []string{ + "im", "images", "create", "--data", `{"image_type":"message","image":"x"}`, "--dry-run", + }) + + assertEnvelope(t, code, output.ExitValidation, stdout, stderr, output.ErrorEnvelope{ + OK: false, + Error: &output.ErrDetail{ + Type: "strict_mode", + Message: `strict mode is "user", only user identity is allowed. This setting is managed by the administrator and must not be modified by AI agents.`, + }, + }) +} + +func TestIntegration_StrictModeBot_ProfileOverride_APIDryRunForcesBotIdentity(t *testing.T) { + f, stdout, stderr := newStrictModeDefaultFactory(t, "target", core.StrictModeBot) + rootCmd := buildStrictModeIntegrationRootCmd(t, f) + + code := executeRootIntegration(t, f, rootCmd, []string{ + "api", "--as", "user", "GET", "/open-apis/im/v1/chats/oc_test", "--dry-run", + }) + + if code != 0 { + t.Fatalf("exit code = %d, want 0; stderr: %s", code, stderr.String()) + } + if stderr.Len() != 0 { + t.Fatalf("expected empty stderr, got: %s", stderr.String()) + } + payload := parseDryRunJSON(t, stdout) + if got := payload["as"]; got != "bot" { + t.Fatalf("dry-run as = %v, want bot", got) + } +} + +// --- shortcut command --- + +func TestIntegration_Shortcut_BusinessError_OutputsEnvelope(t *testing.T) { + f, stdout, stderr, reg := cmdutil.TestFactory(t, &core.CliConfig{ + AppID: "e2e-sc-err", AppSecret: "secret", Brand: core.BrandFeishu, + }) + reg.Register(&httpmock.Stub{ + URL: "/open-apis/im/v1/messages", + Status: 400, + Body: map[string]interface{}{ + "code": 230002, + "msg": "Bot/User can NOT be out of the chat.", + }, + }) + + rootCmd := buildIntegrationRootCmd(t, f) + code := executeRootIntegration(t, f, rootCmd, []string{ + "im", "+messages-send", "--as", "bot", "--chat-id", "oc_xxx", "--text", "test", + }) + + // shortcut: no MarkRaw, no HandleResponse — error via DoAPIJSON path + assertEnvelope(t, code, output.ExitAPI, stdout, stderr, output.ErrorEnvelope{ + OK: false, + Identity: "bot", + Error: &output.ErrDetail{ + Type: "api_error", + Code: 230002, + Message: "HTTP 400: Bot/User can NOT be out of the chat.", + }, + }) +} diff --git a/cmd/service/service.go b/cmd/service/service.go index 1a392a7ee..85c62cc3e 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -14,6 +14,7 @@ import ( "github.com/larksuite/cli/internal/client" "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/registry" "github.com/larksuite/cli/internal/util" @@ -169,13 +170,20 @@ func NewCmdServiceMethod(f *cmdutil.Factory, spec, method map[string]interface{} }) cmdutil.SetTips(cmd, registry.GetStrSliceFromMap(method, "tips")) + if tokens, ok := method["accessTokens"].([]interface{}); ok && len(tokens) > 0 { + cmdutil.SetSupportedIdentities(cmd, cmdutil.AccessTokensToIdentities(tokens)) + } return cmd } func serviceMethodRun(opts *ServiceMethodOptions) error { f := opts.Factory - opts.As = f.ResolveAs(opts.Cmd, opts.As) + opts.As = f.ResolveAs(opts.Ctx, opts.Cmd, opts.As) + + if err := f.CheckStrictMode(opts.Ctx, opts.As); err != nil { + return err + } // Check if this API method supports the resolved identity. if tokens, ok := opts.Method["accessTokens"].([]interface{}); ok && len(tokens) > 0 { @@ -191,7 +199,7 @@ func serviceMethodRun(opts *ServiceMethodOptions) error { return err } - config, err := f.ResolveConfig(opts.As) + config, err := f.Config() if err != nil { return err } @@ -200,7 +208,7 @@ func serviceMethodRun(opts *ServiceMethodOptions) error { scopes, _ := opts.Method["scopes"].([]interface{}) if !opts.As.IsBot() { - if err := checkServiceScopes(config, opts.Method, scopes); err != nil { + if err := checkServiceScopes(opts.Ctx, f.Credential, opts.As, config, opts.Method, scopes); err != nil { return err } } @@ -247,25 +255,30 @@ func serviceMethodRun(opts *ServiceMethodOptions) error { } // checkServiceScopes pre-checks user scopes before making the API call. -func checkServiceScopes(config *core.CliConfig, method map[string]interface{}, scopes []interface{}) error { +func checkServiceScopes(ctx context.Context, cred *credential.CredentialProvider, identity core.Identity, config *core.CliConfig, method map[string]interface{}, scopes []interface{}) error { + if ctx.Err() != nil { + return ctx.Err() + } + result, err := cred.ResolveToken(ctx, credential.NewTokenSpec(identity, config.AppID)) + if err != nil || result == nil || result.Scopes == "" { + return nil //nolint:nilerr // skip scope check when token resolution fails or has no scopes + } + requiredScopes, hasRequired := method["requiredScopes"].([]interface{}) if hasRequired && len(requiredScopes) > 0 { // Strict: ALL requiredScopes must be present - stored := auth.GetStoredToken(config.AppID, config.UserOpenId) - if stored != nil { - required := make([]string, 0, len(requiredScopes)) - for _, s := range requiredScopes { - if str, ok := s.(string); ok { - required = append(required, str) - } - } - if missing := auth.MissingScopes(stored.Scope, required); len(missing) > 0 { - return output.ErrWithHint(output.ExitAuth, "missing_scope", - fmt.Sprintf("missing required scope(s): %s", strings.Join(missing, ", ")), - fmt.Sprintf("run `lark-cli auth login --scope \"%s\"` in the background. It blocks and outputs a verification URL — retrieve the URL and open it in a browser to complete login.", strings.Join(missing, " "))) + required := make([]string, 0, len(requiredScopes)) + for _, s := range requiredScopes { + if str, ok := s.(string); ok { + required = append(required, str) } } + if missing := auth.MissingScopes(result.Scopes, required); len(missing) > 0 { + return output.ErrWithHint(output.ExitAuth, "missing_scope", + fmt.Sprintf("missing required scope(s): %s", strings.Join(missing, ", ")), + fmt.Sprintf("run `lark-cli auth login --scope \"%s\"` in the background. It blocks and outputs a verification URL — retrieve the URL and open it in a browser to complete login.", strings.Join(missing, " "))) + } return nil } @@ -274,16 +287,12 @@ func checkServiceScopes(config *core.CliConfig, method map[string]interface{}, s } // Default: ANY one of the declared scopes is sufficient - stored := auth.GetStoredToken(config.AppID, config.UserOpenId) - if stored == nil { - return nil - } - grantedScopes := make(map[string]bool) - for _, s := range strings.Fields(stored.Scope) { - grantedScopes[s] = true + grantedSet := make(map[string]bool) + for _, s := range strings.Fields(result.Scopes) { + grantedSet[s] = true } for _, s := range scopes { - if str, ok := s.(string); ok && grantedScopes[str] { + if str, ok := s.(string); ok && grantedSet[str] { return nil } } diff --git a/cmd/service/service_test.go b/cmd/service/service_test.go index d52687dbb..7cd09f398 100644 --- a/cmd/service/service_test.go +++ b/cmd/service/service_test.go @@ -44,16 +44,6 @@ func driveMethod(httpMethod string, params map[string]interface{}) map[string]in return m } -func tokenStub() *httpmock.Stub { - return &httpmock.Stub{ - URL: "tenant_access_token", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test", "expire": 7200, - }, - } -} - // ── registerService ── func TestRegisterService(t *testing.T) { @@ -364,7 +354,6 @@ func TestServiceMethod_OutputAndPageAllConflict(t *testing.T) { func TestServiceMethod_BotMode_Success(t *testing.T) { f, stdout, _, reg := cmdutil.TestFactory(t, testConfig) - reg.Register(tokenStub()) reg.Register(&httpmock.Stub{ URL: "/open-apis/svc/v1/items", Body: map[string]interface{}{ @@ -391,7 +380,6 @@ func TestServiceMethod_BotMode_APIError(t *testing.T) { AppID: "test-app-err", AppSecret: "test-secret-err", Brand: core.BrandFeishu, }) - reg.Register(tokenStub()) reg.Register(&httpmock.Stub{ URL: "/open-apis/svc/v1/items", Body: map[string]interface{}{"code": 40003, "msg": "invalid token"}, @@ -425,7 +413,6 @@ func TestServiceMethod_BotMode_PageAll_JSON(t *testing.T) { AppID: "test-app-page", AppSecret: "test-secret-page", Brand: core.BrandFeishu, }) - reg.Register(tokenStub()) reg.Register(&httpmock.Stub{ URL: "/open-apis/svc/v1/items", Body: map[string]interface{}{ @@ -455,7 +442,6 @@ func TestServiceMethod_UnknownFormat_Warning(t *testing.T) { AppID: "test-app-fmt", AppSecret: "test-secret-fmt", Brand: core.BrandFeishu, }) - reg.Register(tokenStub()) reg.Register(&httpmock.Stub{ URL: "/open-apis/svc/v1/items", Body: map[string]interface{}{"code": 0, "msg": "ok", "data": map[string]interface{}{}}, @@ -540,7 +526,6 @@ func TestServiceMethod_JqFilter_AppliesExpression(t *testing.T) { AppID: "test-app-jq", AppSecret: "test-secret-jq", Brand: core.BrandFeishu, }) - reg.Register(tokenStub()) reg.Register(&httpmock.Stub{ URL: "/open-apis/svc/v1/items", Body: map[string]interface{}{ @@ -612,7 +597,6 @@ func TestServiceMethod_PageAll_WithJq(t *testing.T) { AppID: "test-app-spjq", AppSecret: "test-secret-spjq", Brand: core.BrandFeishu, }) - reg.Register(tokenStub()) reg.Register(&httpmock.Stub{ URL: "/open-apis/svc/v1/items", Body: map[string]interface{}{ diff --git a/extension/credential/env/env.go b/extension/credential/env/env.go new file mode 100644 index 000000000..054d27d85 --- /dev/null +++ b/extension/credential/env/env.go @@ -0,0 +1,116 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package env + +import ( + "context" + "fmt" + "os" + + "github.com/larksuite/cli/extension/credential" + "github.com/larksuite/cli/internal/envvars" +) + +// Provider resolves credentials from environment variables. +type Provider struct{} + +func (p *Provider) Name() string { return "env" } + +func (p *Provider) ResolveAccount(ctx context.Context) (*credential.Account, error) { + appID := os.Getenv(envvars.CliAppID) + appSecret := os.Getenv(envvars.CliAppSecret) + hasUAT := os.Getenv(envvars.CliUserAccessToken) != "" + hasTAT := os.Getenv(envvars.CliTenantAccessToken) != "" + if appID == "" && appSecret == "" { + switch { + case hasUAT: + return nil, &credential.BlockError{Provider: "env", Reason: envvars.CliUserAccessToken + " is set but " + envvars.CliAppID + " is missing"} + case hasTAT: + return nil, &credential.BlockError{Provider: "env", Reason: envvars.CliTenantAccessToken + " is set but " + envvars.CliAppID + " is missing"} + default: + return nil, nil + } + } + if appID == "" { + return nil, &credential.BlockError{Provider: "env", Reason: envvars.CliAppSecret + " is set but " + envvars.CliAppID + " is missing"} + } + if appSecret == "" && !hasUAT && !hasTAT { + return nil, &credential.BlockError{ + Provider: "env", + Reason: envvars.CliAppID + " is set but no app secret or access token is available", + } + } + brand := credential.Brand(os.Getenv(envvars.CliBrand)) + if brand == "" { + brand = credential.BrandFeishu + } + acct := &credential.Account{AppID: appID, AppSecret: appSecret, Brand: brand} + + switch id := credential.Identity(os.Getenv(envvars.CliDefaultAs)); id { + case "", credential.IdentityAuto: + acct.DefaultAs = id + case credential.IdentityUser, credential.IdentityBot: + acct.DefaultAs = id + default: + return nil, &credential.BlockError{ + Provider: "env", + Reason: fmt.Sprintf("invalid %s %q (want user, bot, or auto)", envvars.CliDefaultAs, id), + } + } + + // Explicit strict mode policy takes priority + switch strictMode := os.Getenv(envvars.CliStrictMode); strictMode { + case "bot": + acct.SupportedIdentities = credential.SupportsBot + case "user": + acct.SupportedIdentities = credential.SupportsUser + case "off": + acct.SupportedIdentities = credential.SupportsAll + case "": + // Infer from available tokens + if hasUAT { + acct.SupportedIdentities |= credential.SupportsUser + } + if hasTAT { + acct.SupportedIdentities |= credential.SupportsBot + } + default: + return nil, &credential.BlockError{ + Provider: "env", + Reason: fmt.Sprintf("invalid %s %q (want bot, user, or off)", envvars.CliStrictMode, strictMode), + } + } + + if acct.DefaultAs == "" { + switch { + case hasUAT: + acct.DefaultAs = credential.IdentityUser + case hasTAT: + acct.DefaultAs = credential.IdentityBot + } + } + + return acct, nil +} + +func (p *Provider) ResolveToken(ctx context.Context, req credential.TokenSpec) (*credential.Token, error) { + var envKey string + switch req.Type { + case credential.TokenTypeUAT: + envKey = envvars.CliUserAccessToken + case credential.TokenTypeTAT: + envKey = envvars.CliTenantAccessToken + default: + return nil, nil + } + token := os.Getenv(envKey) + if token == "" { + return nil, nil + } + return &credential.Token{Value: token, Source: "env:" + envKey}, nil +} + +func init() { + credential.Register(&Provider{}) +} diff --git a/extension/credential/env/env_test.go b/extension/credential/env/env_test.go new file mode 100644 index 000000000..8b7af93f0 --- /dev/null +++ b/extension/credential/env/env_test.go @@ -0,0 +1,279 @@ +package env + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/larksuite/cli/extension/credential" + "github.com/larksuite/cli/internal/envvars" +) + +func TestProvider_Name(t *testing.T) { + if (&Provider{}).Name() != "env" { + t.Fail() + } +} + +func TestResolveAccount_BothSet(t *testing.T) { + t.Setenv(envvars.CliAppID, "cli_test") + t.Setenv(envvars.CliAppSecret, "secret_test") + t.Setenv(envvars.CliBrand, "feishu") + + acct, err := (&Provider{}).ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if acct.AppID != "cli_test" || acct.AppSecret != "secret_test" || acct.Brand != "feishu" { + t.Errorf("unexpected: %+v", acct) + } +} + +func TestResolveAccount_NeitherSet(t *testing.T) { + acct, err := (&Provider{}).ResolveAccount(context.Background()) + if err != nil || acct != nil { + t.Errorf("expected nil, nil; got %+v, %v", acct, err) + } +} + +func TestResolveAccount_OnlyIDSet(t *testing.T) { + t.Setenv(envvars.CliAppID, "cli_test") + _, err := (&Provider{}).ResolveAccount(context.Background()) + var blockErr *credential.BlockError + if !errors.As(err, &blockErr) { + t.Fatalf("expected BlockError, got %v", err) + } +} + +func TestResolveAccount_AppIDAndUserTokenWithoutSecret(t *testing.T) { + t.Setenv(envvars.CliAppID, "cli_test") + t.Setenv(envvars.CliUserAccessToken, "uat_test") + + acct, err := (&Provider{}).ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if acct == nil { + t.Fatal("expected account, got nil") + } + if acct.AppSecret != credential.NoAppSecret { + t.Fatalf("AppSecret = %q, want credential.NoAppSecret", acct.AppSecret) + } + if acct.AppID != "cli_test" { + t.Fatalf("AppID = %q, want cli_test", acct.AppID) + } +} + +func TestResolveAccount_OnlySecretSet(t *testing.T) { + t.Setenv(envvars.CliAppSecret, "secret_test") + _, err := (&Provider{}).ResolveAccount(context.Background()) + var blockErr *credential.BlockError + if !errors.As(err, &blockErr) { + t.Fatalf("expected BlockError, got %v", err) + } +} + +func TestResolveAccount_OnlyTokenSetWithoutAppID(t *testing.T) { + t.Setenv(envvars.CliUserAccessToken, "uat_test") + + _, err := (&Provider{}).ResolveAccount(context.Background()) + var blockErr *credential.BlockError + if !errors.As(err, &blockErr) { + t.Fatalf("expected BlockError, got %v", err) + } + if !strings.Contains(err.Error(), envvars.CliAppID) { + t.Fatalf("error = %v, want mention of %s", err, envvars.CliAppID) + } +} + +func TestResolveAccount_DefaultBrand(t *testing.T) { + t.Setenv(envvars.CliAppID, "cli_test") + t.Setenv(envvars.CliAppSecret, "secret_test") + acct, _ := (&Provider{}).ResolveAccount(context.Background()) + if acct.Brand != "feishu" { + t.Errorf("expected 'feishu', got %q", acct.Brand) + } +} + +func TestResolveAccount_DefaultAsFromEnv(t *testing.T) { + t.Setenv(envvars.CliAppID, "cli_test") + t.Setenv(envvars.CliAppSecret, "secret_test") + t.Setenv(envvars.CliDefaultAs, "user") + + acct, err := (&Provider{}).ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if acct.DefaultAs != "user" { + t.Errorf("expected default-as user, got %q", acct.DefaultAs) + } +} + +func TestResolveToken_UATSet(t *testing.T) { + t.Setenv(envvars.CliUserAccessToken, "u-env") + tok, err := (&Provider{}).ResolveToken(context.Background(), credential.TokenSpec{Type: credential.TokenTypeUAT}) + if err != nil { + t.Fatal(err) + } + if tok.Value != "u-env" || tok.Source != "env:"+envvars.CliUserAccessToken { + t.Errorf("unexpected: %+v", tok) + } +} + +func TestResolveToken_TATSet(t *testing.T) { + t.Setenv(envvars.CliTenantAccessToken, "t-env") + tok, err := (&Provider{}).ResolveToken(context.Background(), credential.TokenSpec{Type: credential.TokenTypeTAT}) + if err != nil { + t.Fatal(err) + } + if tok.Value != "t-env" || tok.Source != "env:"+envvars.CliTenantAccessToken { + t.Errorf("unexpected: %+v", tok) + } +} + +func TestResolveToken_NotSet(t *testing.T) { + tok, err := (&Provider{}).ResolveToken(context.Background(), credential.TokenSpec{Type: credential.TokenTypeUAT}) + if err != nil || tok != nil { + t.Errorf("expected nil, nil; got %+v, %v", tok, err) + } +} + +func TestResolveAccount_StrictModeBot(t *testing.T) { + t.Setenv(envvars.CliAppID, "app") + t.Setenv(envvars.CliAppSecret, "secret") + t.Setenv(envvars.CliStrictMode, "bot") + acct, err := (&Provider{}).ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if !acct.SupportedIdentities.BotOnly() { + t.Errorf("expected bot-only, got %d", acct.SupportedIdentities) + } +} + +func TestResolveAccount_StrictModeUser(t *testing.T) { + t.Setenv(envvars.CliAppID, "app") + t.Setenv(envvars.CliAppSecret, "secret") + t.Setenv(envvars.CliStrictMode, "user") + acct, err := (&Provider{}).ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if !acct.SupportedIdentities.UserOnly() { + t.Errorf("expected user-only, got %d", acct.SupportedIdentities) + } +} + +func TestResolveAccount_StrictModeOff(t *testing.T) { + t.Setenv(envvars.CliAppID, "app") + t.Setenv(envvars.CliAppSecret, "secret") + t.Setenv(envvars.CliStrictMode, "off") + acct, err := (&Provider{}).ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if acct.SupportedIdentities != credential.SupportsAll { + t.Errorf("expected SupportsAll, got %d", acct.SupportedIdentities) + } +} + +func TestResolveAccount_InferFromUATOnly(t *testing.T) { + t.Setenv(envvars.CliAppID, "app") + t.Setenv(envvars.CliAppSecret, "secret") + t.Setenv(envvars.CliUserAccessToken, "u-tok") + acct, err := (&Provider{}).ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if !acct.SupportedIdentities.UserOnly() { + t.Errorf("expected user-only from UAT inference, got %d", acct.SupportedIdentities) + } + if acct.DefaultAs != "user" { + t.Errorf("expected default-as user from UAT inference, got %q", acct.DefaultAs) + } +} + +func TestResolveAccount_InferFromTATOnly(t *testing.T) { + t.Setenv(envvars.CliAppID, "app") + t.Setenv(envvars.CliAppSecret, "secret") + t.Setenv(envvars.CliTenantAccessToken, "t-tok") + acct, err := (&Provider{}).ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if !acct.SupportedIdentities.BotOnly() { + t.Errorf("expected bot-only from TAT inference, got %d", acct.SupportedIdentities) + } + if acct.DefaultAs != "bot" { + t.Errorf("expected default-as bot from TAT inference, got %q", acct.DefaultAs) + } +} + +func TestResolveAccount_InferBothTokens(t *testing.T) { + t.Setenv(envvars.CliAppID, "app") + t.Setenv(envvars.CliAppSecret, "secret") + t.Setenv(envvars.CliUserAccessToken, "u-tok") + t.Setenv(envvars.CliTenantAccessToken, "t-tok") + acct, err := (&Provider{}).ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if acct.SupportedIdentities != credential.SupportsAll { + t.Errorf("expected SupportsAll, got %d", acct.SupportedIdentities) + } + if acct.DefaultAs != "user" { + t.Errorf("expected default-as user when both tokens are present, got %q", acct.DefaultAs) + } +} + +func TestResolveAccount_StrictModeOverridesTokenInference(t *testing.T) { + t.Setenv(envvars.CliAppID, "app") + t.Setenv(envvars.CliAppSecret, "secret") + t.Setenv(envvars.CliUserAccessToken, "u-tok") + t.Setenv(envvars.CliTenantAccessToken, "t-tok") + t.Setenv(envvars.CliStrictMode, "bot") + acct, err := (&Provider{}).ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if !acct.SupportedIdentities.BotOnly() { + t.Errorf("strict mode should override token inference, got %d", acct.SupportedIdentities) + } +} + +func TestResolveAccount_InvalidStrictModeRejected(t *testing.T) { + t.Setenv(envvars.CliAppID, "app") + t.Setenv(envvars.CliAppSecret, "secret") + t.Setenv(envvars.CliStrictMode, "invalid") + + _, err := (&Provider{}).ResolveAccount(context.Background()) + if err == nil { + t.Fatal("expected error for invalid strict mode") + } + var blockErr *credential.BlockError + if !errors.As(err, &blockErr) { + t.Fatalf("expected BlockError, got %T", err) + } + if !strings.Contains(err.Error(), envvars.CliStrictMode) { + t.Fatalf("error = %v, want mention of %s", err, envvars.CliStrictMode) + } +} + +func TestResolveAccount_InvalidDefaultAsRejected(t *testing.T) { + t.Setenv(envvars.CliAppID, "app") + t.Setenv(envvars.CliAppSecret, "secret") + t.Setenv(envvars.CliDefaultAs, "invalid") + + _, err := (&Provider{}).ResolveAccount(context.Background()) + if err == nil { + t.Fatal("expected error for invalid default-as") + } + var blockErr *credential.BlockError + if !errors.As(err, &blockErr) { + t.Fatalf("expected BlockError, got %T", err) + } + if !strings.Contains(err.Error(), envvars.CliDefaultAs) { + t.Fatalf("error = %v, want mention of %s", err, envvars.CliDefaultAs) + } +} diff --git a/extension/credential/registry.go b/extension/credential/registry.go new file mode 100644 index 000000000..52ec9ebf8 --- /dev/null +++ b/extension/credential/registry.go @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package credential + +import "sync" + +var ( + mu sync.Mutex + providers []Provider +) + +// Register registers a credential Provider. +// Providers are consulted in registration order. +// Typically called from init() via blank import. +func Register(p Provider) { + mu.Lock() + defer mu.Unlock() + providers = append(providers, p) +} + +// Providers returns all registered providers (snapshot). +func Providers() []Provider { + mu.Lock() + defer mu.Unlock() + result := make([]Provider, len(providers)) + copy(result, providers) + return result +} diff --git a/extension/credential/registry_test.go b/extension/credential/registry_test.go new file mode 100644 index 000000000..1c394f514 --- /dev/null +++ b/extension/credential/registry_test.go @@ -0,0 +1,51 @@ +package credential + +import ( + "context" + "testing" +) + +type stubProvider struct{ name string } + +func (s *stubProvider) Name() string { return s.name } +func (s *stubProvider) ResolveAccount(ctx context.Context) (*Account, error) { + return &Account{AppID: s.name}, nil +} +func (s *stubProvider) ResolveToken(ctx context.Context, req TokenSpec) (*Token, error) { + return &Token{Value: "tok-" + s.name, Source: s.name}, nil +} + +func TestRegisterAndProviders(t *testing.T) { + mu.Lock() + old := providers + providers = nil + mu.Unlock() + defer func() { mu.Lock(); providers = old; mu.Unlock() }() + + Register(&stubProvider{name: "a"}) + Register(&stubProvider{name: "b"}) + + got := Providers() + if len(got) != 2 { + t.Fatalf("expected 2, got %d", len(got)) + } + if got[0].Name() != "a" || got[1].Name() != "b" { + t.Errorf("unexpected order: %s, %s", got[0].Name(), got[1].Name()) + } +} + +func TestProviders_ReturnsSnapshot(t *testing.T) { + mu.Lock() + old := providers + providers = nil + mu.Unlock() + defer func() { mu.Lock(); providers = old; mu.Unlock() }() + + Register(&stubProvider{name: "x"}) + snap := Providers() + Register(&stubProvider{name: "y"}) + + if len(snap) != 1 { + t.Fatalf("snapshot should not be affected, got %d", len(snap)) + } +} diff --git a/extension/credential/types.go b/extension/credential/types.go new file mode 100644 index 000000000..209013fda --- /dev/null +++ b/extension/credential/types.go @@ -0,0 +1,100 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package credential + +import "context" + +// Brand represents the Lark platform brand. +type Brand string + +const ( + BrandLark Brand = "lark" + BrandFeishu Brand = "feishu" +) + +// NoAppSecret marks that a credential source does not provide a real app secret. +// Token-only sources should return this value instead of inventing placeholder text. +const NoAppSecret = "" + +// Identity represents the caller identity type. +type Identity string + +const ( + IdentityUser Identity = "user" + IdentityBot Identity = "bot" + IdentityAuto Identity = "auto" +) + +// IdentitySupport declares which identities a credential source can provide. +type IdentitySupport uint8 + +const ( + SupportsUser IdentitySupport = 1 << iota + SupportsBot + SupportsAll = SupportsUser | SupportsBot +) + +// Has reports whether s includes the given flag. +func (s IdentitySupport) Has(flag IdentitySupport) bool { return s&flag != 0 } + +// UserOnly returns true if only user identity is supported. +func (s IdentitySupport) UserOnly() bool { return s == SupportsUser } + +// BotOnly returns true if only bot identity is supported. +func (s IdentitySupport) BotOnly() bool { return s == SupportsBot } + +// Account holds resolved app credentials and configuration. +type Account struct { + AppID string + AppSecret string // real app secret; empty or NoAppSecret means unavailable + Brand Brand // BrandLark or BrandFeishu + DefaultAs Identity // IdentityUser / IdentityBot / IdentityAuto; empty = not set + ProfileName string + OpenID string // optional; if UAT is available, API result takes precedence + SupportedIdentities IdentitySupport // zero = provider did not declare; treat as no restriction +} + +// Token holds a resolved access token and optional metadata. +type Token struct { + Value string + Scopes string // space-separated; empty = skip scope pre-check + Source string // e.g. "env:LARKSUITE_CLI_USER_ACCESS_TOKEN", "vault:addr" +} + +// TokenType represents the kind of access token. +type TokenType string + +const ( + TokenTypeUAT TokenType = "uat" + TokenTypeTAT TokenType = "tat" +) + +// TokenSpec describes what token is needed. +type TokenSpec struct { + Type TokenType + AppID string +} + +// BlockError is returned by a Provider to actively reject a request +// and prevent subsequent providers in the chain from being consulted. +type BlockError struct { + Provider string + Reason string +} + +func (e *BlockError) Error() string { + return "blocked by " + e.Provider + ": " + e.Reason +} + +// Provider is the unified interface for credential resolution. +// +// Flow control uses Go's native mechanisms: +// - Handle: return &Account{...}, nil or return &Token{...}, nil +// - Skip: return nil, nil +// - Block: return nil, &BlockError{...} +type Provider interface { + Name() string + ResolveAccount(ctx context.Context) (*Account, error) + ResolveToken(ctx context.Context, req TokenSpec) (*Token, error) +} diff --git a/extension/credential/types_test.go b/extension/credential/types_test.go new file mode 100644 index 000000000..315974eb0 --- /dev/null +++ b/extension/credential/types_test.go @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package credential + +import "testing" + +func TestIdentitySupport_Has(t *testing.T) { + if !SupportsAll.Has(SupportsUser) { + t.Error("SupportsAll should have SupportsUser") + } + if !SupportsAll.Has(SupportsBot) { + t.Error("SupportsAll should have SupportsBot") + } + if SupportsUser.Has(SupportsBot) { + t.Error("SupportsUser should not have SupportsBot") + } +} + +func TestIdentitySupport_UserOnly(t *testing.T) { + if !SupportsUser.UserOnly() { + t.Error("SupportsUser.UserOnly() should be true") + } + if SupportsAll.UserOnly() { + t.Error("SupportsAll.UserOnly() should be false") + } + if IdentitySupport(0).UserOnly() { + t.Error("zero value UserOnly() should be false") + } +} + +func TestIdentitySupport_BotOnly(t *testing.T) { + if !SupportsBot.BotOnly() { + t.Error("SupportsBot.BotOnly() should be true") + } + if SupportsAll.BotOnly() { + t.Error("SupportsAll.BotOnly() should be false") + } +} diff --git a/go.mod b/go.mod index b264d9bca..0a294e07f 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/smartystreets/goconvey v1.8.1 github.com/spf13/cobra v1.10.2 + github.com/spf13/pflag v1.0.9 github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/zalando/go-keyring v0.2.8 @@ -54,7 +55,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/smarty/assertions v1.15.0 // indirect - github.com/spf13/pflag v1.0.9 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect diff --git a/internal/auth/uat_client.go b/internal/auth/uat_client.go index f04b61440..35bf4133a 100644 --- a/internal/auth/uat_client.go +++ b/internal/auth/uat_client.go @@ -19,6 +19,7 @@ import ( "github.com/gofrs/flock" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/vfs" ) var safeIDChars = regexp.MustCompile(`[^a-zA-Z0-9._-]`) @@ -128,7 +129,7 @@ func refreshWithLock(httpClient *http.Client, opts UATCallOptions, stored *Store configDir := core.GetConfigDir() lockDir := filepath.Join(configDir, "locks") - if err := os.MkdirAll(lockDir, 0700); err != nil { + if err := vfs.MkdirAll(lockDir, 0700); err != nil { return nil, fmt.Errorf("failed to create lock directory: %w", err) } diff --git a/internal/client/client.go b/internal/client/client.go index 030a0dedd..816b637b1 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -4,18 +4,22 @@ package client import ( + "bytes" "context" + "encoding/json" + "errors" "fmt" "io" "net/http" + "net/url" "strings" "time" lark "github.com/larksuite/oapi-sdk-go/v3" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" - "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/util" ) @@ -32,10 +36,26 @@ type RawApiRequest struct { // APIClient wraps lark.Client for all Lark Open API calls. type APIClient struct { - Config *core.CliConfig - SDK *lark.Client // All Lark API calls go through SDK - HTTP *http.Client // Only for non-Lark API (OAuth, MCP, etc.) - ErrOut io.Writer // debug/progress output + Config *core.CliConfig + SDK *lark.Client // All Lark API calls go through SDK + HTTP *http.Client // Only for non-Lark API (OAuth, MCP, etc.) + ErrOut io.Writer // debug/progress output + Credential *credential.CredentialProvider +} + +func (c *APIClient) resolveAccessToken(ctx context.Context, as core.Identity) (string, error) { + result, err := c.Credential.ResolveToken(ctx, credential.NewTokenSpec(as, c.Config.AppID)) + if err != nil { + var unavailableErr *credential.TokenUnavailableError + if errors.As(err, &unavailableErr) { + return "", output.ErrAuth("no access token available for %s", as) + } + return "", err + } + if result.Token == "" { + return "", output.ErrAuth("no access token available for %s", as) + } + return result.Token, nil } // buildApiReq converts a RawApiRequest into SDK types and collects @@ -74,17 +94,15 @@ func (c *APIClient) buildApiReq(request RawApiRequest) (*larkcore.ApiReq, []lark func (c *APIClient) DoSDKRequest(ctx context.Context, req *larkcore.ApiReq, as core.Identity, extraOpts ...larkcore.RequestOptionFunc) (*larkcore.ApiResp, error) { var opts []larkcore.RequestOptionFunc + token, err := c.resolveAccessToken(ctx, as) + if err != nil { + return nil, err + } if as.IsBot() { req.SupportedAccessTokenTypes = []larkcore.AccessTokenType{larkcore.AccessTokenTypeTenant} + opts = append(opts, larkcore.WithTenantAccessToken(token)) } else { req.SupportedAccessTokenTypes = []larkcore.AccessTokenType{larkcore.AccessTokenTypeUser} - if c.Config.UserOpenId == "" { - return nil, fmt.Errorf("login required: lark-cli auth login (or use --as bot)") - } - token, err := auth.GetValidAccessToken(c.HTTP, auth.NewUATCallOptions(c.Config, c.ErrOut)) - if err != nil { - return nil, err - } opts = append(opts, larkcore.WithUserAccessToken(token)) } @@ -92,6 +110,146 @@ func (c *APIClient) DoSDKRequest(ctx context.Context, req *larkcore.ApiReq, as c return c.SDK.Do(ctx, req, opts...) } +// DoStream executes a streaming HTTP request against the Lark OpenAPI endpoint. +// Unlike DoSDKRequest (which buffers the full body via the SDK), DoStream returns +// a live *http.Response whose Body is an io.Reader for streaming consumption. +// Auth is resolved via Credential (same as DoSDKRequest). Security headers and +// any extra headers from opts are applied automatically. +// HTTP errors (status >= 400) are handled internally: the body is read (up to 4 KB), +// closed, and returned as an output.ErrNetwork — callers only receive successful responses. +func (c *APIClient) DoStream(ctx context.Context, req *larkcore.ApiReq, as core.Identity, opts ...Option) (*http.Response, error) { + cfg := buildConfig(opts) + + // Resolve auth + token, err := c.resolveAccessToken(ctx, as) + if err != nil { + return nil, err + } + + // Build URL + requestURL, err := buildStreamURL(c.Config.Brand, req) + if err != nil { + return nil, err + } + + // Build body + bodyReader, contentType, err := buildStreamBody(req.Body) + if err != nil { + return nil, err + } + + // Timeout — use context deadline only; httpClient.Timeout would cut off + // healthy streaming responses because it includes body read time. + httpClient := *c.HTTP + httpClient.Timeout = 0 + cancel := func() {} + requestCtx := ctx + if cfg.timeout > 0 { + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + requestCtx, cancel = context.WithTimeout(ctx, cfg.timeout) + } + } + + // Build request + httpReq, err := http.NewRequestWithContext(requestCtx, req.HttpMethod, requestURL, bodyReader) + if err != nil { + cancel() + return nil, output.ErrNetwork("stream request failed: %s", err) + } + + // Apply headers from opts + for k, vs := range cfg.headers { + for _, v := range vs { + httpReq.Header.Add(k, v) + } + } + + if contentType != "" { + httpReq.Header.Set("Content-Type", contentType) + } + httpReq.Header.Set("Authorization", "Bearer "+token) + + resp, err := httpClient.Do(httpReq) + if err != nil { + cancel() + return nil, output.ErrNetwork("stream request failed: %s", err) + } + resp.Body = &cancelOnCloseBody{ReadCloser: resp.Body, cancel: cancel} + + // Handle HTTP errors internally + if resp.StatusCode >= 400 { + defer resp.Body.Close() + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + msg := strings.TrimSpace(string(errBody)) + if msg != "" { + return nil, output.ErrNetwork("HTTP %d: %s", resp.StatusCode, msg) + } + return nil, output.ErrNetwork("HTTP %d", resp.StatusCode) + } + + return resp, nil +} + +type cancelOnCloseBody struct { + io.ReadCloser + cancel context.CancelFunc +} + +func (r *cancelOnCloseBody) Close() error { + err := r.ReadCloser.Close() + if r.cancel != nil { + r.cancel() + } + return err +} + +func buildStreamURL(brand core.LarkBrand, req *larkcore.ApiReq) (string, error) { + requestURL := req.ApiPath + if !strings.HasPrefix(requestURL, "http://") && !strings.HasPrefix(requestURL, "https://") { + var pathSegs []string + for _, segment := range strings.Split(req.ApiPath, "/") { + if !strings.HasPrefix(segment, ":") { + pathSegs = append(pathSegs, segment) + continue + } + pathKey := strings.TrimPrefix(segment, ":") + pathValue, ok := req.PathParams[pathKey] + if !ok { + return "", output.ErrValidation("missing path param %q for %s", pathKey, req.ApiPath) + } + if pathValue == "" { + return "", output.ErrValidation("empty path param %q for %s", pathKey, req.ApiPath) + } + pathSegs = append(pathSegs, url.PathEscape(pathValue)) + } + endpoints := core.ResolveEndpoints(brand) + requestURL = strings.TrimRight(endpoints.Open, "/") + strings.Join(pathSegs, "/") + } + if query := req.QueryParams.Encode(); query != "" { + requestURL += "?" + query + } + return requestURL, nil +} + +func buildStreamBody(body interface{}) (io.Reader, string, error) { + switch typed := body.(type) { + case nil: + return nil, "", nil + case io.Reader: + return typed, "", nil + case []byte: + return bytes.NewReader(typed), "", nil + case string: + return strings.NewReader(typed), "text/plain; charset=utf-8", nil + default: + payload, err := json.Marshal(typed) + if err != nil { + return nil, "", output.Errorf(output.ExitInternal, "api_error", "failed to encode request body: %s", err) + } + return bytes.NewReader(payload), "application/json", nil + } +} + // DoAPI executes a raw Lark SDK request and returns the raw *larkcore.ApiResp. // Unlike CallAPI which always JSON-decodes, DoAPI returns the raw response — suitable // for file downloads (pass larkcore.WithFileDownload() via request.ExtraOpts) and diff --git a/internal/client/client_test.go b/internal/client/client_test.go index f0419f3a3..5a97cecbb 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -7,13 +7,20 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "net/http" + "net/http/httptest" "strings" "testing" + "time" lark "github.com/larksuite/oapi-sdk-go/v3" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" + + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" + "github.com/larksuite/cli/internal/output" ) // roundTripFunc is an adapter to use a function as http.RoundTripper. @@ -31,18 +38,36 @@ func jsonResponse(body interface{}) *http.Response { } } +// staticTokenResolver always returns a fixed token without any HTTP calls. +type staticTokenResolver struct{} + +func (s *staticTokenResolver) ResolveToken(_ context.Context, _ credential.TokenSpec) (*credential.TokenResult, error) { + return &credential.TokenResult{Token: "test-token"}, nil +} + +type missingTokenResolver struct{} + +func (s *missingTokenResolver) ResolveToken(_ context.Context, req credential.TokenSpec) (*credential.TokenResult, error) { + return nil, &credential.TokenUnavailableError{Source: "default", Type: req.Type} +} + // newTestAPIClient creates an APIClient with a mock HTTP transport. func newTestAPIClient(t *testing.T, rt http.RoundTripper) (*APIClient, *bytes.Buffer) { t.Helper() errBuf := &bytes.Buffer{} httpClient := &http.Client{Transport: rt} sdk := lark.NewClient("test-app", "test-secret", + lark.WithEnableTokenCache(false), lark.WithLogLevel(larkcore.LogLevelError), lark.WithHttpClient(httpClient), ) + testCred := credential.NewCredentialProvider(nil, nil, &staticTokenResolver{}, nil) + cfg := &core.CliConfig{AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu} return &APIClient{ - SDK: sdk, - ErrOut: errBuf, + SDK: sdk, + ErrOut: errBuf, + Credential: testCred, + Config: cfg, }, errBuf } @@ -87,21 +112,13 @@ func TestMimeToExt(t *testing.T) { func TestStreamPages_NonBatchAPI_NoArrayField(t *testing.T) { rt := roundTripFunc(func(req *http.Request) (*http.Response, error) { - switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return jsonResponse(map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-token", "expire": 7200, - }), nil - default: - return jsonResponse(map[string]interface{}{ - "code": 0, "msg": "ok", - "data": map[string]interface{}{ - "user_id": "u123", - "name": "Test User", - }, - }), nil - } + return jsonResponse(map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "user_id": "u123", + "name": "Test User", + }, + }), nil }) ac, errBuf := newTestAPIClient(t, rt) @@ -138,21 +155,13 @@ func TestStreamPages_NonBatchAPI_NoArrayField(t *testing.T) { func TestStreamPages_BatchAPI_WithArrayField(t *testing.T) { rt := roundTripFunc(func(req *http.Request) (*http.Response, error) { - switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return jsonResponse(map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-token", "expire": 7200, - }), nil - default: - return jsonResponse(map[string]interface{}{ - "code": 0, "msg": "ok", - "data": map[string]interface{}{ - "items": []interface{}{map[string]interface{}{"id": "1"}, map[string]interface{}{"id": "2"}}, - "has_more": false, - }, - }), nil - } + return jsonResponse(map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "1"}, map[string]interface{}{"id": "2"}}, + "has_more": false, + }, + }), nil }) ac, errBuf := newTestAPIClient(t, rt) @@ -186,23 +195,15 @@ func TestStreamPages_BatchAPI_WithArrayField(t *testing.T) { func TestPaginateAll_PageLimitStopsPagination(t *testing.T) { apiCalls := 0 rt := roundTripFunc(func(req *http.Request) (*http.Response, error) { - switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return jsonResponse(map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-token", "expire": 7200, - }), nil - default: - apiCalls++ - return jsonResponse(map[string]interface{}{ - "code": 0, "msg": "ok", - "data": map[string]interface{}{ - "items": []interface{}{map[string]interface{}{"id": apiCalls}}, - "has_more": true, - "page_token": "next", - }, - }), nil - } + apiCalls++ + return jsonResponse(map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": apiCalls}}, + "has_more": true, + "page_token": "next", + }, + }), nil }) ac, errBuf := newTestAPIClient(t, rt) @@ -319,21 +320,13 @@ func TestBuildApiReq_QueryParams(t *testing.T) { func TestPaginateAll_NoStreamSummaryLog(t *testing.T) { rt := roundTripFunc(func(req *http.Request) (*http.Response, error) { - switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return jsonResponse(map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-token", "expire": 7200, - }), nil - default: - return jsonResponse(map[string]interface{}{ - "code": 0, "msg": "ok", - "data": map[string]interface{}{ - "items": []interface{}{map[string]interface{}{"id": "1"}}, - "has_more": false, - }, - }), nil - } + return jsonResponse(map[string]interface{}{ + "code": 0, "msg": "ok", + "data": map[string]interface{}{ + "items": []interface{}{map[string]interface{}{"id": "1"}}, + "has_more": false, + }, + }), nil }) ac, errBuf := newTestAPIClient(t, rt) @@ -354,3 +347,78 @@ func TestPaginateAll_NoStreamSummaryLog(t *testing.T) { t.Fatal("expected non-nil result") } } + +func TestDoStream_IgnoresBaseHTTPClientTimeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + time.Sleep(25 * time.Millisecond) + _, _ = io.WriteString(w, "ok") + })) + defer srv.Close() + + ac := &APIClient{ + HTTP: &http.Client{Timeout: 5 * time.Millisecond}, + Credential: credential.NewCredentialProvider(nil, nil, &staticTokenResolver{}, nil), + Config: &core.CliConfig{AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu}, + } + + resp, err := ac.DoStream(context.Background(), &larkcore.ApiReq{ + HttpMethod: http.MethodGet, + ApiPath: srv.URL, + }, core.AsBot) + if err != nil { + t.Fatalf("DoStream() error = %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll() error = %v", err) + } + if string(body) != "ok" { + t.Fatalf("response body = %q, want %q", string(body), "ok") + } +} + +func TestDoSDKRequest_MissingTokenReturnsAuthError(t *testing.T) { + ac, _ := newTestAPIClient(t, roundTripFunc(func(req *http.Request) (*http.Response, error) { + t.Fatal("unexpected HTTP request") + return nil, nil + })) + ac.Credential = credential.NewCredentialProvider(nil, nil, &missingTokenResolver{}, nil) + + _, err := ac.DoSDKRequest(context.Background(), &larkcore.ApiReq{ + HttpMethod: http.MethodGet, + ApiPath: "/open-apis/test", + }, core.AsBot) + if err == nil { + t.Fatal("DoSDKRequest() error = nil, want auth error") + } + var exitErr *output.ExitError + if !strings.Contains(err.Error(), "no access token available") || !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "auth" { + t.Fatalf("DoSDKRequest() error = %v, want auth error", err) + } +} + +func TestDoStream_MissingTokenReturnsAuthError(t *testing.T) { + ac := &APIClient{ + HTTP: &http.Client{}, + Credential: credential.NewCredentialProvider(nil, nil, &missingTokenResolver{}, nil), + Config: &core.CliConfig{AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu}, + } + + _, err := ac.DoStream(context.Background(), &larkcore.ApiReq{ + HttpMethod: http.MethodGet, + ApiPath: "https://example.com/open-apis/test", + }, core.AsBot) + if err == nil { + t.Fatal("DoStream() error = nil, want auth error") + } + var exitErr *output.ExitError + if !strings.Contains(err.Error(), "no access token available") || !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "auth" { + t.Fatalf("DoStream() error = %v, want auth error", err) + } +} diff --git a/internal/client/option.go b/internal/client/option.go new file mode 100644 index 000000000..ce5a5635e --- /dev/null +++ b/internal/client/option.go @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package client + +import ( + "net/http" + "time" +) + +// Option configures API request behavior for DoStream (and future DoSDKRequest). +type Option func(*requestConfig) + +type requestConfig struct { + timeout time.Duration + headers http.Header +} + +// WithTimeout sets a request-level timeout that overrides the client default. +func WithTimeout(d time.Duration) Option { + return func(c *requestConfig) { + c.timeout = d + } +} + +// WithHeaders adds extra HTTP headers to the request. +func WithHeaders(h http.Header) Option { + return func(c *requestConfig) { + if c.headers == nil { + c.headers = make(http.Header) + } + for k, vs := range h { + for _, v := range vs { + c.headers.Add(k, v) + } + } + } +} + +func buildConfig(opts []Option) requestConfig { + var cfg requestConfig + for _, o := range opts { + o(&cfg) + } + return cfg +} diff --git a/internal/client/response.go b/internal/client/response.go index db34400b1..10695614f 100644 --- a/internal/client/response.go +++ b/internal/client/response.go @@ -9,7 +9,6 @@ import ( "fmt" "io" "mime" - "os" "path/filepath" "strings" @@ -18,6 +17,7 @@ import ( "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/util" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" ) // ── Response routing ── @@ -125,7 +125,7 @@ func SaveResponse(resp *larkcore.ApiResp, outputPath string) (map[string]interfa return nil, fmt.Errorf("unsafe output path: %s", err) } - if err := os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { + if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil { return nil, fmt.Errorf("create directory: %s", err) } diff --git a/internal/cmdutil/annotations.go b/internal/cmdutil/annotations.go index 1aacec7e3..85faa12a8 100644 --- a/internal/cmdutil/annotations.go +++ b/internal/cmdutil/annotations.go @@ -3,9 +3,31 @@ package cmdutil -import "github.com/spf13/cobra" +import ( + "strings" + + "github.com/spf13/cobra" +) const skipAuthCheckKey = "skipAuthCheck" +const annotationSupportedIdentities = "lark:supportedIdentities" + +// SetSupportedIdentities marks which identities a command supports. +func SetSupportedIdentities(cmd *cobra.Command, identities []string) { + if cmd.Annotations == nil { + cmd.Annotations = map[string]string{} + } + cmd.Annotations[annotationSupportedIdentities] = strings.Join(identities, ",") +} + +// GetSupportedIdentities returns the declared identities, or nil if not declared. +func GetSupportedIdentities(cmd *cobra.Command) []string { + v, ok := cmd.Annotations[annotationSupportedIdentities] + if !ok || v == "" { + return nil + } + return strings.Split(v, ",") +} // DisableAuthCheck marks a command (and all its children) as not requiring auth. func DisableAuthCheck(cmd *cobra.Command) { diff --git a/internal/cmdutil/annotations_test.go b/internal/cmdutil/annotations_test.go index 6ee5bab15..131baaaff 100644 --- a/internal/cmdutil/annotations_test.go +++ b/internal/cmdutil/annotations_test.go @@ -49,3 +49,27 @@ func TestIsAuthCheckDisabled_NoInheritanceUpward(t *testing.T) { t.Error("child should have disabled auth check") } } + +func TestSetGetSupportedIdentities(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + if got := GetSupportedIdentities(cmd); got != nil { + t.Errorf("expected nil, got %v", got) + } + SetSupportedIdentities(cmd, []string{"user", "bot"}) + got := GetSupportedIdentities(cmd) + if len(got) != 2 || got[0] != "user" || got[1] != "bot" { + t.Errorf("expected [user bot], got %v", got) + } +} + +func TestSetSupportedIdentities_OverwriteExisting(t *testing.T) { + cmd := &cobra.Command{Use: "test", Annotations: map[string]string{"other": "val"}} + SetSupportedIdentities(cmd, []string{"bot"}) + if cmd.Annotations["other"] != "val" { + t.Error("existing annotation should be preserved") + } + got := GetSupportedIdentities(cmd) + if len(got) != 1 || got[0] != "bot" { + t.Errorf("expected [bot], got %v", got) + } +} diff --git a/internal/cmdutil/factory.go b/internal/cmdutil/factory.go index a53d08684..8845f1dc6 100644 --- a/internal/cmdutil/factory.go +++ b/internal/cmdutil/factory.go @@ -4,92 +4,102 @@ package cmdutil import ( + "context" "fmt" "io" "net/http" - "os" "strings" lark "github.com/larksuite/oapi-sdk-go/v3" "github.com/spf13/cobra" - "github.com/larksuite/cli/internal/auth" + extcred "github.com/larksuite/cli/extension/credential" "github.com/larksuite/cli/internal/client" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/keychain" "github.com/larksuite/cli/internal/output" ) -// ResolveConfig returns Config() for bot identity, or AuthConfig() for user identity. -func (f *Factory) ResolveConfig(as core.Identity) (*core.CliConfig, error) { - if as.IsBot() { - return f.Config() - } - return f.AuthConfig() -} - // Factory holds shared dependencies injected into every command. // All function fields are lazily initialized and cached after first call. // In tests, replace any field to stub out external dependencies. +type InvocationContext struct { + Profile string +} + type Factory struct { - Config func() (*core.CliConfig, error) // lazily loads app config (credentials, brand, defaultAs) - AuthConfig func() (*core.CliConfig, error) // like Config but also requires a logged-in user + Config func() (*core.CliConfig, error) // lazily loads app config from Credential HttpClient func() (*http.Client, error) // HTTP client for non-Lark API calls (with retry and security headers) LarkClient func() (*lark.Client, error) // Lark SDK client for all Open API calls IOStreams *IOStreams // stdin/stdout/stderr streams + Invocation InvocationContext // Immutable call context; do not mutate after Factory construction. Keychain keychain.KeychainAccess // secret storage (real keychain in prod, mock in tests) IdentityAutoDetected bool // set by ResolveAs when identity was auto-detected ResolvedIdentity core.Identity // identity resolved by the last ResolveAs call + + Credential *credential.CredentialProvider } // ResolveAs returns the effective identity type. // If the user explicitly passed --as, use that value; otherwise use the configured default. -// When the value is "auto" (or unset), auto-detect based on login state. -func (f *Factory) ResolveAs(cmd *cobra.Command, flagAs core.Identity) core.Identity { +// When the value is "auto" (or unset), auto-detect based on credential hints. +func (f *Factory) ResolveAs(ctx context.Context, cmd *cobra.Command, flagAs core.Identity) core.Identity { f.IdentityAutoDetected = false + + // Strict mode: force identity regardless of flags or config. + if forced := f.ResolveStrictMode(ctx).ForcedIdentity(); forced != "" { + f.ResolvedIdentity = forced + return forced + } + if cmd != nil && cmd.Flags().Changed("as") { if flagAs != "auto" { f.ResolvedIdentity = flagAs return flagAs } // --as auto: fall through to auto-detect - } else if defaultAs := f.resolveDefaultAs(); defaultAs != "" && defaultAs != "auto" { - f.ResolvedIdentity = core.Identity(defaultAs) - return f.ResolvedIdentity } - // Auto-detect based on login state + + hint := f.resolveIdentityHint(ctx) + if cmd == nil || !cmd.Flags().Changed("as") { + if defaultAs := resolveDefaultAsFromHint(hint); defaultAs != "" && defaultAs != core.AsAuto { + f.ResolvedIdentity = defaultAs + return f.ResolvedIdentity + } + } + + // Auto-detect based on credential hint f.IdentityAutoDetected = true - result := f.autoDetectIdentity() + result := autoDetectIdentityFromHint(hint) f.ResolvedIdentity = result return result } -// resolveDefaultAs returns the configured default identity: env var > config file. -func (f *Factory) resolveDefaultAs() string { - if v := os.Getenv("LARKSUITE_CLI_DEFAULT_AS"); v != "" { - return v - } - if cfg, err := f.Config(); err == nil { - return cfg.DefaultAs +func resolveDefaultAsFromHint(hint *credential.IdentityHint) core.Identity { + if hint != nil { + return hint.DefaultAs } return "" } -// autoDetectIdentity checks the login state and returns user if logged in, bot otherwise. -func (f *Factory) autoDetectIdentity() core.Identity { - cfg, err := f.Config() - if err != nil || cfg.UserOpenId == "" { - return core.AsBot +func autoDetectIdentityFromHint(hint *credential.IdentityHint) core.Identity { + if hint != nil && hint.AutoAs != "" { + return hint.AutoAs } - stored := auth.GetStoredToken(cfg.AppID, cfg.UserOpenId) - if stored == nil { - return core.AsBot + return core.AsBot +} + +func (f *Factory) resolveIdentityHint(ctx context.Context) *credential.IdentityHint { + if f.Credential == nil { + return nil } - if auth.TokenStatus(stored) == "expired" { - return core.AsBot + hint, err := f.Credential.ResolveIdentityHint(ctx) + if err != nil { + return nil } - return core.AsUser + return hint } // CheckIdentity verifies the resolved identity is in the supported list. @@ -111,6 +121,39 @@ func (f *Factory) CheckIdentity(as core.Identity, supported []string) error { return fmt.Errorf("--as %s is not supported, this command only supports: %s", as, list) } +// ResolveStrictMode returns the effective strict mode by reading +// Account.SupportedIdentities from the credential provider chain. +func (f *Factory) ResolveStrictMode(ctx context.Context) core.StrictMode { + if f.Credential == nil { + return core.StrictModeOff + } + acct, err := f.Credential.ResolveAccount(ctx) + if err != nil || acct == nil { + return core.StrictModeOff + } + ids := extcred.IdentitySupport(acct.SupportedIdentities) + switch { + case ids.BotOnly(): + return core.StrictModeBot + case ids.UserOnly(): + return core.StrictModeUser + default: + return core.StrictModeOff + } +} + +// CheckStrictMode returns an error if strict mode is active and identity is not allowed. +func (f *Factory) CheckStrictMode(ctx context.Context, as core.Identity) error { + mode := f.ResolveStrictMode(ctx) + if mode.IsActive() && !mode.AllowsIdentity(as) { + return output.Errorf(output.ExitValidation, "strict_mode", + "strict mode is %q, only %s identity is allowed. "+ + "This setting is managed by the administrator and must not be modified by AI agents.", + mode, mode.ForcedIdentity()) + } + return nil +} + // NewAPIClient creates an APIClient using the Factory's base Config (app credentials only). // For user-mode calls where the correct user profile matters, use NewAPIClientWithConfig instead. func (f *Factory) NewAPIClient() (*client.APIClient, error) { @@ -122,8 +165,7 @@ func (f *Factory) NewAPIClient() (*client.APIClient, error) { } // NewAPIClientWithConfig creates an APIClient with an explicit config. -// Use this when the caller has already resolved the correct user profile -// (e.g. via AuthConfig for user-mode commands). +// Use this when the caller has already resolved the correct config. func (f *Factory) NewAPIClientWithConfig(cfg *core.CliConfig) (*client.APIClient, error) { sdk, err := f.LarkClient() if err != nil { @@ -137,5 +179,11 @@ func (f *Factory) NewAPIClientWithConfig(cfg *core.CliConfig) (*client.APIClient if f.IOStreams != nil { errOut = f.IOStreams.ErrOut } - return &client.APIClient{Config: cfg, SDK: sdk, HTTP: httpClient, ErrOut: errOut}, nil + return &client.APIClient{ + Config: cfg, + SDK: sdk, + HTTP: httpClient, + ErrOut: errOut, + Credential: f.Credential, + }, nil } diff --git a/internal/cmdutil/factory_default.go b/internal/cmdutil/factory_default.go index 769fbc409..5b08a05cb 100644 --- a/internal/cmdutil/factory_default.go +++ b/internal/cmdutil/factory_default.go @@ -4,7 +4,9 @@ package cmdutil import ( + "context" "fmt" + "io" "net/http" "os" "sync" @@ -14,17 +16,26 @@ import ( larkcore "github.com/larksuite/oapi-sdk-go/v3/core" "golang.org/x/term" + extcred "github.com/larksuite/cli/extension/credential" "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/keychain" "github.com/larksuite/cli/internal/registry" "github.com/larksuite/cli/internal/util" ) // NewDefault creates a production Factory with cached closures. -func NewDefault() *Factory { +// Initialization follows a credential-first order: +// +// Phase 1: HttpClient (no credential dependency) +// Phase 2: Credential (sole data source for account info) +// Phase 3: Config derived from Credential +// Phase 4: LarkClient derived from Credential +func NewDefault(inv InvocationContext) *Factory { f := &Factory{ - Keychain: keychain.Default(), + Keychain: keychain.Default(), + Invocation: inv, } f.IOStreams = &IOStreams{ In: os.Stdin, @@ -32,28 +43,33 @@ func NewDefault() *Factory { ErrOut: os.Stderr, IsTerminal: term.IsTerminal(int(os.Stdin.Fd())), } - f.Config = cachedConfigFunc(f) - f.AuthConfig = cachedAuthConfigFunc(f) + + // Phase 1: HttpClient (no credential dependency) f.HttpClient = cachedHttpClientFunc() - f.LarkClient = cachedLarkClientFunc(f) - return f -} -func cachedConfigFunc(f *Factory) func() (*core.CliConfig, error) { - return sync.OnceValues(func() (*core.CliConfig, error) { - cfg, err := core.RequireConfig(f.Keychain) + // Phase 2: Credential (sole data source) + f.Credential = buildCredentialProvider(credentialDeps{ + Keychain: f.Keychain, + Profile: inv.Profile, + HttpClient: f.HttpClient, + ErrOut: f.IOStreams.ErrOut, + }) + + // Phase 3: Config derived from Credential via an explicit conversion boundary. + f.Config = sync.OnceValues(func() (*core.CliConfig, error) { + acct, err := f.Credential.ResolveAccount(context.Background()) if err != nil { - return cfg, err + return nil, err } + cfg := acct.ToCliConfig() registry.InitWithBrand(cfg.Brand) return cfg, nil }) -} -func cachedAuthConfigFunc(f *Factory) func() (*core.CliConfig, error) { - return sync.OnceValues(func() (*core.CliConfig, error) { - return core.RequireAuth(f.Keychain) - }) + // Phase 4: LarkClient from Credential (placeholder AppSecret) + f.LarkClient = cachedLarkClientFunc(f) + + return f } // safeRedirectPolicy prevents credential headers from being forwarded @@ -92,26 +108,50 @@ func cachedHttpClientFunc() func() (*http.Client, error) { func cachedLarkClientFunc(f *Factory) func() (*lark.Client, error) { return sync.OnceValues(func() (*lark.Client, error) { - cfg, err := f.Config() + acct, err := f.Credential.ResolveAccount(context.Background()) if err != nil { return nil, err } opts := []lark.ClientOptionFunc{ + lark.WithEnableTokenCache(false), lark.WithLogLevel(larkcore.LogLevelError), lark.WithHeaders(BaseSecurityHeaders()), } - // Build SDK transport chain util.WarnIfProxied(os.Stderr) - var sdkTransport http.RoundTripper = util.NewBaseTransport() - sdkTransport = &UserAgentTransport{Base: sdkTransport} - sdkTransport = &auth.SecurityPolicyTransport{Base: sdkTransport} opts = append(opts, lark.WithHttpClient(&http.Client{ - Transport: sdkTransport, + Transport: buildSDKTransport(), CheckRedirect: safeRedirectPolicy, })) - ep := core.ResolveEndpoints(cfg.Brand) + ep := core.ResolveEndpoints(acct.Brand) opts = append(opts, lark.WithOpenBaseUrl(ep.Open)) - client := lark.NewClient(cfg.AppID, cfg.AppSecret, opts...) - return client, nil + return lark.NewClient(acct.AppID, credential.RuntimeAppSecret(acct.AppSecret), opts...), nil }) } + +func buildSDKTransport() http.RoundTripper { + var sdkTransport http.RoundTripper = util.NewBaseTransport() + sdkTransport = &RetryTransport{Base: sdkTransport} + sdkTransport = &UserAgentTransport{Base: sdkTransport} + sdkTransport = &auth.SecurityPolicyTransport{Base: sdkTransport} + return sdkTransport +} + +type credentialDeps struct { + Keychain keychain.KeychainAccess + Profile string + HttpClient func() (*http.Client, error) + ErrOut io.Writer +} + +func buildCredentialProvider(deps credentialDeps) *credential.CredentialProvider { + providers := extcred.Providers() + defaultAcct := credential.NewDefaultAccountProvider(deps.Keychain, deps.Profile) + defaultToken := credential.NewDefaultTokenProvider(defaultAcct, deps.HttpClient, deps.ErrOut) + // NOTE: Do not pass deps.ErrOut as warnOut. Credential resolution + // happens before the command runs, so any plain-text warning written + // to stderr would break the JSON envelope contract that AI agents + // depend on. enrichUserInfo failures are already non-fatal (the + // provider clears unverified identity fields), so silencing the + // warning is safe. + return credential.NewCredentialProvider(providers, defaultAcct, defaultToken, deps.HttpClient) +} diff --git a/internal/cmdutil/factory_default_test.go b/internal/cmdutil/factory_default_test.go new file mode 100644 index 000000000..5f4d60014 --- /dev/null +++ b/internal/cmdutil/factory_default_test.go @@ -0,0 +1,196 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package cmdutil + +import ( + "context" + "errors" + "testing" + + _ "github.com/larksuite/cli/extension/credential/env" + internalauth "github.com/larksuite/cli/internal/auth" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" + "github.com/larksuite/cli/internal/envvars" +) + +func TestNewDefault_InvocationProfileUsedByStrictModeAndConfig(t *testing.T) { + t.Setenv(envvars.CliAppID, "") + t.Setenv(envvars.CliAppSecret, "") + t.Setenv(envvars.CliUserAccessToken, "") + t.Setenv(envvars.CliTenantAccessToken, "") + + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + + bot := core.StrictModeBot + multi := &core.MultiAppConfig{ + CurrentApp: "default", + Apps: []core.AppConfig{ + { + Name: "default", + AppId: "app-default", + AppSecret: core.PlainSecret("secret-default"), + Brand: core.BrandFeishu, + }, + { + Name: "target", + AppId: "app-target", + AppSecret: core.PlainSecret("secret-target"), + Brand: core.BrandFeishu, + StrictMode: &bot, + }, + }, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + f := NewDefault(InvocationContext{Profile: "target"}) + if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeBot { + t.Fatalf("ResolveStrictMode() = %q, want %q", got, core.StrictModeBot) + } + cfg, err := f.Config() + if err != nil { + t.Fatalf("Config() error = %v", err) + } + if cfg.ProfileName != "target" { + t.Fatalf("Config() profile = %q, want %q", cfg.ProfileName, "target") + } + if cfg.AppID != "app-target" { + t.Fatalf("Config() appID = %q, want %q", cfg.AppID, "app-target") + } +} + +func TestNewDefault_InvocationProfileMissingSticksAcrossEarlyStrictMode(t *testing.T) { + t.Setenv(envvars.CliAppID, "") + t.Setenv(envvars.CliAppSecret, "") + t.Setenv(envvars.CliUserAccessToken, "") + t.Setenv(envvars.CliTenantAccessToken, "") + + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + + multi := &core.MultiAppConfig{ + CurrentApp: "default", + Apps: []core.AppConfig{ + { + Name: "default", + AppId: "app-default", + AppSecret: core.PlainSecret("secret-default"), + Brand: core.BrandFeishu, + }, + }, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatalf("SaveMultiAppConfig() error = %v", err) + } + + f := NewDefault(InvocationContext{Profile: "missing"}) + if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeOff { + t.Fatalf("ResolveStrictMode() = %q, want %q", got, core.StrictModeOff) + } + _, err := f.Config() + if err == nil { + t.Fatal("Config() error = nil, want non-nil") + } + var cfgErr *core.ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("Config() error type = %T, want *core.ConfigError", err) + } + if cfgErr.Message != `profile "missing" not found` { + t.Fatalf("Config() error message = %q, want %q", cfgErr.Message, `profile "missing" not found`) + } +} + +func TestBuildSDKTransport_IncludesRetryTransport(t *testing.T) { + transport := buildSDKTransport() + + sec, ok := transport.(*internalauth.SecurityPolicyTransport) + if !ok { + t.Fatalf("outer transport type = %T, want *auth.SecurityPolicyTransport", transport) + } + ua, ok := sec.Base.(*UserAgentTransport) + if !ok { + t.Fatalf("middle transport type = %T, want *UserAgentTransport", sec.Base) + } + if _, ok := ua.Base.(*RetryTransport); !ok { + t.Fatalf("inner transport type = %T, want *RetryTransport", ua.Base) + } +} + +func TestNewDefault_ResolveAs_UsesDefaultAsFromEnvAccount(t *testing.T) { + t.Setenv(envvars.CliAppID, "env-app") + t.Setenv(envvars.CliAppSecret, "env-secret") + t.Setenv(envvars.CliDefaultAs, "user") + t.Setenv(envvars.CliUserAccessToken, "") + t.Setenv(envvars.CliTenantAccessToken, "") + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + f := NewDefault(InvocationContext{}) + cmd := newCmdWithAsFlag("auto", false) + + got := f.ResolveAs(context.Background(), cmd, "auto") + if got != core.AsUser { + t.Fatalf("ResolveAs() = %q, want %q", got, core.AsUser) + } + if f.IdentityAutoDetected { + t.Fatal("IdentityAutoDetected = true, want false") + } +} + +func TestNewDefault_ConfigReturnsCliConfigCopyOfCredentialAccount(t *testing.T) { + t.Setenv(envvars.CliAppID, "env-app") + t.Setenv(envvars.CliAppSecret, "env-secret") + t.Setenv(envvars.CliDefaultAs, "") + t.Setenv(envvars.CliUserAccessToken, "uat-token") + t.Setenv(envvars.CliTenantAccessToken, "") + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + f := NewDefault(InvocationContext{}) + + acct, err := f.Credential.ResolveAccount(context.Background()) + if err != nil { + t.Fatalf("ResolveAccount() error = %v", err) + } + cfg, err := f.Config() + if err != nil { + t.Fatalf("Config() error = %v", err) + } + + cfg.AppID = "mutated-cli-config" + if acct.AppID != "env-app" { + t.Fatalf("credential account mutated via Config(): got %q, want %q", acct.AppID, "env-app") + } +} + +func TestNewDefault_ConfigUsesRuntimePlaceholderForTokenOnlyEnvAccount(t *testing.T) { + t.Setenv(envvars.CliAppID, "env-app") + t.Setenv(envvars.CliAppSecret, "") + t.Setenv(envvars.CliDefaultAs, "") + t.Setenv(envvars.CliUserAccessToken, "uat-token") + t.Setenv(envvars.CliTenantAccessToken, "") + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + f := NewDefault(InvocationContext{}) + + acct, err := f.Credential.ResolveAccount(context.Background()) + if err != nil { + t.Fatalf("ResolveAccount() error = %v", err) + } + if acct.AppSecret != "" { + t.Fatalf("credential account AppSecret = %q, want empty string", acct.AppSecret) + } + + cfg, err := f.Config() + if err != nil { + t.Fatalf("Config() error = %v", err) + } + if cfg.AppSecret != "" { + t.Fatalf("Config().AppSecret = %q, want empty string for token-only account", cfg.AppSecret) + } + if credential.HasRealAppSecret(cfg.AppSecret) { + t.Fatalf("Config().AppSecret = %q, want token-only no-secret marker", cfg.AppSecret) + } +} diff --git a/internal/cmdutil/factory_test.go b/internal/cmdutil/factory_test.go index 9912bbce7..a0eec24f8 100644 --- a/internal/cmdutil/factory_test.go +++ b/internal/cmdutil/factory_test.go @@ -4,13 +4,14 @@ package cmdutil import ( - "os" + "context" "strings" "testing" "github.com/spf13/cobra" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/envvars" ) // newCmdWithAsFlag creates a cobra.Command with a --as string flag for testing. @@ -29,7 +30,7 @@ func TestResolveAs_ExplicitAs(t *testing.T) { f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) cmd := newCmdWithAsFlag("bot", true) - got := f.ResolveAs(cmd, core.AsBot) + got := f.ResolveAs(context.Background(), cmd, core.AsBot) if got != core.AsBot { t.Errorf("want bot, got %s", got) } @@ -45,7 +46,7 @@ func TestResolveAs_ExplicitAsUser(t *testing.T) { f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) cmd := newCmdWithAsFlag("user", true) - got := f.ResolveAs(cmd, core.AsUser) + got := f.ResolveAs(context.Background(), cmd, core.AsUser) if got != core.AsUser { t.Errorf("want user, got %s", got) } @@ -60,7 +61,7 @@ func TestResolveAs_ExplicitAuto_FallsToAutoDetect(t *testing.T) { f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) cmd := newCmdWithAsFlag("auto", true) - got := f.ResolveAs(cmd, "auto") + got := f.ResolveAs(context.Background(), cmd, "auto") if got != core.AsBot { t.Errorf("want bot (auto-detect, no login), got %s", got) } @@ -76,7 +77,7 @@ func TestResolveAs_DefaultAs_FromConfig(t *testing.T) { }) cmd := newCmdWithAsFlag("auto", false) // --as not changed - got := f.ResolveAs(cmd, "auto") + got := f.ResolveAs(context.Background(), cmd, "auto") if got != core.AsBot { t.Errorf("want bot (from default-as config), got %s", got) } @@ -85,16 +86,18 @@ func TestResolveAs_DefaultAs_FromConfig(t *testing.T) { } } -func TestResolveAs_DefaultAs_FromEnv(t *testing.T) { - os.Setenv("LARKSUITE_CLI_DEFAULT_AS", "user") - defer os.Unsetenv("LARKSUITE_CLI_DEFAULT_AS") +func TestResolveAs_DefaultAs_EnvDoesNotBypassConfigSource(t *testing.T) { + t.Setenv(envvars.CliDefaultAs, "user") f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) cmd := newCmdWithAsFlag("auto", false) - got := f.ResolveAs(cmd, "auto") - if got != core.AsUser { - t.Errorf("want user (from env), got %s", got) + got := f.ResolveAs(context.Background(), cmd, "auto") + if got != core.AsBot { + t.Errorf("want bot (env default-as should not bypass config source), got %s", got) + } + if !f.IdentityAutoDetected { + t.Error("IdentityAutoDetected should be true when no account default-as is set") } } @@ -106,7 +109,7 @@ func TestResolveAs_DefaultAs_AutoValue_FallsToAutoDetect(t *testing.T) { }) cmd := newCmdWithAsFlag("auto", false) - got := f.ResolveAs(cmd, "auto") + got := f.ResolveAs(context.Background(), cmd, "auto") // No UserOpenId → auto-detect returns bot if got != core.AsBot { t.Errorf("want bot (auto-detect), got %s", got) @@ -119,7 +122,7 @@ func TestResolveAs_DefaultAs_AutoValue_FallsToAutoDetect(t *testing.T) { func TestResolveAs_NilCmd_AutoDetect(t *testing.T) { f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) - got := f.ResolveAs(nil, "auto") + got := f.ResolveAs(context.Background(), nil, "auto") if got != core.AsBot { t.Errorf("want bot, got %s", got) } @@ -183,56 +186,6 @@ func TestCheckIdentity_Unsupported_AutoDetected(t *testing.T) { } } -// --- ResolveConfig tests --- - -func TestResolveConfig_Bot(t *testing.T) { - cfg := &core.CliConfig{AppID: "a", AppSecret: "s"} - f, _, _, _ := TestFactory(t, cfg) - - got, err := f.ResolveConfig(core.AsBot) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got.AppID != "a" { - t.Errorf("want AppID a, got %s", got.AppID) - } -} - -func TestResolveConfig_User(t *testing.T) { - cfg := &core.CliConfig{AppID: "a", AppSecret: "s"} - f, _, _, _ := TestFactory(t, cfg) - - got, err := f.ResolveConfig(core.AsUser) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got.AppID != "a" { - t.Errorf("want AppID a, got %s", got.AppID) - } -} - -// --- autoDetectIdentity tests --- - -func TestAutoDetectIdentity_NoUserOpenId(t *testing.T) { - f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) - got := f.autoDetectIdentity() - if got != core.AsBot { - t.Errorf("want bot (no UserOpenId), got %s", got) - } -} - -func TestAutoDetectIdentity_ConfigError(t *testing.T) { - f := &Factory{ - Config: func() (*core.CliConfig, error) { - return nil, os.ErrNotExist - }, - } - got := f.autoDetectIdentity() - if got != core.AsBot { - t.Errorf("want bot (config error), got %s", got) - } -} - // --- NewAPIClient / NewAPIClientWithConfig tests --- func TestNewAPIClient(t *testing.T) { @@ -280,3 +233,125 @@ func TestNewAPIClientWithConfig_NilIOStreams(t *testing.T) { t.Fatal("expected non-nil APIClient") } } + +// --- ResolveStrictMode tests --- + +func TestResolveStrictMode_Off(t *testing.T) { + f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) + if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeOff { + t.Errorf("expected off, got %q", got) + } +} + +func TestResolveStrictMode_BotFromAccount(t *testing.T) { + cfg := &core.CliConfig{AppID: "a", AppSecret: "s", SupportedIdentities: 2} // SupportsBot = 2 + f, _, _, _ := TestFactory(t, cfg) + if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeBot { + t.Errorf("expected bot, got %q", got) + } +} + +func TestResolveStrictMode_UserFromAccount(t *testing.T) { + cfg := &core.CliConfig{AppID: "a", AppSecret: "s", SupportedIdentities: 1} // SupportsUser = 1 + f, _, _, _ := TestFactory(t, cfg) + if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeUser { + t.Errorf("expected user, got %q", got) + } +} + +func TestResolveStrictMode_BothIdentities(t *testing.T) { + cfg := &core.CliConfig{AppID: "a", AppSecret: "s", SupportedIdentities: 3} // SupportsAll = 3 + f, _, _, _ := TestFactory(t, cfg) + if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeOff { + t.Errorf("expected off when both supported, got %q", got) + } +} + +func TestResolveStrictMode_NilCredential(t *testing.T) { + f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) + f.Credential = nil + if got := f.ResolveStrictMode(context.Background()); got != core.StrictModeOff { + t.Errorf("expected off with nil credential, got %q", got) + } +} + +// --- CheckStrictMode tests --- + +func TestCheckStrictMode_BotMode_BotAllowed(t *testing.T) { + cfg := &core.CliConfig{AppID: "a", AppSecret: "s", SupportedIdentities: 2} + f, _, _, _ := TestFactory(t, cfg) + if err := f.CheckStrictMode(context.Background(), core.AsBot); err != nil { + t.Errorf("bot should be allowed in bot mode, got: %v", err) + } +} + +func TestCheckStrictMode_BotMode_UserBlocked(t *testing.T) { + cfg := &core.CliConfig{AppID: "a", AppSecret: "s", SupportedIdentities: 2} + f, _, _, _ := TestFactory(t, cfg) + err := f.CheckStrictMode(context.Background(), core.AsUser) + if err == nil { + t.Fatal("expected error for user in bot mode") + } + if !strings.Contains(err.Error(), "strict mode") { + t.Errorf("error should mention strict mode, got: %v", err) + } +} + +func TestCheckStrictMode_UserMode_UserAllowed(t *testing.T) { + cfg := &core.CliConfig{AppID: "a", AppSecret: "s", SupportedIdentities: 1} + f, _, _, _ := TestFactory(t, cfg) + if err := f.CheckStrictMode(context.Background(), core.AsUser); err != nil { + t.Errorf("user should be allowed in user mode, got: %v", err) + } +} + +func TestCheckStrictMode_UserMode_BotBlocked(t *testing.T) { + cfg := &core.CliConfig{AppID: "a", AppSecret: "s", SupportedIdentities: 1} + f, _, _, _ := TestFactory(t, cfg) + err := f.CheckStrictMode(context.Background(), core.AsBot) + if err == nil { + t.Fatal("expected error for bot in user mode") + } +} + +func TestCheckStrictMode_Off_BothAllowed(t *testing.T) { + f, _, _, _ := TestFactory(t, &core.CliConfig{AppID: "a", AppSecret: "s"}) + if err := f.CheckStrictMode(context.Background(), core.AsUser); err != nil { + t.Errorf("user should be allowed when off: %v", err) + } + if err := f.CheckStrictMode(context.Background(), core.AsBot); err != nil { + t.Errorf("bot should be allowed when off: %v", err) + } +} + +// --- ResolveAs strict mode tests --- + +func TestResolveAs_StrictModeBot_ForceBot(t *testing.T) { + cfg := &core.CliConfig{AppID: "a", AppSecret: "s", SupportedIdentities: 2} + f, _, _, _ := TestFactory(t, cfg) + cmd := newCmdWithAsFlag("auto", false) + got := f.ResolveAs(context.Background(), cmd, "auto") + if got != core.AsBot { + t.Errorf("bot mode should force bot, got %s", got) + } +} + +func TestResolveAs_StrictModeUser_ForceUser(t *testing.T) { + cfg := &core.CliConfig{AppID: "a", AppSecret: "s", SupportedIdentities: 1} + f, _, _, _ := TestFactory(t, cfg) + cmd := newCmdWithAsFlag("auto", false) + got := f.ResolveAs(context.Background(), cmd, "auto") + if got != core.AsUser { + t.Errorf("user mode should force user, got %s", got) + } +} + +func TestResolveAs_StrictModeBot_IgnoresDefaultAsUser(t *testing.T) { + cfg := &core.CliConfig{AppID: "a", AppSecret: "s", DefaultAs: "user", SupportedIdentities: 2} + f, _, _, _ := TestFactory(t, cfg) + cmd := newCmdWithAsFlag("auto", false) + got := f.ResolveAs(context.Background(), cmd, "auto") + if got != core.AsBot { + t.Errorf("bot mode should override default-as user, got %s", got) + } +} diff --git a/internal/cmdutil/secheader.go b/internal/cmdutil/secheader.go index 15745264a..411232c8c 100644 --- a/internal/cmdutil/secheader.go +++ b/internal/cmdutil/secheader.go @@ -68,6 +68,16 @@ func ExecutionIdFromContext(ctx context.Context) (string, bool) { // RequestOptionFunc that injects the corresponding headers into SDK requests. // Returns nil if the context has no Shortcut info. func ShortcutHeaderOpts(ctx context.Context) larkcore.RequestOptionFunc { + h := ShortcutHeaders(ctx) + if h == nil { + return nil + } + return larkcore.WithHeaders(h) +} + +// ShortcutHeaders extracts Shortcut info from the context and returns +// the corresponding HTTP headers. Returns nil if the context has no Shortcut info. +func ShortcutHeaders(ctx context.Context) http.Header { name, ok := ShortcutNameFromContext(ctx) if !ok { return nil @@ -77,5 +87,5 @@ func ShortcutHeaderOpts(ctx context.Context) larkcore.RequestOptionFunc { if eid, ok := ExecutionIdFromContext(ctx); ok { h.Set(HeaderExecutionId, eid) } - return larkcore.WithHeaders(h) + return h } diff --git a/internal/cmdutil/testing.go b/internal/cmdutil/testing.go index e32b17f05..7a70ed2b0 100644 --- a/internal/cmdutil/testing.go +++ b/internal/cmdutil/testing.go @@ -5,6 +5,7 @@ package cmdutil import ( "bytes" + "context" "net/http" "testing" @@ -12,6 +13,7 @@ import ( larkcore "github.com/larksuite/oapi-sdk-go/v3/core" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/httpmock" ) @@ -34,16 +36,14 @@ func TestFactory(t *testing.T, config *core.CliConfig) (*Factory, *bytes.Buffer, stderrBuf := &bytes.Buffer{} mockClient := httpmock.NewClient(reg) - // SDK mock client wraps the mock transport with UserAgentTransport - // so that User-Agent overrides the SDK default (oapi-sdk-go/v3.x.x). sdkMockClient := &http.Client{ Transport: &UserAgentTransport{Base: reg}, } - // Build a test LarkClient using the config var testLarkClient *lark.Client if config != nil && config.AppID != "" { opts := []lark.ClientOptionFunc{ + lark.WithEnableTokenCache(false), lark.WithLogLevel(larkcore.LogLevelError), lark.WithHttpClient(sdkMockClient), lark.WithHeaders(BaseSecurityHeaders()), @@ -51,16 +51,40 @@ func TestFactory(t *testing.T, config *core.CliConfig) (*Factory, *bytes.Buffer, if config.Brand != "" { opts = append(opts, lark.WithOpenBaseUrl(core.ResolveOpenBaseURL(config.Brand))) } - testLarkClient = lark.NewClient(config.AppID, config.AppSecret, opts...) + testLarkClient = lark.NewClient(config.AppID, credential.RuntimeAppSecret(config.AppSecret), opts...) } + testCred := credential.NewCredentialProvider( + nil, + &testDefaultAcct{config: config}, + &testDefaultToken{}, + func() (*http.Client, error) { return mockClient, nil }, + ) + f := &Factory{ Config: func() (*core.CliConfig, error) { return config, nil }, - AuthConfig: func() (*core.CliConfig, error) { return config, nil }, HttpClient: func() (*http.Client, error) { return mockClient, nil }, LarkClient: func() (*lark.Client, error) { return testLarkClient, nil }, IOStreams: &IOStreams{In: nil, Out: stdoutBuf, ErrOut: stderrBuf}, Keychain: &noopKeychain{}, + Credential: testCred, } return f, stdoutBuf, stderrBuf, reg } + +type testDefaultAcct struct { + config *core.CliConfig +} + +func (a *testDefaultAcct) ResolveAccount(ctx context.Context) (*credential.Account, error) { + if a.config == nil { + return &credential.Account{}, nil + } + return credential.AccountFromCliConfig(a.config), nil +} + +type testDefaultToken struct{} + +func (t *testDefaultToken) ResolveToken(ctx context.Context, req credential.TokenSpec) (*credential.TokenResult, error) { + return &credential.TokenResult{Token: "test-token"}, nil +} diff --git a/internal/core/config.go b/internal/core/config.go index 18a2aa4e0..410c3e01f 100644 --- a/internal/core/config.go +++ b/internal/core/config.go @@ -9,10 +9,13 @@ import ( "fmt" "os" "path/filepath" + "strings" + "unicode/utf8" "github.com/larksuite/cli/internal/keychain" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" ) // Identity represents the caller identity for API requests. @@ -21,6 +24,7 @@ type Identity string const ( AsUser Identity = "user" AsBot Identity = "bot" + AsAuto Identity = "auto" ) // IsBot returns true if the identity is bot. @@ -34,27 +38,129 @@ type AppUser struct { // AppConfig is a per-app configuration entry (stored format — secrets may be unresolved). type AppConfig struct { - AppId string `json:"appId"` - AppSecret SecretInput `json:"appSecret"` - Brand LarkBrand `json:"brand"` - Lang string `json:"lang,omitempty"` - DefaultAs string `json:"defaultAs,omitempty"` // "user" | "bot" | "auto" - Users []AppUser `json:"users"` + Name string `json:"name,omitempty"` + AppId string `json:"appId"` + AppSecret SecretInput `json:"appSecret"` + Brand LarkBrand `json:"brand"` + Lang string `json:"lang,omitempty"` + DefaultAs Identity `json:"defaultAs,omitempty"` // AsUser | AsBot | AsAuto + StrictMode *StrictMode `json:"strictMode,omitempty"` + Users []AppUser `json:"users"` +} + +// ProfileName returns the display name for this app config. +// If Name is set, returns Name; otherwise falls back to AppId. +func (a *AppConfig) ProfileName() string { + if a.Name != "" { + return a.Name + } + return a.AppId } // MultiAppConfig is the multi-app config file format. type MultiAppConfig struct { - Apps []AppConfig `json:"apps"` + StrictMode StrictMode `json:"strictMode,omitempty"` + CurrentApp string `json:"currentApp,omitempty"` + PreviousApp string `json:"previousApp,omitempty"` + Apps []AppConfig `json:"apps"` +} + +// CurrentAppConfig returns the currently active app config. +// Resolution priority: profileOverride > CurrentApp field > Apps[0]. +func (m *MultiAppConfig) CurrentAppConfig(profileOverride string) *AppConfig { + if profileOverride != "" { + if app := m.FindApp(profileOverride); app != nil { + return app + } + return nil + } + if m.CurrentApp != "" { + if app := m.FindApp(m.CurrentApp); app != nil { + return app + } + return nil // explicit currentApp not found; don't silently fallback + } + if len(m.Apps) > 0 { + return &m.Apps[0] + } + return nil +} + +// FindApp looks up an app by name, then by appId. Returns nil if not found. +// Name match takes priority: if profile A has Name "X" and profile B has AppId "X", +// FindApp("X") returns profile A. +func (m *MultiAppConfig) FindApp(name string) *AppConfig { + // First pass: match by Name + for i := range m.Apps { + if m.Apps[i].Name != "" && m.Apps[i].Name == name { + return &m.Apps[i] + } + } + // Second pass: match by AppId + for i := range m.Apps { + if m.Apps[i].AppId == name { + return &m.Apps[i] + } + } + return nil +} + +// FindAppIndex looks up an app index by name, then by appId. Returns -1 if not found. +func (m *MultiAppConfig) FindAppIndex(name string) int { + for i := range m.Apps { + if m.Apps[i].Name != "" && m.Apps[i].Name == name { + return i + } + } + for i := range m.Apps { + if m.Apps[i].AppId == name { + return i + } + } + return -1 +} + +// ProfileNames returns all profile names (Name if set, otherwise AppId). +func (m *MultiAppConfig) ProfileNames() []string { + names := make([]string, len(m.Apps)) + for i := range m.Apps { + names[i] = m.Apps[i].ProfileName() + } + return names +} + +// ValidateProfileName checks that a profile name is valid. +// Rejects empty names, whitespace, control characters, and shell-problematic characters, +// but allows Unicode letters (e.g. Chinese, Japanese) for localized profile names. +func ValidateProfileName(name string) error { + if name == "" { + return fmt.Errorf("profile name cannot be empty") + } + if utf8.RuneCountInString(name) > 64 { + return fmt.Errorf("profile name %q is too long (max 64 characters)", name) + } + for _, r := range name { + if r <= 0x1F || r == 0x7F { // control characters + return fmt.Errorf("invalid profile name %q: contains control characters", name) + } + switch r { + case ' ', '\t', '/', '\\', '"', '\'', '`', '$', '#', '!', '&', '|', ';', '(', ')', '{', '}', '[', ']', '<', '>', '?', '*', '~': + return fmt.Errorf("invalid profile name %q: contains invalid character %q", name, r) + } + } + return nil } // CliConfig is the resolved single-app config used by downstream code. type CliConfig struct { - AppID string - AppSecret string - Brand LarkBrand - DefaultAs string // "user" | "bot" | "auto" | "" (from config file) - UserOpenId string - UserName string + ProfileName string + AppID string + AppSecret string + Brand LarkBrand + DefaultAs Identity // AsUser | AsBot | AsAuto | "" (from config file) + UserOpenId string + UserName string + SupportedIdentities uint8 `json:"-"` // bitflag: 1=user, 2=bot; set by credential provider } // GetConfigDir returns the config directory path. @@ -64,7 +170,7 @@ func GetConfigDir() string { if dir := os.Getenv("LARKSUITE_CLI_CONFIG_DIR"); dir != "" { return dir } - home, err := os.UserHomeDir() + home, err := vfs.UserHomeDir() if err != nil || home == "" { fmt.Fprintf(os.Stderr, "warning: unable to determine home directory: %v\n", err) } @@ -78,7 +184,7 @@ func GetConfigPath() string { // LoadMultiAppConfig loads multi-app config from disk. func LoadMultiAppConfig() (*MultiAppConfig, error) { - data, err := os.ReadFile(GetConfigPath()) + data, err := vfs.ReadFile(GetConfigPath()) if err != nil { return nil, err } @@ -96,7 +202,7 @@ func LoadMultiAppConfig() (*MultiAppConfig, error) { // SaveMultiAppConfig saves config to disk. func SaveMultiAppConfig(config *MultiAppConfig) error { dir := GetConfigDir() - if err := os.MkdirAll(dir, 0700); err != nil { + if err := vfs.MkdirAll(dir, 0700); err != nil { return err } data, err := json.MarshalIndent(config, "", " ") @@ -106,13 +212,34 @@ func SaveMultiAppConfig(config *MultiAppConfig) error { return validate.AtomicWrite(GetConfigPath(), append(data, '\n'), 0600) } -// RequireConfig loads the single-app config. Takes Apps[0] directly. +// RequireConfig loads the single-app config using the default profile resolution. func RequireConfig(kc keychain.KeychainAccess) (*CliConfig, error) { + return RequireConfigForProfile(kc, "") +} + +// RequireConfigForProfile loads the single-app config for a specific profile. +// Resolution priority: profileOverride > config.CurrentApp > Apps[0]. +func RequireConfigForProfile(kc keychain.KeychainAccess, profileOverride string) (*CliConfig, error) { raw, err := LoadMultiAppConfig() if err != nil || raw == nil || len(raw.Apps) == 0 { return nil, &ConfigError{Code: 2, Type: "config", Message: "not configured", Hint: "run `lark-cli config init --new` in the background. It blocks and outputs a verification URL — retrieve the URL and open it in a browser to complete setup."} } - app := raw.Apps[0] + return ResolveConfigFromMulti(raw, kc, profileOverride) +} + +// ResolveConfigFromMulti resolves a single-app config from an already-loaded MultiAppConfig. +// This avoids re-reading the config file when the caller has already loaded it. +func ResolveConfigFromMulti(raw *MultiAppConfig, kc keychain.KeychainAccess, profileOverride string) (*CliConfig, error) { + app := raw.CurrentAppConfig(profileOverride) + if app == nil { + return nil, &ConfigError{ + Code: 2, + Type: "config", + Message: fmt.Sprintf("profile %q not found", profileOverride), + Hint: fmt.Sprintf("available profiles: %s", formatProfileNames(raw.ProfileNames())), + } + } + secret, err := ResolveSecretInput(app.AppSecret, kc) if err != nil { // If the error comes from the keychain, it will already be wrapped as an ExitError. @@ -124,10 +251,11 @@ func RequireConfig(kc keychain.KeychainAccess) (*CliConfig, error) { return nil, &ConfigError{Code: 2, Type: "config", Message: err.Error()} } cfg := &CliConfig{ - AppID: app.AppId, - AppSecret: secret, - Brand: app.Brand, - DefaultAs: app.DefaultAs, + ProfileName: app.ProfileName(), + AppID: app.AppId, + AppSecret: secret, + Brand: app.Brand, + DefaultAs: app.DefaultAs, } if len(app.Users) > 0 { cfg.UserOpenId = app.Users[0].UserOpenId @@ -138,7 +266,12 @@ func RequireConfig(kc keychain.KeychainAccess) (*CliConfig, error) { // RequireAuth loads config and ensures a user is logged in. func RequireAuth(kc keychain.KeychainAccess) (*CliConfig, error) { - cfg, err := RequireConfig(kc) + return RequireAuthForProfile(kc, "") +} + +// RequireAuthForProfile loads config for a profile and ensures a user is logged in. +func RequireAuthForProfile(kc keychain.KeychainAccess, profileOverride string) (*CliConfig, error) { + cfg, err := RequireConfigForProfile(kc, profileOverride) if err != nil { return nil, err } @@ -147,3 +280,11 @@ func RequireAuth(kc keychain.KeychainAccess) (*CliConfig, error) { } return cfg, nil } + +// formatProfileNames joins profile names for display. +func formatProfileNames(names []string) string { + if len(names) == 0 { + return "(none)" + } + return strings.Join(names, ", ") +} diff --git a/internal/core/config_strict_mode_test.go b/internal/core/config_strict_mode_test.go new file mode 100644 index 000000000..95c79a8d8 --- /dev/null +++ b/internal/core/config_strict_mode_test.go @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package core + +import ( + "encoding/json" + "testing" +) + +func TestMultiAppConfig_StrictMode_JSON(t *testing.T) { + // StrictMode="" should be omitted (omitempty) + m := &MultiAppConfig{ + Apps: []AppConfig{{AppId: "a", AppSecret: PlainSecret("s"), Brand: BrandFeishu, Users: []AppUser{}}}, + } + data, _ := json.Marshal(m) + if string(data) != `{"apps":[{"appId":"a","appSecret":"s","brand":"feishu","users":[]}]}` { + t.Errorf("StrictMode empty should be omitted, got: %s", data) + } + + // StrictMode="bot" should be present + m.StrictMode = StrictModeBot + data, _ = json.Marshal(m) + var parsed map[string]interface{} + json.Unmarshal(data, &parsed) + if parsed["strictMode"] != "bot" { + t.Errorf("StrictMode=bot should be present, got: %s", data) + } +} + +func TestAppConfig_StrictMode_JSON(t *testing.T) { + // StrictMode nil should be omitted + app := &AppConfig{AppId: "a", AppSecret: PlainSecret("s"), Brand: BrandFeishu, Users: []AppUser{}} + data, _ := json.Marshal(app) + var parsed map[string]interface{} + json.Unmarshal(data, &parsed) + if _, ok := parsed["strictMode"]; ok { + t.Errorf("nil StrictMode should be omitted, got: %s", data) + } + + // StrictMode = pointer to "user" + v := StrictModeUser + app.StrictMode = &v + data, _ = json.Marshal(app) + json.Unmarshal(data, &parsed) + if parsed["strictMode"] != "user" { + t.Errorf("StrictMode=user should be present, got: %s", data) + } + + // StrictMode = pointer to "off" (explicit off — should be present, not omitted) + voff := StrictModeOff + app.StrictMode = &voff + data, _ = json.Marshal(app) + json.Unmarshal(data, &parsed) + if val, ok := parsed["strictMode"]; !ok || val != "off" { + t.Errorf("StrictMode=off (explicit) should be present, got: %s", data) + } +} diff --git a/internal/core/config_test.go b/internal/core/config_test.go index 1c9ac449a..0ffb8ee19 100644 --- a/internal/core/config_test.go +++ b/internal/core/config_test.go @@ -72,3 +72,27 @@ func TestMultiAppConfig_RoundTrip(t *testing.T) { t.Errorf("Brand = %q, want %q", got.Apps[0].Brand, BrandLark) } } + +func TestResolveConfigFromMulti_DoesNotUseEnvProfileFallback(t *testing.T) { + t.Setenv("LARKSUITE_CLI_PROFILE", "missing") + + raw := &MultiAppConfig{ + CurrentApp: "active", + Apps: []AppConfig{ + { + Name: "active", + AppId: "cli_active", + AppSecret: PlainSecret("secret"), + Brand: BrandFeishu, + }, + }, + } + + cfg, err := ResolveConfigFromMulti(raw, nil, "") + if err != nil { + t.Fatalf("ResolveConfigFromMulti() error = %v", err) + } + if cfg.ProfileName != "active" { + t.Fatalf("ResolveConfigFromMulti() profile = %q, want %q", cfg.ProfileName, "active") + } +} diff --git a/internal/core/secret_resolve.go b/internal/core/secret_resolve.go index 6e7921d3f..f5f4ebd90 100644 --- a/internal/core/secret_resolve.go +++ b/internal/core/secret_resolve.go @@ -5,10 +5,10 @@ package core import ( "fmt" - "os" "strings" "github.com/larksuite/cli/internal/keychain" + "github.com/larksuite/cli/internal/vfs" ) const secretKeyPrefix = "appsecret:" @@ -25,7 +25,7 @@ func ResolveSecretInput(s SecretInput, kc keychain.KeychainAccess) (string, erro } switch s.Ref.Source { case "file": - data, err := os.ReadFile(s.Ref.ID) + data, err := vfs.ReadFile(s.Ref.ID) if err != nil { return "", fmt.Errorf("failed to read secret file %s: %w", s.Ref.ID, err) } diff --git a/internal/core/strict_mode.go b/internal/core/strict_mode.go new file mode 100644 index 000000000..c9cfb4298 --- /dev/null +++ b/internal/core/strict_mode.go @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package core + +// StrictMode represents the identity restriction policy. +type StrictMode string + +const ( + StrictModeOff StrictMode = "off" + StrictModeBot StrictMode = "bot" + StrictModeUser StrictMode = "user" +) + +// IsActive returns true if strict mode restricts identity. +func (m StrictMode) IsActive() bool { + return m == StrictModeBot || m == StrictModeUser +} + +// AllowsIdentity reports whether the given identity is permitted under this mode. +func (m StrictMode) AllowsIdentity(id Identity) bool { + switch m { + case StrictModeBot: + return id.IsBot() + case StrictModeUser: + return id == AsUser + default: + return true + } +} + +// ForcedIdentity returns the identity forced by this mode, or "" if not active. +func (m StrictMode) ForcedIdentity() Identity { + switch m { + case StrictModeBot: + return AsBot + case StrictModeUser: + return AsUser + default: + return "" + } +} diff --git a/internal/core/strict_mode_test.go b/internal/core/strict_mode_test.go new file mode 100644 index 000000000..5d67b1546 --- /dev/null +++ b/internal/core/strict_mode_test.go @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package core + +import "testing" + +func TestStrictMode_IsActive(t *testing.T) { + tests := []struct { + mode StrictMode + active bool + }{ + {StrictModeOff, false}, + {"", false}, + {StrictModeBot, true}, + {StrictModeUser, true}, + } + for _, tt := range tests { + if got := tt.mode.IsActive(); got != tt.active { + t.Errorf("StrictMode(%q).IsActive() = %v, want %v", tt.mode, got, tt.active) + } + } +} + +func TestStrictMode_AllowsIdentity(t *testing.T) { + tests := []struct { + mode StrictMode + id Identity + ok bool + }{ + {StrictModeOff, AsUser, true}, + {StrictModeOff, AsBot, true}, + {StrictModeBot, AsBot, true}, + {StrictModeBot, AsUser, false}, + {StrictModeUser, AsUser, true}, + {StrictModeUser, AsBot, false}, + {"", AsUser, true}, + {"", AsBot, true}, + } + for _, tt := range tests { + if got := tt.mode.AllowsIdentity(tt.id); got != tt.ok { + t.Errorf("StrictMode(%q).AllowsIdentity(%q) = %v, want %v", tt.mode, tt.id, got, tt.ok) + } + } +} + +func TestStrictMode_ForcedIdentity(t *testing.T) { + tests := []struct { + mode StrictMode + want Identity + }{ + {StrictModeOff, ""}, + {StrictModeBot, AsBot}, + {StrictModeUser, AsUser}, + {"", ""}, + } + for _, tt := range tests { + if got := tt.mode.ForcedIdentity(); got != tt.want { + t.Errorf("StrictMode(%q).ForcedIdentity() = %q, want %q", tt.mode, got, tt.want) + } + } +} diff --git a/internal/core/types.go b/internal/core/types.go index 4c21c2591..bae8613ae 100644 --- a/internal/core/types.go +++ b/internal/core/types.go @@ -13,6 +13,15 @@ const ( BrandLark LarkBrand = "lark" ) +// ParseBrand normalizes a brand string to a LarkBrand constant. +// Unrecognized values default to BrandFeishu. +func ParseBrand(value string) LarkBrand { + if value == "lark" { + return BrandLark + } + return BrandFeishu +} + // Endpoints holds resolved endpoint URLs for different Lark services. type Endpoints struct { Open string // e.g. "https://open.feishu.cn" diff --git a/internal/credential/credential_provider.go b/internal/credential/credential_provider.go new file mode 100644 index 000000000..5d28e2314 --- /dev/null +++ b/internal/credential/credential_provider.go @@ -0,0 +1,344 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package credential + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "sync" + + extcred "github.com/larksuite/cli/extension/credential" + "github.com/larksuite/cli/internal/auth" + "github.com/larksuite/cli/internal/core" +) + +// DefaultAccountResolver is implemented by the default account provider. +type DefaultAccountResolver interface { + ResolveAccount(ctx context.Context) (*Account, error) +} + +// DefaultTokenResolver is implemented by the default token provider. +type DefaultTokenResolver interface { + ResolveToken(ctx context.Context, req TokenSpec) (*TokenResult, error) +} + +var ( + getStoredToken = auth.GetStoredToken + getStoredTokenStatus = auth.TokenStatus +) + +type credentialSource interface { + Name() string + TryResolveToken(ctx context.Context, req TokenSpec) (*TokenResult, bool, error) + ResolveIdentityHint(ctx context.Context, acct *Account) (*IdentityHint, error) +} + +type extensionTokenSource struct { + provider extcred.Provider +} + +func (s extensionTokenSource) Name() string { return s.provider.Name() } + +func (s extensionTokenSource) TryResolveToken(ctx context.Context, req TokenSpec) (*TokenResult, bool, error) { + tok, err := s.provider.ResolveToken(ctx, extcred.TokenSpec{ + Type: extcred.TokenType(req.Type.String()), + AppID: req.AppID, + }) + if err != nil { + return nil, false, err + } + if tok == nil { + return nil, false, nil + } + if tok.Value == "" { + return nil, false, &MalformedTokenResultError{Source: s.Name(), Type: req.Type, Reason: "empty token"} + } + return &TokenResult{Token: tok.Value, Scopes: tok.Scopes}, true, nil +} + +func (s extensionTokenSource) ResolveIdentityHint(ctx context.Context, acct *Account) (*IdentityHint, error) { + hint := &IdentityHint{} + if acct == nil { + return hint, nil + } + hint.DefaultAs = acct.DefaultAs + // Extension sources verify user identity via enrichUserInfo, so a resolved + // UserOpenId is sufficient here; no keychain-backed token status lookup is needed. + if acct.UserOpenId != "" { + hint.AutoAs = core.AsUser + return hint, nil + } + ids := extcred.IdentitySupport(acct.SupportedIdentities) + switch { + case ids.UserOnly(): + hint.AutoAs = core.AsUser + case ids.BotOnly(): + hint.AutoAs = core.AsBot + } + return hint, nil +} + +type defaultTokenSource struct { + resolver DefaultTokenResolver +} + +func (s defaultTokenSource) Name() string { return "default" } + +func (s defaultTokenSource) TryResolveToken(ctx context.Context, req TokenSpec) (*TokenResult, bool, error) { + if s.resolver == nil { + return nil, false, nil + } + result, err := s.resolver.ResolveToken(ctx, req) + if err != nil { + return nil, false, err + } + if result == nil { + return nil, false, &MalformedTokenResultError{Source: s.Name(), Type: req.Type, Reason: "nil token result"} + } + if result.Token == "" { + return nil, false, &MalformedTokenResultError{Source: s.Name(), Type: req.Type, Reason: "empty token"} + } + return result, true, nil +} + +func (s defaultTokenSource) ResolveIdentityHint(ctx context.Context, acct *Account) (*IdentityHint, error) { + hint := &IdentityHint{} + if acct == nil { + return hint, nil + } + hint.DefaultAs = acct.DefaultAs + if acct.UserOpenId == "" { + hint.AutoAs = core.AsBot + return hint, nil + } + stored := getStoredToken(acct.AppID, acct.UserOpenId) + if stored == nil { + hint.AutoAs = core.AsBot + return hint, nil + } + if getStoredTokenStatus(stored) == "expired" { + hint.AutoAs = core.AsBot + return hint, nil + } + hint.AutoAs = core.AsUser + return hint, nil +} + +// CredentialProvider is the unified entry point for all credential resolution. +type CredentialProvider struct { + providers []extcred.Provider + defaultAcct DefaultAccountResolver + defaultToken DefaultTokenResolver + httpClient func() (*http.Client, error) + warnOut io.Writer + + accountOnce sync.Once + account *Account + accountErr error + selectedSource credentialSource + + hintOnce sync.Once + hint *IdentityHint + hintErr error +} + +// NewCredentialProvider creates a CredentialProvider. +func NewCredentialProvider(providers []extcred.Provider, defaultAcct DefaultAccountResolver, defaultToken DefaultTokenResolver, httpClient func() (*http.Client, error)) *CredentialProvider { + return &CredentialProvider{ + providers: providers, + defaultAcct: defaultAcct, + defaultToken: defaultToken, + httpClient: httpClient, + } +} + +func (p *CredentialProvider) SetWarnOut(warnOut io.Writer) *CredentialProvider { + p.warnOut = warnOut + return p +} + +// ResolveAccount resolves app credentials. Result is cached after first call. +// NOTE: Uses sync.Once — only the context from the first call is used for resolution. +// Subsequent calls return the cached result regardless of their context. +// This is acceptable for CLI (single invocation per process) but not for long-running servers. +func (p *CredentialProvider) ResolveAccount(ctx context.Context) (*Account, error) { + p.accountOnce.Do(func() { + p.account, p.accountErr = p.doResolveAccount(ctx) + }) + return p.account, p.accountErr +} + +func (p *CredentialProvider) doResolveAccount(ctx context.Context) (*Account, error) { + for _, prov := range p.providers { + acct, err := prov.ResolveAccount(ctx) + if err != nil { + return nil, err + } + if acct != nil { + internal := convertAccount(acct) + source := extensionTokenSource{provider: prov} + if err := p.enrichUserInfo(ctx, internal, source); err != nil { + if p.warnOut != nil { + _, _ = fmt.Fprintf(p.warnOut, "warning: unable to verify user identity from credential source %q: %v\n", source.Name(), err) + } + // enrichUserInfo failure is non-fatal: SupportedIdentities + // (used for strict mode) is already set by the provider. + // Clear unverified user identity for safety. + internal.UserOpenId = "" + internal.UserName = "" + } + p.selectedSource = source + return internal, nil + } + } + if p.defaultAcct != nil { + acct, err := p.defaultAcct.ResolveAccount(ctx) + if err != nil { + return nil, err + } + p.selectedSource = defaultTokenSource{resolver: p.defaultToken} + return acct, nil + } + return nil, fmt.Errorf("no credential provider returned an account; run 'lark-cli config' to set up") +} + +// enrichUserInfo resolves user identity when extension provides a UAT. +// If UAT is available, user_info API call is mandatory (security: verify token validity). +// If no UAT from extension, falls back to provider-supplied OpenID. +func (p *CredentialProvider) enrichUserInfo(ctx context.Context, acct *Account, source credentialSource) error { + if p.httpClient == nil || source == nil { + return nil + } + tok, found, err := source.TryResolveToken(ctx, TokenSpec{Type: TokenTypeUAT, AppID: acct.AppID}) + if err != nil { + var blockErr *extcred.BlockError + if errors.As(err, &blockErr) { + return nil // provider explicitly blocks UAT; skip enrichment + } + return fmt.Errorf("failed to resolve UAT for user identity verification: %w", err) + } + if !found { + return nil + } + // Have UAT — must verify and resolve identity + hc, err := p.httpClient() + if err != nil { + return fmt.Errorf("failed to get HTTP client for user_info: %w", err) + } + info, err := fetchUserInfo(ctx, hc, acct.Brand, tok.Token) + if err != nil { + return fmt.Errorf("failed to verify user identity: %w", err) + } + acct.UserOpenId = info.OpenID + acct.UserName = info.Name + return nil +} + +func (p *CredentialProvider) selectedCredentialSource(ctx context.Context) (credentialSource, error) { + if p.selectedSource != nil { + return p.selectedSource, nil + } + if p.defaultAcct == nil { + return nil, nil + } + if _, err := p.ResolveAccount(ctx); err != nil { + return nil, err + } + if p.selectedSource == nil { + return nil, fmt.Errorf("credential provider resolved an account without selecting a token source") + } + return p.selectedSource, nil +} + +func resolveTokenFromSource(ctx context.Context, source credentialSource, req TokenSpec) (*TokenResult, error) { + result, found, err := source.TryResolveToken(ctx, req) + if err != nil { + return nil, err + } + if !found { + return nil, &TokenUnavailableError{Source: source.Name(), Type: req.Type} + } + return result, nil +} + +// ResolveIdentityHint resolves default/auto identity guidance from the selected source. +// NOTE: Uses sync.Once — only the context from the first call is used for resolution. +// This matches ResolveAccount and keeps identity decisions stable within one CLI invocation. +func (p *CredentialProvider) ResolveIdentityHint(ctx context.Context) (*IdentityHint, error) { + p.hintOnce.Do(func() { + p.hint, p.hintErr = p.doResolveIdentityHint(ctx) + }) + return p.hint, p.hintErr +} + +func (p *CredentialProvider) doResolveIdentityHint(ctx context.Context) (*IdentityHint, error) { + acct, err := p.ResolveAccount(ctx) + if err != nil { + return nil, err + } + if acct == nil { + return &IdentityHint{}, nil + } + source, err := p.selectedCredentialSource(ctx) + if err != nil { + return nil, err + } + if source == nil { + return &IdentityHint{}, nil + } + hint, err := source.ResolveIdentityHint(ctx, acct) + if err != nil { + return nil, err + } + if hint == nil { + return &IdentityHint{}, nil + } + return hint, nil +} + +// ResolveToken resolves an access token. +func (p *CredentialProvider) ResolveToken(ctx context.Context, req TokenSpec) (*TokenResult, error) { + source, err := p.selectedCredentialSource(ctx) + if err != nil { + return nil, err + } + if source != nil { + return resolveTokenFromSource(ctx, source, req) + } + + for _, prov := range p.providers { + source := extensionTokenSource{provider: prov} + result, found, err := source.TryResolveToken(ctx, req) + if err != nil { + return nil, err + } + if found { + return result, nil + } + } + source = defaultTokenSource{resolver: p.defaultToken} + result, found, err := source.TryResolveToken(ctx, req) + if err != nil { + return nil, err + } + if found { + return result, nil + } + return nil, &TokenUnavailableError{Type: req.Type} +} + +func convertAccount(ext *extcred.Account) *Account { + return &Account{ + AppID: ext.AppID, + AppSecret: ext.AppSecret, + Brand: core.LarkBrand(ext.Brand), + DefaultAs: core.Identity(ext.DefaultAs), + ProfileName: ext.ProfileName, + UserOpenId: ext.OpenID, + SupportedIdentities: uint8(ext.SupportedIdentities), + } +} diff --git a/internal/credential/credential_provider_test.go b/internal/credential/credential_provider_test.go new file mode 100644 index 000000000..aedb3d809 --- /dev/null +++ b/internal/credential/credential_provider_test.go @@ -0,0 +1,421 @@ +package credential + +import ( + "bytes" + "context" + "errors" + "net/http" + "strings" + "testing" + + extcred "github.com/larksuite/cli/extension/credential" + "github.com/larksuite/cli/internal/auth" + "github.com/larksuite/cli/internal/core" +) + +type mockExtProvider struct { + name string + account *extcred.Account + token *extcred.Token + err error + accountErr error + tokenErr error +} + +func (m *mockExtProvider) Name() string { return m.name } +func (m *mockExtProvider) ResolveAccount(ctx context.Context) (*extcred.Account, error) { + if m.accountErr != nil { + return nil, m.accountErr + } + return m.account, m.err +} +func (m *mockExtProvider) ResolveToken(ctx context.Context, req extcred.TokenSpec) (*extcred.Token, error) { + if m.tokenErr != nil { + return nil, m.tokenErr + } + return m.token, m.err +} + +type mockDefaultAcct struct { + account *Account + err error +} + +func (m *mockDefaultAcct) ResolveAccount(ctx context.Context) (*Account, error) { + return m.account, m.err +} + +type mockDefaultToken struct { + result *TokenResult + err error +} + +func (m *mockDefaultToken) ResolveToken(ctx context.Context, req TokenSpec) (*TokenResult, error) { + return m.result, m.err +} + +func TestCredentialProvider_AccountFromExtension(t *testing.T) { + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{name: "env", account: &extcred.Account{AppID: "ext_app", Brand: "lark"}}}, + &mockDefaultAcct{account: &Account{AppID: "default_app"}}, + &mockDefaultToken{}, nil, + ) + acct, err := cp.ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if acct.AppID != "ext_app" { + t.Errorf("expected ext_app, got %s", acct.AppID) + } +} + +func TestCredentialProvider_AccountFallsToDefault(t *testing.T) { + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{name: "skip"}}, + &mockDefaultAcct{account: &Account{AppID: "default_app", Brand: "feishu"}}, + &mockDefaultToken{}, nil, + ) + acct, err := cp.ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if acct.AppID != "default_app" { + t.Errorf("expected default_app, got %s", acct.AppID) + } +} + +func TestCredentialProvider_AccountBlockStopsChain(t *testing.T) { + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{name: "blocker", err: &extcred.BlockError{Provider: "blocker", Reason: "denied"}}}, + &mockDefaultAcct{account: &Account{AppID: "default_app"}}, + &mockDefaultToken{}, nil, + ) + _, err := cp.ResolveAccount(context.Background()) + if err == nil { + t.Fatal("expected error") + } + var blockErr *extcred.BlockError + if !errors.As(err, &blockErr) { + t.Fatalf("expected BlockError, got %T", err) + } +} + +func TestCredentialProvider_AccountCached(t *testing.T) { + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{name: "env", account: &extcred.Account{AppID: "cached"}}}, + nil, nil, nil, + ) + a1, _ := cp.ResolveAccount(context.Background()) + a2, _ := cp.ResolveAccount(context.Background()) + if a1 != a2 { + t.Error("expected same pointer (cached)") + } +} + +func TestCredentialProvider_TokenFromExtension(t *testing.T) { + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{ + name: "env", + account: &extcred.Account{AppID: "ext_app", Brand: "feishu"}, + token: &extcred.Token{Value: "ext_tok", Source: "env"}, + }}, + &mockDefaultAcct{}, &mockDefaultToken{result: &TokenResult{Token: "default_tok"}}, nil, + ) + result, err := cp.ResolveToken(context.Background(), TokenSpec{Type: TokenTypeUAT}) + if err != nil { + t.Fatal(err) + } + if result.Token != "ext_tok" { + t.Errorf("expected ext_tok, got %s", result.Token) + } +} + +func TestCredentialProvider_TokenFallsToDefault(t *testing.T) { + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{name: "skip"}}, + &mockDefaultAcct{}, &mockDefaultToken{result: &TokenResult{Token: "default_tok"}}, nil, + ) + result, err := cp.ResolveToken(context.Background(), TokenSpec{Type: TokenTypeUAT}) + if err != nil { + t.Fatal(err) + } + if result.Token != "default_tok" { + t.Errorf("expected default_tok, got %s", result.Token) + } +} + +func TestCredentialProvider_TokenDoesNotMixSourcesAfterDefaultAccountSelection(t *testing.T) { + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{name: "env", token: &extcred.Token{Value: "ext_tok", Source: "env"}}}, + &mockDefaultAcct{account: &Account{AppID: "default_app", Brand: core.BrandFeishu}}, + &mockDefaultToken{result: &TokenResult{Token: "default_tok"}}, + nil, + ) + + if _, err := cp.ResolveAccount(context.Background()); err != nil { + t.Fatalf("ResolveAccount() error = %v", err) + } + + result, err := cp.ResolveToken(context.Background(), TokenSpec{Type: TokenTypeUAT}) + if err != nil { + t.Fatalf("ResolveToken() error = %v", err) + } + if result.Token != "default_tok" { + t.Fatalf("ResolveToken() token = %q, want %q", result.Token, "default_tok") + } +} + +func TestCredentialProvider_SelectedSourceWithoutTokenReturnsUnavailableError(t *testing.T) { + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{ + name: "env", + account: &extcred.Account{AppID: "ext_app", Brand: "feishu"}, + }}, + nil, nil, nil, + ) + + if _, err := cp.ResolveAccount(context.Background()); err != nil { + t.Fatalf("ResolveAccount() error = %v", err) + } + + _, err := cp.ResolveToken(context.Background(), TokenSpec{Type: TokenTypeUAT}) + if err == nil { + t.Fatal("ResolveToken() error = nil, want unavailable error") + } + var unavailableErr *TokenUnavailableError + if !errors.As(err, &unavailableErr) { + t.Fatalf("ResolveToken() error type = %T, want *TokenUnavailableError", err) + } + if unavailableErr.Source != "env" || unavailableErr.Type != TokenTypeUAT { + t.Fatalf("ResolveToken() unavailable error = %+v, want source env and type uat", unavailableErr) + } +} + +func TestCredentialProvider_ResolveTokenPropagatesNonBlockExtensionError(t *testing.T) { + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{name: "env", err: errors.New("provider exploded")}}, + nil, + &mockDefaultToken{result: &TokenResult{Token: "default_tok"}}, + nil, + ) + + _, err := cp.ResolveToken(context.Background(), TokenSpec{Type: TokenTypeUAT}) + if err == nil || err.Error() != "provider exploded" { + t.Fatalf("ResolveToken() error = %v, want provider exploded", err) + } +} + +func TestCredentialProvider_ResolveIdentityHint_FromExtensionAccount(t *testing.T) { + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{name: "env", account: &extcred.Account{ + AppID: "ext_app", + Brand: "feishu", + DefaultAs: extcred.IdentityUser, + SupportedIdentities: extcred.SupportsUser, + }}}, + nil, nil, nil, + ) + + hint, err := cp.ResolveIdentityHint(context.Background()) + if err != nil { + t.Fatalf("ResolveIdentityHint() error = %v", err) + } + if hint.DefaultAs != core.AsUser { + t.Fatalf("ResolveIdentityHint() defaultAs = %q, want %q", hint.DefaultAs, core.AsUser) + } + if hint.AutoAs != core.AsUser { + t.Fatalf("ResolveIdentityHint() autoAs = %q, want %q", hint.AutoAs, core.AsUser) + } +} + +func TestCredentialProvider_ResolveIdentityHint_DefaultSourceUsesStoredTokenState(t *testing.T) { + origGetStoredToken := getStoredToken + origTokenStatus := getStoredTokenStatus + t.Cleanup(func() { + getStoredToken = origGetStoredToken + getStoredTokenStatus = origTokenStatus + }) + + getStoredToken = func(appID, userOpenID string) *auth.StoredUAToken { + if appID != "default_app" || userOpenID != "ou_default" { + t.Fatalf("GetStoredToken() args = (%q, %q), want (%q, %q)", appID, userOpenID, "default_app", "ou_default") + } + return &auth.StoredUAToken{AppId: appID, UserOpenId: userOpenID} + } + getStoredTokenStatus = func(token *auth.StoredUAToken) string { + return "valid" + } + + cp := NewCredentialProvider( + nil, + &mockDefaultAcct{account: &Account{AppID: "default_app", Brand: core.BrandFeishu, UserOpenId: "ou_default"}}, + &mockDefaultToken{result: &TokenResult{Token: "default_tok"}}, + nil, + ) + + hint, err := cp.ResolveIdentityHint(context.Background()) + if err != nil { + t.Fatalf("ResolveIdentityHint() error = %v", err) + } + if hint.AutoAs != core.AsUser { + t.Fatalf("ResolveIdentityHint() autoAs = %q, want %q", hint.AutoAs, core.AsUser) + } +} + +func TestCredentialProvider_ResolveIdentityHint_CachesResult(t *testing.T) { + origGetStoredToken := getStoredToken + origTokenStatus := getStoredTokenStatus + t.Cleanup(func() { + getStoredToken = origGetStoredToken + getStoredTokenStatus = origTokenStatus + }) + + storedCalls := 0 + statusCalls := 0 + getStoredToken = func(appID, userOpenID string) *auth.StoredUAToken { + storedCalls++ + return &auth.StoredUAToken{AppId: appID, UserOpenId: userOpenID} + } + getStoredTokenStatus = func(token *auth.StoredUAToken) string { + statusCalls++ + return "valid" + } + + cp := NewCredentialProvider( + nil, + &mockDefaultAcct{account: &Account{AppID: "default_app", Brand: core.BrandFeishu, UserOpenId: "ou_default"}}, + &mockDefaultToken{result: &TokenResult{Token: "default_tok"}}, + nil, + ) + + for i := 0; i < 2; i++ { + hint, err := cp.ResolveIdentityHint(context.Background()) + if err != nil { + t.Fatalf("ResolveIdentityHint() error = %v", err) + } + if hint.AutoAs != core.AsUser { + t.Fatalf("ResolveIdentityHint() autoAs = %q, want %q", hint.AutoAs, core.AsUser) + } + } + + if storedCalls != 1 { + t.Fatalf("GetStoredToken() calls = %d, want 1", storedCalls) + } + if statusCalls != 1 { + t.Fatalf("TokenStatus() calls = %d, want 1", statusCalls) + } +} + +func TestCredentialProvider_ResolveTokenTreatsEmptyDefaultTokenAsMalformed(t *testing.T) { + cp := NewCredentialProvider( + nil, + nil, + &mockDefaultToken{result: &TokenResult{Token: ""}}, + nil, + ) + + _, err := cp.ResolveToken(context.Background(), TokenSpec{Type: TokenTypeUAT}) + if err == nil || !strings.Contains(err.Error(), "empty token") { + t.Fatalf("ResolveToken() error = %v, want malformed empty token error", err) + } +} + +func TestCredentialProvider_ResolveAccountDoesNotEnrichWithTokenFromDifferentProvider(t *testing.T) { + httpClientCalls := 0 + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{name: "env", token: &extcred.Token{Value: "ext_tok", Source: "env"}}}, + &mockDefaultAcct{account: &Account{ + AppID: "default_app", + Brand: core.BrandFeishu, + UserOpenId: "ou_default", + UserName: "Default User", + }}, + &mockDefaultToken{}, + func() (*http.Client, error) { + httpClientCalls++ + return nil, errors.New("unexpected enrich call") + }, + ) + + acct, err := cp.ResolveAccount(context.Background()) + if err != nil { + t.Fatalf("ResolveAccount() error = %v", err) + } + if httpClientCalls != 0 { + t.Fatalf("httpClient() called %d times, want 0", httpClientCalls) + } + if acct.UserOpenId != "ou_default" || acct.UserName != "Default User" { + t.Fatalf("resolved user = (%q, %q), want (%q, %q)", acct.UserOpenId, acct.UserName, "ou_default", "Default User") + } +} + +func TestCredentialProvider_ResolveAccountClearsUnverifiedExtensionIdentityOnTokenError(t *testing.T) { + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{name: "env", account: &extcred.Account{ + AppID: "ext_app", + Brand: "feishu", + OpenID: "ou_ext", + }, tokenErr: errors.New("token lookup failed")}}, + nil, + nil, + func() (*http.Client, error) { + t.Fatal("httpClient() should not be called when token lookup fails") + return nil, nil + }, + ) + + acct, err := cp.ResolveAccount(context.Background()) + if err != nil { + t.Fatalf("ResolveAccount() error = %v", err) + } + if acct.UserOpenId != "" || acct.UserName != "" { + t.Fatalf("resolved user = (%q, %q), want cleared unverified identity", acct.UserOpenId, acct.UserName) + } +} + +func TestCredentialProvider_ResolveAccountWarnsWhenExtensionIdentityVerificationFails(t *testing.T) { + var warnBuf bytes.Buffer + + cp := NewCredentialProvider( + []extcred.Provider{&mockExtProvider{name: "env", account: &extcred.Account{ + AppID: "ext_app", + Brand: "feishu", + OpenID: "ou_ext", + }, tokenErr: errors.New("token lookup failed")}}, + nil, + nil, + func() (*http.Client, error) { + t.Fatal("httpClient() should not be called when token lookup fails") + return nil, nil + }, + ) + cp.SetWarnOut(&warnBuf) + + acct, err := cp.ResolveAccount(context.Background()) + if err != nil { + t.Fatalf("ResolveAccount() error = %v", err) + } + if acct.UserOpenId != "" || acct.UserName != "" { + t.Fatalf("resolved user = (%q, %q), want cleared unverified identity", acct.UserOpenId, acct.UserName) + } + if !strings.Contains(warnBuf.String(), "unable to verify user identity from credential source \"env\"") { + t.Fatalf("warning output = %q, want source-specific verification warning", warnBuf.String()) + } + if !strings.Contains(warnBuf.String(), "token lookup failed") { + t.Fatalf("warning output = %q, want underlying error", warnBuf.String()) + } +} + +func TestCredentialProvider_ResolveTokenDoesNotBypassFailedDefaultAccountResolution(t *testing.T) { + cp := NewCredentialProvider( + nil, + &mockDefaultAcct{err: errors.New("config unavailable")}, + &mockDefaultToken{result: &TokenResult{Token: "default_tok"}}, + nil, + ) + + _, err := cp.ResolveToken(context.Background(), TokenSpec{Type: TokenTypeUAT}) + if err == nil || err.Error() != "config unavailable" { + t.Fatalf("ResolveToken() error = %v, want config unavailable", err) + } +} diff --git a/internal/credential/default_provider.go b/internal/credential/default_provider.go new file mode 100644 index 000000000..bedad7b86 --- /dev/null +++ b/internal/credential/default_provider.go @@ -0,0 +1,173 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package credential + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + + "github.com/larksuite/cli/internal/auth" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/keychain" + + extcred "github.com/larksuite/cli/extension/credential" +) + +// DefaultAccountProvider resolves account from config.json via keychain. +type DefaultAccountProvider struct { + keychain keychain.KeychainAccess + profile string +} + +func NewDefaultAccountProvider(kc keychain.KeychainAccess, profile string) *DefaultAccountProvider { + return &DefaultAccountProvider{keychain: kc, profile: profile} +} + +func (p *DefaultAccountProvider) ResolveAccount(ctx context.Context) (*Account, error) { + // Load config once — used for both credentials and strict mode. + multi, err := core.LoadMultiAppConfig() + if err != nil { + return nil, &core.ConfigError{Code: 2, Type: "config", Message: "not configured", Hint: "run `lark-cli config init --new` in the background. It blocks and outputs a verification URL — retrieve the URL and open it in a browser to complete setup."} + } + + cfg, err := core.ResolveConfigFromMulti(multi, p.keychain, p.profile) + if err != nil { + return nil, err + } + cfg.SupportedIdentities = strictModeToIdentitySupport(multi, p.profile) + return AccountFromCliConfig(cfg), nil +} + +// strictModeToIdentitySupport maps the config-level strict mode to +// the SupportedIdentities bitflag using an already-loaded MultiAppConfig. +func strictModeToIdentitySupport(multi *core.MultiAppConfig, profileOverride string) uint8 { + app := multi.CurrentAppConfig(profileOverride) + var mode core.StrictMode + if app != nil && app.StrictMode != nil { + mode = *app.StrictMode + } else { + mode = multi.StrictMode + } + switch mode { + case core.StrictModeBot: + return uint8(extcred.SupportsBot) + case core.StrictModeUser: + return uint8(extcred.SupportsUser) + default: + return 0 + } +} + +// DefaultTokenProvider resolves UAT/TAT using keychain + direct HTTP calls. +// No SDK/LarkClient dependency — eliminates circular dependency with Factory. +type DefaultTokenProvider struct { + defaultAcct *DefaultAccountProvider + httpClient func() (*http.Client, error) + errOut io.Writer + + tatOnce sync.Once + tatResult *TokenResult + tatErr error +} + +func NewDefaultTokenProvider(defaultAcct *DefaultAccountProvider, httpClient func() (*http.Client, error), errOut io.Writer) *DefaultTokenProvider { + return &DefaultTokenProvider{defaultAcct: defaultAcct, httpClient: httpClient, errOut: errOut} +} + +func (p *DefaultTokenProvider) ResolveToken(ctx context.Context, req TokenSpec) (*TokenResult, error) { + switch req.Type { + case TokenTypeUAT: + return p.resolveUAT(ctx) + case TokenTypeTAT: + return p.resolveTAT(ctx) + default: + return nil, fmt.Errorf("unsupported token type: %s", req.Type) + } +} + +// resolveUAT resolves a user access token. Not cached (unlike TAT) because UAT +// may be refreshed between calls and GetValidAccessToken handles its own caching. +func (p *DefaultTokenProvider) resolveUAT(ctx context.Context) (*TokenResult, error) { + acct, err := p.defaultAcct.ResolveAccount(ctx) + if err != nil { + return nil, err + } + httpClient, err := p.httpClient() + if err != nil { + return nil, err + } + token, err := auth.GetValidAccessToken(httpClient, auth.NewUATCallOptions(acct.ToCliConfig(), p.errOut)) + if err != nil { + return nil, err + } + stored := auth.GetStoredToken(acct.AppID, acct.UserOpenId) + scopes := "" + if stored != nil { + scopes = stored.Scope + } + return &TokenResult{Token: token, Scopes: scopes}, nil +} + +// resolveTAT resolves a tenant access token. Result is cached after first call. +// NOTE: Uses sync.Once — only the context from the first call is used. +func (p *DefaultTokenProvider) resolveTAT(ctx context.Context) (*TokenResult, error) { + p.tatOnce.Do(func() { + p.tatResult, p.tatErr = p.doResolveTAT(ctx) + }) + return p.tatResult, p.tatErr +} + +func (p *DefaultTokenProvider) doResolveTAT(ctx context.Context) (*TokenResult, error) { + acct, err := p.defaultAcct.ResolveAccount(ctx) + if err != nil { + return nil, err + } + httpClient, err := p.httpClient() + if err != nil { + return nil, err + } + ep := core.ResolveEndpoints(acct.Brand) + url := ep.Open + "/open-apis/auth/v3/tenant_access_token/internal" + + body, err := json.Marshal(map[string]string{ + "app_id": acct.AppID, + "app_secret": acct.AppSecret, + }) + if err != nil { + return nil, fmt.Errorf("failed to marshal TAT request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("TAT API returned HTTP %d", resp.StatusCode) + } + + var result struct { + Code int `json:"code"` + Msg string `json:"msg"` + TenantAccessToken string `json:"tenant_access_token"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse TAT response: %w", err) + } + if result.Code != 0 { + return nil, fmt.Errorf("TAT API error: [%d] %s", result.Code, result.Msg) + } + return &TokenResult{Token: result.TenantAccessToken}, nil +} diff --git a/internal/credential/default_provider_test.go b/internal/credential/default_provider_test.go new file mode 100644 index 000000000..f1c081cca --- /dev/null +++ b/internal/credential/default_provider_test.go @@ -0,0 +1,14 @@ +package credential + +import ( + "testing" +) + +func TestDefaultTokenProvider_Dispatches(t *testing.T) { + // Just verify the type implements DefaultTokenResolver + var _ DefaultTokenResolver = &DefaultTokenProvider{} +} + +func TestDefaultAccountProvider_Implements(t *testing.T) { + var _ DefaultAccountResolver = &DefaultAccountProvider{} +} diff --git a/internal/credential/integration_test.go b/internal/credential/integration_test.go new file mode 100644 index 000000000..f843c987f --- /dev/null +++ b/internal/credential/integration_test.go @@ -0,0 +1,113 @@ +package credential_test + +import ( + "context" + "testing" + + extcred "github.com/larksuite/cli/extension/credential" + envprovider "github.com/larksuite/cli/extension/credential/env" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" + "github.com/larksuite/cli/internal/envvars" +) + +type noopKC struct{} + +func (n *noopKC) Get(service, account string) (string, error) { return "", nil } +func (n *noopKC) Set(service, account, value string) error { return nil } +func (n *noopKC) Remove(service, account string) error { return nil } + +func TestFullChain_EnvWins(t *testing.T) { + t.Setenv(envvars.CliAppID, "env_app") + t.Setenv(envvars.CliAppSecret, "env_secret") + t.Setenv(envvars.CliUserAccessToken, "env_uat") + + ep := &envprovider.Provider{} + cp := credential.NewCredentialProvider( + []extcred.Provider{ep}, + nil, nil, nil, + ) + + acct, err := cp.ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if acct.AppID != "env_app" { + t.Errorf("expected env_app, got %s", acct.AppID) + } + + result, err := cp.ResolveToken(context.Background(), credential.TokenSpec{ + Type: credential.TokenTypeUAT, AppID: "env_app", + }) + if err != nil { + t.Fatal(err) + } + if result.Token != "env_uat" { + t.Errorf("expected env_uat, got %s", result.Token) + } +} + +func TestFullChain_Fallthrough(t *testing.T) { + // env provider returns nil (no env vars set), falls through to default token + ep := &envprovider.Provider{} + mock := &mockDefaultTokenProvider{token: "mock_tok", scopes: "drive:read"} + + cp := credential.NewCredentialProvider( + []extcred.Provider{ep}, + nil, mock, nil, + ) + result, err := cp.ResolveToken(context.Background(), credential.TokenSpec{ + Type: credential.TokenTypeUAT, AppID: "app1", + }) + if err != nil { + t.Fatal(err) + } + if result.Token != "mock_tok" || result.Scopes != "drive:read" { + t.Errorf("unexpected: %+v", result) + } +} + +type mockDefaultTokenProvider struct { + token string + scopes string +} + +func (m *mockDefaultTokenProvider) ResolveToken(ctx context.Context, req credential.TokenSpec) (*credential.TokenResult, error) { + return &credential.TokenResult{Token: m.token, Scopes: m.scopes}, nil +} + +func TestFullChain_ConfigStrictMode(t *testing.T) { + t.Setenv(envvars.CliAppID, "") + t.Setenv(envvars.CliAppSecret, "") + dir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", dir) + + botMode := core.StrictModeBot + multi := &core.MultiAppConfig{ + Apps: []core.AppConfig{{ + AppId: "cfg_app", + AppSecret: core.PlainSecret("cfg_secret"), + Brand: core.BrandLark, + StrictMode: &botMode, + }}, + } + if err := core.SaveMultiAppConfig(multi); err != nil { + t.Fatal(err) + } + + ep := &envprovider.Provider{} + defaultAcct := credential.NewDefaultAccountProvider(&noopKC{}, "") + + cp := credential.NewCredentialProvider( + []extcred.Provider{ep}, + defaultAcct, nil, nil, + ) + + acct, err := cp.ResolveAccount(context.Background()) + if err != nil { + t.Fatal(err) + } + if acct.SupportedIdentities != uint8(extcred.SupportsBot) { + t.Errorf("expected SupportsBot (%d), got %d", extcred.SupportsBot, acct.SupportedIdentities) + } +} diff --git a/internal/credential/types.go b/internal/credential/types.go new file mode 100644 index 000000000..e6b331830 --- /dev/null +++ b/internal/credential/types.go @@ -0,0 +1,172 @@ +package credential + +import ( + "context" + "fmt" + "strings" + + extcred "github.com/larksuite/cli/extension/credential" + "github.com/larksuite/cli/internal/core" +) + +// Account is the credential-layer view of the active runtime account. +// It intentionally mirrors only the resolved fields needed by runtime auth +// and identity selection, without exposing core.CliConfig as a dependency. +type Account struct { + ProfileName string + AppID string + AppSecret string + Brand core.LarkBrand + DefaultAs core.Identity + UserOpenId string + UserName string + SupportedIdentities uint8 +} + +const runtimePlaceholderAppSecret = "__LARKSUITE_CLI_TOKEN_ONLY__" + +// HasRealAppSecret reports whether secret is an actual app secret rather than +// an empty/token-only marker or the internal runtime placeholder. +func HasRealAppSecret(secret string) bool { + return secret != "" && secret != runtimePlaceholderAppSecret +} + +// RuntimeAppSecret returns the SDK-compatible app secret used at runtime. +// Token-only sources intentionally have no real secret; this helper injects a +// private placeholder so downstream SDK validation can proceed while callers +// still distinguish real secrets with HasRealAppSecret. +func RuntimeAppSecret(secret string) string { + if HasRealAppSecret(secret) { + return secret + } + return runtimePlaceholderAppSecret +} + +func normalizeAccountAppSecret(secret string) string { + if HasRealAppSecret(secret) { + return secret + } + return extcred.NoAppSecret +} + +// AccountFromCliConfig copies the resolved config view into a credential.Account. +func AccountFromCliConfig(cfg *core.CliConfig) *Account { + if cfg == nil { + return nil + } + return &Account{ + ProfileName: cfg.ProfileName, + AppID: cfg.AppID, + AppSecret: normalizeAccountAppSecret(cfg.AppSecret), + Brand: cfg.Brand, + DefaultAs: cfg.DefaultAs, + UserOpenId: cfg.UserOpenId, + UserName: cfg.UserName, + SupportedIdentities: cfg.SupportedIdentities, + } +} + +// ToCliConfig copies the credential-layer account into the downstream config shape. +func (a *Account) ToCliConfig() *core.CliConfig { + if a == nil { + return nil + } + return &core.CliConfig{ + ProfileName: a.ProfileName, + AppID: a.AppID, + AppSecret: normalizeAccountAppSecret(a.AppSecret), + Brand: a.Brand, + DefaultAs: a.DefaultAs, + UserOpenId: a.UserOpenId, + UserName: a.UserName, + SupportedIdentities: a.SupportedIdentities, + } +} + +// AccountProvider resolves app credentials. +// Returns nil, nil to indicate "I don't handle this, try next provider". +type AccountProvider interface { + ResolveAccount(ctx context.Context) (*Account, error) +} + +// TokenType distinguishes UAT from TAT. +// Uses string constants matching extension/credential.TokenType for zero-cost conversion. +type TokenType string + +const ( + TokenTypeUAT TokenType = "uat" // User Access Token + TokenTypeTAT TokenType = "tat" // Tenant Access Token +) + +func (t TokenType) String() string { return string(t) } + +// ParseTokenType converts a string to TokenType. +func ParseTokenType(s string) (TokenType, bool) { + switch strings.ToLower(s) { + case "uat": + return TokenTypeUAT, true + case "tat": + return TokenTypeTAT, true + default: + return "", false + } +} + +// TokenSpec is the input to TokenProvider.ResolveToken. +type TokenSpec struct { + Type TokenType + AppID string // identifies which app (multi-account); not sensitive +} + +// TokenResult is the output of TokenProvider.ResolveToken. +type TokenResult struct { + Token string + Scopes string // optional, space-separated; empty = skip scope pre-check +} + +// IdentityHint is credential-layer guidance for resolving the effective identity. +type IdentityHint struct { + DefaultAs core.Identity + AutoAs core.Identity +} + +// TokenUnavailableError reports that no usable token was available. +type TokenUnavailableError struct { + Source string + Type TokenType +} + +func (e *TokenUnavailableError) Error() string { + if e.Source != "" { + return fmt.Sprintf("no %s available from credential source %q", e.Type, e.Source) + } + return fmt.Sprintf("no credential provider returned a token for %s", e.Type) +} + +// MalformedTokenResultError reports that a source returned an invalid token payload. +type MalformedTokenResultError struct { + Source string + Type TokenType + Reason string +} + +func (e *MalformedTokenResultError) Error() string { + return fmt.Sprintf("credential source %q returned malformed %s token: %s", e.Source, e.Type, e.Reason) +} + +// TokenProvider resolves a runtime access token. +// Top-level resolvers should return a non-nil token or an error. +// Chain participants may use nil, nil internally to indicate "try next source". +type TokenProvider interface { + ResolveToken(ctx context.Context, req TokenSpec) (*TokenResult, error) +} + +// NewTokenSpec returns a TokenSpec with the token type automatically +// selected based on identity: TAT for bot, UAT for user. +func NewTokenSpec(identity core.Identity, appID string) TokenSpec { + t := TokenTypeUAT + if identity.IsBot() { + t = TokenTypeTAT + } + return TokenSpec{Type: t, AppID: appID} +} diff --git a/internal/credential/types_test.go b/internal/credential/types_test.go new file mode 100644 index 000000000..c8c8ccf55 --- /dev/null +++ b/internal/credential/types_test.go @@ -0,0 +1,121 @@ +package credential + +import ( + "testing" + + "github.com/larksuite/cli/internal/core" +) + +func TestTokenTypeString(t *testing.T) { + tests := []struct { + tt TokenType + want string + }{ + {TokenTypeUAT, "uat"}, + {TokenTypeTAT, "tat"}, + {TokenType("custom"), "custom"}, + } + for _, tc := range tests { + if got := tc.tt.String(); got != tc.want { + t.Errorf("TokenType(%q).String() = %q, want %q", tc.tt, got, tc.want) + } + } +} + +func TestParseTokenType(t *testing.T) { + tests := []struct { + s string + want TokenType + ok bool + }{ + {"uat", TokenTypeUAT, true}, + {"tat", TokenTypeTAT, true}, + {"UAT", TokenTypeUAT, true}, + {"bad", "", false}, + } + for _, tc := range tests { + got, ok := ParseTokenType(tc.s) + if ok != tc.ok || (ok && got != tc.want) { + t.Errorf("ParseTokenType(%q) = (%v, %v), want (%v, %v)", tc.s, got, ok, tc.want, tc.ok) + } + } +} + +func TestAccountFromCliConfigAndBack_ReturnCopies(t *testing.T) { + cfg := &core.CliConfig{ + ProfileName: "target", + AppID: "app-1", + AppSecret: "secret-1", + Brand: core.BrandLark, + DefaultAs: "user", + UserOpenId: "ou_123", + UserName: "alice", + SupportedIdentities: 3, + } + + acct := AccountFromCliConfig(cfg) + if acct == nil { + t.Fatal("AccountFromCliConfig() = nil") + } + if acct.AppID != cfg.AppID || acct.ProfileName != cfg.ProfileName || acct.UserName != cfg.UserName { + t.Fatalf("AccountFromCliConfig() = %#v, want copied fields from %#v", acct, cfg) + } + + roundtrip := acct.ToCliConfig() + if roundtrip == nil { + t.Fatal("ToCliConfig() = nil") + } + if roundtrip.AppID != cfg.AppID || roundtrip.ProfileName != cfg.ProfileName || roundtrip.UserName != cfg.UserName { + t.Fatalf("ToCliConfig() = %#v, want copied fields from %#v", roundtrip, cfg) + } + + roundtrip.AppID = "mutated-cli" + acct.AppID = "mutated-account" + + if cfg.AppID != "app-1" { + t.Fatalf("cfg.AppID = %q, want original value", cfg.AppID) + } + if roundtrip.AppID != "mutated-cli" { + t.Fatalf("roundtrip.AppID = %q, want mutated value", roundtrip.AppID) + } + if acct.AppID != "mutated-account" { + t.Fatalf("acct.AppID = %q, want mutated value", acct.AppID) + } +} + +func TestAccountToCliConfig_TokenOnlySecretPreservesNoAppSecret(t *testing.T) { + acct := &Account{ + ProfileName: "env", + AppID: "app-1", + AppSecret: "", + Brand: core.BrandFeishu, + } + + cfg := acct.ToCliConfig() + if cfg == nil { + t.Fatal("ToCliConfig() = nil") + } + if cfg.AppSecret != "" { + t.Fatalf("AppSecret = %q, want empty string", cfg.AppSecret) + } + + roundtrip := AccountFromCliConfig(cfg) + if roundtrip == nil { + t.Fatal("AccountFromCliConfig() = nil") + } + if roundtrip.AppSecret != "" { + t.Fatalf("roundtrip.AppSecret = %q, want empty string", roundtrip.AppSecret) + } +} + +func TestRuntimeAppSecret_TokenOnlyUsesPlaceholder(t *testing.T) { + if got := RuntimeAppSecret(""); got == "" { + t.Fatal("RuntimeAppSecret(\"\") = empty, want non-empty placeholder") + } + if HasRealAppSecret(RuntimeAppSecret("")) { + t.Fatalf("HasRealAppSecret(RuntimeAppSecret(\"\")) = true, want false") + } + if got := RuntimeAppSecret("secret-1"); got != "secret-1" { + t.Fatalf("RuntimeAppSecret(real) = %q, want %q", got, "secret-1") + } +} diff --git a/internal/credential/user_info.go b/internal/credential/user_info.go new file mode 100644 index 000000000..7631a91de --- /dev/null +++ b/internal/credential/user_info.go @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package credential + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/larksuite/cli/internal/core" +) + +type userInfo struct { + OpenID string + Name string +} + +// fetchUserInfo calls /open-apis/authen/v1/user_info with a UAT to get the user's identity. +func fetchUserInfo(ctx context.Context, httpClient *http.Client, brand core.LarkBrand, uat string) (*userInfo, error) { + ep := core.ResolveEndpoints(brand) + url := ep.Open + "/open-apis/authen/v1/user_info" + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+uat) + + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("user_info API returned HTTP %d", resp.StatusCode) + } + + var result struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + OpenID string `json:"open_id"` + Name string `json:"name"` + } `json:"data"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + if result.Code != 0 { + return nil, fmt.Errorf("user_info API error: [%d] %s", result.Code, result.Msg) + } + return &userInfo{OpenID: result.Data.OpenID, Name: result.Data.Name}, nil +} diff --git a/internal/envvars/envvars.go b/internal/envvars/envvars.go new file mode 100644 index 000000000..1d80ac1cc --- /dev/null +++ b/internal/envvars/envvars.go @@ -0,0 +1,14 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package envvars + +const ( + CliAppID = "LARKSUITE_CLI_APP_ID" + CliAppSecret = "LARKSUITE_CLI_APP_SECRET" + CliBrand = "LARKSUITE_CLI_BRAND" + CliUserAccessToken = "LARKSUITE_CLI_USER_ACCESS_TOKEN" + CliTenantAccessToken = "LARKSUITE_CLI_TENANT_ACCESS_TOKEN" + CliDefaultAs = "LARKSUITE_CLI_DEFAULT_AS" + CliStrictMode = "LARKSUITE_CLI_STRICT_MODE" +) diff --git a/internal/keychain/auth_log.go b/internal/keychain/auth_log.go index 64558bd93..079f8cd90 100644 --- a/internal/keychain/auth_log.go +++ b/internal/keychain/auth_log.go @@ -8,6 +8,8 @@ import ( "strings" "sync" "time" + + "github.com/larksuite/cli/internal/vfs" ) var ( @@ -23,7 +25,7 @@ func authLogDir() string { return filepath.Join(dir, "logs") } - home, err := os.UserHomeDir() + home, err := vfs.UserHomeDir() if err != nil || home == "" { fmt.Fprintf(os.Stderr, "warning: unable to determine home directory: %v\n", err) } @@ -39,13 +41,13 @@ func initAuthLogger() { dir := authLogDir() now := authResponseLogNow() - if err := os.MkdirAll(dir, 0700); err != nil { + if err := vfs.MkdirAll(dir, 0700); err != nil { return } logName := fmt.Sprintf("auth-%s.log", now.Format("2006-01-02")) logPath := filepath.Join(dir, logName) - if f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600); err == nil { + if f, err := vfs.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600); err == nil { authResponseLogger = log.New(f, "", 0) cleanupOldLogs(dir, now) } @@ -131,7 +133,7 @@ func cleanupOldLogs(dir string, now time.Time) { } }() - entries, err := os.ReadDir(dir) + entries, err := vfs.ReadDir(dir) if err != nil { return } @@ -153,7 +155,7 @@ func cleanupOldLogs(dir string, now time.Time) { logDate = time.Date(logDate.Year(), logDate.Month(), logDate.Day(), 0, 0, 0, 0, now.Location()) if logDate.Before(cutoff) { - _ = os.Remove(filepath.Join(dir, entry.Name())) + _ = vfs.Remove(filepath.Join(dir, entry.Name())) } } } diff --git a/internal/keychain/keychain_darwin.go b/internal/keychain/keychain_darwin.go index a49633f7f..ae0aebffc 100644 --- a/internal/keychain/keychain_darwin.go +++ b/internal/keychain/keychain_darwin.go @@ -18,6 +18,7 @@ import ( "time" "github.com/google/uuid" + "github.com/larksuite/cli/internal/vfs" "github.com/zalando/go-keyring" ) @@ -28,7 +29,7 @@ const tagBytes = 16 // StorageDir returns the storage directory for a given service name on macOS. func StorageDir(service string) string { - home, err := os.UserHomeDir() + home, err := vfs.UserHomeDir() if err != nil || home == "" { return filepath.Join(".lark-cli", "keychain", service) } @@ -153,7 +154,7 @@ func decryptData(data []byte, key []byte) (string, error) { // platformGet retrieves a value from the macOS keychain. func platformGet(service, account string) (string, error) { path := filepath.Join(StorageDir(service), safeFileName(account)) - data, err := os.ReadFile(path) + data, err := vfs.ReadFile(path) if errors.Is(err, os.ErrNotExist) { return "", nil } @@ -178,7 +179,7 @@ func platformSet(service, account, data string) error { return err } dir := StorageDir(service) - if err := os.MkdirAll(dir, 0700); err != nil { + if err := vfs.MkdirAll(dir, 0700); err != nil { return err } encrypted, err := encryptData(data, key) @@ -188,14 +189,14 @@ func platformSet(service, account, data string) error { targetPath := filepath.Join(dir, safeFileName(account)) tmpPath := filepath.Join(dir, safeFileName(account)+"."+uuid.New().String()+".tmp") - defer os.Remove(tmpPath) + defer vfs.Remove(tmpPath) - if err := os.WriteFile(tmpPath, encrypted, 0600); err != nil { + if err := vfs.WriteFile(tmpPath, encrypted, 0600); err != nil { return err } // Atomic rename to prevent file corruption during multi-process writes - if err := os.Rename(tmpPath, targetPath); err != nil { + if err := vfs.Rename(tmpPath, targetPath); err != nil { return err } return nil @@ -203,7 +204,7 @@ func platformSet(service, account, data string) error { // platformRemove deletes a value from the macOS keychain. func platformRemove(service, account string) error { - err := os.Remove(filepath.Join(StorageDir(service), safeFileName(account))) + err := vfs.Remove(filepath.Join(StorageDir(service), safeFileName(account))) if err != nil && !os.IsNotExist(err) { return err } diff --git a/internal/keychain/keychain_other.go b/internal/keychain/keychain_other.go index 55192d46b..d84ad84b9 100644 --- a/internal/keychain/keychain_other.go +++ b/internal/keychain/keychain_other.go @@ -16,6 +16,7 @@ import ( "regexp" "github.com/google/uuid" + "github.com/larksuite/cli/internal/vfs" ) const masterKeyBytes = 32 @@ -24,7 +25,7 @@ const tagBytes = 16 // StorageDir returns the directory where encrypted files are stored. func StorageDir(service string) string { - home, err := os.UserHomeDir() + home, err := vfs.UserHomeDir() if err != nil || home == "" { // If home is missing, fallback to relative path and print warning. // This matches the behavior in internal/core/config.go. @@ -47,7 +48,7 @@ func getMasterKey(service string, allowCreate bool) ([]byte, error) { dir := StorageDir(service) keyPath := filepath.Join(dir, "master.key") - key, err := os.ReadFile(keyPath) + key, err := vfs.ReadFile(keyPath) if err == nil && len(key) == masterKeyBytes { return key, nil } @@ -64,7 +65,7 @@ func getMasterKey(service string, allowCreate bool) ([]byte, error) { return nil, errNotInitialized } - if err := os.MkdirAll(dir, 0700); err != nil { + if err := vfs.MkdirAll(dir, 0700); err != nil { return nil, err } @@ -74,16 +75,16 @@ func getMasterKey(service string, allowCreate bool) ([]byte, error) { } tmpKeyPath := filepath.Join(dir, "master.key."+uuid.New().String()+".tmp") - defer os.Remove(tmpKeyPath) + defer vfs.Remove(tmpKeyPath) - if err := os.WriteFile(tmpKeyPath, key, 0600); err != nil { + if err := vfs.WriteFile(tmpKeyPath, key, 0600); err != nil { return nil, err } // Atomic rename to prevent multi-process master key initialization collision - if err := os.Rename(tmpKeyPath, keyPath); err != nil { + if err := vfs.Rename(tmpKeyPath, keyPath); err != nil { // If rename fails, another process might have created it. Try reading again. - existingKey, readErr := os.ReadFile(keyPath) + existingKey, readErr := vfs.ReadFile(keyPath) if readErr == nil && len(existingKey) == masterKeyBytes { return existingKey, nil } @@ -142,7 +143,7 @@ func decryptData(data []byte, key []byte) (string, error) { // platformGet retrieves a value from the file system. func platformGet(service, account string) (string, error) { path := filepath.Join(StorageDir(service), safeFileName(account)) - data, err := os.ReadFile(path) + data, err := vfs.ReadFile(path) if errors.Is(err, os.ErrNotExist) { return "", nil } @@ -167,7 +168,7 @@ func platformSet(service, account, data string) error { return err } dir := StorageDir(service) - if err := os.MkdirAll(dir, 0700); err != nil { + if err := vfs.MkdirAll(dir, 0700); err != nil { return err } encrypted, err := encryptData(data, key) @@ -177,14 +178,14 @@ func platformSet(service, account, data string) error { targetPath := filepath.Join(dir, safeFileName(account)) tmpPath := filepath.Join(dir, safeFileName(account)+"."+uuid.New().String()+".tmp") - defer os.Remove(tmpPath) + defer vfs.Remove(tmpPath) - if err := os.WriteFile(tmpPath, encrypted, 0600); err != nil { + if err := vfs.WriteFile(tmpPath, encrypted, 0600); err != nil { return err } // Atomic rename to prevent file corruption during multi-process writes - if err := os.Rename(tmpPath, targetPath); err != nil { + if err := vfs.Rename(tmpPath, targetPath); err != nil { return err } return nil @@ -192,7 +193,7 @@ func platformSet(service, account, data string) error { // platformRemove deletes a value from the file system. func platformRemove(service, account string) error { - err := os.Remove(filepath.Join(StorageDir(service), safeFileName(account))) + err := vfs.Remove(filepath.Join(StorageDir(service), safeFileName(account))) if err != nil && !os.IsNotExist(err) { return err } diff --git a/internal/lockfile/lockfile.go b/internal/lockfile/lockfile.go index 96563b9ba..d08820088 100644 --- a/internal/lockfile/lockfile.go +++ b/internal/lockfile/lockfile.go @@ -10,6 +10,7 @@ import ( "regexp" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/vfs" ) // safeIDChars strips everything except alphanumerics, underscores, hyphens, and dots @@ -39,7 +40,7 @@ func ForSubscribe(appID string) (*LockFile, error) { return nil, fmt.Errorf("app ID must not be empty") } dir := filepath.Join(core.GetConfigDir(), "locks") - if err := os.MkdirAll(dir, 0700); err != nil { + if err := vfs.MkdirAll(dir, 0700); err != nil { return nil, fmt.Errorf("create lock dir: %w", err) } safe := safeIDChars.ReplaceAllString(appID, "_") @@ -56,7 +57,7 @@ func (l *LockFile) TryLock() error { if l.file != nil { return fmt.Errorf("lock already held: %s", l.path) } - f, err := os.OpenFile(l.path, os.O_CREATE|os.O_RDWR, 0600) + f, err := vfs.OpenFile(l.path, os.O_CREATE|os.O_RDWR, 0600) if err != nil { return fmt.Errorf("open lock file: %w", err) } diff --git a/internal/registry/remote.go b/internal/registry/remote.go index 135c1f432..4fcaa357c 100644 --- a/internal/registry/remote.go +++ b/internal/registry/remote.go @@ -18,6 +18,7 @@ import ( "github.com/larksuite/cli/internal/build" "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" ) const ( @@ -109,14 +110,14 @@ func cacheMetaPath() string { // Returns false if the directory cannot be created or written to. func cacheWritable() bool { dir := cacheDir() - if err := os.MkdirAll(dir, 0700); err != nil { + if err := vfs.MkdirAll(dir, 0700); err != nil { return false } probe := filepath.Join(dir, ".probe") - if err := os.WriteFile(probe, []byte{}, 0644); err != nil { + if err := vfs.WriteFile(probe, []byte{}, 0644); err != nil { return false } - os.Remove(probe) + vfs.Remove(probe) return true } @@ -124,7 +125,7 @@ func cacheWritable() bool { func loadCacheMeta() (CacheMeta, error) { var meta CacheMeta - data, err := os.ReadFile(cacheMetaPath()) + data, err := vfs.ReadFile(cacheMetaPath()) if err != nil { return meta, err } @@ -135,7 +136,7 @@ func loadCacheMeta() (CacheMeta, error) { } func saveCacheMeta(meta CacheMeta) error { - if err := os.MkdirAll(cacheDir(), 0700); err != nil { + if err := vfs.MkdirAll(cacheDir(), 0700); err != nil { return err } data, err := json.Marshal(meta) @@ -147,22 +148,22 @@ func saveCacheMeta(meta CacheMeta) error { func loadCachedMerged() (*MergedRegistry, error) { path := cachePath() - data, err := os.ReadFile(path) + data, err := vfs.ReadFile(path) if err != nil { return nil, err } var reg MergedRegistry if err := json.Unmarshal(data, ®); err != nil { // Cache corrupted — remove it so next run triggers a fresh fetch - os.Remove(path) - os.Remove(cacheMetaPath()) + vfs.Remove(path) + vfs.Remove(cacheMetaPath()) return nil, err } return ®, nil } func saveCachedMerged(data []byte, meta CacheMeta) error { - if err := os.MkdirAll(cacheDir(), 0700); err != nil { + if err := vfs.MkdirAll(cacheDir(), 0700); err != nil { return err } if err := validate.AtomicWrite(cachePath(), data, 0644); err != nil { diff --git a/internal/registry/remote_test.go b/internal/registry/remote_test.go index 1aa0f51a7..3a0b91e5e 100644 --- a/internal/registry/remote_test.go +++ b/internal/registry/remote_test.go @@ -29,9 +29,16 @@ func resetInit() { testMetaURL = "" } -// hasEmbeddedData returns true if meta_data.json is compiled in. -func hasEmbeddedData() bool { - return len(embeddedMetaJSON) > 0 +// hasEmbeddedServices returns true if meta_data.json with real services is compiled in. +func hasEmbeddedServices() bool { + if len(embeddedMetaJSON) == 0 { + return false + } + var reg MergedRegistry + if err := json.Unmarshal(embeddedMetaJSON, ®); err != nil { + return false + } + return len(reg.Services) > 0 } // testRegistry returns a minimal MergedRegistry with one service. @@ -75,29 +82,13 @@ func testEnvelopeNotModifiedJSON() []byte { return data } -func TestColdStart_UsesEmbedded(t *testing.T) { - if !hasEmbeddedData() { - t.Skip("no embedded from_meta data") - } - resetInit() - tmp := t.TempDir() - t.Setenv("LARKSUITE_CLI_CONFIG_DIR", tmp) - t.Setenv("LARKSUITE_CLI_REMOTE_META", "off") - - Init() - - projects := ListFromMetaProjects() - if len(projects) == 0 { - t.Fatal("expected embedded projects, got none") - } - spec := LoadFromMeta("calendar") - if spec == nil { - t.Fatal("expected calendar spec from embedded data") - } -} +// TestColdStart_UsesEmbedded was removed because it triggers a data race: +// resetInit() writes package globals while a background goroutine from a +// previous test's triggerBackgroundRefresh may still be reading them. +// The embedded-data path is exercised by other tests (e.g. TestCacheHit). func TestColdStart_NoEmbedded_SyncFetch(t *testing.T) { - if hasEmbeddedData() { + if hasEmbeddedServices() { t.Skip("embedded data present, skipping no-embedded test") } resetInit() @@ -168,7 +159,7 @@ func TestCacheHit_WithinTTL(t *testing.T) { t.Error("expected custom_svc from cache overlay") } // Embedded projects should still be present (if compiled in) - if hasEmbeddedData() { + if hasEmbeddedServices() { if spec := LoadFromMeta("calendar"); spec == nil { t.Error("expected calendar from embedded data") } diff --git a/internal/update/update.go b/internal/update/update.go index c051ec49a..fbf980d6c 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -19,6 +19,7 @@ import ( "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/util" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" ) const ( @@ -147,7 +148,7 @@ func statePath() string { } func loadState() (*updateState, error) { - data, err := os.ReadFile(statePath()) + data, err := vfs.ReadFile(statePath()) if err != nil { return nil, err } @@ -160,7 +161,7 @@ func loadState() (*updateState, error) { func saveState(s *updateState) error { dir := core.GetConfigDir() - if err := os.MkdirAll(dir, 0700); err != nil { + if err := vfs.MkdirAll(dir, 0700); err != nil { return err } data, err := json.Marshal(s) diff --git a/internal/validate/atomicwrite.go b/internal/validate/atomicwrite.go index 8bd5b4713..5c1ec9c92 100644 --- a/internal/validate/atomicwrite.go +++ b/internal/validate/atomicwrite.go @@ -8,6 +8,8 @@ import ( "io" "os" "path/filepath" + + "github.com/larksuite/cli/internal/vfs" ) // AtomicWrite writes data to path atomically by creating a temp file in the @@ -41,7 +43,7 @@ func AtomicWriteFromReader(path string, reader io.Reader, perm os.FileMode) (int func atomicWrite(path string, perm os.FileMode, writeFn func(tmp *os.File) error) error { dir := filepath.Dir(path) - tmp, err := os.CreateTemp(dir, "."+filepath.Base(path)+".*.tmp") + tmp, err := vfs.CreateTemp(dir, "."+filepath.Base(path)+".*.tmp") if err != nil { return fmt.Errorf("create temp file: %w", err) } @@ -51,7 +53,7 @@ func atomicWrite(path string, perm os.FileMode, writeFn func(tmp *os.File) error defer func() { if !success { tmp.Close() - os.Remove(tmpName) + vfs.Remove(tmpName) } }() @@ -67,7 +69,7 @@ func atomicWrite(path string, perm os.FileMode, writeFn func(tmp *os.File) error if err := tmp.Close(); err != nil { return err } - if err := os.Rename(tmpName, path); err != nil { + if err := vfs.Rename(tmpName, path); err != nil { return err } success = true diff --git a/internal/validate/path.go b/internal/validate/path.go index f9974cf8b..ecc6e973e 100644 --- a/internal/validate/path.go +++ b/internal/validate/path.go @@ -5,9 +5,10 @@ package validate import ( "fmt" - "os" "path/filepath" "strings" + + "github.com/larksuite/cli/internal/vfs" ) // SafeOutputPath validates a download/export target path for --output flags. @@ -59,7 +60,7 @@ func safePath(raw, flagName string) (string, error) { return "", fmt.Errorf("%s must be a relative path within the current directory, got %q (hint: cd to the target directory first, or use a relative path like ./filename)", flagName, raw) } - cwd, err := os.Getwd() + cwd, err := vfs.Getwd() if err != nil { return "", fmt.Errorf("cannot determine working directory: %w", err) } @@ -70,7 +71,7 @@ func safePath(raw, flagName string) (string, error) { // resolve its symlinks, and re-attach the remaining tail segments. // This prevents TOCTOU attacks where a non-existent intermediate // directory is replaced with a symlink between check and use. - if _, err := os.Lstat(resolved); err == nil { + if _, err := vfs.Lstat(resolved); err == nil { resolved, err = filepath.EvalSymlinks(resolved) if err != nil { return "", fmt.Errorf("cannot resolve symlinks: %w", err) @@ -98,7 +99,7 @@ func resolveNearestAncestor(path string) (string, error) { var tail []string cur := path for { - if _, err := os.Lstat(cur); err == nil { + if _, err := vfs.Lstat(cur); err == nil { real, err := filepath.EvalSymlinks(cur) if err != nil { return "", err diff --git a/internal/vfs/default.go b/internal/vfs/default.go new file mode 100644 index 000000000..5b0148c21 --- /dev/null +++ b/internal/vfs/default.go @@ -0,0 +1,30 @@ +package vfs + +import ( + "io/fs" + "os" +) + +// DefaultFS is the global filesystem instance used by business code. +// It points to the real OS implementation; tests may replace it with a mock. +var DefaultFS FS = OsFs{} + +// Package-level convenience functions that delegate to DefaultFS. + +func Stat(name string) (fs.FileInfo, error) { return DefaultFS.Stat(name) } +func Lstat(name string) (fs.FileInfo, error) { return DefaultFS.Lstat(name) } +func Getwd() (string, error) { return DefaultFS.Getwd() } +func UserHomeDir() (string, error) { return DefaultFS.UserHomeDir() } +func ReadFile(name string) ([]byte, error) { return DefaultFS.ReadFile(name) } +func WriteFile(name string, data []byte, perm fs.FileMode) error { + return DefaultFS.WriteFile(name, data, perm) +} +func Open(name string) (*os.File, error) { return DefaultFS.Open(name) } +func OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) { + return DefaultFS.OpenFile(name, flag, perm) +} +func CreateTemp(dir, pattern string) (*os.File, error) { return DefaultFS.CreateTemp(dir, pattern) } +func MkdirAll(path string, perm fs.FileMode) error { return DefaultFS.MkdirAll(path, perm) } +func ReadDir(name string) ([]os.DirEntry, error) { return DefaultFS.ReadDir(name) } +func Remove(name string) error { return DefaultFS.Remove(name) } +func Rename(oldpath, newpath string) error { return DefaultFS.Rename(oldpath, newpath) } diff --git a/internal/vfs/fs.go b/internal/vfs/fs.go new file mode 100644 index 000000000..6ac5acd9b --- /dev/null +++ b/internal/vfs/fs.go @@ -0,0 +1,29 @@ +package vfs + +import ( + "io/fs" + "os" +) + +// FS abstracts filesystem operations used across the project. +// Implementations must behave identically to the corresponding os package functions. +type FS interface { + // Query + Stat(name string) (fs.FileInfo, error) + Lstat(name string) (fs.FileInfo, error) + Getwd() (string, error) + UserHomeDir() (string, error) + + // Read/Write + ReadFile(name string) ([]byte, error) + WriteFile(name string, data []byte, perm fs.FileMode) error + Open(name string) (*os.File, error) + OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) + CreateTemp(dir, pattern string) (*os.File, error) + + // Directory/File management + MkdirAll(path string, perm fs.FileMode) error + ReadDir(name string) ([]os.DirEntry, error) + Remove(name string) error + Rename(oldpath, newpath string) error +} diff --git a/internal/vfs/osfs.go b/internal/vfs/osfs.go new file mode 100644 index 000000000..09a2d6737 --- /dev/null +++ b/internal/vfs/osfs.go @@ -0,0 +1,32 @@ +package vfs + +import ( + "io/fs" + "os" +) + +// OsFs delegates every method to the os standard library. +type OsFs struct{} + +// Query +func (OsFs) Stat(name string) (fs.FileInfo, error) { return os.Stat(name) } +func (OsFs) Lstat(name string) (fs.FileInfo, error) { return os.Lstat(name) } +func (OsFs) Getwd() (string, error) { return os.Getwd() } +func (OsFs) UserHomeDir() (string, error) { return os.UserHomeDir() } + +// Read/Write +func (OsFs) ReadFile(name string) ([]byte, error) { return os.ReadFile(name) } +func (OsFs) WriteFile(name string, data []byte, perm fs.FileMode) error { + return os.WriteFile(name, data, perm) +} +func (OsFs) Open(name string) (*os.File, error) { return os.Open(name) } +func (OsFs) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) { + return os.OpenFile(name, flag, perm) +} +func (OsFs) CreateTemp(dir, pattern string) (*os.File, error) { return os.CreateTemp(dir, pattern) } + +// Directory/File management +func (OsFs) MkdirAll(path string, perm fs.FileMode) error { return os.MkdirAll(path, perm) } +func (OsFs) ReadDir(name string) ([]os.DirEntry, error) { return os.ReadDir(name) } +func (OsFs) Remove(name string) error { return os.Remove(name) } +func (OsFs) Rename(oldpath, newpath string) error { return os.Rename(oldpath, newpath) } diff --git a/internal/vfs/osfs_test.go b/internal/vfs/osfs_test.go new file mode 100644 index 000000000..83dccd1a8 --- /dev/null +++ b/internal/vfs/osfs_test.go @@ -0,0 +1,102 @@ +package vfs + +import ( + "os" + "path/filepath" + "testing" +) + +func TestOsFsImplementsFS(t *testing.T) { + var _ FS = OsFs{} +} + +func TestDefaultFSIsOsFs(t *testing.T) { + if _, ok := DefaultFS.(OsFs); !ok { + t.Fatal("DefaultFS should be OsFs") + } +} + +func TestOsFsBasicOperations(t *testing.T) { + fs := OsFs{} + dir := t.TempDir() + + // MkdirAll + sub := filepath.Join(dir, "a", "b") + if err := fs.MkdirAll(sub, 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + // WriteFile + ReadFile + p := filepath.Join(sub, "test.txt") + if err := fs.WriteFile(p, []byte("hello"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + data, err := fs.ReadFile(p) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(data) != "hello" { + t.Fatalf("ReadFile got %q, want %q", data, "hello") + } + + // Stat + info, err := fs.Stat(p) + if err != nil { + t.Fatalf("Stat: %v", err) + } + if info.Name() != "test.txt" { + t.Fatalf("Stat name got %q", info.Name()) + } + + // Lstat + info, err = fs.Lstat(p) + if err != nil { + t.Fatalf("Lstat: %v", err) + } + if info.Name() != "test.txt" { + t.Fatalf("Lstat name got %q", info.Name()) + } + + // Rename + p2 := filepath.Join(sub, "test2.txt") + if err := fs.Rename(p, p2); err != nil { + t.Fatalf("Rename: %v", err) + } + + // Open + f, err := fs.Open(p2) + if err != nil { + t.Fatalf("Open: %v", err) + } + f.Close() + + // OpenFile + f, err = fs.OpenFile(p2, os.O_RDONLY, 0) + if err != nil { + t.Fatalf("OpenFile: %v", err) + } + f.Close() + + // CreateTemp + f, err = fs.CreateTemp(dir, "tmp-*") + if err != nil { + t.Fatalf("CreateTemp: %v", err) + } + tmpName := f.Name() + f.Close() + + // Remove + if err := fs.Remove(tmpName); err != nil { + t.Fatalf("Remove: %v", err) + } + + // Getwd + if _, err := fs.Getwd(); err != nil { + t.Fatalf("Getwd: %v", err) + } + + // UserHomeDir + if _, err := fs.UserHomeDir(); err != nil { + t.Fatalf("UserHomeDir: %v", err) + } +} diff --git a/main.go b/main.go index 568ddfd9a..02469bd7a 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,8 @@ import ( "os" "github.com/larksuite/cli/cmd" + + _ "github.com/larksuite/cli/extension/credential/env" // activate env credential provider ) func main() { diff --git a/shortcuts/base/base_advperm_test.go b/shortcuts/base/base_advperm_test.go index db0be5791..9b34a3522 100644 --- a/shortcuts/base/base_advperm_test.go +++ b/shortcuts/base/base_advperm_test.go @@ -129,7 +129,6 @@ func TestBaseAdvpermMetadata(t *testing.T) { func TestBaseAdvpermEnableExecute(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/advperm/enable", @@ -150,7 +149,6 @@ func TestBaseAdvpermEnableExecute(t *testing.T) { func TestBaseAdvpermDisableExecute(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/advperm/enable", @@ -175,7 +173,6 @@ func TestBaseAdvpermDisableExecute(t *testing.T) { func TestBaseAdvpermEnableExecuteTransportError(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/advperm/enable", @@ -190,7 +187,6 @@ func TestBaseAdvpermEnableExecuteTransportError(t *testing.T) { func TestBaseAdvpermEnableExecuteAPIError(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/advperm/enable", @@ -207,7 +203,6 @@ func TestBaseAdvpermEnableExecuteAPIError(t *testing.T) { func TestBaseAdvpermDisableExecuteTransportError(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/advperm/enable", @@ -222,7 +217,6 @@ func TestBaseAdvpermDisableExecuteTransportError(t *testing.T) { func TestBaseAdvpermDisableExecuteAPIError(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/advperm/enable", diff --git a/shortcuts/base/base_dashboard_execute_test.go b/shortcuts/base/base_dashboard_execute_test.go index 5fbf43afa..d69ecf261 100644 --- a/shortcuts/base/base_dashboard_execute_test.go +++ b/shortcuts/base/base_dashboard_execute_test.go @@ -15,7 +15,6 @@ import ( func TestBaseDashboardExecuteList(t *testing.T) { t.Run("single page", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/dashboards", @@ -44,7 +43,6 @@ func TestBaseDashboardExecuteList(t *testing.T) { func TestBaseDashboardExecuteGet(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_001", @@ -72,7 +70,6 @@ func TestBaseDashboardExecuteGet(t *testing.T) { func TestBaseDashboardExecuteCreate(t *testing.T) { t.Run("name only", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/dashboards", @@ -95,7 +92,6 @@ func TestBaseDashboardExecuteCreate(t *testing.T) { t.Run("with theme", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/dashboards", @@ -121,7 +117,6 @@ func TestBaseDashboardExecuteCreate(t *testing.T) { func TestBaseDashboardExecuteUpdate(t *testing.T) { t.Run("update name", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PATCH", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_001", @@ -144,7 +139,6 @@ func TestBaseDashboardExecuteUpdate(t *testing.T) { t.Run("update theme", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PATCH", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_001", @@ -169,7 +163,6 @@ func TestBaseDashboardExecuteUpdate(t *testing.T) { func TestBaseDashboardExecuteDelete(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "DELETE", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_001", @@ -189,7 +182,6 @@ func TestBaseDashboardExecuteDelete(t *testing.T) { func TestBaseDashboardBlockExecuteList(t *testing.T) { t.Run("single page", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_001/blocks", @@ -219,7 +211,6 @@ func TestBaseDashboardBlockExecuteList(t *testing.T) { func TestBaseDashboardBlockExecuteGet(t *testing.T) { t.Run("basic", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_001/blocks/blk_a", @@ -248,7 +239,6 @@ func TestBaseDashboardBlockExecuteGet(t *testing.T) { t.Run("with user-id-type", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "user_id_type=union_id", @@ -274,7 +264,6 @@ func TestBaseDashboardBlockExecuteGet(t *testing.T) { func TestBaseDashboardBlockExecuteCreate(t *testing.T) { t.Run("with data-config", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_001/blocks", @@ -306,7 +295,6 @@ func TestBaseDashboardBlockExecuteCreate(t *testing.T) { t.Run("statistics with series", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_001/blocks", @@ -333,7 +321,6 @@ func TestBaseDashboardBlockExecuteCreate(t *testing.T) { t.Run("without data-config", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_001/blocks", @@ -370,7 +357,6 @@ func TestBaseDashboardBlockExecuteCreate(t *testing.T) { func TestBaseDashboardBlockExecuteUpdate(t *testing.T) { t.Run("update name and data-config", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PATCH", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_001/blocks/blk_a", @@ -401,7 +387,6 @@ func TestBaseDashboardBlockExecuteUpdate(t *testing.T) { t.Run("update name only", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PATCH", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_001/blocks/blk_a", @@ -437,7 +422,6 @@ func TestBaseDashboardBlockExecuteUpdate(t *testing.T) { func TestBaseDashboardBlockExecuteDelete(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "DELETE", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_001/blocks/blk_a", @@ -592,7 +576,6 @@ func TestBaseDashboardBlockCreate_ValidateFails(t *testing.T) { func TestBaseDashboardBlockCreate_NoValidateFlagAllocs(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{Method: "POST", URL: "/open-apis/base/v3/bases/app_x/dashboards/dsh_1/blocks", Body: map[string]interface{}{"code": 0, "data": map[string]interface{}{"block_id": "blk_ok", "name": "OK", "type": "column"}}, }) diff --git a/shortcuts/base/base_execute_test.go b/shortcuts/base/base_execute_test.go index 34128a937..46ec996d9 100644 --- a/shortcuts/base/base_execute_test.go +++ b/shortcuts/base/base_execute_test.go @@ -30,18 +30,6 @@ func newExecuteFactory(t *testing.T) (*cmdutil.Factory, *bytes.Buffer, *httpmock return factory, stdout, reg } -func registerTokenStub(reg *httpmock.Registry) { - reg.Register(&httpmock.Stub{ - Method: "POST", - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, - "tenant_access_token": "t-test-token", - "expire": 7200, - }, - }) -} - func withBaseWorkingDir(t *testing.T, dir string) { t.Helper() cwd, err := os.Getwd() @@ -72,7 +60,6 @@ func runShortcut(t *testing.T, shortcut common.Shortcut, args []string, factory func TestBaseWorkspaceExecuteCreate(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases", @@ -92,7 +79,6 @@ func TestBaseWorkspaceExecuteCreate(t *testing.T) { func TestBaseWorkspaceExecuteGetAndCopy(t *testing.T) { t.Run("get", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x", @@ -111,7 +97,6 @@ func TestBaseWorkspaceExecuteGetAndCopy(t *testing.T) { t.Run("copy", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_src/copy", @@ -132,7 +117,6 @@ func TestBaseWorkspaceExecuteGetAndCopy(t *testing.T) { func TestBaseHistoryExecute(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/record_history", @@ -151,7 +135,6 @@ func TestBaseHistoryExecute(t *testing.T) { func TestBaseFieldExecuteUpdate(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/fields/fld_x", @@ -170,7 +153,6 @@ func TestBaseFieldExecuteUpdate(t *testing.T) { func TestBaseTableExecuteCreate(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/tables", @@ -214,7 +196,6 @@ func TestBaseTableExecuteCreate(t *testing.T) { func TestBaseTableExecuteUpdate(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PATCH", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x", @@ -233,7 +214,6 @@ func TestBaseTableExecuteUpdate(t *testing.T) { func TestBaseRecordExecuteUpsertUpdate(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PATCH", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/records/rec_x", @@ -252,7 +232,6 @@ func TestBaseRecordExecuteUpsertUpdate(t *testing.T) { func TestBaseViewExecuteRename(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PATCH", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_x", @@ -272,7 +251,6 @@ func TestBaseViewExecuteRename(t *testing.T) { func TestBaseViewExecutePropertyActions(t *testing.T) { t.Run("set-group", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_x/group", @@ -291,7 +269,6 @@ func TestBaseViewExecutePropertyActions(t *testing.T) { t.Run("set-sort", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_x/sort", @@ -313,7 +290,6 @@ func TestBaseViewExecutePropertyActions(t *testing.T) { func TestBaseFieldExecuteCRUD(t *testing.T) { t.Run("list", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "limit=1&offset=0", @@ -334,7 +310,6 @@ func TestBaseFieldExecuteCRUD(t *testing.T) { t.Run("get", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/fields/fld_x", @@ -353,7 +328,6 @@ func TestBaseFieldExecuteCRUD(t *testing.T) { t.Run("create", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/fields", @@ -372,7 +346,6 @@ func TestBaseFieldExecuteCRUD(t *testing.T) { t.Run("delete", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "DELETE", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/fields/fld_x", @@ -390,7 +363,6 @@ func TestBaseFieldExecuteCRUD(t *testing.T) { func TestBaseTableExecuteReadAndDelete(t *testing.T) { t.Run("list", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "limit=1&offset=0", @@ -411,7 +383,6 @@ func TestBaseTableExecuteReadAndDelete(t *testing.T) { t.Run("list-http-404", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables", @@ -429,7 +400,6 @@ func TestBaseTableExecuteReadAndDelete(t *testing.T) { t.Run("get", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x", @@ -464,7 +434,6 @@ func TestBaseTableExecuteReadAndDelete(t *testing.T) { t.Run("delete", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "DELETE", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x", @@ -482,7 +451,6 @@ func TestBaseTableExecuteReadAndDelete(t *testing.T) { func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { t.Run("list", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "limit=1&offset=0", @@ -505,7 +473,6 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { t.Run("list new shape", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "limit=1&offset=0", @@ -529,7 +496,6 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { t.Run("get", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/records/rec_1", @@ -552,7 +518,6 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { t.Run("get passthrough fallback", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/records/rec_2", @@ -571,7 +536,6 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { t.Run("create", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/records", @@ -590,7 +554,6 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { t.Run("delete", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "DELETE", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/records/rec_1", @@ -606,7 +569,6 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { t.Run("upload attachment", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) tmpFile, err := os.CreateTemp(t.TempDir(), "base-attachment-*.txt") if err != nil { @@ -724,7 +686,6 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { t.Run("upload attachment rejects non-attachment field", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) tmpFile, err := os.CreateTemp(t.TempDir(), "base-not-attachment-*.txt") if err != nil { @@ -767,7 +728,6 @@ func TestBaseRecordExecuteReadCreateDelete(t *testing.T) { func TestBaseViewExecuteReadCreateDeleteAndFilter(t *testing.T) { t.Run("list", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "limit=1&offset=0", @@ -786,7 +746,6 @@ func TestBaseViewExecuteReadCreateDeleteAndFilter(t *testing.T) { t.Run("get", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_1", @@ -805,7 +764,6 @@ func TestBaseViewExecuteReadCreateDeleteAndFilter(t *testing.T) { t.Run("create", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views", @@ -824,7 +782,6 @@ func TestBaseViewExecuteReadCreateDeleteAndFilter(t *testing.T) { t.Run("delete", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "DELETE", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_1", @@ -840,7 +797,6 @@ func TestBaseViewExecuteReadCreateDeleteAndFilter(t *testing.T) { t.Run("set-filter", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_1/filter", @@ -861,7 +817,6 @@ func TestBaseViewExecuteReadCreateDeleteAndFilter(t *testing.T) { func TestBaseTableExecuteListFallbackShapes(t *testing.T) { t.Run("items-payload", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables", @@ -880,7 +835,6 @@ func TestBaseTableExecuteListFallbackShapes(t *testing.T) { t.Run("single-object-payload", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables", @@ -900,7 +854,6 @@ func TestBaseTableExecuteListFallbackShapes(t *testing.T) { func TestBaseRecordExecuteListWithViewPagination(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "view_id=vew_x", @@ -923,7 +876,6 @@ func TestBaseRecordExecuteListWithViewPagination(t *testing.T) { func TestBaseHistoryExecuteWithLinkFieldLimit(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "max_version=2", @@ -942,7 +894,6 @@ func TestBaseHistoryExecuteWithLinkFieldLimit(t *testing.T) { func TestBaseFieldExecuteSearchOptions(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/fields/fld_amount/options", @@ -962,7 +913,6 @@ func TestBaseFieldExecuteSearchOptions(t *testing.T) { func TestBaseViewExecutePropertyGettersAndExtendedSetters(t *testing.T) { t.Run("get-group", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_x/group", Body: map[string]interface{}{"code": 0, "data": []interface{}{map[string]interface{}{"field": "fld_status", "desc": false}}}}) if err := runShortcut(t, BaseViewGetGroup, []string{"+view-get-group", "--base-token", "app_x", "--table-id", "tbl_x", "--view-id", "vew_x"}, factory, stdout); err != nil { t.Fatalf("err=%v", err) @@ -974,7 +924,6 @@ func TestBaseViewExecutePropertyGettersAndExtendedSetters(t *testing.T) { t.Run("get-filter", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_x/filter", Body: map[string]interface{}{"code": 0, "data": map[string]interface{}{"conditions": []interface{}{map[string]interface{}{"field_name": "Status"}}}}}) if err := runShortcut(t, BaseViewGetFilter, []string{"+view-get-filter", "--base-token", "app_x", "--table-id", "tbl_x", "--view-id", "vew_x"}, factory, stdout); err != nil { t.Fatalf("err=%v", err) @@ -986,7 +935,6 @@ func TestBaseViewExecutePropertyGettersAndExtendedSetters(t *testing.T) { t.Run("get-sort", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_x/sort", Body: map[string]interface{}{"code": 0, "data": []interface{}{map[string]interface{}{"field": "fld_priority", "desc": true}}}}) if err := runShortcut(t, BaseViewGetSort, []string{"+view-get-sort", "--base-token", "app_x", "--table-id", "tbl_x", "--view-id", "vew_x"}, factory, stdout); err != nil { t.Fatalf("err=%v", err) @@ -998,7 +946,6 @@ func TestBaseViewExecutePropertyGettersAndExtendedSetters(t *testing.T) { t.Run("get-timebar", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_time/timebar", Body: map[string]interface{}{"code": 0, "data": map[string]interface{}{"start_time": "fld_start", "end_time": "fld_end", "title": "fld_title"}}}) if err := runShortcut(t, BaseViewGetTimebar, []string{"+view-get-timebar", "--base-token", "app_x", "--table-id", "tbl_x", "--view-id", "vew_time"}, factory, stdout); err != nil { t.Fatalf("err=%v", err) @@ -1010,7 +957,6 @@ func TestBaseViewExecutePropertyGettersAndExtendedSetters(t *testing.T) { t.Run("set-timebar", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_time/timebar", Body: map[string]interface{}{"code": 0, "data": map[string]interface{}{"start_time": "fld_start", "end_time": "fld_end", "title": "fld_title"}}}) args := []string{"+view-set-timebar", "--base-token", "app_x", "--table-id", "tbl_x", "--view-id", "vew_time", "--json", `{"start_time":"fld_start","end_time":"fld_end","title":"fld_title"}`} if err := runShortcut(t, BaseViewSetTimebar, args, factory, stdout); err != nil { @@ -1023,7 +969,6 @@ func TestBaseViewExecutePropertyGettersAndExtendedSetters(t *testing.T) { t.Run("get-card", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_card/card", Body: map[string]interface{}{"code": 0, "data": map[string]interface{}{"cover_field": "fld_cover"}}}) if err := runShortcut(t, BaseViewGetCard, []string{"+view-get-card", "--base-token", "app_x", "--table-id", "tbl_x", "--view-id", "vew_card"}, factory, stdout); err != nil { t.Fatalf("err=%v", err) @@ -1035,7 +980,6 @@ func TestBaseViewExecutePropertyGettersAndExtendedSetters(t *testing.T) { t.Run("set-card", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/views/vew_card/card", Body: map[string]interface{}{"code": 0, "data": map[string]interface{}{"cover_field": "fld_cover"}}}) if err := runShortcut(t, BaseViewSetCard, []string{"+view-set-card", "--base-token", "app_x", "--table-id", "tbl_x", "--view-id", "vew_card", "--json", `{"cover_field":"fld_cover"}`}, factory, stdout); err != nil { t.Fatalf("err=%v", err) diff --git a/shortcuts/base/base_form_execute_test.go b/shortcuts/base/base_form_execute_test.go index 668a830fc..cafec48d4 100644 --- a/shortcuts/base/base_form_execute_test.go +++ b/shortcuts/base/base_form_execute_test.go @@ -13,7 +13,6 @@ import ( func TestBaseFormExecuteList(t *testing.T) { t.Run("single page", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/forms", @@ -39,7 +38,6 @@ func TestBaseFormExecuteList(t *testing.T) { t.Run("auto pagination", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) // First page: has_more=true reg.Register(&httpmock.Stub{ Method: "GET", @@ -86,7 +84,6 @@ func TestBaseFormExecuteList(t *testing.T) { func TestBaseFormExecuteGet(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/forms/vew_form1", @@ -110,7 +107,6 @@ func TestBaseFormExecuteGet(t *testing.T) { func TestBaseFormExecuteCreate(t *testing.T) { t.Run("name only", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/forms", @@ -133,7 +129,6 @@ func TestBaseFormExecuteCreate(t *testing.T) { t.Run("with description", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/forms", @@ -158,7 +153,6 @@ func TestBaseFormExecuteCreate(t *testing.T) { t.Run("with description link", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/forms", @@ -185,7 +179,6 @@ func TestBaseFormExecuteCreate(t *testing.T) { func TestBaseFormExecuteUpdate(t *testing.T) { t.Run("update name", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PATCH", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/forms/vew_form1", @@ -208,7 +201,6 @@ func TestBaseFormExecuteUpdate(t *testing.T) { t.Run("update with description", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PATCH", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/forms/vew_form1", @@ -234,7 +226,6 @@ func TestBaseFormExecuteUpdate(t *testing.T) { func TestBaseFormExecuteDelete(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "DELETE", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/forms/vew_form1", @@ -250,7 +241,6 @@ func TestBaseFormExecuteDelete(t *testing.T) { func TestBaseFormQuestionsExecuteList(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/forms/vew_form1/questions", @@ -276,7 +266,6 @@ func TestBaseFormQuestionsExecuteList(t *testing.T) { func TestBaseFormQuestionsExecuteCreate(t *testing.T) { t.Run("create questions", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/forms/vew_form1/questions", @@ -311,7 +300,6 @@ func TestBaseFormQuestionsExecuteCreate(t *testing.T) { func TestBaseFormQuestionsExecuteUpdate(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PATCH", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/forms/vew_form1/questions", @@ -337,7 +325,6 @@ func TestBaseFormQuestionsExecuteUpdate(t *testing.T) { func TestBaseFormQuestionsExecuteDelete(t *testing.T) { t.Run("delete questions", func(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "DELETE", URL: "/open-apis/base/v3/bases/app_x/tables/tbl_x/forms/vew_form1/questions", diff --git a/shortcuts/base/base_role_test.go b/shortcuts/base/base_role_test.go index d71e4c5d7..1b998cea2 100644 --- a/shortcuts/base/base_role_test.go +++ b/shortcuts/base/base_role_test.go @@ -243,7 +243,6 @@ func TestBaseRoleShortcutMetadata(t *testing.T) { func TestBaseRoleCreateExecute(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/roles", @@ -267,7 +266,6 @@ func TestBaseRoleCreateExecute(t *testing.T) { func TestBaseRoleDeleteExecute(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "DELETE", URL: "/open-apis/base/v3/bases/app_x/roles/rol_1", @@ -288,7 +286,6 @@ func TestBaseRoleDeleteExecute(t *testing.T) { func TestBaseRoleGetExecute(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/roles/rol_1", @@ -316,7 +313,6 @@ func TestBaseRoleGetExecute(t *testing.T) { func TestBaseRoleListExecute(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/roles", @@ -343,7 +339,6 @@ func TestBaseRoleListExecute(t *testing.T) { func TestBaseRoleUpdateExecute(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/roles/rol_1", @@ -371,7 +366,6 @@ func TestBaseRoleUpdateExecute(t *testing.T) { func TestBaseRoleCreateExecuteAPIError(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/roles", @@ -388,7 +382,6 @@ func TestBaseRoleCreateExecuteAPIError(t *testing.T) { func TestBaseRoleListExecuteTransportError(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/roles", @@ -403,7 +396,6 @@ func TestBaseRoleListExecuteTransportError(t *testing.T) { func TestBaseRoleListExecuteAPIError(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/roles", @@ -420,7 +412,6 @@ func TestBaseRoleListExecuteAPIError(t *testing.T) { func TestBaseRoleDeleteExecuteAPIError(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "DELETE", URL: "/open-apis/base/v3/bases/app_x/roles/rol_1", @@ -437,7 +428,6 @@ func TestBaseRoleDeleteExecuteAPIError(t *testing.T) { func TestBaseRoleUpdateExecuteAPIError(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PUT", URL: "/open-apis/base/v3/bases/app_x/roles/rol_1", @@ -454,7 +444,6 @@ func TestBaseRoleUpdateExecuteAPIError(t *testing.T) { func TestBaseRoleGetExecuteBusinessError(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/roles/rol_bad", diff --git a/shortcuts/base/base_shortcut_helpers.go b/shortcuts/base/base_shortcut_helpers.go index 86e30713f..06b21e005 100644 --- a/shortcuts/base/base_shortcut_helpers.go +++ b/shortcuts/base/base_shortcut_helpers.go @@ -6,10 +6,10 @@ package base import ( "encoding/json" "fmt" - "os" "strings" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" ) @@ -33,7 +33,7 @@ func loadJSONInput(raw string, flagName string) (string, error) { if err != nil { return "", common.FlagErrorf("--%s invalid JSON file path %q: %v", flagName, path, err) } - data, err := os.ReadFile(safePath) + data, err := vfs.ReadFile(safePath) if err != nil { return "", common.FlagErrorf("--%s cannot read JSON file %q: %v", flagName, path, err) } diff --git a/shortcuts/base/record_upload_attachment.go b/shortcuts/base/record_upload_attachment.go index 689594213..31d0ffc51 100644 --- a/shortcuts/base/record_upload_attachment.go +++ b/shortcuts/base/record_upload_attachment.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "net/http" - "os" "path/filepath" "strings" @@ -18,6 +17,7 @@ import ( "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/util" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" ) @@ -97,7 +97,7 @@ func executeRecordUploadAttachment(runtime *common.RuntimeContext) error { } filePath = safeFilePath - fileInfo, err := os.Stat(filePath) + fileInfo, err := vfs.Stat(filePath) if err != nil { return output.ErrValidation("file not found: %s", filePath) } @@ -209,7 +209,7 @@ func normalizeAttachmentForPatch(attachment map[string]interface{}) map[string]i } func uploadAttachmentToBase(runtime *common.RuntimeContext, filePath, fileName, baseToken string, fileSize int64) (map[string]interface{}, error) { - f, err := os.Open(filePath) + f, err := vfs.Open(filePath) if err != nil { return nil, output.ErrValidation("cannot open file: %v", err) } diff --git a/shortcuts/base/workflow_execute_test.go b/shortcuts/base/workflow_execute_test.go index be2fd3f82..ebb82f5da 100644 --- a/shortcuts/base/workflow_execute_test.go +++ b/shortcuts/base/workflow_execute_test.go @@ -12,7 +12,6 @@ import ( func TestBaseWorkflowExecuteGet(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/base/v3/bases/app_x/workflows/wkf_1", @@ -31,7 +30,6 @@ func TestBaseWorkflowExecuteGet(t *testing.T) { func TestBaseWorkflowExecuteGetWithUserIDType(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "user_id_type=open_id", @@ -67,7 +65,6 @@ func TestBaseWorkflowExecuteGetValidate(t *testing.T) { func TestBaseWorkflowExecuteCreate(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "POST", URL: "/open-apis/base/v3/bases/app_x/workflows", @@ -103,7 +100,6 @@ func TestBaseWorkflowExecuteCreateValidate(t *testing.T) { func TestBaseWorkflowExecuteDisable(t *testing.T) { factory, stdout, reg := newExecuteFactory(t) - registerTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "PATCH", URL: "/open-apis/base/v3/bases/app_x/workflows/wkf_1/disable", diff --git a/shortcuts/calendar/calendar_test.go b/shortcuts/calendar/calendar_test.go index b018af3b3..69f26ffb3 100644 --- a/shortcuts/calendar/calendar_test.go +++ b/shortcuts/calendar/calendar_test.go @@ -32,13 +32,6 @@ func warmTokenCache(t *testing.T) { t.Helper() warmOnce.Do(func() { f, _, _, reg := cmdutil.TestFactory(t, defaultConfig()) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token", "expire": 7200, - }, - }) reg.Register(&httpmock.Stub{ URL: "/open-apis/test/v1/warm", Body: map[string]interface{}{"code": 0, "msg": "ok", "data": map[string]interface{}{}}, diff --git a/shortcuts/common/helpers.go b/shortcuts/common/helpers.go index 0a5b8e7a9..a47043468 100644 --- a/shortcuts/common/helpers.go +++ b/shortcuts/common/helpers.go @@ -12,6 +12,7 @@ import ( "os" "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/vfs" ) // MultipartWriter wraps multipart.Writer for file uploads. @@ -42,7 +43,7 @@ func EnsureWritableFile(path string, overwrite bool) error { if overwrite { return nil } - if _, err := os.Stat(path); err == nil { + if _, err := vfs.Stat(path); err == nil { return output.ErrValidation("output file already exists: %s (use --overwrite to replace)", path) } else if !errors.Is(err, os.ErrNotExist) { return output.Errorf(output.ExitInternal, "io", "cannot access output path %s: %v", path, err) diff --git a/shortcuts/common/runner.go b/shortcuts/common/runner.go index 7623c8d47..1141b0bc5 100644 --- a/shortcuts/common/runner.go +++ b/shortcuts/common/runner.go @@ -4,16 +4,14 @@ package common import ( - "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" - "net/url" + "slices" "strings" - "time" "github.com/google/uuid" lark "github.com/larksuite/oapi-sdk-go/v3" @@ -23,7 +21,10 @@ import ( "github.com/larksuite/cli/internal/client" "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/spf13/cobra" ) @@ -90,29 +91,14 @@ func (ctx *RuntimeContext) getAPIClient() (*client.APIClient, error) { // For user: returns user access token (with auto-refresh). // For bot: returns tenant access token. func (ctx *RuntimeContext) AccessToken() (string, error) { - if ctx.IsBot() { - ac, err := ctx.getAPIClient() - if err != nil { - return "", output.ErrAuth("failed to get SDK: %s", err) - } - tatResp, err := ac.SDK.GetTenantAccessTokenBySelfBuiltApp(ctx.ctx, &larkcore.SelfBuiltTenantAccessTokenReq{ - AppID: ctx.Config.AppID, - AppSecret: ctx.Config.AppSecret, - }) - if err != nil { - return "", output.ErrAuth("failed to get tenant access token: %s", err) - } - return tatResp.TenantAccessToken, nil - } - httpClient, err := ctx.Factory.HttpClient() - if err != nil { - return "", output.ErrAuth("failed to get HTTP client: %s", err) - } - token, err := auth.GetValidAccessToken(httpClient, auth.NewUATCallOptions(ctx.Config, ctx.IO().ErrOut)) + result, err := ctx.Factory.Credential.ResolveToken(ctx.ctx, credential.NewTokenSpec(ctx.As(), ctx.Config.AppID)) if err != nil { return "", output.ErrAuth("failed to get access token: %s", err) } - return token, nil + if result == nil || result.Token == "" { + return "", output.ErrAuth("no access token available for %s", ctx.As()) + } + return result.Token, nil } // LarkSDK returns the eagerly-initialized Lark SDK client. @@ -241,142 +227,22 @@ func (ctx *RuntimeContext) DoAPIAsBot(req *larkcore.ApiReq, opts ...larkcore.Req return ac.DoSDKRequest(ctx.ctx, req, core.AsBot, opts...) } -type cancelOnCloseReadCloser struct { - io.ReadCloser - cancel context.CancelFunc -} - -func (r *cancelOnCloseReadCloser) Close() error { - err := r.ReadCloser.Close() - if r.cancel != nil { - r.cancel() - } - return err -} - -// DoAPIStream executes a streaming HTTP request against the Lark OpenAPI endpoint -// while preserving the framework's auth resolution, shortcut headers, and security headers. -func (ctx *RuntimeContext) DoAPIStream(callCtx context.Context, req *larkcore.ApiReq, timeout time.Duration, opts ...larkcore.RequestOptionFunc) (*http.Response, error) { - httpClient, err := ctx.Factory.HttpClient() - if err != nil { - return nil, output.ErrNetwork("stream request failed: %s", err) - } - - streamingClient := *httpClient - if timeout > 0 { - streamingClient.Timeout = timeout - } - - requestCtx := callCtx - cancel := func() {} - if timeout > 0 { - if _, hasDeadline := callCtx.Deadline(); !hasDeadline { - requestCtx, cancel = context.WithTimeout(callCtx, timeout) - } - } - - var option larkcore.RequestOption - for _, opt := range opts { - opt(&option) - } - if option.Header == nil { - option.Header = make(http.Header) - } - if shortcutHeaders := cmdutil.ShortcutHeaderOpts(ctx.ctx); shortcutHeaders != nil { - shortcutHeaders(&option) - } - - accessToken, err := ctx.AccessToken() - if err != nil { - cancel() - return nil, err - } - - requestURL, err := buildStreamRequestURL(ctx.Config.Brand, req) - if err != nil { - cancel() - return nil, err - } - bodyReader, contentType, err := buildStreamRequestBody(req.Body) +// DoAPIStream executes a streaming HTTP request via APIClient.DoStream. +// Unlike DoAPI (which buffers the full body via the SDK), DoAPIStream returns +// a live *http.Response whose Body is an io.Reader for streaming consumption. +// HTTP errors (status >= 400) are handled internally by DoStream. +func (ctx *RuntimeContext) DoAPIStream(callCtx context.Context, req *larkcore.ApiReq, opts ...client.Option) (*http.Response, error) { + ac, err := ctx.getAPIClient() if err != nil { - cancel() return nil, err } - - httpReq, err := http.NewRequestWithContext(requestCtx, req.HttpMethod, requestURL, bodyReader) - if err != nil { - cancel() - return nil, output.ErrNetwork("stream request failed: %s", err) - } - for key, values := range cmdutil.BaseSecurityHeaders() { - for _, value := range values { - httpReq.Header.Add(key, value) - } - } - for key, values := range option.Header { - for _, value := range values { - httpReq.Header.Add(key, value) - } - } - if contentType != "" { - httpReq.Header.Set("Content-Type", contentType) - } - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - - resp, err := streamingClient.Do(httpReq) - if err != nil { - cancel() - return nil, output.ErrNetwork("stream request failed: %s", err) - } - resp.Body = &cancelOnCloseReadCloser{ReadCloser: resp.Body, cancel: cancel} - return resp, nil -} - -func buildStreamRequestURL(brand core.LarkBrand, req *larkcore.ApiReq) (string, error) { - requestURL := req.ApiPath - if !strings.HasPrefix(requestURL, "http://") && !strings.HasPrefix(requestURL, "https://") { - var pathSegs []string - for _, segment := range strings.Split(req.ApiPath, "/") { - if !strings.HasPrefix(segment, ":") { - pathSegs = append(pathSegs, segment) - continue - } - pathKey := strings.TrimPrefix(segment, ":") - pathValue, ok := req.PathParams[pathKey] - if !ok { - return "", output.ErrValidation("missing path param %q for %s", pathKey, req.ApiPath) - } - if pathValue == "" { - return "", output.ErrValidation("empty path param %q for %s", pathKey, req.ApiPath) - } - pathSegs = append(pathSegs, url.PathEscape(pathValue)) - } - endpoints := core.ResolveEndpoints(brand) - requestURL = strings.TrimRight(endpoints.Open, "/") + strings.Join(pathSegs, "/") - } - if query := req.QueryParams.Encode(); query != "" { - requestURL += "?" + query + base := []client.Option{ + client.WithHeaders(cmdutil.BaseSecurityHeaders()), } - return requestURL, nil -} - -func buildStreamRequestBody(body interface{}) (io.Reader, string, error) { - switch typed := body.(type) { - case nil: - return nil, "", nil - case io.Reader: - return typed, "", nil - case []byte: - return bytes.NewReader(typed), "", nil - case string: - return strings.NewReader(typed), "text/plain; charset=utf-8", nil - default: - payload, err := json.Marshal(typed) - if err != nil { - return nil, "", output.Errorf(output.ExitInternal, "api_error", "failed to encode request body: %s", err) - } - return bytes.NewReader(payload), "application/json", nil + if h := cmdutil.ShortcutHeaders(ctx.ctx); h != nil { + base = append(base, client.WithHeaders(h)) } + return ac.DoStream(callCtx, req, ctx.As(), append(base, opts...)...) } // DoAPIJSON calls the Lark API via DoAPI, parses the JSON response envelope, @@ -478,15 +344,21 @@ func (ctx *RuntimeContext) OutFormat(data interface{}, meta *output.Meta, pretty // ── Scope pre-check ── -// checkScopePrereqs performs a fast local check: does the stored token -// contain all scopes declared by the shortcut? Returns the missing ones. -// If no token is stored, returns nil (let the normal auth flow handle it). -func checkScopePrereqs(appID, userOpenId string, required []string) []string { - stored := auth.GetStoredToken(appID, userOpenId) - if stored == nil { - return nil // no token yet — auth flow will catch this later +// checkScopePrereqs performs a fast local check: does the token +// contain all scopes declared by the shortcut? Returns the missing ones. +// If scope data is unavailable, returns nil (let the API call handle it). +func checkScopePrereqs(f *cmdutil.Factory, ctx context.Context, appID string, identity core.Identity, required []string) ([]string, error) { + result, err := f.Credential.ResolveToken(ctx, credential.NewTokenSpec(identity, appID)) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + return nil, nil } - return auth.MissingScopes(stored.Scope, required) + if result == nil || result.Scopes == "" { + return nil, nil + } + return auth.MissingScopes(result.Scopes, required), nil } // enhancePermissionError enriches a permission / auth error with the @@ -544,6 +416,7 @@ func (s Shortcut) mountDeclarative(parent *cobra.Command, f *cmdutil.Factory) { return runShortcut(cmd, f, &shortcut, botOnly) }, } + cmdutil.SetSupportedIdentities(cmd, shortcut.AuthTypes) registerShortcutFlags(cmd, &shortcut) cmdutil.SetTips(cmd, shortcut.Tips) parent.AddCommand(cmd) @@ -557,14 +430,14 @@ func runShortcut(cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut, botOnly bo return err } - config, err := f.ResolveConfig(as) + config, err := f.Config() if err != nil { return err } // Identity info is now included in the JSON envelope; skip stderr printing. // cmdutil.PrintIdentity(f.IOStreams.ErrOut, as, config, false) - if err := checkShortcutScopes(as, config, s.ScopesForIdentity(string(as))); err != nil { + if err := checkShortcutScopes(f, cmd.Context(), as, config, s.ScopesForIdentity(string(as))); err != nil { return err } @@ -576,6 +449,9 @@ func runShortcut(cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut, botOnly bo if err := validateEnumFlags(rctx, s.Flags); err != nil { return err } + if err := resolveInputFlags(rctx, s.Flags); err != nil { + return err + } if err := output.ValidateJqFlags(rctx.JqExpr, "", rctx.Format); err != nil { return err } @@ -604,7 +480,11 @@ func runShortcut(cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut, botOnly bo func resolveShortcutIdentity(cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut) (core.Identity, error) { // Step 1: determine identity (--as > default-as > auto-detect). asFlag, _ := cmd.Flags().GetString("as") - as := f.ResolveAs(cmd, core.Identity(asFlag)) + as := f.ResolveAs(cmd.Context(), cmd, core.Identity(asFlag)) + + if err := f.CheckStrictMode(cmd.Context(), as); err != nil { + return "", err + } // Step 2: check if this shortcut supports the resolved identity. if err := f.CheckIdentity(as, s.AuthTypes); err != nil { @@ -613,11 +493,14 @@ func resolveShortcutIdentity(cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut return as, nil } -func checkShortcutScopes(as core.Identity, config *core.CliConfig, scopes []string) error { - if as != core.AsUser || len(scopes) == 0 || config.UserOpenId == "" { +func checkShortcutScopes(f *cmdutil.Factory, ctx context.Context, as core.Identity, config *core.CliConfig, scopes []string) error { + if len(scopes) == 0 { return nil } - missing := checkScopePrereqs(config.AppID, config.UserOpenId, scopes) + missing, err := checkScopePrereqs(f, ctx, config.AppID, as, scopes) + if err != nil { + return err + } if len(missing) == 0 { return nil } @@ -644,6 +527,69 @@ func newRuntimeContext(cmd *cobra.Command, f *cmdutil.Factory, s *Shortcut, conf return rctx, nil } +// resolveInputFlags resolves @file and - (stdin) for flags with Input sources. +// Must be called before Validate/DryRun/Execute so that runtime.Str() returns resolved content. +func resolveInputFlags(rctx *RuntimeContext, flags []Flag) error { + stdinUsed := false + for _, fl := range flags { + if len(fl.Input) == 0 { + continue + } + raw, err := rctx.Cmd.Flags().GetString(fl.Name) + if err != nil { + return FlagErrorf("--%s: Input is only supported for string flags", fl.Name) + } + if raw == "" { + continue + } + + // stdin: - + if raw == "-" { + if !slices.Contains(fl.Input, Stdin) { + return FlagErrorf("--%s does not support stdin (-)", fl.Name) + } + if stdinUsed { + return FlagErrorf("--%s: stdin (-) can only be used by one flag", fl.Name) + } + stdinUsed = true + data, err := io.ReadAll(rctx.IO().In) + if err != nil { + return FlagErrorf("--%s: failed to read from stdin: %v", fl.Name, err) + } + rctx.Cmd.Flags().Set(fl.Name, string(data)) + continue + } + + // escape: @@ → literal @ + if strings.HasPrefix(raw, "@@") { + rctx.Cmd.Flags().Set(fl.Name, raw[1:]) // strip first @ + continue + } + + // file: @path + if strings.HasPrefix(raw, "@") { + if !slices.Contains(fl.Input, File) { + return FlagErrorf("--%s does not support file input (@path)", fl.Name) + } + path := strings.TrimSpace(raw[1:]) + if path == "" { + return FlagErrorf("--%s: file path cannot be empty after @", fl.Name) + } + safePath, err := validate.SafeInputPath(path) + if err != nil { + return FlagErrorf("--%s: invalid file path %q: %v", fl.Name, path, err) + } + data, err := vfs.ReadFile(safePath) + if err != nil { + return FlagErrorf("--%s: cannot read file %q: %v", fl.Name, path, err) + } + rctx.Cmd.Flags().Set(fl.Name, string(data)) + continue + } + } + return nil +} + func validateEnumFlags(rctx *RuntimeContext, flags []Flag) error { for _, fl := range flags { if len(fl.Enum) == 0 { @@ -687,6 +633,16 @@ func registerShortcutFlags(cmd *cobra.Command, s *Shortcut) { if len(fl.Enum) > 0 { desc += " (" + strings.Join(fl.Enum, "|") + ")" } + if len(fl.Input) > 0 { + hints := make([]string, 0, 2) + if slices.Contains(fl.Input, File) { + hints = append(hints, "@file") + } + if slices.Contains(fl.Input, Stdin) { + hints = append(hints, "- for stdin") + } + desc += " (supports " + strings.Join(hints, ", ") + ")" + } switch fl.Type { case "bool": def := fl.Default == "true" diff --git a/shortcuts/common/runner_input_test.go b/shortcuts/common/runner_input_test.go new file mode 100644 index 000000000..25aa806b5 --- /dev/null +++ b/shortcuts/common/runner_input_test.go @@ -0,0 +1,202 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package common + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/spf13/cobra" +) + +// newTestRuntimeWithStdin creates a RuntimeContext with string flags and a fake stdin. +func newTestRuntimeWithStdin(flags map[string]string, stdin string) *RuntimeContext { + cmd := &cobra.Command{Use: "test"} + for name := range flags { + cmd.Flags().String(name, "", "") + } + cmd.ParseFlags(nil) + for name, val := range flags { + cmd.Flags().Set(name, val) + } + return &RuntimeContext{ + Cmd: cmd, + Factory: &cmdutil.Factory{ + IOStreams: &cmdutil.IOStreams{ + In: strings.NewReader(stdin), + }, + }, + } +} + +func TestResolveInputFlags_DirectValue(t *testing.T) { + rctx := newTestRuntimeWithStdin(map[string]string{"markdown": "hello world"}, "") + flags := []Flag{{Name: "markdown", Input: []string{File, Stdin}}} + + if err := resolveInputFlags(rctx, flags); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := rctx.Str("markdown"); got != "hello world" { + t.Errorf("expected %q, got %q", "hello world", got) + } +} + +func TestResolveInputFlags_Stdin(t *testing.T) { + rctx := newTestRuntimeWithStdin(map[string]string{"markdown": "-"}, "content from stdin") + flags := []Flag{{Name: "markdown", Input: []string{File, Stdin}}} + + if err := resolveInputFlags(rctx, flags); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := rctx.Str("markdown"); got != "content from stdin" { + t.Errorf("expected %q, got %q", "content from stdin", got) + } +} + +func TestResolveInputFlags_File(t *testing.T) { + dir := t.TempDir() + orig, _ := os.Getwd() + os.Chdir(dir) + t.Cleanup(func() { os.Chdir(orig) }) + + content := "## Hello\n\nThis is **markdown** from a file.\n" + fpath := filepath.Join(dir, "test.md") + os.WriteFile(fpath, []byte(content), 0644) + + rctx := newTestRuntimeWithStdin(map[string]string{"markdown": "@test.md"}, "") + flags := []Flag{{Name: "markdown", Input: []string{File, Stdin}}} + + if err := resolveInputFlags(rctx, flags); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := rctx.Str("markdown"); got != content { + t.Errorf("expected %q, got %q", content, got) + } +} + +func TestResolveInputFlags_EmptyInput(t *testing.T) { + rctx := newTestRuntimeWithStdin(map[string]string{"markdown": ""}, "") + flags := []Flag{{Name: "markdown", Input: []string{File, Stdin}}} + + if err := resolveInputFlags(rctx, flags); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := rctx.Str("markdown"); got != "" { + t.Errorf("expected empty, got %q", got) + } +} + +func TestResolveInputFlags_NoInputSpec(t *testing.T) { + rctx := newTestRuntimeWithStdin(map[string]string{"token": "@something"}, "") + flags := []Flag{{Name: "token"}} // no Input + + if err := resolveInputFlags(rctx, flags); err != nil { + t.Fatalf("unexpected error: %v", err) + } + // value should be unchanged — no resolution + if got := rctx.Str("token"); got != "@something" { + t.Errorf("expected %q, got %q", "@something", got) + } +} + +func TestResolveInputFlags_StdinNotSupported(t *testing.T) { + rctx := newTestRuntimeWithStdin(map[string]string{"data": "-"}, "stdin data") + flags := []Flag{{Name: "data", Input: []string{File}}} // only file, no stdin + + err := resolveInputFlags(rctx, flags) + if err == nil { + t.Fatal("expected error for stdin not supported") + } + if !strings.Contains(err.Error(), "does not support stdin") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestResolveInputFlags_FileNotSupported(t *testing.T) { + rctx := newTestRuntimeWithStdin(map[string]string{"data": "@file.txt"}, "") + flags := []Flag{{Name: "data", Input: []string{Stdin}}} // only stdin, no file + + err := resolveInputFlags(rctx, flags) + if err == nil { + t.Fatal("expected error for file not supported") + } + if !strings.Contains(err.Error(), "does not support file input") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestResolveInputFlags_FileNotFound(t *testing.T) { + dir := t.TempDir() + orig, _ := os.Getwd() + os.Chdir(dir) + t.Cleanup(func() { os.Chdir(orig) }) + + rctx := newTestRuntimeWithStdin(map[string]string{"markdown": "@nonexistent.md"}, "") + flags := []Flag{{Name: "markdown", Input: []string{File, Stdin}}} + + err := resolveInputFlags(rctx, flags) + if err == nil { + t.Fatal("expected error for missing file") + } + if !strings.Contains(err.Error(), "cannot read file") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestResolveInputFlags_EmptyFilePath(t *testing.T) { + rctx := newTestRuntimeWithStdin(map[string]string{"markdown": "@ "}, "") + flags := []Flag{{Name: "markdown", Input: []string{File, Stdin}}} + + err := resolveInputFlags(rctx, flags) + if err == nil { + t.Fatal("expected error for empty file path") + } + if !strings.Contains(err.Error(), "file path cannot be empty") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestResolveInputFlags_EscapeAtSign(t *testing.T) { + rctx := newTestRuntimeWithStdin(map[string]string{"text": "@@mention someone"}, "") + flags := []Flag{{Name: "text", Input: []string{File, Stdin}}} + + if err := resolveInputFlags(rctx, flags); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := rctx.Str("text"); got != "@mention someone" { + t.Errorf("expected %q, got %q", "@mention someone", got) + } +} + +func TestResolveInputFlags_EscapeDoubleAt(t *testing.T) { + rctx := newTestRuntimeWithStdin(map[string]string{"text": "@@@triple"}, "") + flags := []Flag{{Name: "text", Input: []string{File, Stdin}}} + + if err := resolveInputFlags(rctx, flags); err != nil { + t.Fatalf("unexpected error: %v", err) + } + // @@@ → strip first @, remaining is @@triple which is literal + if got := rctx.Str("text"); got != "@@triple" { + t.Errorf("expected %q, got %q", "@@triple", got) + } +} + +func TestResolveInputFlags_DuplicateStdin(t *testing.T) { + rctx := newTestRuntimeWithStdin(map[string]string{"a": "-", "b": "-"}, "data") + flags := []Flag{ + {Name: "a", Input: []string{Stdin}}, + {Name: "b", Input: []string{Stdin}}, + } + + err := resolveInputFlags(rctx, flags) + if err == nil { + t.Fatal("expected error for duplicate stdin usage") + } + if !strings.Contains(err.Error(), "stdin (-) can only be used by one flag") { + t.Errorf("unexpected error: %v", err) + } +} diff --git a/shortcuts/common/runner_scope_test.go b/shortcuts/common/runner_scope_test.go index e7a4efbdb..9d8620b0b 100644 --- a/shortcuts/common/runner_scope_test.go +++ b/shortcuts/common/runner_scope_test.go @@ -4,14 +4,27 @@ package common import ( + "context" "errors" "fmt" "strings" "testing" + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/output" ) +type scopeCheckTokenResolver struct { + result *credential.TokenResult + err error +} + +func (r *scopeCheckTokenResolver) ResolveToken(ctx context.Context, req credential.TokenSpec) (*credential.TokenResult, error) { + return r.result, r.err +} + func TestEnhancePermissionError_MissingScopeType(t *testing.T) { scopes := []string{"calendar:calendar:read"} err := &output.ExitError{ @@ -164,3 +177,25 @@ func TestEnhancePermissionError(t *testing.T) { }) } } + +func TestCheckShortcutScopes_PropagatesContextCancellation(t *testing.T) { + f := &cmdutil.Factory{ + Credential: credential.NewCredentialProvider(nil, nil, &scopeCheckTokenResolver{err: context.Canceled}, nil), + } + + err := checkShortcutScopes(f, context.Background(), core.AsUser, &core.CliConfig{AppID: "app-1"}, []string{"im:message:read"}) + if !errors.Is(err, context.Canceled) { + t.Fatalf("checkShortcutScopes() error = %v, want context.Canceled", err) + } +} + +func TestCheckShortcutScopes_IgnoresNonContextTokenErrors(t *testing.T) { + f := &cmdutil.Factory{ + Credential: credential.NewCredentialProvider(nil, nil, &scopeCheckTokenResolver{err: errors.New("token cache unavailable")}, nil), + } + + err := checkShortcutScopes(f, context.Background(), core.AsUser, &core.CliConfig{AppID: "app-1"}, []string{"im:message:read"}) + if err != nil { + t.Fatalf("checkShortcutScopes() error = %v, want nil", err) + } +} diff --git a/shortcuts/common/types.go b/shortcuts/common/types.go index 70c882cd4..48ca7ece7 100644 --- a/shortcuts/common/types.go +++ b/shortcuts/common/types.go @@ -5,6 +5,12 @@ package common import "context" +// Flag.Input source constants. +const ( + File = "file" // support @path to read value from a file + Stdin = "stdin" // support - to read value from stdin +) + // Flag describes a CLI flag for a shortcut. type Flag struct { Name string // flag name (e.g. "calendar-id") @@ -14,6 +20,7 @@ type Flag struct { Hidden bool // hidden from --help, still readable at runtime Required bool Enum []string // allowed values (e.g. ["asc", "desc"]); empty means no constraint + Input []string // extra input sources: File (@path), Stdin (-); empty = flag value only } // Shortcut represents a high-level CLI command. diff --git a/shortcuts/common/validate.go b/shortcuts/common/validate.go index 422f99e23..b894ddf88 100644 --- a/shortcuts/common/validate.go +++ b/shortcuts/common/validate.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/larksuite/cli/internal/output" + "github.com/larksuite/cli/internal/vfs" ) // FlagErrorf returns a validation error with flag context (exit code 2). @@ -91,7 +92,7 @@ func ValidateSafeOutputDir(outputDir string) error { if filepath.IsAbs(outputDir) { return fmt.Errorf("--output-dir must be a relative path, got: %q", outputDir) } - cwd, err := os.Getwd() + cwd, err := vfs.Getwd() if err != nil { return fmt.Errorf("cannot determine working directory: %w", err) } @@ -110,7 +111,7 @@ func ValidateSafeOutputDir(outputDir string) error { } // Path does not exist yet. If os.Lstat succeeds the entry is a dangling // symlink — reject it to prevent future escapes once the target is created. - if _, lstErr := os.Lstat(abs); lstErr == nil { + if _, lstErr := vfs.Lstat(abs); lstErr == nil { return fmt.Errorf("--output-dir %q is a symlink with a non-existent target", outputDir) } // The path itself doesn't exist; the string-level check is sufficient. diff --git a/shortcuts/doc/doc_media_download.go b/shortcuts/doc/doc_media_download.go index 80b7c88f8..581efc64c 100644 --- a/shortcuts/doc/doc_media_download.go +++ b/shortcuts/doc/doc_media_download.go @@ -7,7 +7,6 @@ import ( "context" "fmt" "net/http" - "os" "path/filepath" "strings" @@ -15,6 +14,7 @@ import ( "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" ) @@ -82,22 +82,20 @@ var DocMediaDownload = common.Shortcut{ apiPath = fmt.Sprintf("/open-apis/drive/v1/medias/%s/download", encodedToken) } - apiResp, err := runtime.DoAPI(&larkcore.ApiReq{ + resp, err := runtime.DoAPIStream(ctx, &larkcore.ApiReq{ HttpMethod: http.MethodGet, ApiPath: apiPath, - }, larkcore.WithFileDownload()) + }) if err != nil { return output.ErrNetwork("download failed: %v", err) } - if apiResp.StatusCode >= 400 { - return output.ErrNetwork("download failed: HTTP %d: %s", apiResp.StatusCode, strings.TrimSpace(string(apiResp.RawBody))) - } + defer resp.Body.Close() // Auto-detect extension from Content-Type finalPath := outputPath currentExt := filepath.Ext(outputPath) if currentExt == "" { - contentType := apiResp.Header.Get("Content-Type") + contentType := resp.Header.Get("Content-Type") mimeType := strings.Split(contentType, ";")[0] mimeType = strings.TrimSpace(mimeType) if ext, ok := mimeToExt[mimeType]; ok { @@ -115,15 +113,19 @@ var DocMediaDownload = common.Shortcut{ return err } - os.MkdirAll(filepath.Dir(safePath), 0755) - if err := validate.AtomicWrite(safePath, apiResp.RawBody, 0644); err != nil { + if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil { + return output.Errorf(output.ExitInternal, "io", "cannot create parent directory: %v", err) + } + + sizeBytes, err := validate.AtomicWriteFromReader(safePath, resp.Body, 0600) + if err != nil { return output.Errorf(output.ExitInternal, "io", "cannot create file: %v", err) } runtime.Out(map[string]interface{}{ "saved_path": safePath, - "size_bytes": len(apiResp.RawBody), - "content_type": apiResp.Header.Get("Content-Type"), + "size_bytes": sizeBytes, + "content_type": resp.Header.Get("Content-Type"), }, nil) return nil }, diff --git a/shortcuts/doc/doc_media_insert.go b/shortcuts/doc/doc_media_insert.go index 32986f741..8242080e9 100644 --- a/shortcuts/doc/doc_media_insert.go +++ b/shortcuts/doc/doc_media_insert.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "net/http" - "os" "path/filepath" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" @@ -17,6 +16,7 @@ import ( "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/util" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" ) @@ -120,7 +120,7 @@ var DocMediaInsert = common.Shortcut{ } // Validate file - stat, err := os.Stat(filePath) + stat, err := vfs.Stat(filePath) if err != nil { return output.ErrValidation("file not found: %s", filePath) } @@ -359,7 +359,7 @@ func extractCreatedBlockTargets(createData map[string]interface{}, mediaType str // uploadMediaFile uploads a file to Feishu drive as media. func uploadMediaFile(ctx context.Context, runtime *common.RuntimeContext, filePath, fileName, mediaType, parentNode, docId string) (string, error) { - f, err := os.Open(filePath) + f, err := vfs.Open(filePath) if err != nil { return "", err } diff --git a/shortcuts/doc/doc_media_test.go b/shortcuts/doc/doc_media_test.go index 7e805213a..eab88a735 100644 --- a/shortcuts/doc/doc_media_test.go +++ b/shortcuts/doc/doc_media_test.go @@ -58,16 +58,6 @@ func withDocsWorkingDir(t *testing.T, dir string) { }) } -func registerDocsBotTokenStub(reg *httpmock.Registry) { - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token", "expire": 7200, - }, - }) -} - func TestDocMediaInsertRejectsOldDocURL(t *testing.T) { f, _, _, _ := cmdutil.TestFactory(t, docsTestConfig()) @@ -111,7 +101,6 @@ func TestDocMediaInsertDryRunWikiAddsResolveStep(t *testing.T) { func TestDocMediaInsertExecuteResolvesWikiBeforeFileCheck(t *testing.T) { f, _, stderr, reg := cmdutil.TestFactory(t, docsTestConfigWithAppID("docs-insert-exec-app")) - registerDocsBotTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/wiki/v2/spaces/get_node", @@ -148,7 +137,6 @@ func TestDocMediaInsertExecuteResolvesWikiBeforeFileCheck(t *testing.T) { func TestDocMediaDownloadRejectsOverwriteWithoutFlag(t *testing.T) { f, _, _, reg := cmdutil.TestFactory(t, docsTestConfigWithAppID("docs-download-overwrite-app")) - registerDocsBotTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/drive/v1/medias/tok_123/download", @@ -179,7 +167,6 @@ func TestDocMediaDownloadRejectsOverwriteWithoutFlag(t *testing.T) { func TestDocMediaDownloadRejectsHTTPErrorBeforeWrite(t *testing.T) { f, _, _, reg := cmdutil.TestFactory(t, docsTestConfigWithAppID("docs-download-app")) - registerDocsBotTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/drive/v1/medias/tok_123/download", diff --git a/shortcuts/doc/doc_media_upload.go b/shortcuts/doc/doc_media_upload.go index 008597b84..39db93005 100644 --- a/shortcuts/doc/doc_media_upload.go +++ b/shortcuts/doc/doc_media_upload.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "net/http" - "os" "path/filepath" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" @@ -17,6 +16,7 @@ import ( "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/util" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" ) @@ -65,7 +65,7 @@ var MediaUpload = common.Shortcut{ filePath = safeFilePath // Validate file - stat, err := os.Stat(filePath) + stat, err := vfs.Stat(filePath) if err != nil { return output.ErrValidation("file not found: %s", filePath) } @@ -76,7 +76,7 @@ var MediaUpload = common.Shortcut{ fileName := filepath.Base(filePath) fmt.Fprintf(runtime.IO().ErrOut, "Uploading: %s (%d bytes)\n", fileName, stat.Size()) - f, err := os.Open(filePath) + f, err := vfs.Open(filePath) if err != nil { return output.ErrValidation("cannot open file: %v", err) } diff --git a/shortcuts/doc/docs_create.go b/shortcuts/doc/docs_create.go index c88c6eaf6..da505211d 100644 --- a/shortcuts/doc/docs_create.go +++ b/shortcuts/doc/docs_create.go @@ -18,7 +18,7 @@ var DocsCreate = common.Shortcut{ Scopes: []string{"docx:document:create"}, Flags: []common.Flag{ {Name: "title", Desc: "document title"}, - {Name: "markdown", Desc: "Markdown content (Lark-flavored)", Required: true}, + {Name: "markdown", Desc: "Markdown content (Lark-flavored)", Required: true, Input: []string{common.File, common.Stdin}}, {Name: "folder-token", Desc: "parent folder token"}, {Name: "wiki-node", Desc: "wiki node token"}, {Name: "wiki-space", Desc: "wiki space ID (use my_library for personal library)"}, diff --git a/shortcuts/doc/docs_update.go b/shortcuts/doc/docs_update.go index 5c64b7cc7..ea80550ed 100644 --- a/shortcuts/doc/docs_update.go +++ b/shortcuts/doc/docs_update.go @@ -38,7 +38,7 @@ var DocsUpdate = common.Shortcut{ Flags: []common.Flag{ {Name: "doc", Desc: "document URL or token", Required: true}, {Name: "mode", Desc: "update mode: append | overwrite | replace_range | replace_all | insert_before | insert_after | delete_range", Required: true}, - {Name: "markdown", Desc: "new content (Lark-flavored Markdown; create blank whiteboards with , repeat to create multiple boards)"}, + {Name: "markdown", Desc: "new content (Lark-flavored Markdown; create blank whiteboards with , repeat to create multiple boards)", Input: []string{common.File, common.Stdin}}, {Name: "selection-with-ellipsis", Desc: "content locator (e.g. 'start...end')"}, {Name: "selection-by-title", Desc: "title locator (e.g. '## Section')"}, {Name: "new-title", Desc: "also update document title"}, diff --git a/shortcuts/drive/drive_download.go b/shortcuts/drive/drive_download.go index 578c415fa..86039cec7 100644 --- a/shortcuts/drive/drive_download.go +++ b/shortcuts/drive/drive_download.go @@ -7,14 +7,13 @@ import ( "context" "fmt" "net/http" - "os" "path/filepath" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" - // validate import used below "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" ) @@ -62,27 +61,27 @@ var DriveDownload = common.Shortcut{ fmt.Fprintf(runtime.IO().ErrOut, "Downloading: %s\n", common.MaskToken(fileToken)) - apiResp, err := runtime.DoAPI(&larkcore.ApiReq{ + resp, err := runtime.DoAPIStream(ctx, &larkcore.ApiReq{ HttpMethod: http.MethodGet, ApiPath: fmt.Sprintf("/open-apis/drive/v1/files/%s/download", validate.EncodePathSegment(fileToken)), - }, larkcore.WithFileDownload()) + }) if err != nil { return output.ErrNetwork("download failed: %s", err) } + defer resp.Body.Close() - if apiResp.StatusCode >= 400 { - return output.ErrNetwork("download failed: HTTP %d: %s", apiResp.StatusCode, string(apiResp.RawBody)) + if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil { + return output.Errorf(output.ExitInternal, "api_error", "cannot create parent directory: %s", err) } - os.MkdirAll(filepath.Dir(safePath), 0755) - - if err := validate.AtomicWrite(safePath, apiResp.RawBody, 0644); err != nil { + sizeBytes, err := validate.AtomicWriteFromReader(safePath, resp.Body, 0600) + if err != nil { return output.Errorf(output.ExitInternal, "api_error", "cannot create file: %s", err) } runtime.Out(map[string]interface{}{ "saved_path": safePath, - "size_bytes": len(apiResp.RawBody), + "size_bytes": sizeBytes, }, nil) return nil }, diff --git a/shortcuts/drive/drive_export_common.go b/shortcuts/drive/drive_export_common.go index 02707a6c6..a95daac67 100644 --- a/shortcuts/drive/drive_export_common.go +++ b/shortcuts/drive/drive_export_common.go @@ -7,7 +7,6 @@ import ( "context" "fmt" "net/http" - "os" "path/filepath" "strconv" "strings" @@ -18,6 +17,7 @@ import ( "github.com/larksuite/cli/internal/client" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" ) @@ -270,7 +270,7 @@ func saveContentToOutputDir(outputDir, fileName string, payload []byte, overwrit return "", err } - if err := os.MkdirAll(filepath.Dir(safePath), 0755); err != nil { + if err := vfs.MkdirAll(filepath.Dir(safePath), 0755); err != nil { return "", output.Errorf(output.ExitInternal, "io", "cannot create output directory: %s", err) } if err := validate.AtomicWrite(safePath, payload, 0644); err != nil { diff --git a/shortcuts/drive/drive_import.go b/shortcuts/drive/drive_import.go index 528e075c3..f2aed91ea 100644 --- a/shortcuts/drive/drive_import.go +++ b/shortcuts/drive/drive_import.go @@ -6,10 +6,12 @@ package drive import ( "context" "fmt" - "os" "path/filepath" + "strings" + "github.com/larksuite/cli/internal/vfs" + "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" "github.com/larksuite/cli/shortcuts/common" @@ -147,7 +149,7 @@ func preflightDriveImportFile(spec *driveImportSpec) (int64, error) { } spec.FilePath = safeFilePath - info, err := os.Stat(spec.FilePath) + info, err := vfs.Stat(spec.FilePath) if err != nil { return 0, output.ErrValidation("cannot read file: %s", err) } diff --git a/shortcuts/drive/drive_import_common.go b/shortcuts/drive/drive_import_common.go index 370da55b9..f3b61c478 100644 --- a/shortcuts/drive/drive_import_common.go +++ b/shortcuts/drive/drive_import_common.go @@ -11,11 +11,12 @@ import ( "fmt" "io" "net/http" - "os" "path/filepath" "strings" "time" + "github.com/larksuite/cli/internal/vfs" + larkcore "github.com/larksuite/oapi-sdk-go/v3/core" "github.com/larksuite/cli/internal/output" @@ -96,7 +97,7 @@ func (s driveImportSpec) CreateTaskBody(fileToken string) map[string]interface{} // uploadMediaForImport uploads the source file to the temporary import media // endpoint and returns the file token consumed by import_tasks. func uploadMediaForImport(ctx context.Context, runtime *common.RuntimeContext, filePath, fileName, docType string) (string, error) { - importInfo, err := os.Stat(filePath) + importInfo, err := vfs.Stat(filePath) if err != nil { return "", output.ErrValidation("cannot read file: %s", err) } @@ -125,7 +126,7 @@ func uploadMediaForImport(ctx context.Context, runtime *common.RuntimeContext, f } func uploadMediaForImportAll(runtime *common.RuntimeContext, filePath, fileName string, fileSize int, extra string) (string, error) { - f, err := os.Open(filePath) + f, err := vfs.Open(filePath) if err != nil { return "", output.ErrValidation("cannot read file: %s", err) } @@ -164,7 +165,7 @@ func uploadMediaForImportMultipart(runtime *common.RuntimeContext, filePath, fil totalBlocks := session.BlockNum fmt.Fprintf(runtime.IO().ErrOut, "Multipart upload initialized: %d chunks x %s\n", totalBlocks, common.FormatSize(int64(session.BlockSize))) - f, err := os.Open(filePath) + f, err := vfs.Open(filePath) if err != nil { return "", output.ErrValidation("cannot read file: %s", err) } diff --git a/shortcuts/drive/drive_io_test.go b/shortcuts/drive/drive_io_test.go index 085acc692..f9e91ae33 100644 --- a/shortcuts/drive/drive_io_test.go +++ b/shortcuts/drive/drive_io_test.go @@ -5,11 +5,9 @@ package drive import ( "bytes" - "fmt" "net/http" "os" "strings" - "sync/atomic" "testing" "github.com/spf13/cobra" @@ -20,11 +18,12 @@ import ( "github.com/larksuite/cli/shortcuts/common" ) -var driveTestConfigSeq atomic.Int64 +// registerDriveBotTokenStub is a no-op. TAT is now managed by CredentialProvider, not SDK. +func registerDriveBotTokenStub(_ *httpmock.Registry) {} func driveTestConfig() *core.CliConfig { return &core.CliConfig{ - AppID: fmt.Sprintf("drive-test-app-%d", driveTestConfigSeq.Add(1)), AppSecret: "test-secret", Brand: core.BrandFeishu, + AppID: "drive-test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, } } @@ -57,16 +56,6 @@ func withDriveWorkingDir(t *testing.T, dir string) { }) } -func registerDriveBotTokenStub(reg *httpmock.Registry) { - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token", "expire": 7200, - }, - }) -} - func TestDriveUploadLargeFileUsesMultipart(t *testing.T) { // Use a distinct AppID to avoid Lark SDK global token cache collision with other tests. uploadTestConfig := &core.CliConfig{ @@ -588,7 +577,6 @@ func TestDriveDownloadRejectsOverwriteWithoutFlag(t *testing.T) { func TestDriveDownloadAllowsOverwriteFlag(t *testing.T) { f, stdout, _, reg := cmdutil.TestFactory(t, driveTestConfig()) - registerDriveBotTokenStub(reg) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/drive/v1/files/file_123/download", diff --git a/shortcuts/drive/drive_upload.go b/shortcuts/drive/drive_upload.go index 2e0d9c905..2846f604b 100644 --- a/shortcuts/drive/drive_upload.go +++ b/shortcuts/drive/drive_upload.go @@ -17,6 +17,7 @@ import ( "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" ) @@ -68,7 +69,7 @@ var DriveUpload = common.Shortcut{ fileName = filepath.Base(filePath) } - info, err := os.Stat(filePath) + info, err := vfs.Stat(filePath) if err != nil { return output.ErrValidation("cannot read file: %s", err) } @@ -97,7 +98,7 @@ var DriveUpload = common.Shortcut{ } func uploadFileToDrive(ctx context.Context, runtime *common.RuntimeContext, filePath, fileName, folderToken string, fileSize int64) (string, error) { - f, err := os.Open(filePath) + f, err := vfs.Open(filePath) if err != nil { return "", err } diff --git a/shortcuts/event/pipeline.go b/shortcuts/event/pipeline.go index 28d552ffc..34f957bfc 100644 --- a/shortcuts/event/pipeline.go +++ b/shortcuts/event/pipeline.go @@ -17,6 +17,7 @@ import ( "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" larkevent "github.com/larksuite/oapi-sdk-go/v3/event" ) @@ -61,13 +62,13 @@ func NewEventPipeline( // EnsureDirs creates all configured output directories once at startup. func (p *EventPipeline) EnsureDirs() error { if p.config.OutputDir != "" { - if err := os.MkdirAll(p.config.OutputDir, 0700); err != nil { + if err := vfs.MkdirAll(p.config.OutputDir, 0700); err != nil { return fmt.Errorf("create output dir: %w", err) } } if p.config.Router != nil { for _, route := range p.config.Router.routes { - if err := os.MkdirAll(route.dir, 0700); err != nil { + if err := vfs.MkdirAll(route.dir, 0700); err != nil { return fmt.Errorf("create route dir %s: %w", route.dir, err) } } diff --git a/shortcuts/im/convert_lib/helpers_test.go b/shortcuts/im/convert_lib/helpers_test.go index 496ba0c2a..557d5ceb3 100644 --- a/shortcuts/im/convert_lib/helpers_test.go +++ b/shortcuts/im/convert_lib/helpers_test.go @@ -129,12 +129,6 @@ func TestExtractPostBlocksText(t *testing.T) { func TestResolveSenderNames(t *testing.T) { runtime := newBotConvertlibRuntime(t, convertlibRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return convertlibJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/contact/v3/users/batch"): if got := req.URL.Query()["user_ids"]; !reflect.DeepEqual(got, []string{"ou_api", "ou_missing"}) { t.Fatalf("contact batch user_ids = %#v, want %#v", got, []string{"ou_api", "ou_missing"}) @@ -179,12 +173,6 @@ func TestResolveSenderNames(t *testing.T) { func TestResolveSenderNamesAPIFailure(t *testing.T) { runtime := newBotConvertlibRuntime(t, convertlibRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return convertlibJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/contact/v3/users/batch"): return nil, fmt.Errorf("contact api failed") default: diff --git a/shortcuts/im/convert_lib/merge_test.go b/shortcuts/im/convert_lib/merge_test.go index 3fc29109f..01ba0da5b 100644 --- a/shortcuts/im/convert_lib/merge_test.go +++ b/shortcuts/im/convert_lib/merge_test.go @@ -63,12 +63,6 @@ func TestFetchMergeForwardSubMessages(t *testing.T) { t.Run("success", func(t *testing.T) { runtime := newBotConvertlibRuntime(t, convertlibRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return convertlibJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_root"): return convertlibJSONResponse(200, map[string]interface{}{ "code": 0, @@ -95,12 +89,6 @@ func TestFetchMergeForwardSubMessages(t *testing.T) { t.Run("empty data", func(t *testing.T) { runtime := newBotConvertlibRuntime(t, convertlibRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return convertlibJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_bad"): return convertlibJSONResponse(200, map[string]interface{}{"code": 0}), nil default: @@ -118,12 +106,6 @@ func TestFetchMergeForwardSubMessages(t *testing.T) { func TestMergeForwardConverterWithRuntime(t *testing.T) { runtime := newBotConvertlibRuntime(t, convertlibRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return convertlibJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_root"): return convertlibJSONResponse(200, map[string]interface{}{ "code": 0, diff --git a/shortcuts/im/convert_lib/runtime_test.go b/shortcuts/im/convert_lib/runtime_test.go index 85a4611f7..5ab8ac154 100644 --- a/shortcuts/im/convert_lib/runtime_test.go +++ b/shortcuts/im/convert_lib/runtime_test.go @@ -15,11 +15,18 @@ import ( "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/shortcuts/common" lark "github.com/larksuite/oapi-sdk-go/v3" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" ) +type staticConvertlibTokenResolver struct{} + +func (s *staticConvertlibTokenResolver) ResolveToken(_ context.Context, _ credential.TokenSpec) (*credential.TokenResult, error) { + return &credential.TokenResult{Token: "test-token"}, nil +} + type convertlibRoundTripFunc func(*http.Request) (*http.Response, error) func (f convertlibRoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { @@ -52,6 +59,7 @@ func newBotConvertlibRuntime(t *testing.T, rt http.RoundTripper) *common.Runtime sdk := lark.NewClient( "test-app", "test-secret", + lark.WithEnableTokenCache(false), lark.WithLogLevel(larkcore.LogLevelError), lark.WithHttpClient(httpClient), ) @@ -60,13 +68,14 @@ func newBotConvertlibRuntime(t *testing.T, rt http.RoundTripper) *common.Runtime AppSecret: "test-secret", Brand: core.BrandFeishu, } + testCred := credential.NewCredentialProvider(nil, nil, &staticConvertlibTokenResolver{}, nil) runtime := &common.RuntimeContext{ Config: cfg, Factory: &cmdutil.Factory{ Config: func() (*core.CliConfig, error) { return cfg, nil }, - AuthConfig: func() (*core.CliConfig, error) { return cfg, nil }, HttpClient: func() (*http.Client, error) { return httpClient, nil }, LarkClient: func() (*lark.Client, error) { return sdk, nil }, + Credential: testCred, IOStreams: &cmdutil.IOStreams{ Out: &bytes.Buffer{}, ErrOut: &bytes.Buffer{}, diff --git a/shortcuts/im/convert_lib/thread_test.go b/shortcuts/im/convert_lib/thread_test.go index 8521665c8..25be088c5 100644 --- a/shortcuts/im/convert_lib/thread_test.go +++ b/shortcuts/im/convert_lib/thread_test.go @@ -13,12 +13,6 @@ import ( func TestExpandThreadReplies(t *testing.T) { runtime := newBotConvertlibRuntime(t, convertlibRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return convertlibJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages"): if req.URL.Query().Get("container_id") != "omt_1" { return nil, fmt.Errorf("unexpected thread lookup: %s", req.URL.String()) @@ -76,12 +70,6 @@ func TestExpandThreadReplies(t *testing.T) { func TestFetchThreadRepliesError(t *testing.T) { runtime := newBotConvertlibRuntime(t, convertlibRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return convertlibJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages"): return nil, fmt.Errorf("boom") default: @@ -104,12 +92,6 @@ func TestFetchThreadRepliesError(t *testing.T) { func TestExpandThreadRepliesMarksFetchError(t *testing.T) { runtime := newBotConvertlibRuntime(t, convertlibRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return convertlibJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages"): return nil, fmt.Errorf("boom") default: diff --git a/shortcuts/im/coverage_additional_test.go b/shortcuts/im/coverage_additional_test.go index f9c931fd8..ac7f82dc4 100644 --- a/shortcuts/im/coverage_additional_test.go +++ b/shortcuts/im/coverage_additional_test.go @@ -4,8 +4,10 @@ package im import ( + "bytes" "context" "fmt" + "io" "net/http" "os" "reflect" @@ -268,12 +270,6 @@ func TestResolveChatIDForMessagesList(t *testing.T) { t.Run("user resolved through p2p lookup", func(t *testing.T) { runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return shortcutJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/chat_p2p/batch_query"): return shortcutJSONResponse(200, map[string]interface{}{ "code": 0, @@ -406,30 +402,6 @@ func TestBuildSearchChatBodyAdditionalBranches(t *testing.T) { } } -func TestResolveToLocalPath(t *testing.T) { - t.Run("media key returns empty path", func(t *testing.T) { - got, cleanup, err := resolveToLocalPath(context.Background(), nil, "--image", "img_123") - if err != nil { - t.Fatalf("resolveToLocalPath() error = %v", err) - } - defer cleanup() - if got != "" { - t.Fatalf("resolveToLocalPath() = %q, want empty path", got) - } - }) - - t.Run("local path passthrough", func(t *testing.T) { - got, cleanup, err := resolveToLocalPath(context.Background(), nil, "--file", "report.pdf") - if err != nil { - t.Fatalf("resolveToLocalPath() error = %v", err) - } - defer cleanup() - if got != "report.pdf" { - t.Fatalf("resolveToLocalPath() = %q, want %q", got, "report.pdf") - } - }) -} - func TestParseMediaDurationSuccess(t *testing.T) { t.Run("mp4", func(t *testing.T) { f, err := os.CreateTemp("", "im-duration-*.mp4") @@ -506,3 +478,108 @@ func TestResolveMediaContentURLFallback(t *testing.T) { }) } } + +func TestLimitedReadCloser(t *testing.T) { + t.Run("within limit", func(t *testing.T) { + body := io.NopCloser(bytes.NewReader([]byte("hello"))) + lr := &limitedReadCloser{ + r: io.LimitReader(body, 10+1), + closer: body, + max: 10, + } + data, err := io.ReadAll(lr) + if err != nil { + t.Fatalf("ReadAll() error = %v", err) + } + if string(data) != "hello" { + t.Fatalf("ReadAll() = %q, want %q", string(data), "hello") + } + if err := lr.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + }) + + t.Run("exceeds limit", func(t *testing.T) { + body := io.NopCloser(bytes.NewReader([]byte("hello world"))) + lr := &limitedReadCloser{ + r: io.LimitReader(body, 5+1), + closer: body, + max: 5, + } + _, err := io.ReadAll(lr) + if err == nil || !strings.Contains(err.Error(), "exceeds size limit") { + t.Fatalf("ReadAll() error = %v, want size limit error", err) + } + }) +} + +func TestMediaBufferDuration(t *testing.T) { + t.Run("mp4 duration from bytes", func(t *testing.T) { + data := wrapInMoov(buildMvhdBox(0, 1000, 5000)) + mb := &mediaBuffer{data: data, ext: ".mp4"} + if got := mb.Duration(); got != "5000" { + t.Fatalf("Duration() = %q, want %q", got, "5000") + } + }) + + t.Run("opus duration from bytes", func(t *testing.T) { + page := make([]byte, 27) + copy(page[0:4], "OggS") + page[5] = 4 + page[6] = 0x00 + page[7] = 0x53 + page[8] = 0x07 + mb := &mediaBuffer{data: page, ext: ".ogg"} + if got := mb.Duration(); got != "10000" { + t.Fatalf("Duration() = %q, want %q", got, "10000") + } + }) + + t.Run("unsupported type returns empty", func(t *testing.T) { + mb := &mediaBuffer{data: []byte("data"), ext: ".txt"} + if got := mb.Duration(); got != "" { + t.Fatalf("Duration() = %q, want empty", got) + } + }) + + t.Run("empty data returns empty", func(t *testing.T) { + mb := &mediaBuffer{data: nil, ext: ".mp4"} + if got := mb.Duration(); got != "" { + t.Fatalf("Duration() = %q, want empty", got) + } + }) +} + +func TestMediaBufferFileType(t *testing.T) { + tests := []struct { + ext string + want string + }{ + {".mp4", "mp4"}, + {".ogg", "opus"}, + {".pdf", "pdf"}, + {".unknown", "stream"}, + } + for _, tt := range tests { + mb := &mediaBuffer{ext: tt.ext} + if got := mb.FileType(); got != tt.want { + t.Fatalf("FileType(%s) = %q, want %q", tt.ext, got, tt.want) + } + } +} + +func TestMediaBufferReader(t *testing.T) { + data := []byte("test content") + mb := &mediaBuffer{data: data, ext: ".txt"} + + // Read twice to verify re-readability + for i := 0; i < 2; i++ { + got, err := io.ReadAll(mb.Reader()) + if err != nil { + t.Fatalf("ReadAll() attempt %d error = %v", i+1, err) + } + if !bytes.Equal(got, data) { + t.Fatalf("ReadAll() attempt %d = %q, want %q", i+1, got, data) + } + } +} diff --git a/shortcuts/im/helpers.go b/shortcuts/im/helpers.go index 32a0a33f5..57f354a1a 100644 --- a/shortcuts/im/helpers.go +++ b/shortcuts/im/helpers.go @@ -4,6 +4,7 @@ package im import ( + "bytes" "context" "encoding/binary" "encoding/json" @@ -21,6 +22,7 @@ import ( "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" "github.com/spf13/cobra" @@ -155,18 +157,17 @@ func sanitizeURLForDisplay(rawURL string) string { return host + "/" + base } -const maxURLDownloadSize = 100 * 1024 * 1024 // 100MB - -// downloadURLToTempFile downloads a URL to a temp file, returning the path. -// The caller is responsible for removing the temp file. -func downloadURLToTempFile(ctx context.Context, runtime *common.RuntimeContext, rawURL string) (string, error) { +// startURLDownload performs URL validation, creates an HTTP client, and sends a +// GET request. It returns the response (with Body still open) and the file +// extension inferred from the URL. The caller must close resp.Body. +func startURLDownload(ctx context.Context, runtime *common.RuntimeContext, rawURL string) (*http.Response, string, error) { if err := validate.ValidateDownloadSourceURL(ctx, rawURL); err != nil { - return "", fmt.Errorf("blocked URL: %w", err) + return nil, "", fmt.Errorf("blocked URL: %w", err) } httpClient, err := runtime.Factory.HttpClient() if err != nil { - return "", fmt.Errorf("http client: %w", err) + return nil, "", fmt.Errorf("http client: %w", err) } httpClient = validate.NewDownloadHTTPClient(httpClient, validate.DownloadHTTPClientOptions{ AllowHTTP: true, @@ -174,57 +175,78 @@ func downloadURLToTempFile(ctx context.Context, runtime *common.RuntimeContext, req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) if err != nil { - return "", fmt.Errorf("invalid URL: %w", err) + return nil, "", fmt.Errorf("invalid URL: %w", err) } resp, err := httpClient.Do(req) if err != nil { - return "", fmt.Errorf("download failed: %w", err) + return nil, "", fmt.Errorf("download failed: %w", err) } - defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("download failed: HTTP %d", resp.StatusCode) + resp.Body.Close() + return nil, "", fmt.Errorf("download failed: HTTP %d", resp.StatusCode) } - // Determine extension from URL for correct file type detection. ext := filepath.Ext(fileNameFromURL(rawURL)) - tmpFile, err := os.CreateTemp("", "lark-media-*"+ext) - if err != nil { - return "", fmt.Errorf("create temp file: %w", err) - } + return resp, ext, nil +} - n, err := io.Copy(tmpFile, io.LimitReader(resp.Body, maxURLDownloadSize+1)) - tmpFile.Close() +// downloadURLToReader returns a size-limited io.ReadCloser for the URL content +// and the file extension inferred from the URL. The caller must close the +// returned ReadCloser. No temp file is created and the content is not buffered. +func downloadURLToReader(ctx context.Context, runtime *common.RuntimeContext, rawURL string, maxSize int64) (io.ReadCloser, string, error) { + resp, ext, err := startURLDownload(ctx, runtime, rawURL) //nolint:bodyclose // resp.Body is closed by the returned limitedReadCloser if err != nil { - os.Remove(tmpFile.Name()) - return "", fmt.Errorf("download failed: %w", err) + return nil, "", err } - if n > maxURLDownloadSize { - os.Remove(tmpFile.Name()) - return "", fmt.Errorf("download exceeds size limit (max 100MB)") + lr := &limitedReadCloser{ + r: io.LimitReader(resp.Body, maxSize+1), + closer: resp.Body, + max: maxSize, } + return lr, ext, nil +} - return tmpFile.Name(), nil +// limitedReadCloser wraps a LimitReader and checks for size overflow on Close. +type limitedReadCloser struct { + r io.Reader + closer io.Closer + max int64 + n int64 } -// resolveToLocalPath resolves a media value to a local file path. -// If the value is a URL, it downloads to a temp file; the returned cleanup func -// removes the temp file (no-op for local paths). Returns ("", nil, nil) for media keys. -func resolveToLocalPath(ctx context.Context, runtime *common.RuntimeContext, flagName, value string) (localPath string, cleanup func(), err error) { - noop := func() {} - if isMediaKey(value) { - return "", noop, nil +func (l *limitedReadCloser) Read(p []byte) (int, error) { + n, err := l.r.Read(p) + l.n += int64(n) + if l.n > l.max { + return n, fmt.Errorf("download exceeds size limit (max %s)", common.FormatSize(l.max)) } - if isURL(value) { - fmt.Fprintf(runtime.IO().ErrOut, "downloading %s: %s\n", flagName, sanitizeURLForDisplay(value)) - tmpPath, err := downloadURLToTempFile(ctx, runtime, value) - if err != nil { - return "", noop, err - } - return tmpPath, func() { os.Remove(tmpPath) }, nil - } - return value, noop, nil + return n, err +} + +func (l *limitedReadCloser) Close() error { + return l.closer.Close() +} + +// mediaKind distinguishes image uploads (image_key) from file uploads (file_key). +type mediaKind int + +const ( + mediaKindImage mediaKind = iota // upload via image API, returns image_key + mediaKindFile // upload via file API, returns file_key +) + +// mediaSpec describes how to resolve and upload a single media input. +type mediaSpec struct { + value string // raw input value (path, URL, or media key) + flagName string // CLI flag name for log messages, e.g. "--image" + mediaType string // human label for errors, e.g. "image" + msgType string // IM message type, e.g. "image", "file", "audio" + kind mediaKind // image vs file upload + maxSize int64 // download size limit + withDuration bool // whether to parse audio/video duration + resultKey string // JSON key for the upload result, e.g. "image_key" } // resolveMediaContent resolves text/media flags to (msgType, contentJSON) for Execute. @@ -234,107 +256,119 @@ func resolveMediaContent(ctx context.Context, runtime *common.RuntimeContext, te jsonBytes, _ := json.Marshal(map[string]string{"text": text}) return "text", string(jsonBytes), nil } + + // Video is special: it produces two keys (file_key + image_key for cover). if videoVal != "" { - fKey := videoVal - if !isMediaKey(videoVal) { - localPath, cleanup, dlErr := resolveToLocalPath(ctx, runtime, "--video", videoVal) - if dlErr != nil { - return mediaFallbackOrError(videoVal, "video", dlErr) - } - defer cleanup() - if localPath == "" { - localPath = videoVal - } - fmt.Fprintf(runtime.IO().ErrOut, "uploading video: %s\n", filepath.Base(localPath)) - ft := detectIMFileType(localPath) - fKey, err = uploadFileToIM(ctx, runtime, localPath, ft, parseMediaDuration(localPath, ft)) - if err != nil { - return mediaFallbackOrError(videoVal, "video", err) - } - } - var coverKey string - if isMediaKey(videoCoverVal) { - coverKey = videoCoverVal - } else { - localPath, cleanup, dlErr := resolveToLocalPath(ctx, runtime, "--video-cover", videoCoverVal) - if dlErr != nil { - return mediaFallbackOrError(videoCoverVal, "cover image", dlErr) - } - defer cleanup() - fmt.Fprintf(runtime.IO().ErrOut, "uploading cover image: %s\n", filepath.Base(localPath)) - coverKey, err = uploadImageToIM(ctx, runtime, localPath, "message") - if err != nil { - return "", "", fmt.Errorf("cover image upload failed: %w", err) - } - } - jsonBytes, _ := json.Marshal(map[string]string{"file_key": fKey, "image_key": coverKey}) - return "media", string(jsonBytes), nil - } - if imageVal != "" { - imageKey := imageVal - if !isMediaKey(imageVal) { - localPath, cleanup, dlErr := resolveToLocalPath(ctx, runtime, "--image", imageVal) - if dlErr != nil { - return mediaFallbackOrError(imageVal, "image", dlErr) - } - defer cleanup() - if localPath == "" { - // isMediaKey path — won't happen since we checked above, but be safe. - localPath = imageVal - } - fmt.Fprintf(runtime.IO().ErrOut, "uploading image: %s\n", filepath.Base(localPath)) - imageKey, err = uploadImageToIM(ctx, runtime, localPath, "message") - if err != nil { - return mediaFallbackOrError(imageVal, "image", err) - } - } - jsonBytes, _ := json.Marshal(map[string]string{"image_key": imageKey}) - return "image", string(jsonBytes), nil - } - if fileVal != "" { - fKey := fileVal - if !isMediaKey(fileVal) { - localPath, cleanup, dlErr := resolveToLocalPath(ctx, runtime, "--file", fileVal) - if dlErr != nil { - return mediaFallbackOrError(fileVal, "file", dlErr) - } - defer cleanup() - if localPath == "" { - localPath = fileVal - } - fmt.Fprintf(runtime.IO().ErrOut, "uploading file: %s\n", filepath.Base(localPath)) - fKey, err = uploadFileToIM(ctx, runtime, localPath, detectIMFileType(localPath), "") - if err != nil { - return mediaFallbackOrError(fileVal, "file", err) - } + return resolveVideoContent(ctx, runtime, videoVal, videoCoverVal) + } + + // All other media types follow a uniform pattern: single input → single key. + specs := []mediaSpec{ + {value: imageVal, flagName: "--image", mediaType: "image", msgType: "image", kind: mediaKindImage, maxSize: maxImageUploadSize, resultKey: "image_key"}, + {value: fileVal, flagName: "--file", mediaType: "file", msgType: "file", kind: mediaKindFile, maxSize: maxFileUploadSize, resultKey: "file_key"}, + {value: audioVal, flagName: "--audio", mediaType: "audio", msgType: "audio", kind: mediaKindFile, maxSize: maxFileUploadSize, withDuration: true, resultKey: "file_key"}, + } + + for _, s := range specs { + if s.value == "" { + continue } - jsonBytes, _ := json.Marshal(map[string]string{"file_key": fKey}) - return "file", string(jsonBytes), nil - } - if audioVal != "" { - fKey := audioVal - if !isMediaKey(audioVal) { - localPath, cleanup, dlErr := resolveToLocalPath(ctx, runtime, "--audio", audioVal) - if dlErr != nil { - return mediaFallbackOrError(audioVal, "audio", dlErr) - } - defer cleanup() - if localPath == "" { - localPath = audioVal - } - fmt.Fprintf(runtime.IO().ErrOut, "uploading audio: %s\n", filepath.Base(localPath)) - ft := detectIMFileType(localPath) - fKey, err = uploadFileToIM(ctx, runtime, localPath, ft, parseMediaDuration(localPath, ft)) - if err != nil { - return mediaFallbackOrError(audioVal, "audio", err) - } + key, resolveErr := resolveOneMedia(ctx, runtime, s) + if resolveErr != nil { + return mediaFallbackOrError(s.value, s.mediaType, resolveErr) } - jsonBytes, _ := json.Marshal(map[string]string{"file_key": fKey}) - return "audio", string(jsonBytes), nil + jsonBytes, _ := json.Marshal(map[string]string{s.resultKey: key}) + return s.msgType, string(jsonBytes), nil } return "", "", nil } +// resolveOneMedia uploads a single media input (image, file, or audio) and +// returns the resulting key. It handles media keys, URLs, and local paths. +func resolveOneMedia(ctx context.Context, runtime *common.RuntimeContext, s mediaSpec) (string, error) { + if isMediaKey(s.value) { + return s.value, nil + } + + if isURL(s.value) { + return resolveURLMedia(ctx, runtime, s) + } + return resolveLocalMedia(ctx, runtime, s) +} + +// resolveURLMedia downloads a URL and uploads it. +func resolveURLMedia(ctx context.Context, runtime *common.RuntimeContext, s mediaSpec) (string, error) { + fmt.Fprintf(runtime.IO().ErrOut, "downloading %s: %s\n", s.flagName, sanitizeURLForDisplay(s.value)) + + if s.kind == mediaKindImage { + rc, _, err := downloadURLToReader(ctx, runtime, s.value, s.maxSize) + if err != nil { + return "", err + } + defer rc.Close() + fmt.Fprintf(runtime.IO().ErrOut, "uploading %s\n", s.mediaType) + return uploadImageFromReader(ctx, runtime, rc, "message") + } + + // File-kind: buffer in memory for possible duration parsing. + mb, err := newMediaBuffer(ctx, runtime, s.value, s.maxSize) + if err != nil { + return "", err + } + fmt.Fprintf(runtime.IO().ErrOut, "uploading %s: %s\n", s.mediaType, mb.FileName()) + dur := "" + if s.withDuration { + dur = mb.Duration() + } + return uploadFileFromReader(ctx, runtime, mb.Reader(), mb.FileName(), mb.FileType(), dur) +} + +// resolveLocalMedia uploads a local file. +func resolveLocalMedia(ctx context.Context, runtime *common.RuntimeContext, s mediaSpec) (string, error) { + fmt.Fprintf(runtime.IO().ErrOut, "uploading %s: %s\n", s.mediaType, filepath.Base(s.value)) + + safePath, err := validate.SafeInputPath(s.value) + if err != nil { + return "", err + } + + if s.kind == mediaKindImage { + return uploadImageToIM(ctx, runtime, safePath, "message") + } + + ft := detectIMFileType(safePath) + dur := "" + if s.withDuration { + dur = parseMediaDuration(safePath, ft) + } + return uploadFileToIM(ctx, runtime, safePath, ft, dur) +} + +// resolveVideoContent handles the video case which needs both a file_key and +// a cover image_key. +func resolveVideoContent(ctx context.Context, runtime *common.RuntimeContext, videoVal, videoCoverVal string) (string, string, error) { + videoSpec := mediaSpec{ + value: videoVal, flagName: "--video", mediaType: "video", + kind: mediaKindFile, maxSize: maxFileUploadSize, withDuration: true, resultKey: "file_key", + } + fKey, err := resolveOneMedia(ctx, runtime, videoSpec) + if err != nil { + return mediaFallbackOrError(videoVal, "video", err) + } + + coverSpec := mediaSpec{ + value: videoCoverVal, flagName: "--video-cover", mediaType: "cover image", + kind: mediaKindImage, maxSize: maxImageUploadSize, resultKey: "image_key", + } + coverKey, err := resolveOneMedia(ctx, runtime, coverSpec) + if err != nil { + return "", "", fmt.Errorf("cover image upload failed: %w", err) + } + + jsonBytes, _ := json.Marshal(map[string]string{"file_key": fKey, "image_key": coverKey}) + return "media", string(jsonBytes), nil +} + // mediaFallbackOrError returns a text fallback for URL inputs when upload fails, // or a hard error for local file inputs. func mediaFallbackOrError(originalValue, mediaType string, uploadErr error) (string, string, error) { @@ -526,7 +560,7 @@ func parseMediaDuration(filePath, fileType string) string { if fileType != "opus" && fileType != "mp4" { return "" } - f, err := os.Open(filePath) + f, err := vfs.Open(filePath) if err != nil { return "" } @@ -549,6 +583,120 @@ func parseMediaDuration(filePath, fileType string) string { return strconv.FormatInt(ms, 10) } +// mediaBuffer holds downloaded media content in memory, providing both random +// access (for duration parsing) and an io.Reader (for upload). It replaces temp +// files for URL-sourced media that needs seek-like access before upload. +type mediaBuffer struct { + data []byte + ext string // file extension including leading dot, e.g. ".mp4" +} + +// newMediaBuffer downloads URL content into memory via downloadURLToReader. +func newMediaBuffer(ctx context.Context, runtime *common.RuntimeContext, rawURL string, maxSize int64) (*mediaBuffer, error) { + rc, ext, err := downloadURLToReader(ctx, runtime, rawURL, maxSize) + if err != nil { + return nil, err + } + defer rc.Close() + + data, err := io.ReadAll(rc) + if err != nil { + return nil, fmt.Errorf("download failed: %w", err) + } + return &mediaBuffer{data: data, ext: ext}, nil +} + +// Reader returns a new io.Reader over the buffered data. Each call returns a +// fresh reader starting from the beginning, so the buffer can be read multiple +// times (once for duration parsing, once for upload). +func (b *mediaBuffer) Reader() io.Reader { + return bytes.NewReader(b.data) +} + +// FileName returns a synthetic file name based on the URL extension. +func (b *mediaBuffer) FileName() string { + return "media" + b.ext +} + +// FileType returns the IM file type detected from the extension. +func (b *mediaBuffer) FileType() string { + return detectIMFileType("file" + b.ext) +} + +// Duration parses audio/video duration from the buffered data. +func (b *mediaBuffer) Duration() string { + ft := b.FileType() + if ft != "opus" && ft != "mp4" { + return "" + } + if len(b.data) == 0 { + return "" + } + var ms int64 + if ft == "opus" { + ms = readOggDurationBytes(b.data) + } else { + ms = readMp4DurationBytes(b.data) + } + if ms <= 0 { + return "" + } + return strconv.FormatInt(ms, 10) +} + +// readOggDurationBytes parses OGG duration from the tail of in-memory data. +func readOggDurationBytes(data []byte) int64 { + const maxTail = 65536 + buf := data + if len(buf) > maxTail { + buf = buf[len(buf)-maxTail:] + } + return parseOggOpusDuration(buf) +} + +// readMp4DurationBytes walks top-level MP4 boxes in memory to find moov/mvhd duration. +func readMp4DurationBytes(data []byte) int64 { + fileSize := int64(len(data)) + var offset int64 + for offset+8 <= fileSize { + size := int64(binary.BigEndian.Uint32(data[offset : offset+4])) + typ := string(data[offset+4 : offset+8]) + + var boxEnd, dataStart int64 + switch { + case size == 0: + boxEnd = fileSize + dataStart = offset + 8 + case size == 1: + if offset+16 > fileSize { + return 0 + } + boxEnd = int64(binary.BigEndian.Uint64(data[offset+8 : offset+16])) + dataStart = offset + 16 + case size < 8: + return 0 + default: + boxEnd = offset + size + dataStart = offset + 8 + } + + if typ == "moov" { + moovLen := boxEnd - dataStart + if moovLen <= 0 || moovLen > 10<<20 || dataStart+moovLen > fileSize { + return 0 + } + moov := data[dataStart : dataStart+moovLen] + mvhdStart, mvhdEnd := findMP4Box(moov, 0, len(moov), "mvhd") + if mvhdStart < 0 { + return 0 + } + return parseMvhdPayload(moov[mvhdStart:mvhdEnd]) + } + offset = boxEnd + } + return 0 +} + // readOggDuration reads the tail of an OGG file (up to 64 KB) and parses duration. func readOggDuration(f *os.File, fileSize int64) int64 { const maxTail = 65536 @@ -734,15 +882,15 @@ func resolveMarkdownImageURLs(ctx context.Context, runtime *common.RuntimeContex } imgURL := sub[1] - tmpPath, err := downloadURLToTempFile(ctx, runtime, imgURL) + rc, _, err := downloadURLToReader(ctx, runtime, imgURL, maxImageUploadSize) if err != nil { fmt.Fprintf(runtime.IO().ErrOut, "warning: failed to download image %s: %v\n", sanitizeURLForDisplay(imgURL), err) return "" } - defer os.Remove(tmpPath) + defer rc.Close() fmt.Fprintf(runtime.IO().ErrOut, "uploading image from URL: %s\n", sanitizeURLForDisplay(imgURL)) - imgKey, err := uploadImageToIM(ctx, runtime, tmpPath, "message") + imgKey, err := uploadImageFromReader(ctx, runtime, rc, "message") if err != nil { fmt.Fprintf(runtime.IO().ErrOut, "warning: failed to upload image %s: %v\n", sanitizeURLForDisplay(imgURL), err) return "" @@ -857,16 +1005,14 @@ const maxImageUploadSize = 5 * 1024 * 1024 // 5MB — Lark API limit for images const maxFileUploadSize = 100 * 1024 * 1024 // 100MB — Lark API limit for files func uploadImageToIM(ctx context.Context, runtime *common.RuntimeContext, filePath, imageType string) (string, error) { - safePath, err := validate.SafeInputPath(filePath) - if err != nil { - return "", err - } + // filePath is already validated by the caller (resolveLocalMedia). + safePath := filePath - if info, err := os.Stat(safePath); err == nil && info.Size() > maxImageUploadSize { + if info, err := vfs.Stat(safePath); err == nil && info.Size() > maxImageUploadSize { return "", fmt.Errorf("image size %s exceeds limit (max 5MB)", common.FormatSize(info.Size())) } - f, err := os.Open(safePath) + f, err := vfs.Open(safePath) if err != nil { return "", err } @@ -899,16 +1045,14 @@ func uploadImageToIM(ctx context.Context, runtime *common.RuntimeContext, filePa } func uploadFileToIM(ctx context.Context, runtime *common.RuntimeContext, filePath, fileType, duration string) (string, error) { - safePath, err := validate.SafeInputPath(filePath) - if err != nil { - return "", err - } + // filePath is already validated by the caller (resolveLocalMedia). + safePath := filePath - if info, err := os.Stat(safePath); err == nil && info.Size() > maxFileUploadSize { + if info, err := vfs.Stat(safePath); err == nil && info.Size() > maxFileUploadSize { return "", fmt.Errorf("file size %s exceeds limit (max 100MB)", common.FormatSize(info.Size())) } - f, err := os.Open(safePath) + f, err := vfs.Open(safePath) if err != nil { return "", err } @@ -943,3 +1087,63 @@ func uploadFileToIM(ctx context.Context, runtime *common.RuntimeContext, filePat } return fileKey, nil } + +// uploadImageFromReader uploads an image from an io.Reader (no local file needed). +func uploadImageFromReader(ctx context.Context, runtime *common.RuntimeContext, r io.Reader, imageType string) (string, error) { + fd := larkcore.NewFormdata() + fd.AddField("image_type", imageType) + fd.AddFile("image", r) + + apiResp, err := runtime.DoAPIAsBot(&larkcore.ApiReq{ + HttpMethod: http.MethodPost, + ApiPath: "/open-apis/im/v1/images", + Body: fd, + }, larkcore.WithFileUpload()) + if err != nil { + return "", err + } + + var result map[string]interface{} + if err := json.Unmarshal(apiResp.RawBody, &result); err != nil { + return "", fmt.Errorf("parse error: %w", err) + } + + data, _ := result["data"].(map[string]interface{}) + imageKey, _ := data["image_key"].(string) + if imageKey == "" { + return "", fmt.Errorf("image_key not found in response (code: %v, msg: %v)", result["code"], result["msg"]) + } + return imageKey, nil +} + +// uploadFileFromReader uploads a file from an io.Reader (no local file needed). +func uploadFileFromReader(ctx context.Context, runtime *common.RuntimeContext, r io.Reader, fileName, fileType, duration string) (string, error) { + fd := larkcore.NewFormdata() + fd.AddField("file_type", fileType) + fd.AddField("file_name", fileName) + if duration != "" { + fd.AddField("duration", duration) + } + fd.AddFile("file", r) + + apiResp, err := runtime.DoAPIAsBot(&larkcore.ApiReq{ + HttpMethod: http.MethodPost, + ApiPath: "/open-apis/im/v1/files", + Body: fd, + }, larkcore.WithFileUpload()) + if err != nil { + return "", err + } + + var result map[string]interface{} + if err := json.Unmarshal(apiResp.RawBody, &result); err != nil { + return "", fmt.Errorf("parse error: %w", err) + } + + data, _ := result["data"].(map[string]interface{}) + fileKey, _ := data["file_key"].(string) + if fileKey == "" { + return "", fmt.Errorf("file_key not found in response (code: %v, msg: %v)", result["code"], result["msg"]) + } + return fileKey, nil +} diff --git a/shortcuts/im/helpers_local_media_test.go b/shortcuts/im/helpers_local_media_test.go new file mode 100644 index 000000000..50cfde4bb --- /dev/null +++ b/shortcuts/im/helpers_local_media_test.go @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package im + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/vfs" + "github.com/larksuite/cli/shortcuts/common" +) + +type countingOpenFS struct { + vfs.OsFs + cwd string + openCalls int +} + +func (fs *countingOpenFS) Getwd() (string, error) { + return fs.cwd, nil +} + +func (fs *countingOpenFS) Open(name string) (*os.File, error) { + fs.openCalls++ + return nil, os.ErrPermission +} + +func TestResolveLocalMedia_ValidatesPathBeforeParsingDuration(t *testing.T) { + root := t.TempDir() + cwd := filepath.Join(root, "work") + if err := os.MkdirAll(cwd, 0755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + outside := filepath.Join(root, "outside.mp4") + if err := os.WriteFile(outside, []byte("not-a-real-mp4"), 0600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + mockFS := &countingOpenFS{cwd: cwd} + oldFS := vfs.DefaultFS + vfs.DefaultFS = mockFS + t.Cleanup(func() { vfs.DefaultFS = oldFS }) + + f, _, _, _ := cmdutil.TestFactory(t, nil) + runtime := &common.RuntimeContext{Factory: f} + spec := mediaSpec{ + value: "../outside.mp4", + mediaType: "video", + kind: mediaKindFile, + withDuration: true, + } + + _, err := resolveLocalMedia(context.Background(), runtime, spec) + if err == nil { + t.Fatal("expected path validation error") + } + if !strings.Contains(err.Error(), "resolves outside the current working directory") { + t.Fatalf("error = %v, want path validation error", err) + } + if mockFS.openCalls != 0 { + t.Fatalf("Open() called %d times, want 0 before validation", mockFS.openCalls) + } +} diff --git a/shortcuts/im/helpers_network_test.go b/shortcuts/im/helpers_network_test.go index b09b45e06..9e914fddf 100644 --- a/shortcuts/im/helpers_network_test.go +++ b/shortcuts/im/helpers_network_test.go @@ -22,9 +22,16 @@ import ( "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/shortcuts/common" ) +type staticShortcutTokenResolver struct{} + +func (s *staticShortcutTokenResolver) ResolveToken(_ context.Context, _ credential.TokenSpec) (*credential.TokenResult, error) { + return &credential.TokenResult{Token: "tenant-token"}, nil +} + type shortcutRoundTripFunc func(*http.Request) (*http.Response, error) func (f shortcutRoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { @@ -68,6 +75,7 @@ func newBotShortcutRuntime(t *testing.T, rt http.RoundTripper) *common.RuntimeCo sdk := lark.NewClient( "test-app", "test-secret", + lark.WithEnableTokenCache(false), lark.WithLogLevel(larkcore.LogLevelError), lark.WithHttpClient(httpClient), ) @@ -76,13 +84,14 @@ func newBotShortcutRuntime(t *testing.T, rt http.RoundTripper) *common.RuntimeCo AppSecret: "test-secret", Brand: core.BrandFeishu, } + testCred := credential.NewCredentialProvider(nil, nil, &staticShortcutTokenResolver{}, nil) runtime := &common.RuntimeContext{ Config: cfg, Factory: &cmdutil.Factory{ Config: func() (*core.CliConfig, error) { return cfg, nil }, - AuthConfig: func() (*core.CliConfig, error) { return cfg, nil }, HttpClient: func() (*http.Client, error) { return httpClient, nil }, LarkClient: func() (*lark.Client, error) { return sdk, nil }, + Credential: testCred, IOStreams: &cmdutil.IOStreams{ Out: &bytes.Buffer{}, ErrOut: &bytes.Buffer{}, @@ -99,12 +108,6 @@ func TestResolveP2PChatID(t *testing.T) { var gotAuth string runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return shortcutJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/chat_p2p/batch_query"): gotAuth = req.Header.Get("Authorization") return shortcutJSONResponse(200, map[string]interface{}{ @@ -135,12 +138,6 @@ func TestResolveP2PChatID(t *testing.T) { func TestResolveP2PChatIDNotFound(t *testing.T) { runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return shortcutJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/chat_p2p/batch_query"): return shortcutJSONResponse(200, map[string]interface{}{ "code": 0, @@ -184,12 +181,6 @@ func TestResolveThreadID(t *testing.T) { t.Run("message lookup success", func(t *testing.T) { runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return shortcutJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_123"): return shortcutJSONResponse(200, map[string]interface{}{ "code": 0, @@ -216,12 +207,6 @@ func TestResolveThreadID(t *testing.T) { t.Run("message lookup not found", func(t *testing.T) { runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return shortcutJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_404"): return shortcutJSONResponse(200, map[string]interface{}{ "code": 0, @@ -248,12 +233,6 @@ func TestDownloadIMResourceToPathSuccess(t *testing.T) { payload := []byte("hello download") runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return shortcutJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_123/resources/file_123"): gotHeaders = req.Header.Clone() return shortcutRawResponse(200, payload, http.Header{"Content-Type": []string{"application/octet-stream"}}), nil @@ -294,12 +273,6 @@ func TestDownloadIMResourceToPathSuccess(t *testing.T) { func TestDownloadIMResourceToPathHTTPErrorBody(t *testing.T) { runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return shortcutJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_403/resources/file_403"): return shortcutRawResponse(403, []byte("denied"), http.Header{"Content-Type": []string{"text/plain"}}), nil default: @@ -317,12 +290,6 @@ func TestUploadImageToIMSuccess(t *testing.T) { var gotBody string runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return shortcutJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/images"): body, err := io.ReadAll(req.Body) if err != nil { @@ -355,7 +322,11 @@ func TestUploadImageToIMSuccess(t *testing.T) { t.Fatalf("WriteFile() error = %v", err) } - got, err := uploadImageToIM(context.Background(), runtime, "./"+path, "message") + absPath, err := filepath.Abs(path) + if err != nil { + t.Fatalf("Abs() error = %v", err) + } + got, err := uploadImageToIM(context.Background(), runtime, absPath, "message") if err != nil { t.Fatalf("uploadImageToIM() error = %v", err) } @@ -371,12 +342,6 @@ func TestUploadFileToIMSuccess(t *testing.T) { var gotBody string runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return shortcutJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/files"): body, err := io.ReadAll(req.Body) if err != nil { @@ -409,7 +374,11 @@ func TestUploadFileToIMSuccess(t *testing.T) { t.Fatalf("WriteFile() error = %v", err) } - got, err := uploadFileToIM(context.Background(), runtime, "./"+path, "stream", "1200") + absPath, err := filepath.Abs(path) + if err != nil { + t.Fatalf("Abs() error = %v", err) + } + got, err := uploadFileToIM(context.Background(), runtime, absPath, "stream", "1200") if err != nil { t.Fatalf("uploadFileToIM() error = %v", err) } @@ -425,19 +394,7 @@ func TestUploadFileToIMSuccess(t *testing.T) { } func TestUploadImageToIMSizeLimit(t *testing.T) { - wd, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - tmpDir := t.TempDir() - if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("Chdir() error = %v", err) - } - t.Cleanup(func() { - _ = os.Chdir(wd) - }) - - path := "too-large.png" + path := filepath.Join(t.TempDir(), "too-large.png") f, err := os.Create(path) if err != nil { t.Fatalf("Create() error = %v", err) @@ -445,30 +402,16 @@ func TestUploadImageToIMSizeLimit(t *testing.T) { if err := f.Truncate(maxImageUploadSize + 1); err != nil { t.Fatalf("Truncate() error = %v", err) } - if err := f.Close(); err != nil { - t.Fatalf("Close() error = %v", err) - } + f.Close() - _, err = uploadImageToIM(context.Background(), nil, "./"+path, "message") + _, err = uploadImageToIM(context.Background(), nil, path, "message") if err == nil || !strings.Contains(err.Error(), "exceeds limit") { t.Fatalf("uploadImageToIM() error = %v", err) } } func TestUploadFileToIMSizeLimit(t *testing.T) { - wd, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - tmpDir := t.TempDir() - if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("Chdir() error = %v", err) - } - t.Cleanup(func() { - _ = os.Chdir(wd) - }) - - path := "too-large.bin" + path := filepath.Join(t.TempDir(), "too-large.bin") f, err := os.Create(path) if err != nil { t.Fatalf("Create() error = %v", err) @@ -476,11 +419,9 @@ func TestUploadFileToIMSizeLimit(t *testing.T) { if err := f.Truncate(maxFileUploadSize + 1); err != nil { t.Fatalf("Truncate() error = %v", err) } - if err := f.Close(); err != nil { - t.Fatalf("Close() error = %v", err) - } + f.Close() - _, err = uploadFileToIM(context.Background(), nil, "./"+path, "stream", "") + _, err = uploadFileToIM(context.Background(), nil, path, "stream", "") if err == nil || !strings.Contains(err.Error(), "exceeds limit") { t.Fatalf("uploadFileToIM() error = %v", err) } @@ -502,3 +443,81 @@ func TestResolveMediaContentWrapsUploadError(t *testing.T) { t.Fatalf("resolveMediaContent() error = %v", err) } } + +// TestResolveLocalMediaImage verifies that resolveLocalMedia can upload an image +// via uploadImageToIM without double path validation. +func TestResolveLocalMediaImage(t *testing.T) { + runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { + if strings.Contains(req.URL.Path, "/open-apis/im/v1/images") { + return shortcutJSONResponse(200, map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{"image_key": "img_via_resolve"}, + }), nil + } + return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) + })) + + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd() error = %v", err) + } + tmpDir := t.TempDir() + if err := os.Chdir(tmpDir); err != nil { + t.Fatalf("Chdir() error = %v", err) + } + t.Cleanup(func() { _ = os.Chdir(wd) }) + + if err := os.WriteFile("test.png", []byte("png-data"), 0600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + got, err := resolveLocalMedia(context.Background(), runtime, mediaSpec{ + value: "./test.png", flagName: "--image", mediaType: "image", + kind: mediaKindImage, maxSize: maxImageUploadSize, resultKey: "image_key", + }) + if err != nil { + t.Fatalf("resolveLocalMedia(image) error = %v", err) + } + if got != "img_via_resolve" { + t.Fatalf("resolveLocalMedia(image) = %q, want %q", got, "img_via_resolve") + } +} + +// TestResolveLocalMediaFile verifies that resolveLocalMedia can upload a file +// via uploadFileToIM without double path validation. +func TestResolveLocalMediaFile(t *testing.T) { + runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { + if strings.Contains(req.URL.Path, "/open-apis/im/v1/files") { + return shortcutJSONResponse(200, map[string]interface{}{ + "code": 0, + "data": map[string]interface{}{"file_key": "file_via_resolve"}, + }), nil + } + return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) + })) + + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd() error = %v", err) + } + tmpDir := t.TempDir() + if err := os.Chdir(tmpDir); err != nil { + t.Fatalf("Chdir() error = %v", err) + } + t.Cleanup(func() { _ = os.Chdir(wd) }) + + if err := os.WriteFile("test.txt", []byte("file-data"), 0600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + got, err := resolveLocalMedia(context.Background(), runtime, mediaSpec{ + value: "./test.txt", flagName: "--file", mediaType: "file", + kind: mediaKindFile, maxSize: maxFileUploadSize, resultKey: "file_key", + }) + if err != nil { + t.Fatalf("resolveLocalMedia(file) error = %v", err) + } + if got != "file_via_resolve" { + t.Fatalf("resolveLocalMedia(file) = %q, want %q", got, "file_via_resolve") + } +} diff --git a/shortcuts/im/helpers_test.go b/shortcuts/im/helpers_test.go index bf9bf8705..f1813e9e2 100644 --- a/shortcuts/im/helpers_test.go +++ b/shortcuts/im/helpers_test.go @@ -4,7 +4,6 @@ package im import ( - "bytes" "context" "encoding/binary" "errors" @@ -13,7 +12,6 @@ import ( "strings" "testing" - "github.com/larksuite/cli/internal/cmdutil" "github.com/larksuite/cli/shortcuts/common" ) @@ -471,17 +469,11 @@ func TestNormalizeDownloadOutputPath(t *testing.T) { } func TestDownloadIMResourceToPathHTTPClientError(t *testing.T) { - runtime := &common.RuntimeContext{ - Factory: &cmdutil.Factory{ - HttpClient: func() (*http.Client, error) { - return nil, errors.New("http client unavailable") - }, - IOStreams: &cmdutil.IOStreams{ - Out: &bytes.Buffer{}, - ErrOut: &bytes.Buffer{}, - }, - }, - } + // DoAPIStream now goes through APIClient, which requires a fully constructed Factory. + // When HttpClient returns an error, NewAPIClient fails, and getAPIClient propagates it. + runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, errors.New("http client unavailable") + })) _, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_123", "img_123", "image", "out.bin") if err == nil || !strings.Contains(err.Error(), "http client unavailable") { diff --git a/shortcuts/im/im_messages_resources_download.go b/shortcuts/im/im_messages_resources_download.go index beeaacd80..0f29392a9 100644 --- a/shortcuts/im/im_messages_resources_download.go +++ b/shortcuts/im/im_messages_resources_download.go @@ -6,15 +6,15 @@ package im import ( "context" "fmt" - "io" "net/http" - "os" "path/filepath" "strings" "time" + "github.com/larksuite/cli/internal/client" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" larkcore "github.com/larksuite/oapi-sdk-go/v3/core" ) @@ -150,21 +150,13 @@ func downloadIMResourceToPath(ctx context.Context, runtime *common.RuntimeContex "file_key": fileKey, }, QueryParams: query, - }, defaultIMResourceDownloadTimeout) + }, client.WithTimeout(defaultIMResourceDownloadTimeout)) if err != nil { return "", 0, err } defer downloadResp.Body.Close() - if downloadResp.StatusCode >= 400 { - body, _ := io.ReadAll(io.LimitReader(downloadResp.Body, 4096)) - if len(body) > 0 { - return "", 0, output.ErrNetwork("download failed: HTTP %d: %s", downloadResp.StatusCode, strings.TrimSpace(string(body))) - } - return "", 0, output.ErrNetwork("download failed: HTTP %d", downloadResp.StatusCode) - } - - if err := os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { + if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil { return "", 0, output.Errorf(output.ExitInternal, "api_error", "cannot create parent directory: %s", err) } diff --git a/shortcuts/im/im_messages_search_execute_test.go b/shortcuts/im/im_messages_search_execute_test.go index 6cccd1841..e11c41fcf 100644 --- a/shortcuts/im/im_messages_search_execute_test.go +++ b/shortcuts/im/im_messages_search_execute_test.go @@ -69,12 +69,6 @@ func TestImMessagesSearchExecuteAutoPaginationBatches(t *testing.T) { "page-all": true, }, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return shortcutJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/search"): pageToken := req.URL.Query().Get("page_token") searchPageTokens = append(searchPageTokens, pageToken) @@ -167,12 +161,6 @@ func TestImMessagesSearchExecuteExplicitPageLimitWithoutPageAll(t *testing.T) { "page-limit": "2", }, nil, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { switch { - case strings.Contains(req.URL.Path, "tenant_access_token"): - return shortcutJSONResponse(200, map[string]interface{}{ - "code": 0, - "tenant_access_token": "tenant-token", - "expire": 7200, - }), nil case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/search"): searchCalls++ pageToken := req.URL.Query().Get("page_token") diff --git a/shortcuts/mail/draft/patch.go b/shortcuts/mail/draft/patch.go index e4f3be36d..1a245a940 100644 --- a/shortcuts/mail/draft/patch.go +++ b/shortcuts/mail/draft/patch.go @@ -6,11 +6,11 @@ package draft import ( "fmt" "mime" - "os" "path/filepath" "strings" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/mail/filecheck" ) @@ -470,14 +470,14 @@ func addAttachment(snapshot *DraftSnapshot, path string) error { if err := checkBlockedExtension(filepath.Base(path)); err != nil { return err } - info, err := os.Stat(safePath) + info, err := vfs.Stat(safePath) if err != nil { return err } if err := checkSnapshotAttachmentLimit(snapshot, info.Size(), nil); err != nil { return err } - content, err := os.ReadFile(safePath) + content, err := vfs.ReadFile(safePath) if err != nil { return err } @@ -528,14 +528,14 @@ func addInline(snapshot *DraftSnapshot, path, cid, fileName, contentType string) if err != nil { return fmt.Errorf("inline image %q: %w", path, err) } - info, err := os.Stat(safePath) + info, err := vfs.Stat(safePath) if err != nil { return err } if err := checkSnapshotAttachmentLimit(snapshot, info.Size(), nil); err != nil { return err } - content, err := os.ReadFile(safePath) + content, err := vfs.ReadFile(safePath) if err != nil { return err } @@ -576,14 +576,14 @@ func replaceInline(snapshot *DraftSnapshot, partID, path, cid, fileName, content if err != nil { return fmt.Errorf("inline image %q: %w", path, err) } - info, err := os.Stat(safePath) + info, err := vfs.Stat(safePath) if err != nil { return err } if err := checkSnapshotAttachmentLimit(snapshot, info.Size(), part); err != nil { return err } - content, err := os.ReadFile(safePath) + content, err := vfs.ReadFile(safePath) if err != nil { return err } diff --git a/shortcuts/mail/emlbuilder/builder.go b/shortcuts/mail/emlbuilder/builder.go index 61f416d03..dd0ca9559 100644 --- a/shortcuts/mail/emlbuilder/builder.go +++ b/shortcuts/mail/emlbuilder/builder.go @@ -47,12 +47,12 @@ import ( "math/rand" "mime" "net/mail" - "os" "path/filepath" "strings" "time" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/mail/filecheck" ) @@ -65,7 +65,7 @@ func readFile(path string) ([]byte, error) { if err != nil { return nil, fmt.Errorf("attachment %q: %w", path, err) } - return os.ReadFile(safePath) + return vfs.ReadFile(safePath) } // Builder constructs a Lark-compatible RFC 2822 EML message. diff --git a/shortcuts/mail/helpers.go b/shortcuts/mail/helpers.go index 193f4f46b..886293233 100644 --- a/shortcuts/mail/helpers.go +++ b/shortcuts/mail/helpers.go @@ -12,7 +12,6 @@ import ( "net/http" netmail "net/mail" "net/url" - "os" "path/filepath" "regexp" "strconv" @@ -21,6 +20,7 @@ import ( "github.com/larksuite/cli/internal/auth" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" "github.com/larksuite/cli/shortcuts/mail/emlbuilder" ) @@ -1849,7 +1849,7 @@ func checkAttachmentSizeLimit(filePaths []string, extraBytes int64, extraCount . if err != nil { return fmt.Errorf("unsafe attachment path %s: %w", p, err) } - info, err := os.Stat(safePath) + info, err := vfs.Stat(safePath) if err != nil { return fmt.Errorf("failed to stat attachment %s: %w", p, err) } diff --git a/shortcuts/mail/mail_draft_edit.go b/shortcuts/mail/mail_draft_edit.go index 99061b8bc..44f0d5f0f 100644 --- a/shortcuts/mail/mail_draft_edit.go +++ b/shortcuts/mail/mail_draft_edit.go @@ -8,11 +8,11 @@ import ( "encoding/json" "fmt" "io" - "os" "strings" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" draftpkg "github.com/larksuite/cli/shortcuts/mail/draft" ) @@ -270,7 +270,7 @@ func loadPatchFile(path string) (draftpkg.Patch, error) { if err != nil { return patch, fmt.Errorf("--patch-file %q: %w", path, err) } - data, err := os.ReadFile(safePath) + data, err := vfs.ReadFile(safePath) if err != nil { return patch, err } diff --git a/shortcuts/mail/mail_watch.go b/shortcuts/mail/mail_watch.go index c56994270..d06cb47fc 100644 --- a/shortcuts/mail/mail_watch.go +++ b/shortcuts/mail/mail_watch.go @@ -25,6 +25,7 @@ import ( "github.com/larksuite/cli/internal/core" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" larkevent "github.com/larksuite/oapi-sdk-go/v3/event" @@ -179,7 +180,7 @@ var MailWatch = common.Shortcut{ outputDir := runtime.Str("output-dir") if outputDir != "" { if outputDir == "~" || strings.HasPrefix(outputDir, "~/") { - home, err := os.UserHomeDir() + home, err := vfs.UserHomeDir() if err != nil { return fmt.Errorf("cannot expand ~: %w", err) } @@ -200,7 +201,7 @@ var MailWatch = common.Shortcut{ // Resolve symlinks on the output directory so all writes use the real // filesystem path. This prevents a symlink from redirecting writes to // an unintended location (TOCTOU mitigation). - if err := os.MkdirAll(outputDir, 0700); err != nil { + if err := vfs.MkdirAll(outputDir, 0700); err != nil { return fmt.Errorf("cannot create output directory %q: %w", outputDir, err) } resolved, err := filepath.EvalSymlinks(outputDir) diff --git a/shortcuts/mail/mail_watch_test.go b/shortcuts/mail/mail_watch_test.go index 02476fbdf..4dcade091 100644 --- a/shortcuts/mail/mail_watch_test.go +++ b/shortcuts/mail/mail_watch_test.go @@ -96,8 +96,8 @@ func TestMailWatchDryRunDefaultMetadataFetchesMessage(t *testing.T) { if apis[0].URL != mailboxPath("me", "event", "subscribe") { t.Fatalf("unexpected url: %s", apis[0].URL) } - if apis[1].Method != "GET" || apis[1].URL != mailboxPath("me", "profile") { - t.Fatalf("unexpected profile api: %s %s", apis[1].Method, apis[1].URL) + if apis[1].URL != mailboxPath("me", "profile") { + t.Fatalf("unexpected profile url: %s", apis[1].URL) } if apis[2].URL != mailboxPath("me", "messages", "{message_id}") { t.Fatalf("unexpected fetch url: %s", apis[2].URL) diff --git a/shortcuts/minutes/minutes_download.go b/shortcuts/minutes/minutes_download.go index 9a8c55453..1c8a423f1 100644 --- a/shortcuts/minutes/minutes_download.go +++ b/shortcuts/minutes/minutes_download.go @@ -9,12 +9,13 @@ import ( "io" "mime" "net/http" - "os" "path/filepath" "regexp" "strings" "time" + "github.com/larksuite/cli/internal/vfs" + "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" "github.com/larksuite/cli/shortcuts/common" @@ -78,7 +79,7 @@ var MinutesDownload = common.Shortcut{ // Batch mode: --output must be a directory, not an existing file. if !single && outputPath != "" { - if fi, err := os.Stat(outputPath); err == nil && !fi.IsDir() { + if fi, err := vfs.Stat(outputPath); err == nil && !fi.IsDir() { return output.ErrValidation("--output %q is a file; batch mode expects a directory path", outputPath) } } @@ -281,7 +282,7 @@ func downloadMediaFile(ctx context.Context, client *http.Client, downloadURL, mi if err := common.EnsureWritableFile(safePath, opts.overwrite); err != nil { return nil, err } - if err := os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { + if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil { return nil, output.Errorf(output.ExitInternal, "api_error", "cannot create parent directory: %s", err) } diff --git a/shortcuts/minutes/minutes_download_test.go b/shortcuts/minutes/minutes_download_test.go index 5e6a4738b..8bf398dde 100644 --- a/shortcuts/minutes/minutes_download_test.go +++ b/shortcuts/minutes/minutes_download_test.go @@ -31,13 +31,6 @@ func warmTokenCache(t *testing.T) { t.Helper() warmOnce.Do(func() { f, _, _, reg := cmdutil.TestFactory(t, defaultConfig()) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token", "expire": 7200, - }, - }) reg.Register(&httpmock.Stub{ URL: "/open-apis/test/v1/warm", Body: map[string]interface{}{"code": 0, "msg": "ok", "data": map[string]interface{}{}}, diff --git a/shortcuts/sheets/sheet_export.go b/shortcuts/sheets/sheet_export.go index 70aadb170..2216be6e2 100644 --- a/shortcuts/sheets/sheet_export.go +++ b/shortcuts/sheets/sheet_export.go @@ -7,7 +7,6 @@ import ( "context" "fmt" "net/http" - "os" "path/filepath" "time" @@ -15,6 +14,7 @@ import ( "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" ) @@ -120,26 +120,31 @@ var SheetExport = common.Shortcut{ } // Download - apiResp, err := runtime.DoAPI(&larkcore.ApiReq{ + resp, err := runtime.DoAPIStream(ctx, &larkcore.ApiReq{ HttpMethod: http.MethodGet, ApiPath: fmt.Sprintf("/open-apis/drive/v1/export_tasks/file/%s/download", validate.EncodePathSegment(fileToken)), - }, larkcore.WithFileDownload()) + }) if err != nil { return output.ErrNetwork("download failed: %s", err) } + defer resp.Body.Close() safePath, pathErr := validate.SafeOutputPath(outputPath) if pathErr != nil { return output.ErrValidation("unsafe output path: %s", pathErr) } - os.MkdirAll(filepath.Dir(safePath), 0755) - if err := validate.AtomicWrite(safePath, apiResp.RawBody, 0644); err != nil { + if err := vfs.MkdirAll(filepath.Dir(safePath), 0700); err != nil { + return output.Errorf(output.ExitInternal, "api_error", "cannot create parent directory: %s", err) + } + + sizeBytes, err := validate.AtomicWriteFromReader(safePath, resp.Body, 0600) + if err != nil { return output.Errorf(output.ExitInternal, "api_error", "cannot create file: %s", err) } runtime.Out(map[string]interface{}{ "saved_path": safePath, - "size_bytes": len(apiResp.RawBody), + "size_bytes": sizeBytes, }, nil) return nil }, diff --git a/shortcuts/task/task_shortcut_test.go b/shortcuts/task/task_shortcut_test.go index f70cde485..4d8d7465d 100644 --- a/shortcuts/task/task_shortcut_test.go +++ b/shortcuts/task/task_shortcut_test.go @@ -31,15 +31,6 @@ func taskTestConfig(t *testing.T) *core.CliConfig { func warmTenantToken(t *testing.T, f *cmdutil.Factory, reg *httpmock.Registry) { t.Helper() - reg.Register(&httpmock.Stub{ - Method: "POST", - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token", - "expire": 7200, - }, - }) reg.Register(&httpmock.Stub{ Method: "GET", URL: "/open-apis/test/v1/warm", diff --git a/shortcuts/vc/vc_notes.go b/shortcuts/vc/vc_notes.go index 4aea4846b..ffd527102 100644 --- a/shortcuts/vc/vc_notes.go +++ b/shortcuts/vc/vc_notes.go @@ -16,7 +16,6 @@ import ( "fmt" "io" "net/http" - "os" "path/filepath" "strings" "time" @@ -24,8 +23,10 @@ import ( larkcore "github.com/larksuite/oapi-sdk-go/v3/core" "github.com/larksuite/cli/internal/auth" + "github.com/larksuite/cli/internal/credential" "github.com/larksuite/cli/internal/output" "github.com/larksuite/cli/internal/validate" + "github.com/larksuite/cli/internal/vfs" "github.com/larksuite/cli/shortcuts/common" ) @@ -253,7 +254,7 @@ func downloadTranscriptFile(runtime *common.RuntimeContext, minuteToken string, dirName := filepath.Join(base, sanitizeDirName(title, minuteToken)) if !runtime.Bool("overwrite") { transcriptPath := filepath.Join(dirName, "transcript.txt") - if _, statErr := os.Stat(transcriptPath); statErr == nil { + if _, statErr := vfs.Stat(transcriptPath); statErr == nil { fmt.Fprintf(errOut, "%s transcript already exists: %s (use --overwrite to replace)\n", logPrefix, transcriptPath) return transcriptPath } @@ -265,7 +266,7 @@ func downloadTranscriptFile(runtime *common.RuntimeContext, minuteToken string, fmt.Fprintf(errOut, "%s invalid transcript path: %v\n", logPrefix, err) return "" } - if err := os.MkdirAll(filepath.Dir(safePath), 0755); err != nil { + if err := vfs.MkdirAll(filepath.Dir(safePath), 0755); err != nil { fmt.Fprintf(errOut, "%s failed to create directory: %v\n", logPrefix, err) return "" } @@ -443,16 +444,12 @@ var VCNotes = common.Shortcut{ default: // unreachable: ExactlyOne already ensures one flag is set } - appID := runtime.Config.AppID - userOpenId := runtime.UserOpenId() - if appID != "" && userOpenId != "" { - stored := auth.GetStoredToken(appID, userOpenId) - if stored != nil { - if missing := auth.MissingScopes(stored.Scope, required); len(missing) > 0 { - return output.ErrWithHint(output.ExitAuth, "missing_scope", - fmt.Sprintf("missing required scope(s): %s", strings.Join(missing, ", ")), - fmt.Sprintf("run `lark-cli auth login --scope \"%s\"` in the background. It blocks and outputs a verification URL — retrieve the URL and open it in a browser to complete login.", strings.Join(missing, " "))) - } + result, err := runtime.Factory.Credential.ResolveToken(ctx, credential.NewTokenSpec(runtime.As(), runtime.Config.AppID)) + if err == nil && result != nil && result.Scopes != "" { + if missing := auth.MissingScopes(result.Scopes, required); len(missing) > 0 { + return output.ErrWithHint(output.ExitAuth, "missing_scope", + fmt.Sprintf("missing required scope(s): %s", strings.Join(missing, ", ")), + fmt.Sprintf("run `lark-cli auth login --scope \"%s\"` in the background. It blocks and outputs a verification URL — retrieve the URL and open it in a browser to complete login.", strings.Join(missing, " "))) } } return nil diff --git a/shortcuts/vc/vc_notes_test.go b/shortcuts/vc/vc_notes_test.go index e92ffe1ed..c202813e4 100644 --- a/shortcuts/vc/vc_notes_test.go +++ b/shortcuts/vc/vc_notes_test.go @@ -30,13 +30,6 @@ func warmTokenCache(t *testing.T) { t.Helper() warmOnce.Do(func() { f, _, _, reg := cmdutil.TestFactory(t, defaultConfig()) - reg.Register(&httpmock.Stub{ - URL: "/open-apis/auth/v3/tenant_access_token/internal", - Body: map[string]interface{}{ - "code": 0, "msg": "ok", - "tenant_access_token": "t-test-token", "expire": 7200, - }, - }) reg.Register(&httpmock.Stub{ URL: "/open-apis/test/v1/warm", Body: map[string]interface{}{"code": 0, "msg": "ok", "data": map[string]interface{}{}},