diff --git a/cmd/auth/login.go b/cmd/auth/login.go index ebd744f28..092e5020a 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "fmt" + "io" "sort" "strings" "time" @@ -90,7 +91,7 @@ func completeDomain(toComplete string) []string { return completions } -// authLoginRun executes the login command logic. +// authLoginRun executes the interactive or direct auth login flow. func authLoginRun(opts *LoginOptions) error { f := opts.Factory @@ -299,9 +300,11 @@ func authLoginRun(opts *LoginOptions) error { Scope: result.Token.Scope, GrantedAt: now, } - if err := larkauth.SetStoredToken(storedToken); err != nil { + usedFallback, err := larkauth.SetStoredToken(storedToken) + if err != nil { return output.Errorf(output.ExitInternal, "internal", "failed to save token: %v", err) } + warnIfEncryptedTokenFallback(f.IOStreams.ErrOut, usedFallback) // Step 8: Update config — overwrite Users to single user, clean old tokens multi, _ := core.LoadMultiAppConfig() @@ -379,9 +382,11 @@ func authLoginPollDeviceCode(opts *LoginOptions, config *core.CliConfig, msg *lo Scope: result.Token.Scope, GrantedAt: now, } - if err := larkauth.SetStoredToken(storedToken); err != nil { + usedFallback, err := larkauth.SetStoredToken(storedToken) + if err != nil { return output.Errorf(output.ExitInternal, "internal", "failed to save token: %v", err) } + warnIfEncryptedTokenFallback(f.IOStreams.ErrOut, usedFallback) // Update config — overwrite Users to single user, clean old tokens multi, _ := core.LoadMultiAppConfig() @@ -402,6 +407,14 @@ func authLoginPollDeviceCode(opts *LoginOptions, config *core.CliConfig, msg *lo return nil } +// warnIfEncryptedTokenFallback explains when auth token persistence downgraded to the encrypted file fallback. +func warnIfEncryptedTokenFallback(w io.Writer, usedFallback bool) { + if !usedFallback { + return + } + fmt.Fprintln(w, "warning: keychain unavailable, auth token stored in a local file protected by filesystem permissions (0600) managed by lark-cli") +} + // collectScopesForDomains collects API scopes (from from_meta projects) and // shortcut scopes for the given domain names. func collectScopesForDomains(domains []string, identity string) []string { diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index 06f9717e8..c08ea28fc 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -4,6 +4,7 @@ package auth import ( + "bytes" "context" "sort" "strings" @@ -15,6 +16,7 @@ import ( "github.com/larksuite/cli/shortcuts/common" ) +// TestSuggestDomain_PrefixMatch verifies prefix matching returns the expected shortcut domain. func TestSuggestDomain_PrefixMatch(t *testing.T) { known := map[string]bool{ "calendar": true, @@ -34,6 +36,7 @@ func TestSuggestDomain_PrefixMatch(t *testing.T) { } } +// TestSuggestDomain_NoMatch verifies unknown prefixes do not produce suggestions. func TestSuggestDomain_NoMatch(t *testing.T) { known := map[string]bool{ "calendar": true, @@ -45,6 +48,7 @@ func TestSuggestDomain_NoMatch(t *testing.T) { } } +// TestSuggestDomain_ExactMatch verifies exact matches resolve without modification. func TestSuggestDomain_ExactMatch(t *testing.T) { known := map[string]bool{ "calendar": true, @@ -56,6 +60,7 @@ func TestSuggestDomain_ExactMatch(t *testing.T) { } } +// TestShortcutSupportsIdentity_DefaultUser verifies shortcuts default to user identity when unspecified. func TestShortcutSupportsIdentity_DefaultUser(t *testing.T) { // Empty AuthTypes defaults to ["user"] sc := common.Shortcut{AuthTypes: nil} @@ -67,6 +72,7 @@ func TestShortcutSupportsIdentity_DefaultUser(t *testing.T) { } } +// TestShortcutSupportsIdentity_ExplicitTypes verifies shortcuts honor explicit identity declarations. func TestShortcutSupportsIdentity_ExplicitTypes(t *testing.T) { sc := common.Shortcut{AuthTypes: []string{"user", "bot"}} if !shortcutSupportsIdentity(sc, "user") { @@ -80,6 +86,7 @@ func TestShortcutSupportsIdentity_ExplicitTypes(t *testing.T) { } } +// TestShortcutSupportsIdentity_BotOnly verifies user identity is rejected for bot-only shortcuts. func TestShortcutSupportsIdentity_BotOnly(t *testing.T) { sc := common.Shortcut{AuthTypes: []string{"bot"}} if shortcutSupportsIdentity(sc, "user") { @@ -90,6 +97,7 @@ func TestShortcutSupportsIdentity_BotOnly(t *testing.T) { } } +// TestCompleteDomain verifies shell completion returns matching domain candidates. func TestCompleteDomain(t *testing.T) { projects := registry.ListFromMetaProjects() if len(projects) == 0 { @@ -115,6 +123,7 @@ func TestCompleteDomain(t *testing.T) { } } +// TestCompleteDomain_CommaSeparated verifies completion uses the last comma-separated domain fragment. func TestCompleteDomain_CommaSeparated(t *testing.T) { projects := registry.ListFromMetaProjects() if len(projects) == 0 { @@ -130,6 +139,7 @@ func TestCompleteDomain_CommaSeparated(t *testing.T) { } } +// TestAllKnownDomains verifies the known-domain list includes registry and shortcut-only domains. func TestAllKnownDomains(t *testing.T) { domains := allKnownDomains() if len(domains) == 0 { @@ -144,6 +154,7 @@ func TestAllKnownDomains(t *testing.T) { } } +// TestSortedKnownDomains verifies known domains are returned in sorted order. func TestSortedKnownDomains(t *testing.T) { sorted := sortedKnownDomains() if len(sorted) == 0 { @@ -161,6 +172,7 @@ func TestSortedKnownDomains(t *testing.T) { } } +// TestGetShortcutOnlyDomainNames_HaveDescriptions verifies shortcut-only domains retain descriptions. func TestGetShortcutOnlyDomainNames_HaveDescriptions(t *testing.T) { for _, name := range getShortcutOnlyDomainNames() { zhDesc := registry.GetServiceDescription(name, "zh") @@ -174,6 +186,7 @@ func TestGetShortcutOnlyDomainNames_HaveDescriptions(t *testing.T) { } } +// TestCollectScopesForDomains verifies domain selection expands to the expected scopes. func TestCollectScopesForDomains(t *testing.T) { projects := registry.ListFromMetaProjects() if len(projects) == 0 { @@ -206,6 +219,7 @@ func TestCollectScopesForDomains(t *testing.T) { } } +// TestCollectScopesForDomains_NonexistentDomain verifies unknown domains are ignored safely. func TestCollectScopesForDomains_NonexistentDomain(t *testing.T) { scopes := collectScopesForDomains([]string{"nonexistent_domain_xyz"}, "user") if len(scopes) != 0 { @@ -213,6 +227,7 @@ func TestCollectScopesForDomains_NonexistentDomain(t *testing.T) { } } +// TestGetDomainMetadata_IncludesFromMeta verifies registry-backed domains appear in metadata output. func TestGetDomainMetadata_IncludesFromMeta(t *testing.T) { domains := getDomainMetadata("zh") nameSet := make(map[string]bool) @@ -228,6 +243,7 @@ func TestGetDomainMetadata_IncludesFromMeta(t *testing.T) { } } +// TestGetDomainMetadata_IncludesShortcutOnlyDomains verifies shortcut-only domains appear in metadata output. func TestGetDomainMetadata_IncludesShortcutOnlyDomains(t *testing.T) { domains := getDomainMetadata("zh") nameSet := make(map[string]bool) @@ -242,6 +258,7 @@ func TestGetDomainMetadata_IncludesShortcutOnlyDomains(t *testing.T) { } } +// TestGetDomainMetadata_Sorted verifies domain metadata is sorted predictably. func TestGetDomainMetadata_Sorted(t *testing.T) { domains := getDomainMetadata("zh") for i := 1; i < len(domains); i++ { @@ -251,6 +268,7 @@ func TestGetDomainMetadata_Sorted(t *testing.T) { } } +// TestGetDomainMetadata_HasTitleAndDescription verifies metadata entries are populated with user-facing fields. func TestGetDomainMetadata_HasTitleAndDescription(t *testing.T) { domains := getDomainMetadata("zh") for _, dm := range domains { @@ -260,6 +278,19 @@ func TestGetDomainMetadata_HasTitleAndDescription(t *testing.T) { } } +// TestWarnIfEncryptedTokenFallback verifies the fallback warning explains local-file protection. +func TestWarnIfEncryptedTokenFallback(t *testing.T) { + var stderr bytes.Buffer + + warnIfEncryptedTokenFallback(&stderr, true) + + got := stderr.String() + if !strings.Contains(got, "filesystem permissions") { + t.Fatalf("expected warning to explain filesystem-permission protection, got %q", got) + } +} + +// TestAuthLoginRun_NonTerminal_NoFlags_RejectsWithHint verifies login rejects non-interactive runs without the required flags. func TestAuthLoginRun_NonTerminal_NoFlags_RejectsWithHint(t *testing.T) { f, _, stderr, _ := cmdutil.TestFactory(t, &core.CliConfig{ AppID: "cli_test", AppSecret: "secret", Brand: core.BrandFeishu, @@ -282,6 +313,7 @@ func TestAuthLoginRun_NonTerminal_NoFlags_RejectsWithHint(t *testing.T) { } } +// TestGetDomainMetadata_ExcludesEvent verifies event is hidden from the interactive login domain list. func TestGetDomainMetadata_ExcludesEvent(t *testing.T) { domains := getDomainMetadata("zh") for _, dm := range domains { diff --git a/cmd/config/init.go b/cmd/config/init.go index 8ddff7613..e216ede74 100644 --- a/cmd/config/init.go +++ b/cmd/config/init.go @@ -94,6 +94,69 @@ func saveAsOnlyApp(appId string, secret core.SecretInput, brand core.LarkBrand, return core.SaveMultiAppConfig(config) } +// warnIfEncryptedSecretFallback explains when app-secret persistence downgraded to the encrypted file fallback. +func warnIfEncryptedSecretFallback(w io.Writer, secret core.SecretInput, original core.SecretInput) { + if !original.IsPlain() || !secret.IsSecretRef() || secret.Ref.Source != "encrypted_file" { + return + } + fmt.Fprintln(w, "warning: keychain unavailable, app secret stored in a local file protected by filesystem permissions (0600) managed by lark-cli") +} + +// validateSecretReuse rejects managed secret references that no longer match the selected app ID. +func validateSecretReuse(appID string, secret core.SecretInput) error { + if !secret.IsSecretRef() { + return nil + } + if secret.Ref.Source == "file" { + return nil + } + expectedID := "appsecret:" + appID + if secret.Ref.ID == expectedID { + return nil + } + return output.ErrValidation("App Secret must be re-entered when App ID changes") +} + +// cleanupReplacedCurrentAppSecret removes the previous current-app secret when the backend changes in place. +func cleanupReplacedCurrentAppSecret(existing *core.MultiAppConfig, f *cmdutil.Factory, appID string, newSecret core.SecretInput) { + if existing == nil || len(existing.Apps) == 0 { + return + } + current := existing.Apps[0] + if current.AppId != appID || !current.AppSecret.IsSecretRef() || !newSecret.IsSecretRef() { + return + } + if current.AppSecret.Ref.Source == newSecret.Ref.Source && current.AppSecret.Ref.ID == newSecret.Ref.ID { + return + } + core.RemoveSecretStore(current.AppSecret, f.Keychain) +} + +// storeAndSaveOnlyApp persists a single-app config while keeping secret storage, rollback, and cleanup in sync. +func storeAndSaveOnlyApp(existing *core.MultiAppConfig, f *cmdutil.Factory, appID string, plainSecret core.SecretInput, brand core.LarkBrand, lang string) error { + // Keep the secret persistence pipeline centralized here. + // New config-init branches should call this helper instead of reimplementing + // store -> save -> cleanup -> fallback-warning sequencing inline. + if err := validateSecretReuse(appID, plainSecret); err != nil { + return err + } + secret, err := core.ForStorageWithEncryptedFallback(appID, plainSecret, f.Keychain) + if err != nil { + return output.Errorf(output.ExitInternal, "internal", "%v", err) + } + if err := saveAsOnlyApp(appID, secret, brand, lang); err != nil { + if plainSecret.IsPlain() { + core.RemoveSecretStore(secret, f.Keychain) + } + return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) + } + cleanupReplacedCurrentAppSecret(existing, f, appID, secret) + cleanupOldConfig(existing, f, appID) + warnIfEncryptedSecretFallback(f.IOStreams.ErrOut, secret, plainSecret) + return nil +} + +// configInitRun initializes or updates the local CLI app configuration. func configInitRun(opts *ConfigInitOptions) error { f := opts.Factory @@ -120,13 +183,9 @@ func configInitRun(opts *ConfigInitOptions) error { // Mode 1: Non-interactive if opts.AppID != "" && opts.appSecret != "" { brand := parseBrand(opts.Brand) - secret, err := core.ForStorage(opts.AppID, core.PlainSecret(opts.appSecret), f.Keychain) - if err != nil { - return output.Errorf(output.ExitInternal, "internal", "%v", err) - } - cleanupOldConfig(existing, f, opts.AppID) - if err := saveAsOnlyApp(opts.AppID, secret, brand, opts.Lang); err != nil { - return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) + plainSecret := core.PlainSecret(opts.appSecret) + if err := storeAndSaveOnlyApp(existing, f, opts.AppID, plainSecret, brand, opts.Lang); err != nil { + return err } output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Configuration saved to %s", core.GetConfigPath())) output.PrintJson(f.IOStreams.Out, map[string]interface{}{"appId": opts.AppID, "appSecret": "****", "brand": brand}) @@ -161,13 +220,9 @@ func configInitRun(opts *ConfigInitOptions) error { return output.ErrValidation("app creation returned no result") } existing, _ := core.LoadMultiAppConfig() - secret, err := core.ForStorage(result.AppID, core.PlainSecret(result.AppSecret), f.Keychain) - if err != nil { - return output.Errorf(output.ExitInternal, "internal", "%v", err) - } - cleanupOldConfig(existing, f, result.AppID) - if err := saveAsOnlyApp(result.AppID, secret, result.Brand, opts.Lang); err != nil { - return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) + plainSecret := core.PlainSecret(result.AppSecret) + if err := storeAndSaveOnlyApp(existing, f, result.AppID, plainSecret, result.Brand, opts.Lang); err != nil { + return err } output.PrintJson(f.IOStreams.Out, map[string]interface{}{"appId": result.AppID, "appSecret": "****", "brand": result.Brand}) return nil @@ -187,17 +242,16 @@ func configInitRun(opts *ConfigInitOptions) error { if result.AppSecret != "" { // New secret provided (either from "create" or "existing" with input) - secret, err := core.ForStorage(result.AppID, core.PlainSecret(result.AppSecret), f.Keychain) - if err != nil { - return output.Errorf(output.ExitInternal, "internal", "%v", err) - } - cleanupOldConfig(existing, f, result.AppID) - if err := saveAsOnlyApp(result.AppID, secret, result.Brand, opts.Lang); err != nil { - return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) + plainSecret := core.PlainSecret(result.AppSecret) + if err := storeAndSaveOnlyApp(existing, f, result.AppID, plainSecret, result.Brand, opts.Lang); err != nil { + return err } } else if result.Mode == "existing" && result.AppID != "" { // Existing app with unchanged secret — update app ID and brand only if existing != nil && len(existing.Apps) > 0 { + if err := validateSecretReuse(result.AppID, existing.Apps[0].AppSecret); err != nil { + return err + } existing.Apps[0].AppId = result.AppID existing.Apps[0].Brand = result.Brand existing.Apps[0].Lang = opts.Lang @@ -292,13 +346,8 @@ func configInitRun(opts *ConfigInitOptions) error { return output.ErrValidation("App ID and App Secret cannot be empty") } - storedSecret, err := core.ForStorage(resolvedAppId, resolvedSecret, f.Keychain) - if err != nil { - return output.Errorf(output.ExitInternal, "internal", "%v", err) - } - cleanupOldConfig(existing, f, resolvedAppId) - if err := saveAsOnlyApp(resolvedAppId, storedSecret, parseBrand(resolvedBrand), opts.Lang); err != nil { - return output.Errorf(output.ExitInternal, "internal", "failed to save config: %v", err) + if err := storeAndSaveOnlyApp(existing, f, resolvedAppId, resolvedSecret, parseBrand(resolvedBrand), opts.Lang); err != nil { + return err } output.PrintSuccess(f.IOStreams.ErrOut, fmt.Sprintf("Configuration saved to %s", core.GetConfigPath())) return nil diff --git a/cmd/config/config_test.go b/cmd/config/init_command_test.go similarity index 84% rename from cmd/config/config_test.go rename to cmd/config/init_command_test.go index 65642781f..a3fc9ef1f 100644 --- a/cmd/config/config_test.go +++ b/cmd/config/init_command_test.go @@ -3,6 +3,9 @@ package config +// Command/flag/wiring tests for config init live here. +// New storage/fallback/rollback behavior tests should go to init_storage_test.go. + import ( "context" "strings" @@ -12,6 +15,7 @@ import ( "github.com/larksuite/cli/internal/core" ) +// TestConfigInitCmd_FlagParsing verifies config init binds the expected flags. func TestConfigInitCmd_FlagParsing(t *testing.T) { f, _, _, _ := cmdutil.TestFactory(t, nil) f.IOStreams.In = strings.NewReader("secret123\n") @@ -37,6 +41,7 @@ func TestConfigInitCmd_FlagParsing(t *testing.T) { } } +// TestConfigShowCmd_FlagParsing verifies config show exposes its expected flags. func TestConfigShowCmd_FlagParsing(t *testing.T) { f, _, _, _ := cmdutil.TestFactory(t, &core.CliConfig{ AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu, @@ -56,6 +61,7 @@ func TestConfigShowCmd_FlagParsing(t *testing.T) { } } +// TestConfigInitCmd_LangFlag verifies the lang flag is wired through config init. func TestConfigInitCmd_LangFlag(t *testing.T) { f, _, _, _ := cmdutil.TestFactory(t, nil) @@ -77,6 +83,7 @@ func TestConfigInitCmd_LangFlag(t *testing.T) { } } +// TestConfigInitCmd_LangDefault verifies config init keeps the default language when unspecified. func TestConfigInitCmd_LangDefault(t *testing.T) { f, _, _, _ := cmdutil.TestFactory(t, nil) @@ -98,6 +105,7 @@ func TestConfigInitCmd_LangDefault(t *testing.T) { } } +// TestHasAnyNonInteractiveFlag verifies non-interactive detection across supported flags. func TestHasAnyNonInteractiveFlag(t *testing.T) { tests := []struct { name string @@ -121,9 +129,9 @@ func TestHasAnyNonInteractiveFlag(t *testing.T) { } } +// TestConfigInitRun_NonTerminal_NoFlags_RejectsWithHint verifies non-interactive invocation fails with guidance. func TestConfigInitRun_NonTerminal_NoFlags_RejectsWithHint(t *testing.T) { f, _, _, _ := cmdutil.TestFactory(t, nil) - // TestFactory has IsTerminal=false by default opts := &ConfigInitOptions{Factory: f, Ctx: context.Background(), Lang: "zh"} err := configInitRun(opts) if err == nil { @@ -138,6 +146,7 @@ func TestConfigInitRun_NonTerminal_NoFlags_RejectsWithHint(t *testing.T) { } } +// TestConfigRemoveCmd_FlagParsing verifies config remove exposes its expected flags. func TestConfigRemoveCmd_FlagParsing(t *testing.T) { f, _, _, _ := cmdutil.TestFactory(t, nil) diff --git a/cmd/config/init_storage_test.go b/cmd/config/init_storage_test.go new file mode 100644 index 000000000..17c6a10cb --- /dev/null +++ b/cmd/config/init_storage_test.go @@ -0,0 +1,301 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package config + +// Storage/fallback/rollback behavior tests for config init live here. +// New command/flag/wiring tests should go to init_command_test.go. + +import ( + "context" + "errors" + "os" + "strings" + "testing" + + "github.com/larksuite/cli/internal/cmdutil" + "github.com/larksuite/cli/internal/core" + "github.com/larksuite/cli/internal/keychain" +) + +// unavailableSetKeychain simulates a keychain that only fails on writes with an unavailable error. +type unavailableSetKeychain struct{} + +// Get satisfies the keychain interface for read paths used by these tests. +func (f *unavailableSetKeychain) Get(service, account string) (string, error) { return "", nil } + +// Set satisfies the keychain interface and forces encrypted fallback in storage tests. +func (f *unavailableSetKeychain) Set(service, account, value string) error { + return keychain.WrapUnavailable(errors.New("sandbox denied")) +} + +// Remove satisfies the keychain interface for cleanup paths used by these tests. +func (f *unavailableSetKeychain) Remove(service, account string) error { return nil } + +// trackingKeychain records remove calls and optionally injects write behavior for config-init tests. +type trackingKeychain struct { + setFunc func(service, account, value string) error + removeCalls []string +} + +// Get satisfies the keychain interface for read paths used by these tests. +func (t *trackingKeychain) Get(service, account string) (string, error) { return "", nil } + +// Set satisfies the keychain interface and delegates to the configured test hook. +func (t *trackingKeychain) Set(service, account, value string) error { + if t.setFunc != nil { + return t.setFunc(service, account, value) + } + return nil +} + +// Remove satisfies the keychain interface while recording the removed account IDs. +func (t *trackingKeychain) Remove(service, account string) error { + t.removeCalls = append(t.removeCalls, account) + return nil +} + +// TestConfigInitRun_FallsBackToEncryptedSecretWhenKeychainUnavailable verifies config init persists app secrets via encrypted fallback on keychain failure. +func TestConfigInitRun_FallsBackToEncryptedSecretWhenKeychainUnavailable(t *testing.T) { + configDir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", configDir) + + f, _, stderr, _ := cmdutil.TestFactory(t, nil) + f.Keychain = &unavailableSetKeychain{} + + opts := &ConfigInitOptions{ + Factory: f, + Ctx: context.Background(), + AppID: "cli_test", + appSecret: "secret123", + Brand: "feishu", + Lang: "zh", + } + + if err := configInitRun(opts); err != nil { + t.Fatalf("configInitRun returned error: %v", err) + } + + cfg, err := core.LoadMultiAppConfig() + if err != nil { + t.Fatalf("LoadMultiAppConfig: %v", err) + } + if len(cfg.Apps) != 1 { + t.Fatalf("expected 1 app, got %d", len(cfg.Apps)) + } + ref := cfg.Apps[0].AppSecret.Ref + if ref == nil { + t.Fatal("expected app secret to be stored as an encrypted fallback reference") + } + if ref.Source != "encrypted_file" { + t.Fatalf("expected encrypted_file secret, got %q", ref.Source) + } + resolved, err := core.ResolveSecretInput(cfg.Apps[0].AppSecret, f.Keychain) + if err != nil { + t.Fatalf("ResolveSecretInput: %v", err) + } + if resolved != "secret123" { + t.Fatalf("resolved secret = %q, want %q", resolved, "secret123") + } + if got := stderr.String(); got == "" || !strings.Contains(got, "filesystem permissions") { + t.Fatalf("expected fallback warning in stderr, got %q", got) + } +} + +// TestConfigRemoveRun_RemovesEncryptedFallbackSecret verifies config removal cleans up encrypted fallback secrets. +func TestConfigRemoveRun_RemovesEncryptedFallbackSecret(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + if err := keychain.SetFallback(keychain.LarkCliService, "appsecret:cli_test", "secret123"); err != nil { + t.Fatalf("SetFallback: %v", err) + } + + config := &core.MultiAppConfig{ + Apps: []core.AppConfig{{ + AppId: "cli_test", + AppSecret: core.SecretInput{ + Ref: &core.SecretRef{Source: "encrypted_file", ID: "appsecret:cli_test"}, + }, + Brand: core.BrandFeishu, + Users: []core.AppUser{}, + }}, + } + if err := core.SaveMultiAppConfig(config); err != nil { + t.Fatalf("SaveMultiAppConfig: %v", err) + } + + f, _, _, _ := cmdutil.TestFactory(t, nil) + opts := &ConfigRemoveOptions{Factory: f} + + if err := configRemoveRun(opts); err != nil { + t.Fatalf("configRemoveRun returned error: %v", err) + } + if got := keychain.GetFallback(keychain.LarkCliService, "appsecret:cli_test"); got != "" { + t.Fatalf("expected encrypted fallback secret to be removed, got %q", got) + } +} + +// TestConfigInitRun_SaveFailureDoesNotCleanupExistingSecrets verifies save failures do not delete the currently active secret. +func TestConfigInitRun_SaveFailureDoesNotCleanupExistingSecrets(t *testing.T) { + configDir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", configDir) + + existing := &core.MultiAppConfig{ + Apps: []core.AppConfig{{ + AppId: "old-app", + AppSecret: core.SecretInput{Ref: &core.SecretRef{Source: "keychain", ID: "appsecret:old-app"}}, + Brand: core.BrandFeishu, + Users: []core.AppUser{}, + }}, + } + if err := core.SaveMultiAppConfig(existing); err != nil { + t.Fatalf("SaveMultiAppConfig: %v", err) + } + + kc := &trackingKeychain{ + setFunc: func(service, account, value string) error { + return os.Chmod(configDir, 0500) + }, + } + t.Cleanup(func() { _ = os.Chmod(configDir, 0700) }) + + f, _, _, _ := cmdutil.TestFactory(t, nil) + f.Keychain = kc + + opts := &ConfigInitOptions{ + Factory: f, + Ctx: context.Background(), + AppID: "new-app", + appSecret: "secret123", + Brand: "feishu", + Lang: "zh", + } + + err := configInitRun(opts) + if err == nil { + t.Fatal("expected configInitRun to fail when config save fails") + } + + if len(kc.removeCalls) != 1 || kc.removeCalls[0] != "appsecret:new-app" { + t.Fatalf("expected only newly stored secret to be rolled back, got remove calls %v", kc.removeCalls) + } + + cfg, loadErr := core.LoadMultiAppConfig() + if loadErr != nil { + t.Fatalf("LoadMultiAppConfig: %v", loadErr) + } + if len(cfg.Apps) != 1 || cfg.Apps[0].AppId != "old-app" { + t.Fatalf("expected existing config to stay unchanged, got %#v", cfg.Apps) + } +} + +// TestStoreAndSaveOnlyApp_RejectsSecretRefReuseAcrossAppIDChange verifies managed secret refs cannot be reused across app IDs. +func TestStoreAndSaveOnlyApp_RejectsSecretRefReuseAcrossAppIDChange(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + existing := &core.MultiAppConfig{ + Apps: []core.AppConfig{{ + AppId: "old-app", + AppSecret: core.SecretInput{Ref: &core.SecretRef{Source: "keychain", ID: "appsecret:old-app"}}, + Brand: core.BrandFeishu, + Lang: "zh", + Users: []core.AppUser{{ + UserOpenId: "ou_old_user", + UserName: "old user", + }}, + }}, + } + if err := core.SaveMultiAppConfig(existing); err != nil { + t.Fatalf("SaveMultiAppConfig: %v", err) + } + + kc := &trackingKeychain{} + f, _, _, _ := cmdutil.TestFactory(t, nil) + f.Keychain = kc + + err := storeAndSaveOnlyApp(existing, f, "new-app", existing.Apps[0].AppSecret, core.BrandFeishu, "zh") + if err == nil { + t.Fatal("expected reusing a secret ref with a different app id to fail") + } + + if len(kc.removeCalls) != 0 { + t.Fatalf("expected no secret cleanup on rejected app id change, got %v", kc.removeCalls) + } + + cfg, loadErr := core.LoadMultiAppConfig() + if loadErr != nil { + t.Fatalf("LoadMultiAppConfig: %v", loadErr) + } + if len(cfg.Apps) != 1 || cfg.Apps[0].AppId != "old-app" { + t.Fatalf("expected config to stay unchanged, got %#v", cfg.Apps) + } +} + +// TestValidateSecretReuse_RequiresNewSecretWhenAppIDChanges verifies app ID changes force a fresh managed secret. +func TestValidateSecretReuse_RequiresNewSecretWhenAppIDChanges(t *testing.T) { + err := validateSecretReuse("new-app", core.SecretInput{ + Ref: &core.SecretRef{Source: "keychain", ID: "appsecret:old-app"}, + }) + if err == nil { + t.Fatal("expected app id change with existing secret ref to be rejected") + } + + if err := validateSecretReuse("old-app", core.SecretInput{ + Ref: &core.SecretRef{Source: "keychain", ID: "appsecret:old-app"}, + }); err != nil { + t.Fatalf("expected same-app secret ref reuse to remain allowed, got %v", err) + } +} + +// TestValidateSecretReuse_AllowsFileSecretRefAcrossAppIDChange verifies external file refs remain reusable across app IDs. +func TestValidateSecretReuse_AllowsFileSecretRefAcrossAppIDChange(t *testing.T) { + err := validateSecretReuse("new-app", core.SecretInput{ + Ref: &core.SecretRef{Source: "file", ID: "/tmp/app-secret.txt"}, + }) + if err != nil { + t.Fatalf("expected file-based secret ref reuse to remain allowed, got %v", err) + } +} + +// TestStoreAndSaveOnlyApp_SameAppUpgradeToKeychainRemovesOldFallbackSecret verifies same-app backend upgrades clean up replaced fallback secrets. +func TestStoreAndSaveOnlyApp_SameAppUpgradeToKeychainRemovesOldFallbackSecret(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + if err := keychain.SetFallback(keychain.LarkCliService, "appsecret:cli_test", "old-secret"); err != nil { + t.Fatalf("SetFallback: %v", err) + } + + existing := &core.MultiAppConfig{ + Apps: []core.AppConfig{{ + AppId: "cli_test", + AppSecret: core.SecretInput{ + Ref: &core.SecretRef{Source: "encrypted_file", ID: "appsecret:cli_test"}, + }, + Brand: core.BrandFeishu, + Lang: "zh", + Users: []core.AppUser{}, + }}, + } + if err := core.SaveMultiAppConfig(existing); err != nil { + t.Fatalf("SaveMultiAppConfig: %v", err) + } + + kc := &trackingKeychain{} + f, _, _, _ := cmdutil.TestFactory(t, nil) + f.Keychain = kc + + if err := storeAndSaveOnlyApp(existing, f, "cli_test", core.PlainSecret("new-secret"), core.BrandFeishu, "zh"); err != nil { + t.Fatalf("storeAndSaveOnlyApp: %v", err) + } + + cfg, err := core.LoadMultiAppConfig() + if err != nil { + t.Fatalf("LoadMultiAppConfig: %v", err) + } + if got := cfg.Apps[0].AppSecret.Ref; got == nil || got.Source != "keychain" { + t.Fatalf("expected config to switch to keychain ref, got %#v", got) + } + if fallback, err := keychain.GetFallbackWithError(keychain.LarkCliService, "appsecret:cli_test"); err == nil && fallback != "" { + t.Fatalf("expected old fallback secret to be removed, got %q", fallback) + } +} diff --git a/cmd/doctor/doctor.go b/cmd/doctor/doctor.go index 67d959894..eedbbcef8 100644 --- a/cmd/doctor/doctor.go +++ b/cmd/doctor/doctor.go @@ -49,27 +49,32 @@ func NewCmdDoctor(f *cmdutil.Factory) *cobra.Command { // checkResult represents one diagnostic check. type checkResult struct { Name string `json:"name"` - Status string `json:"status"` // "pass", "fail", "skip" + Status string `json:"status"` // "pass", "fail", "warn", "skip" Message string `json:"message"` Hint string `json:"hint,omitempty"` } +// pass builds a successful doctor check result. func pass(name, msg string) checkResult { return checkResult{Name: name, Status: "pass", Message: msg} } +// fail builds a failed doctor check result. func fail(name, msg, hint string) checkResult { return checkResult{Name: name, Status: "fail", Message: msg, Hint: hint} } +// warn builds a warning doctor check result. func warn(name, msg, hint string) checkResult { return checkResult{Name: name, Status: "warn", Message: msg, Hint: hint} } +// skip builds a skipped doctor check result. func skip(name, msg string) checkResult { return checkResult{Name: name, Status: "skip", Message: msg} } +// doctorRun executes local configuration, credential, and network diagnostics. func doctorRun(opts *DoctorOptions) error { f := opts.Factory var checks []checkResult @@ -111,7 +116,7 @@ func doctorRun(opts *DoctorOptions) error { } stored := larkauth.GetStoredToken(cfg.AppID, cfg.UserOpenId) if stored == nil { - checks = append(checks, fail("token_exists", "no token in keychain for "+cfg.UserOpenId, "run: lark-cli auth login --help")) + checks = append(checks, fail("token_exists", "no token in local credential store for "+cfg.UserOpenId, "run: lark-cli auth login --help")) checks = append(checks, networkChecks(opts.Ctx, opts, ep)...) return finishDoctor(f, checks) } @@ -243,6 +248,7 @@ func checkCLIUpdate() []checkResult { return []checkResult{pass("cli_update", latest+" (up to date)")} } +// finishDoctor renders the aggregated doctor result and maps failures to a non-zero exit status. func finishDoctor(f *cmdutil.Factory, checks []checkResult) error { allOK := true for _, c := range checks { diff --git a/internal/auth/token_store.go b/internal/auth/token_store.go index 80883a649..2da48e0b1 100644 --- a/internal/auth/token_store.go +++ b/internal/auth/token_store.go @@ -5,9 +5,13 @@ package auth import ( "encoding/json" + "errors" "fmt" + "os" + "path/filepath" "time" + "github.com/larksuite/cli/internal/configdir" "github.com/larksuite/cli/internal/keychain" ) @@ -25,10 +29,38 @@ type StoredUAToken struct { const refreshAheadMs = 5 * 60 * 1000 // 5 minutes +var tokenKeychain = keychain.Default() + +// accountKey builds the logical credential key for a user token. func accountKey(appId, userOpenId string) string { return fmt.Sprintf("%s:%s", appId, userOpenId) } +// tokenConfigDir returns the config directory used by token-store compatibility paths. +func tokenConfigDir() string { + // Keep config dir resolution centralized in internal/configdir. + // New code should reuse configdir.Get() instead of duplicating env/home logic. + return configdir.Get() +} + +// legacyManagedTokenFilePath returns the old plaintext managed-token file path kept for migration reads and cleanup. +func legacyManagedTokenFilePath(appId, userOpenId string) string { + return filepath.Join(tokenConfigDir(), "tokens", sanitizeID(accountKey(appId, userOpenId))+".json") +} + +// readLegacyManagedToken loads a token from the legacy plaintext fallback file if it still exists. +func readLegacyManagedToken(appId, userOpenId string) *StoredUAToken { + data, err := os.ReadFile(legacyManagedTokenFilePath(appId, userOpenId)) + if err != nil { + return nil + } + var token StoredUAToken + if err := json.Unmarshal(data, &token); err != nil { + return nil + } + return &token +} + // MaskToken masks a token for safe logging. func MaskToken(token string) string { if len(token) <= 8 { @@ -39,9 +71,19 @@ func MaskToken(token string) string { // GetStoredToken reads the stored UAT for a given (appId, userOpenId) pair. func GetStoredToken(appId, userOpenId string) *StoredUAToken { - jsonStr := keychain.Get(keychain.LarkCliService, accountKey(appId, userOpenId)) - if jsonStr == "" { - return nil + jsonStr, err := tokenKeychain.Get(keychain.LarkCliService, accountKey(appId, userOpenId)) + if err == nil && jsonStr != "" { + var token StoredUAToken + if err := json.Unmarshal([]byte(jsonStr), &token); err == nil { + return &token + } + } + jsonStr, err = keychain.GetFallbackWithError(keychain.LarkCliService, accountKey(appId, userOpenId)) + if err != nil { + if !os.IsNotExist(err) { + return nil + } + return readLegacyManagedToken(appId, userOpenId) } var token StoredUAToken if err := json.Unmarshal([]byte(jsonStr), &token); err != nil { @@ -51,18 +93,40 @@ func GetStoredToken(appId, userOpenId string) *StoredUAToken { } // SetStoredToken persists a UAT. -func SetStoredToken(token *StoredUAToken) error { +func SetStoredToken(token *StoredUAToken) (bool, error) { key := accountKey(token.AppId, token.UserOpenId) data, err := json.Marshal(token) if err != nil { - return err + return false, err } - return keychain.Set(keychain.LarkCliService, key, string(data)) + if err := tokenKeychain.Set(keychain.LarkCliService, key, string(data)); err == nil { + _ = keychain.RemoveFallback(keychain.LarkCliService, key) + _ = os.Remove(legacyManagedTokenFilePath(token.AppId, token.UserOpenId)) + return false, nil + } else if !keychain.ShouldUseFallback(err) { + return false, fmt.Errorf("store token in keychain: %w", err) + } + if err := keychain.SetFallback(keychain.LarkCliService, key, string(data)); err != nil { + return false, fmt.Errorf("store token encrypted fallback: %w", err) + } + _ = os.Remove(legacyManagedTokenFilePath(token.AppId, token.UserOpenId)) + return true, nil } // RemoveStoredToken removes a stored UAT. func RemoveStoredToken(appId, userOpenId string) error { - return keychain.Remove(keychain.LarkCliService, accountKey(appId, userOpenId)) + key := accountKey(appId, userOpenId) + var errs []error + if err := keychain.RemoveFallback(keychain.LarkCliService, key); err != nil { + errs = append(errs, err) + } + if err := tokenKeychain.Remove(keychain.LarkCliService, key); err != nil { + errs = append(errs, err) + } + if err := os.Remove(legacyManagedTokenFilePath(appId, userOpenId)); err != nil && !os.IsNotExist(err) { + errs = append(errs, err) + } + return errors.Join(errs...) } // TokenStatus determines the freshness of a stored token. diff --git a/internal/auth/token_store_test.go b/internal/auth/token_store_test.go new file mode 100644 index 000000000..82829ea2e --- /dev/null +++ b/internal/auth/token_store_test.go @@ -0,0 +1,302 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/larksuite/cli/internal/keychain" +) + +// failingTokenKeychain simulates unavailable keychain writes that should trigger encrypted fallback. +type failingTokenKeychain struct{} + +// Get satisfies the keychain interface for read paths used by these tests. +func (f *failingTokenKeychain) Get(service, account string) (string, error) { return "", nil } + +// Set satisfies the keychain interface and forces fallback-eligible write errors. +func (f *failingTokenKeychain) Set(service, account, value string) error { + return keychain.WrapUnavailable(errors.New("sandbox denied")) +} + +// Remove satisfies the keychain interface for cleanup paths used by these tests. +func (f *failingTokenKeychain) Remove(service, account string) error { return nil } + +// genericFailingTokenKeychain simulates non-fallback-eligible keychain write failures. +type genericFailingTokenKeychain struct{} + +// Get satisfies the keychain interface for read paths used by these tests. +func (f *genericFailingTokenKeychain) Get(service, account string) (string, error) { return "", nil } + +// Set satisfies the keychain interface and returns a generic write failure. +func (f *genericFailingTokenKeychain) Set(service, account, value string) error { + return errors.New("boom") +} + +// Remove satisfies the keychain interface for cleanup paths used by these tests. +func (f *genericFailingTokenKeychain) Remove(service, account string) error { return nil } + +// removeFailingTokenKeychain simulates keychain delete failures during cleanup. +type removeFailingTokenKeychain struct{} + +// Get satisfies the keychain interface for read paths used by these tests. +func (f *removeFailingTokenKeychain) Get(service, account string) (string, error) { return "", nil } + +// Set satisfies the keychain interface for write paths used by these tests. +func (f *removeFailingTokenKeychain) Set(service, account, value string) error { return nil } + +// Remove satisfies the keychain interface and forces delete failures. +func (f *removeFailingTokenKeychain) Remove(service, account string) error { + return errors.New("remove failed") +} + +// TestSetStoredToken_FallsBackToManagedFileWhenKeychainUnavailable verifies encrypted fallback is used when keychain writes are unavailable. +func TestSetStoredToken_FallsBackToManagedFileWhenKeychainUnavailable(t *testing.T) { + configDir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", configDir) + + prev := tokenKeychain + tokenKeychain = &failingTokenKeychain{} + t.Cleanup(func() { tokenKeychain = prev }) + + token := &StoredUAToken{ + UserOpenId: "ou_test_user", + AppId: "cli_test", + AccessToken: "access-token", + RefreshToken: "refresh-token", + ExpiresAt: 1710000000000, + RefreshExpiresAt: 1710003600000, + Scope: "offline_access", + GrantedAt: 1709996400000, + } + + usedFallback, err := SetStoredToken(token) + if err != nil { + t.Fatalf("SetStoredToken returned error: %v", err) + } + if !usedFallback { + t.Fatal("expected SetStoredToken to report fallback usage") + } + + if fallback, err := keychain.GetFallbackWithError(keychain.LarkCliService, accountKey(token.AppId, token.UserOpenId)); err != nil || fallback == "" { + t.Fatal("expected encrypted fallback token to be stored") + } + + stored := GetStoredToken(token.AppId, token.UserOpenId) + if stored == nil { + t.Fatal("expected GetStoredToken to read file-backed token") + } + if stored.AccessToken != token.AccessToken { + t.Fatalf("stored access token = %q, want %q", stored.AccessToken, token.AccessToken) + } + if stored.RefreshToken != token.RefreshToken { + t.Fatalf("stored refresh token = %q, want %q", stored.RefreshToken, token.RefreshToken) + } +} + +// TestRemoveStoredToken_RemovesManagedFileFallback verifies token cleanup removes encrypted fallback files. +func TestRemoveStoredToken_RemovesManagedFileFallback(t *testing.T) { + configDir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", configDir) + + prev := tokenKeychain + tokenKeychain = &failingTokenKeychain{} + t.Cleanup(func() { tokenKeychain = prev }) + + token := &StoredUAToken{ + UserOpenId: "ou_test_user", + AppId: "cli_test", + AccessToken: "access-token", + RefreshToken: "refresh-token", + ExpiresAt: 1710000000000, + RefreshExpiresAt: 1710003600000, + GrantedAt: 1709996400000, + } + + if _, err := SetStoredToken(token); err != nil { + t.Fatalf("SetStoredToken returned error: %v", err) + } + + if fallback, err := keychain.GetFallbackWithError(keychain.LarkCliService, accountKey(token.AppId, token.UserOpenId)); err != nil || fallback == "" { + t.Fatal("expected encrypted fallback token to exist before removal") + } + + if err := RemoveStoredToken(token.AppId, token.UserOpenId); err != nil { + t.Fatalf("RemoveStoredToken returned error: %v", err) + } + + if fallback, err := keychain.GetFallbackWithError(keychain.LarkCliService, accountKey(token.AppId, token.UserOpenId)); err == nil && fallback != "" { + t.Fatalf("expected encrypted fallback token to be removed, got %q", fallback) + } +} + +// TestGetStoredToken_UsesEncryptedFallbackWhenPrimaryStoreMisses verifies encrypted fallback reads are used after primary-store misses. +func TestGetStoredToken_UsesEncryptedFallbackWhenPrimaryStoreMisses(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + prev := tokenKeychain + tokenKeychain = &failingTokenKeychain{} + t.Cleanup(func() { tokenKeychain = prev }) + + token := &StoredUAToken{ + UserOpenId: "ou_test_user", + AppId: "cli_test", + AccessToken: "access-token", + RefreshToken: "refresh-token", + ExpiresAt: 1710000000000, + RefreshExpiresAt: 1710003600000, + GrantedAt: 1709996400000, + } + if _, err := SetStoredToken(token); err != nil { + t.Fatalf("SetStoredToken returned error: %v", err) + } + + stored := GetStoredToken(token.AppId, token.UserOpenId) + if stored == nil { + t.Fatal("expected GetStoredToken to read encrypted fallback token") + } + if stored.AccessToken != token.AccessToken { + t.Fatalf("stored access token = %q, want %q", stored.AccessToken, token.AccessToken) + } +} + +// TestGetStoredToken_ReadsLegacyManagedTokenFile verifies the legacy plaintext token file remains readable for compatibility. +func TestGetStoredToken_ReadsLegacyManagedTokenFile(t *testing.T) { + configDir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", configDir) + + prev := tokenKeychain + tokenKeychain = &failingTokenKeychain{} + t.Cleanup(func() { tokenKeychain = prev }) + + token := &StoredUAToken{ + UserOpenId: "ou_test_user", + AppId: "cli_test", + AccessToken: "access-token", + RefreshToken: "refresh-token", + ExpiresAt: 1710000000000, + RefreshExpiresAt: 1710003600000, + GrantedAt: 1709996400000, + } + data, err := json.Marshal(token) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if err := os.MkdirAll(filepath.Dir(legacyManagedTokenFilePath(token.AppId, token.UserOpenId)), 0700); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(legacyManagedTokenFilePath(token.AppId, token.UserOpenId), data, 0600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + stored := GetStoredToken(token.AppId, token.UserOpenId) + if stored == nil { + t.Fatal("expected GetStoredToken to read legacy managed token file") + } + if stored.AccessToken != token.AccessToken { + t.Fatalf("stored access token = %q, want %q", stored.AccessToken, token.AccessToken) + } +} + +// TestRemoveStoredToken_ReturnsKeychainErrorWhenFallbackIsAbsent verifies fallback misses do not hide keychain delete errors. +func TestRemoveStoredToken_ReturnsKeychainErrorWhenFallbackIsAbsent(t *testing.T) { + configDir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", configDir) + + prev := tokenKeychain + tokenKeychain = &removeFailingTokenKeychain{} + t.Cleanup(func() { tokenKeychain = prev }) + + err := RemoveStoredToken("cli_test", "ou_test_user") + if err == nil { + t.Fatal("expected RemoveStoredToken to return keychain removal error") + } + if !strings.Contains(err.Error(), "remove failed") { + t.Fatalf("expected keychain removal error, got %v", err) + } +} + +// TestSetStoredToken_DoesNotFallbackOnGenericKeychainError verifies generic write errors do not silently downgrade to fallback storage. +func TestSetStoredToken_DoesNotFallbackOnGenericKeychainError(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + prev := tokenKeychain + tokenKeychain = &genericFailingTokenKeychain{} + t.Cleanup(func() { tokenKeychain = prev }) + + token := &StoredUAToken{ + UserOpenId: "ou_test_user", + AppId: "cli_test", + AccessToken: "access-token", + RefreshToken: "refresh-token", + ExpiresAt: 1710000000000, + RefreshExpiresAt: 1710003600000, + GrantedAt: 1709996400000, + } + + usedFallback, err := SetStoredToken(token) + if err == nil { + t.Fatal("expected SetStoredToken to return the keychain write error") + } + if usedFallback { + t.Fatal("expected generic keychain error to avoid fallback") + } + if !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected original keychain write error, got %v", err) + } + if fallback, readErr := keychain.GetFallbackWithError(keychain.LarkCliService, accountKey(token.AppId, token.UserOpenId)); readErr == nil && fallback != "" { + t.Fatalf("expected no encrypted fallback token for generic keychain error, got %q", fallback) + } +} + +// TestGetStoredToken_DoesNotFallBackToLegacyWhenEncryptedFallbackIsCorrupt verifies corrupt encrypted fallback data blocks legacy fallback reads. +func TestGetStoredToken_DoesNotFallBackToLegacyWhenEncryptedFallbackIsCorrupt(t *testing.T) { + configDir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", configDir) + + prev := tokenKeychain + tokenKeychain = &failingTokenKeychain{} + t.Cleanup(func() { tokenKeychain = prev }) + + token := &StoredUAToken{ + UserOpenId: "ou_test_user", + AppId: "cli_test", + AccessToken: "legacy-access-token", + RefreshToken: "legacy-refresh-token", + ExpiresAt: 1710000000000, + RefreshExpiresAt: 1710003600000, + GrantedAt: 1709996400000, + } + data, err := json.Marshal(token) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if err := os.MkdirAll(filepath.Dir(legacyManagedTokenFilePath(token.AppId, token.UserOpenId)), 0700); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(legacyManagedTokenFilePath(token.AppId, token.UserOpenId), data, 0600); err != nil { + t.Fatalf("WriteFile(legacy): %v", err) + } + + serviceDir := filepath.Join(configDir, "keychain", keychain.LarkCliService) + if err := os.MkdirAll(serviceDir, 0700); err != nil { + t.Fatalf("MkdirAll(serviceDir): %v", err) + } + if err := os.WriteFile(filepath.Join(serviceDir, "master.key"), []byte("12345678901234567890123456789012"), 0600); err != nil { + t.Fatalf("WriteFile(master.key): %v", err) + } + if err := os.WriteFile(filepath.Join(serviceDir, "Y2xpX3Rlc3Q6b3VfdGVzdF91c2Vy.enc"), []byte("corrupt"), 0600); err != nil { + t.Fatalf("WriteFile(ciphertext): %v", err) + } + + stored := GetStoredToken(token.AppId, token.UserOpenId) + if stored != nil { + t.Fatalf("expected corrupt encrypted fallback to stop legacy fallback, got %#v", stored) + } +} diff --git a/internal/auth/uat_client.go b/internal/auth/uat_client.go index 133c9c7e1..950306dc0 100644 --- a/internal/auth/uat_client.go +++ b/internal/auth/uat_client.go @@ -23,6 +23,7 @@ import ( var safeIDChars = regexp.MustCompile(`[^a-zA-Z0-9._-]`) +// sanitizeID rewrites IDs into a filesystem-safe form for lock-file names. func sanitizeID(id string) string { return safeIDChars.ReplaceAllString(id, "_") } @@ -98,6 +99,7 @@ func GetValidAccessToken(httpClient *http.Client, opts UATCallOptions) (string, return "", &NeedAuthorizationError{UserOpenId: opts.UserOpenId} } +// refreshWithLock refreshes a token while serializing concurrent refresh attempts across goroutines and processes. func refreshWithLock(httpClient *http.Client, opts UATCallOptions, stored *StoredUAToken) (*StoredUAToken, error) { key := fmt.Sprintf("%s:%s", opts.AppId, opts.UserOpenId) @@ -165,6 +167,7 @@ func refreshWithLock(httpClient *http.Client, opts UATCallOptions, stored *Store return doRefreshToken(httpClient, opts, stored) } +// doRefreshToken calls the OAuth refresh endpoint and stores the replacement token state. func doRefreshToken(httpClient *http.Client, opts UATCallOptions, stored *StoredUAToken) (*StoredUAToken, error) { errOut := opts.ErrOut if errOut == nil { @@ -298,8 +301,12 @@ func doRefreshToken(httpClient *http.Client, opts UATCallOptions, stored *Stored GrantedAt: stored.GrantedAt, } - if err := SetStoredToken(updated); err != nil { + usedFallback, err := SetStoredToken(updated) + if err != nil { return nil, err } + if usedFallback { + fmt.Fprintln(errOut, "warning: keychain unavailable, auth token stored in a local file protected by filesystem permissions (0600) managed by lark-cli") + } return updated, nil } diff --git a/internal/configdir/config_dir.go b/internal/configdir/config_dir.go new file mode 100644 index 000000000..5db7dd816 --- /dev/null +++ b/internal/configdir/config_dir.go @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package configdir + +import ( + "fmt" + "os" + "path/filepath" +) + +// Get returns the CLI config directory. +func Get() string { + if dir := os.Getenv("LARKSUITE_CLI_CONFIG_DIR"); dir != "" { + return dir + } + home, err := os.UserHomeDir() + if err != nil || home == "" { + fmt.Fprintf(os.Stderr, "warning: unable to determine home directory: %v\n", err) + } + return filepath.Join(home, ".lark-cli") +} diff --git a/internal/core/config.go b/internal/core/config.go index ced1e27b4..69735a1d4 100644 --- a/internal/core/config.go +++ b/internal/core/config.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" + "github.com/larksuite/cli/internal/configdir" "github.com/larksuite/cli/internal/keychain" "github.com/larksuite/cli/internal/validate" ) @@ -58,15 +59,11 @@ type CliConfig struct { // GetConfigDir returns the config directory path. // If the home directory cannot be determined, it falls back to a relative path // and prints a warning to stderr. +// +// Kept as the legacy entrypoint for callers already depending on core. +// New code should prefer configdir.Get(); see internal/configdir/config_dir.go. func GetConfigDir() string { - if dir := os.Getenv("LARKSUITE_CLI_CONFIG_DIR"); dir != "" { - return dir - } - home, err := os.UserHomeDir() - if err != nil || home == "" { - fmt.Fprintf(os.Stderr, "warning: unable to determine home directory: %v\n", err) - } - return filepath.Join(home, ".lark-cli") + return configdir.Get() } // GetConfigPath returns the config file path. diff --git a/internal/core/secret.go b/internal/core/secret.go index a488e5dcf..526032188 100644 --- a/internal/core/secret.go +++ b/internal/core/secret.go @@ -14,7 +14,7 @@ import ( // SecretRef references a secret stored externally. type SecretRef struct { - Source string `json:"source"` // "file" | "keychain" + Source string `json:"source"` // "file" | "keychain" | "encrypted_file" Provider string `json:"provider,omitempty"` // optional, reserved ID string `json:"id"` // env var name / file path / command / keychain key } @@ -78,9 +78,10 @@ func (s *SecretInput) UnmarshalJSON(data []byte) error { // ValidSecretSources is the set of recognized SecretRef sources. var ValidSecretSources = map[string]bool{ - "file": true, "keychain": true, + "file": true, "keychain": true, "encrypted_file": true, } +// isValidSource reports whether a secret reference source is accepted by config decoding. func isValidSource(source string) bool { return ValidSecretSources[source] } diff --git a/internal/core/secret_resolve.go b/internal/core/secret_resolve.go index 6e7921d3f..9296457fa 100644 --- a/internal/core/secret_resolve.go +++ b/internal/core/secret_resolve.go @@ -4,6 +4,7 @@ package core import ( + "errors" "fmt" "os" "strings" @@ -13,12 +14,13 @@ import ( const secretKeyPrefix = "appsecret:" +// secretAccountKey builds the managed storage key for an app secret. func secretAccountKey(appId string) string { return secretKeyPrefix + appId } // ResolveSecretInput resolves a SecretInput to a plain string. -// SecretRef objects are resolved by source (file / keychain). +// SecretRef objects are resolved by source (file / keychain / encrypted_file). func ResolveSecretInput(s SecretInput, kc keychain.KeychainAccess) (string, error) { if s.Ref == nil { return s.Plain, nil @@ -30,6 +32,15 @@ func ResolveSecretInput(s SecretInput, kc keychain.KeychainAccess) (string, erro return "", fmt.Errorf("failed to read secret file %s: %w", s.Ref.ID, err) } return strings.TrimSpace(string(data)), nil + case "encrypted_file": + value, err := keychain.GetFallbackWithError(keychain.LarkCliService, s.Ref.ID) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return "", fmt.Errorf("encrypted fallback secret %s not found", s.Ref.ID) + } + return "", fmt.Errorf("failed to decrypt encrypted fallback secret %s: %w", s.Ref.ID, err) + } + return value, nil case "keychain": return kc.Get(keychain.LarkCliService, s.Ref.ID) default: @@ -47,15 +58,41 @@ func ForStorage(appId string, input SecretInput, kc keychain.KeychainAccess) (Se } key := secretAccountKey(appId) if err := kc.Set(keychain.LarkCliService, key, input.Plain); err != nil { - return SecretInput{}, fmt.Errorf("keychain unavailable: %w\nhint: use file: reference in config to bypass keychain", err) + return SecretInput{}, fmt.Errorf("store secret in keychain: %w", err) } return SecretInput{Ref: &SecretRef{Source: "keychain", ID: key}}, nil } +// ForStorageWithEncryptedFallback stores a plain secret in keychain when available, +// or falls back to the shared encrypted file store. +func ForStorageWithEncryptedFallback(appId string, input SecretInput, kc keychain.KeychainAccess) (SecretInput, error) { + if !input.IsPlain() { + return input, nil + } + key := secretAccountKey(appId) + if err := kc.Set(keychain.LarkCliService, key, input.Plain); err == nil { + return SecretInput{Ref: &SecretRef{Source: "keychain", ID: key}}, nil + } else if !keychain.ShouldUseFallback(err) { + return SecretInput{}, fmt.Errorf("store secret in keychain: %w", err) + } + if err := keychain.SetFallback(keychain.LarkCliService, key, input.Plain); err != nil { + return SecretInput{}, fmt.Errorf("store secret encrypted fallback: %w", err) + } + return SecretInput{Ref: &SecretRef{Source: "encrypted_file", ID: key}}, nil +} + // RemoveSecretStore cleans up keychain entries when an app is removed. // Errors are intentionally ignored — cleanup is best-effort. func RemoveSecretStore(input SecretInput, kc keychain.KeychainAccess) { - if input.IsSecretRef() && input.Ref.Source == "keychain" { + if !input.IsSecretRef() { + return + } + switch input.Ref.Source { + case "file": + return + case "keychain": _ = kc.Remove(keychain.LarkCliService, input.Ref.ID) + case "encrypted_file": + _ = keychain.RemoveFallback(keychain.LarkCliService, input.Ref.ID) } } diff --git a/internal/core/secret_resolve_test.go b/internal/core/secret_resolve_test.go new file mode 100644 index 000000000..6368731c7 --- /dev/null +++ b/internal/core/secret_resolve_test.go @@ -0,0 +1,113 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package core + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/larksuite/cli/internal/keychain" +) + +// erroringSetKeychain simulates configurable keychain write failures in secret-storage tests. +type erroringSetKeychain struct { + err error +} + +// Get satisfies the keychain interface for read paths used by these tests. +func (e *erroringSetKeychain) Get(service, account string) (string, error) { return "", nil } + +// Set satisfies the keychain interface and returns the configured write error. +func (e *erroringSetKeychain) Set(service, account, value string) error { return e.err } + +// Remove satisfies the keychain interface for cleanup paths used by these tests. +func (e *erroringSetKeychain) Remove(service, account string) error { return nil } + +// TestForStorageWithEncryptedFallback_DoesNotFallbackOnGenericSetError verifies generic keychain errors do not trigger encrypted fallback. +func TestForStorageWithEncryptedFallback_DoesNotFallbackOnGenericSetError(t *testing.T) { + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", t.TempDir()) + + _, err := ForStorageWithEncryptedFallback( + "cli_test", + PlainSecret("secret123"), + &erroringSetKeychain{err: errors.New("boom")}, + ) + if err == nil { + t.Fatal("expected ForStorageWithEncryptedFallback to return the keychain write error") + } + if !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected original keychain write error, got %v", err) + } + if got := keychain.GetFallback(keychain.LarkCliService, "appsecret:cli_test"); got != "" { + t.Fatalf("expected no encrypted fallback to be written for generic keychain error, got %q", got) + } +} + +// TestSecretInput_UnmarshalAcceptsFileSource verifies config decoding still accepts legacy file-backed secret refs. +func TestSecretInput_UnmarshalAcceptsFileSource(t *testing.T) { + var input SecretInput + data := []byte(`{"source":"file","id":"/tmp/app-secret.txt"}`) + + if err := json.Unmarshal(data, &input); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if input.Ref == nil { + t.Fatal("expected secret ref") + } + if input.Ref.Source != "file" { + t.Fatalf("source = %q, want file", input.Ref.Source) + } + if input.Ref.ID != "/tmp/app-secret.txt" { + t.Fatalf("id = %q, want /tmp/app-secret.txt", input.Ref.ID) + } +} + +// TestResolveSecretInput_FileSourceReadsSecretFile verifies file-backed secret refs read and trim the external file content. +func TestResolveSecretInput_FileSourceReadsSecretFile(t *testing.T) { + secretFile := filepath.Join(t.TempDir(), "app-secret.txt") + if err := os.WriteFile(secretFile, []byte("secret123\n"), 0600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + secret, err := ResolveSecretInput(SecretInput{ + Ref: &SecretRef{Source: "file", ID: secretFile}, + }, &erroringSetKeychain{}) + if err != nil { + t.Fatalf("ResolveSecretInput: %v", err) + } + if secret != "secret123" { + t.Fatalf("secret = %q, want secret123", secret) + } +} + +// TestResolveSecretInput_EncryptedFallbackIncludesUnderlyingError verifies decrypt failures surface a diagnostic error. +func TestResolveSecretInput_EncryptedFallbackIncludesUnderlyingError(t *testing.T) { + configDir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", configDir) + + serviceDir := filepath.Join(configDir, "keychain", keychain.LarkCliService) + if err := os.MkdirAll(serviceDir, 0700); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(filepath.Join(serviceDir, "master.key"), []byte("12345678901234567890123456789012"), 0600); err != nil { + t.Fatalf("WriteFile(master.key): %v", err) + } + if err := os.WriteFile(filepath.Join(serviceDir, "YXBwc2VjcmV0OmNsaV90ZXN0.enc"), []byte("corrupt"), 0600); err != nil { + t.Fatalf("WriteFile(ciphertext): %v", err) + } + + _, err := ResolveSecretInput(SecretInput{ + Ref: &SecretRef{Source: "encrypted_file", ID: "appsecret:cli_test"}, + }, &erroringSetKeychain{}) + if err == nil { + t.Fatal("expected ResolveSecretInput to report fallback decrypt failure") + } + if !strings.Contains(err.Error(), "failed to decrypt encrypted fallback secret") { + t.Fatalf("expected decrypt-specific error, got %v", err) + } +} diff --git a/internal/keychain/errors.go b/internal/keychain/errors.go new file mode 100644 index 000000000..d7f6fca18 --- /dev/null +++ b/internal/keychain/errors.go @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package keychain + +import ( + "errors" +) + +// ErrUnavailable marks failures where the platform keychain cannot be used and +// callers may degrade to the CLI-managed encrypted fallback store. +var ErrUnavailable = errors.New("keychain unavailable") + +// unavailableError preserves both the fallback-eligibility marker and the underlying platform error. +type unavailableError struct { + cause error +} + +// Error implements error. +func (e *unavailableError) Error() string { + return ErrUnavailable.Error() + ": " + e.cause.Error() +} + +// Unwrap exposes both ErrUnavailable and the underlying cause. +func (e *unavailableError) Unwrap() []error { + return []error{ErrUnavailable, e.cause} +} + +// WrapUnavailable annotates a lower-level keychain error as fallback-eligible. +func WrapUnavailable(err error) error { + if err == nil || errors.Is(err, ErrUnavailable) { + return err + } + return &unavailableError{cause: err} +} diff --git a/internal/keychain/fallback_policy.go b/internal/keychain/fallback_policy.go new file mode 100644 index 000000000..a15608b73 --- /dev/null +++ b/internal/keychain/fallback_policy.go @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package keychain + +import ( + "errors" +) + +// ShouldUseFallback reports whether a keychain write failure should degrade to +// the CLI-managed encrypted fallback store. +// New fallback-eligible errors should be expressed via ErrUnavailable / +// WrapUnavailable so callers share one typed contract instead of matching text. +func ShouldUseFallback(err error) bool { + if err == nil { + return false + } + return errors.Is(err, ErrUnavailable) +} diff --git a/internal/keychain/fallback_policy_test.go b/internal/keychain/fallback_policy_test.go new file mode 100644 index 000000000..af73dce82 --- /dev/null +++ b/internal/keychain/fallback_policy_test.go @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package keychain + +import ( + "context" + "errors" + "testing" +) + +// TestShouldUseFallback_RequiresTypedUnavailableError verifies fallback only triggers on the typed unavailability contract. +func TestShouldUseFallback_RequiresTypedUnavailableError(t *testing.T) { + if ShouldUseFallback(errors.New("exit status 155")) { + t.Fatal("expected raw error strings to stop triggering fallback") + } + + if !ShouldUseFallback(WrapUnavailable(errors.New("sandbox denied"))) { + t.Fatal("expected wrapped unavailable errors to trigger fallback") + } +} + +// TestWrapUnavailable_PreservesUnderlyingCause verifies unavailable wrapping keeps the original cause in the error chain. +func TestWrapUnavailable_PreservesUnderlyingCause(t *testing.T) { + err := WrapUnavailable(context.DeadlineExceeded) + if !errors.Is(err, ErrUnavailable) { + t.Fatal("expected ErrUnavailable to remain detectable") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal("expected underlying cause to remain detectable") + } +} diff --git a/internal/keychain/file_encrypted_store.go b/internal/keychain/file_encrypted_store.go new file mode 100644 index 000000000..127ea24d1 --- /dev/null +++ b/internal/keychain/file_encrypted_store.go @@ -0,0 +1,233 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package keychain + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" + "os" + "path/filepath" + + "github.com/larksuite/cli/internal/configdir" + "github.com/larksuite/cli/internal/validate" +) + +const masterKeyBytes = 32 +const ivBytes = 12 +const tagBytes = 16 + +// safeFileName maps an account key to a collision-free filesystem name. +func safeFileName(account string) string { + return base64.RawURLEncoding.EncodeToString([]byte(account)) + ".enc" +} + +// fallbackStorageDir returns the per-service directory used by the encrypted file fallback store. +func fallbackStorageDir(service string) string { + // Keep config dir resolution centralized in internal/configdir. + // New code should reuse configdir.Get() instead of duplicating env/home logic. + return filepath.Join(configdir.Get(), "keychain", service) +} + +// loadMasterKeyFile reads an existing fallback master key without creating new storage. +func loadMasterKeyFile(dir string) ([]byte, error) { + key, err := os.ReadFile(filepath.Join(dir, "master.key")) + if err != nil { + return nil, err + } + if len(key) != masterKeyBytes { + return nil, os.ErrInvalid + } + return key, nil +} + +// loadOrCreateMasterKeyFile returns the fallback master key, creating it only when absent. +func loadOrCreateMasterKeyFile(dir string) ([]byte, error) { + keyPath := filepath.Join(dir, "master.key") + + key, err := loadMasterKeyFile(dir) + if err == nil { + return key, nil + } + if !os.IsNotExist(err) { + return nil, err + } + + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, err + } + + key = make([]byte, masterKeyBytes) + if _, err := rand.Read(key); err != nil { + return nil, err + } + if err := createMasterKeyFile(keyPath, key); err != nil { + if os.IsExist(err) { + return loadMasterKeyFile(dir) + } + if existingKey, readErr := loadMasterKeyFile(dir); readErr == nil { + return existingKey, nil + } + return nil, err + } + return key, nil +} + +// createMasterKeyFile creates master.key with no-replace semantics. +func createMasterKeyFile(path string, key []byte) error { + file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600) + if err != nil { + return err + } + success := false + defer func() { + if success { + return + } + file.Close() + os.Remove(path) + }() + + if _, err := file.Write(key); err != nil { + return err + } + if err := file.Sync(); err != nil { + return err + } + if err := file.Close(); err != nil { + return err + } + success = true + return nil +} + +// encryptData seals plaintext with AES-GCM using the provided master key. +func encryptData(plaintext string, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + iv := make([]byte, ivBytes) + if _, err := rand.Read(iv); err != nil { + return nil, err + } + + ciphertext := aesGCM.Seal(nil, iv, []byte(plaintext), nil) + result := make([]byte, 0, ivBytes+len(ciphertext)) + result = append(result, iv...) + result = append(result, ciphertext...) + return result, nil +} + +// decryptData opens AES-GCM ciphertext produced by encryptData. +func decryptData(data []byte, key []byte) (string, error) { + if len(data) < ivBytes+tagBytes { + return "", os.ErrInvalid + } + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + iv := data[:ivBytes] + ciphertext := data[ivBytes:] + plaintext, err := aesGCM.Open(nil, iv, ciphertext, nil) + if err != nil { + return "", err + } + return string(plaintext), nil +} + +// readEncryptedFile reads an encrypted fallback file and suppresses read or decrypt errors. +func readEncryptedFile(dir, account string, key []byte) string { + plaintext, err := readEncryptedFileWithError(dir, account, key) + if err != nil { + return "" + } + return plaintext +} + +// readEncryptedFileWithError reads and decrypts an encrypted fallback file while preserving error detail. +func readEncryptedFileWithError(dir, account string, key []byte) (string, error) { + path := filepath.Join(dir, safeFileName(account)) + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + plaintext, err := decryptData(data, key) + if err != nil { + return "", fmt.Errorf("decrypt fallback file %s: %w", path, err) + } + return plaintext, nil +} + +// writeEncryptedFile encrypts and atomically writes fallback data for an account key. +func writeEncryptedFile(dir, account, data string, key []byte) error { + if err := os.MkdirAll(dir, 0700); err != nil { + return err + } + encrypted, err := encryptData(data, key) + if err != nil { + return err + } + return validate.AtomicWrite(filepath.Join(dir, safeFileName(account)), encrypted, 0600) +} + +// removeEncryptedFile deletes the encrypted fallback file for an account key. +func removeEncryptedFile(dir, account string) error { + err := os.Remove(filepath.Join(dir, safeFileName(account))) + if err != nil && !os.IsNotExist(err) { + return err + } + return nil +} + +// GetFallback reads fallback data and returns an empty string when the value is absent or unreadable. +func GetFallback(service, account string) string { + value, err := GetFallbackWithError(service, account) + if err != nil { + return "" + } + return value +} + +// GetFallbackWithError reads fallback data and preserves not-found or decrypt errors for callers that need diagnostics. +func GetFallbackWithError(service, account string) (string, error) { + dir := fallbackStorageDir(service) + path := filepath.Join(dir, safeFileName(account)) + if _, err := os.Stat(path); err != nil { + return "", err + } + key, err := loadMasterKeyFile(dir) + if err != nil { + return "", fmt.Errorf("load fallback master key for %s: %w", path, err) + } + return readEncryptedFileWithError(dir, account, key) +} + +// SetFallback stores fallback data for a service/account pair. +func SetFallback(service, account, data string) error { + dir := fallbackStorageDir(service) + key, err := loadOrCreateMasterKeyFile(dir) + if err != nil { + return err + } + return writeEncryptedFile(dir, account, data, key) +} + +// RemoveFallback removes fallback data for a service/account pair. +func RemoveFallback(service, account string) error { + return removeEncryptedFile(fallbackStorageDir(service), account) +} diff --git a/internal/keychain/file_encrypted_store_test.go b/internal/keychain/file_encrypted_store_test.go new file mode 100644 index 000000000..11a96c71a --- /dev/null +++ b/internal/keychain/file_encrypted_store_test.go @@ -0,0 +1,161 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +package keychain + +import ( + "bytes" + "encoding/base64" + "errors" + "os" + "path/filepath" + "strings" + "testing" +) + +// TestFallbackStore_EncryptsAndRemovesData verifies fallback values are encrypted at rest and removable. +func TestFallbackStore_EncryptsAndRemovesData(t *testing.T) { + configDir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", configDir) + + service := LarkCliService + account := "appsecret:cli_test" + plaintext := "secret123" + + if err := SetFallback(service, account, plaintext); err != nil { + t.Fatalf("SetFallback: %v", err) + } + + if got := GetFallback(service, account); got != plaintext { + t.Fatalf("GetFallback = %q, want %q", got, plaintext) + } + + encryptedPath := filepath.Join(fallbackStorageDir(service), safeFileName(account)) + data, err := os.ReadFile(encryptedPath) + if err != nil { + t.Fatalf("ReadFile(%s): %v", encryptedPath, err) + } + if strings.Contains(string(data), plaintext) { + t.Fatalf("encrypted fallback file unexpectedly contains plaintext %q", plaintext) + } + + masterKeyPath := filepath.Join(fallbackStorageDir(service), "master.key") + if info, err := os.Stat(masterKeyPath); err != nil { + t.Fatalf("master key file missing: %v", err) + } else if info.Mode().Perm() != 0600 { + t.Fatalf("master key perm = %v, want 0600", info.Mode().Perm()) + } + + if err := RemoveFallback(service, account); err != nil { + t.Fatalf("RemoveFallback: %v", err) + } + if got := GetFallback(service, account); got != "" { + t.Fatalf("expected fallback secret to be removed, got %q", got) + } +} + +// TestGetFallback_MissDoesNotCreateStorageArtifacts verifies read misses stay side-effect free. +func TestGetFallback_MissDoesNotCreateStorageArtifacts(t *testing.T) { + configDir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", configDir) + + service := LarkCliService + account := "missing:account" + + if got := GetFallback(service, account); got != "" { + t.Fatalf("GetFallback = %q, want empty string for missing account", got) + } + + fallbackDir := fallbackStorageDir(service) + if _, err := os.Stat(fallbackDir); !os.IsNotExist(err) { + t.Fatalf("expected fallback dir to stay absent on read miss, stat err = %v", err) + } +} + +// TestCreateMasterKeyFile_DoesNotReplaceExistingFile verifies master.key creation uses no-replace semantics. +func TestCreateMasterKeyFile_DoesNotReplaceExistingFile(t *testing.T) { + dir := t.TempDir() + keyPath := filepath.Join(dir, "master.key") + existingKey := bytes.Repeat([]byte{1}, masterKeyBytes) + replacementKey := bytes.Repeat([]byte{2}, masterKeyBytes) + + if err := os.WriteFile(keyPath, existingKey, 0600); err != nil { + t.Fatalf("WriteFile(%s): %v", keyPath, err) + } + + err := createMasterKeyFile(keyPath, replacementKey) + if !errors.Is(err, os.ErrExist) { + t.Fatalf("createMasterKeyFile error = %v, want %v", err, os.ErrExist) + } + + got, readErr := os.ReadFile(keyPath) + if readErr != nil { + t.Fatalf("ReadFile(%s): %v", keyPath, readErr) + } + if !bytes.Equal(got, existingKey) { + t.Fatalf("master key was replaced: got %x want %x", got, existingKey) + } +} + +// TestLoadOrCreateMasterKeyFile_ReusesExistingKey verifies existing master keys are reused rather than replaced. +func TestLoadOrCreateMasterKeyFile_ReusesExistingKey(t *testing.T) { + dir := t.TempDir() + existingKey := bytes.Repeat([]byte{7}, masterKeyBytes) + + if err := os.WriteFile(filepath.Join(dir, "master.key"), existingKey, 0600); err != nil { + t.Fatalf("WriteFile(master.key): %v", err) + } + + got, err := loadOrCreateMasterKeyFile(dir) + if err != nil { + t.Fatalf("loadOrCreateMasterKeyFile: %v", err) + } + if !bytes.Equal(got, existingKey) { + t.Fatalf("loadOrCreateMasterKeyFile returned %x, want %x", got, existingKey) + } +} + +// TestSafeFileName_EncodesFullAccountWithoutCollision verifies account keys map to collision-free filenames. +func TestSafeFileName_EncodesFullAccountWithoutCollision(t *testing.T) { + accountA := "appsecret:cli_test" + accountB := "appsecret/cli_test" + + gotA := safeFileName(accountA) + gotB := safeFileName(accountB) + if gotA == gotB { + t.Fatalf("safeFileName collision: %q and %q both mapped to %q", accountA, accountB, gotA) + } + if want := base64.RawURLEncoding.EncodeToString([]byte(accountA)) + ".enc"; gotA != want { + t.Fatalf("safeFileName(%q) = %q, want %q", accountA, gotA, want) + } +} + +// TestGetFallbackWithError_ReturnsDecryptFailure verifies corrupt ciphertext is surfaced as a decrypt error. +func TestGetFallbackWithError_ReturnsDecryptFailure(t *testing.T) { + configDir := t.TempDir() + t.Setenv("LARKSUITE_CLI_CONFIG_DIR", configDir) + + service := LarkCliService + account := "appsecret:cli_test" + dir := fallbackStorageDir(service) + if err := os.MkdirAll(dir, 0700); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(filepath.Join(dir, "master.key"), bytes.Repeat([]byte{1}, masterKeyBytes), 0600); err != nil { + t.Fatalf("WriteFile(master.key): %v", err) + } + if err := os.WriteFile(filepath.Join(dir, safeFileName(account)), []byte("corrupt"), 0600); err != nil { + t.Fatalf("WriteFile(ciphertext): %v", err) + } + + got, err := GetFallbackWithError(service, account) + if err == nil { + t.Fatal("expected GetFallbackWithError to report decrypt failure") + } + if got != "" { + t.Fatalf("GetFallbackWithError returned %q, want empty string on error", got) + } + if !strings.Contains(err.Error(), "decrypt") { + t.Fatalf("expected decrypt context in error, got %v", err) + } +} diff --git a/internal/keychain/keychain_darwin.go b/internal/keychain/keychain_darwin.go index fe71583d6..5525e5883 100644 --- a/internal/keychain/keychain_darwin.go +++ b/internal/keychain/keychain_darwin.go @@ -7,23 +7,18 @@ package keychain import ( "context" - "crypto/aes" - "crypto/cipher" "crypto/rand" "encoding/base64" + "errors" + "fmt" "os" "path/filepath" - "regexp" "time" - "github.com/google/uuid" "github.com/zalando/go-keyring" ) const keychainTimeout = 5 * time.Second -const masterKeyBytes = 32 -const ivBytes = 12 -const tagBytes = 16 // StorageDir returns the storage directory for a given service name on macOS. func StorageDir(service string) string { @@ -34,12 +29,38 @@ func StorageDir(service string) string { return filepath.Join(home, "Library", "Application Support", service) } -var safeFileNameRe = regexp.MustCompile(`[^a-zA-Z0-9._-]`) - -func safeFileName(account string) string { - return safeFileNameRe.ReplaceAllString(account, "_") + ".enc" +// resolveMasterKey loads, validates, or creates the base64-encoded master key stored in macOS keychain. +func resolveMasterKey( + getFn func() (string, error), + setFn func(string) error, + generateFn func() ([]byte, error), +) ([]byte, error) { + encodedKey, err := getFn() + switch { + case err == nil: + key, decodeErr := base64.StdEncoding.DecodeString(encodedKey) + if decodeErr != nil { + return nil, WrapUnavailable(fmt.Errorf("decode master key: %w", decodeErr)) + } + if len(key) != masterKeyBytes { + return nil, WrapUnavailable(fmt.Errorf("invalid master key length: %d", len(key))) + } + return key, nil + case errors.Is(err, keyring.ErrNotFound): + key, genErr := generateFn() + if genErr != nil { + return nil, genErr + } + if setErr := setFn(base64.StdEncoding.EncodeToString(key)); setErr != nil { + return nil, WrapUnavailable(setErr) + } + return key, nil + default: + return nil, WrapUnavailable(err) + } } +// getMasterKey resolves the per-service master key with a timeout around keychain access. func getMasterKey(service string) ([]byte, error) { ctx, cancel := context.WithTimeout(context.Background(), keychainTimeout) defer cancel() @@ -51,129 +72,53 @@ func getMasterKey(service string) ([]byte, error) { resCh := make(chan result, 1) go func() { defer func() { recover() }() - - encodedKey, err := keyring.Get(service, "master.key") - if err == nil { - key, decodeErr := base64.StdEncoding.DecodeString(encodedKey) - if decodeErr == nil && len(key) == masterKeyBytes { - resCh <- result{key: key, err: nil} - return - } - } - - // Generate new master key if not found or invalid - key := make([]byte, masterKeyBytes) - if _, randErr := rand.Read(key); randErr != nil { - resCh <- result{key: nil, err: randErr} - return - } - - encodedKey = base64.StdEncoding.EncodeToString(key) - setErr := keyring.Set(service, "master.key", encodedKey) - resCh <- result{key: key, err: setErr} + key, err := resolveMasterKey( + func() (string, error) { return keyring.Get(service, "master.key") }, + func(encoded string) error { return keyring.Set(service, "master.key", encoded) }, + func() ([]byte, error) { + key := make([]byte, masterKeyBytes) + if _, randErr := rand.Read(key); randErr != nil { + return nil, randErr + } + return key, nil + }, + ) + resCh <- result{key: key, err: err} }() select { case res := <-resCh: return res.key, res.err case <-ctx.Done(): - return nil, ctx.Err() - } -} - -func encryptData(plaintext string, key []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - - iv := make([]byte, ivBytes) - if _, err := rand.Read(iv); err != nil { - return nil, err - } - - ciphertext := aesGCM.Seal(nil, iv, []byte(plaintext), nil) - result := make([]byte, 0, ivBytes+len(ciphertext)) - result = append(result, iv...) - result = append(result, ciphertext...) - return result, nil -} - -func decryptData(data []byte, key []byte) (string, error) { - if len(data) < ivBytes+tagBytes { - return "", os.ErrInvalid - } - block, err := aes.NewCipher(key) - if err != nil { - return "", err + return nil, WrapUnavailable(ctx.Err()) } - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - iv := data[:ivBytes] - ciphertext := data[ivBytes:] - plaintext, err := aesGCM.Open(nil, iv, ciphertext, nil) - if err != nil { - return "", err - } - return string(plaintext), nil } +// platformGet reads a service/account secret from the macOS encrypted-file backend. func platformGet(service, account string) string { key, err := getMasterKey(service) if err != nil { return "" } - data, err := os.ReadFile(filepath.Join(StorageDir(service), safeFileName(account))) - if err != nil { - return "" - } - plaintext, err := decryptData(data, key) - if err != nil { - return "" - } - return plaintext + // Shared encrypted-file read semantics live in file_encrypted_store.go. + // New code should reuse that helper layer instead of reimplementing file I/O here. + return readEncryptedFile(StorageDir(service), account, key) } +// platformSet writes a service/account secret through the macOS encrypted-file backend. func platformSet(service, account, data string) error { key, err := getMasterKey(service) if err != nil { return err } - dir := StorageDir(service) - if err := os.MkdirAll(dir, 0700); err != nil { - return err - } - encrypted, err := encryptData(data, key) - if err != nil { - return err - } - - targetPath := filepath.Join(dir, safeFileName(account)) - tmpPath := filepath.Join(dir, safeFileName(account)+"."+uuid.New().String()+".tmp") - defer os.Remove(tmpPath) - - if err := os.WriteFile(tmpPath, encrypted, 0600); err != nil { - return err - } - - // Atomic rename to prevent file corruption during multi-process writes - if err := os.Rename(tmpPath, targetPath); err != nil { - return err - } - return nil + // Shared encrypted-file write semantics live in file_encrypted_store.go. + // New code should reuse that helper layer instead of reimplementing file I/O here. + return writeEncryptedFile(StorageDir(service), account, data, key) } +// platformRemove deletes a service/account secret from the macOS encrypted-file backend. func platformRemove(service, account string) error { - err := os.Remove(filepath.Join(StorageDir(service), safeFileName(account))) - if err != nil && !os.IsNotExist(err) { - return err - } - return nil + // Shared encrypted-file cleanup semantics live in file_encrypted_store.go. + // New code should reuse that helper layer instead of reimplementing file I/O here. + return removeEncryptedFile(StorageDir(service), account) } diff --git a/internal/keychain/keychain_darwin_test.go b/internal/keychain/keychain_darwin_test.go new file mode 100644 index 000000000..844a066cb --- /dev/null +++ b/internal/keychain/keychain_darwin_test.go @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Lark Technologies Pte. Ltd. +// SPDX-License-Identifier: MIT + +//go:build darwin + +package keychain + +import ( + "context" + "errors" + "testing" + + "github.com/zalando/go-keyring" +) + +// TestResolveMasterKey_OnlyCreatesOnNotFound verifies existing keychain errors do not rotate the master key. +func TestResolveMasterKey_OnlyCreatesOnNotFound(t *testing.T) { + expected := []byte("12345678901234567890123456789012") + setCalled := false + + _, err := resolveMasterKey( + func() (string, error) { return "", context.DeadlineExceeded }, + func(string) error { + setCalled = true + return nil + }, + func() ([]byte, error) { return expected, nil }, + ) + if err == nil { + t.Fatal("expected resolveMasterKey to return unavailable error") + } + if setCalled { + t.Fatal("expected non-ErrNotFound failure to avoid rotating master key") + } + if !errors.Is(err, ErrUnavailable) { + t.Fatalf("expected ErrUnavailable, got %v", err) + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected original cause in error chain, got %v", err) + } +} + +// TestResolveMasterKey_CreatesOnNotFound verifies missing keychain entries initialize a new master key. +func TestResolveMasterKey_CreatesOnNotFound(t *testing.T) { + setCalled := false + + key, err := resolveMasterKey( + func() (string, error) { return "", keyring.ErrNotFound }, + func(encoded string) error { + setCalled = true + return nil + }, + func() ([]byte, error) { return []byte("12345678901234567890123456789012"), nil }, + ) + if err != nil { + t.Fatalf("resolveMasterKey: %v", err) + } + if !setCalled { + t.Fatal("expected keyring.Set path on ErrNotFound") + } + if len(key) != masterKeyBytes { + t.Fatalf("master key len = %d, want %d", len(key), masterKeyBytes) + } +} diff --git a/internal/keychain/keychain_other.go b/internal/keychain/keychain_other.go index 631a9fb0b..718224bd7 100644 --- a/internal/keychain/keychain_other.go +++ b/internal/keychain/keychain_other.go @@ -6,21 +6,11 @@ package keychain import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" "fmt" "os" "path/filepath" - "regexp" - - "github.com/google/uuid" ) -const masterKeyBytes = 32 -const ivBytes = 12 -const tagBytes = 16 - // StorageDir returns the storage directory for a given service name. // Each service gets its own directory for physical isolation. func StorageDir(service string) string { @@ -34,143 +24,38 @@ func StorageDir(service string) string { return filepath.Join(xdgData, service) } -var safeFileNameRe = regexp.MustCompile(`[^a-zA-Z0-9._-]`) - -func safeFileName(account string) string { - return safeFileNameRe.ReplaceAllString(account, "_") + ".enc" -} - +// getMasterKey loads or initializes the per-service master key on Linux. func getMasterKey(service string) ([]byte, error) { - dir := StorageDir(service) - keyPath := filepath.Join(dir, "master.key") - - key, err := os.ReadFile(keyPath) - if err == nil && len(key) == masterKeyBytes { - return key, nil - } - - if err := os.MkdirAll(dir, 0700); err != nil { - return nil, err - } - - key = make([]byte, masterKeyBytes) - if _, err := rand.Read(key); err != nil { - return nil, err - } - - tmpKeyPath := filepath.Join(dir, "master.key."+uuid.New().String()+".tmp") - defer os.Remove(tmpKeyPath) - - if err := os.WriteFile(tmpKeyPath, key, 0600); err != nil { - return nil, err - } - - // Atomic rename to prevent multi-process master key initialization collision - if err := os.Rename(tmpKeyPath, keyPath); err != nil { - // If rename fails, another process might have created it. Try reading again. - existingKey, readErr := os.ReadFile(keyPath) - if readErr == nil && len(existingKey) == masterKeyBytes { - return existingKey, nil - } - return nil, err - } - - return key, nil -} - -func encryptData(plaintext string, key []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - - iv := make([]byte, ivBytes) - if _, err := rand.Read(iv); err != nil { - return nil, err - } - - ciphertext := aesGCM.Seal(nil, iv, []byte(plaintext), nil) - result := make([]byte, 0, ivBytes+len(ciphertext)) - result = append(result, iv...) - result = append(result, ciphertext...) - return result, nil -} - -func decryptData(data []byte, key []byte) (string, error) { - if len(data) < ivBytes+tagBytes { - return "", os.ErrInvalid - } - block, err := aes.NewCipher(key) - if err != nil { - return "", err - } - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - iv := data[:ivBytes] - ciphertext := data[ivBytes:] - plaintext, err := aesGCM.Open(nil, iv, ciphertext, nil) - if err != nil { - return "", err - } - return string(plaintext), nil + // Shared master-key file handling lives in file_encrypted_store.go. + // New code should reuse that helper layer instead of reimplementing key-file setup here. + return loadOrCreateMasterKeyFile(StorageDir(service)) } +// platformGet reads a service/account secret from the Linux encrypted-file backend. func platformGet(service, account string) string { key, err := getMasterKey(service) if err != nil { return "" } - data, err := os.ReadFile(filepath.Join(StorageDir(service), safeFileName(account))) - if err != nil { - return "" - } - plaintext, err := decryptData(data, key) - if err != nil { - return "" - } - return plaintext + // Shared encrypted-file read semantics live in file_encrypted_store.go. + // New code should reuse that helper layer instead of reimplementing file I/O here. + return readEncryptedFile(StorageDir(service), account, key) } +// platformSet writes a service/account secret through the Linux encrypted-file backend. func platformSet(service, account, data string) error { key, err := getMasterKey(service) if err != nil { return err } - dir := StorageDir(service) - if err := os.MkdirAll(dir, 0700); err != nil { - return err - } - encrypted, err := encryptData(data, key) - if err != nil { - return err - } - - targetPath := filepath.Join(dir, safeFileName(account)) - tmpPath := filepath.Join(dir, safeFileName(account)+"."+uuid.New().String()+".tmp") - defer os.Remove(tmpPath) - - if err := os.WriteFile(tmpPath, encrypted, 0600); err != nil { - return err - } - - // Atomic rename to prevent file corruption during multi-process writes - if err := os.Rename(tmpPath, targetPath); err != nil { - return err - } - return nil + // Shared encrypted-file write semantics live in file_encrypted_store.go. + // New code should reuse that helper layer instead of reimplementing file I/O here. + return writeEncryptedFile(StorageDir(service), account, data, key) } +// platformRemove deletes a service/account secret from the Linux encrypted-file backend. func platformRemove(service, account string) error { - err := os.Remove(filepath.Join(StorageDir(service), safeFileName(account))) - if err != nil && !os.IsNotExist(err) { - return err - } - return nil + // Shared encrypted-file cleanup semantics live in file_encrypted_store.go. + // New code should reuse that helper layer instead of reimplementing file I/O here. + return removeEncryptedFile(StorageDir(service), account) }