diff --git a/.github/workflows/releaser-internal.yaml b/.github/workflows/releaser-internal.yaml index 75e3ebe..ed5861d 100644 --- a/.github/workflows/releaser-internal.yaml +++ b/.github/workflows/releaser-internal.yaml @@ -56,7 +56,7 @@ jobs: VERSION=$(echo "$CURRENT_TAG" | sed 's/^v//' | sed 's/-beta-internal//') NEW_VERSION=$(semver "$VERSION" -i patch) - NEW_TAG="v${NEW_VERSION}" + NEW_TAG="v${NEW_VERSION}-beta-internal" echo "new_tag=$NEW_TAG" >> $GITHUB_OUTPUT echo "New tag: $NEW_TAG" diff --git a/.goreleaser-internal.yaml b/.goreleaser-internal.yaml index 950ea51..db5154f 100644 --- a/.goreleaser-internal.yaml +++ b/.goreleaser-internal.yaml @@ -24,7 +24,6 @@ release: github: owner: berrybytes name: awsctl - name_template: "{{.Version}}-beta-internal" extra_files: - glob: ./LICENSE - glob: ./README.md diff --git a/README.md b/README.md index 5637132..2b7f9ea 100644 --- a/README.md +++ b/README.md @@ -146,14 +146,14 @@ ssoSessions: The following table summarizes the available `awsctl` commands: -| Command | Description | -| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `awsctl sso setup` | Creates or updates an AWS SSO profile. If a config file is available at `~/.config/awsctl/`, it will be used; otherwise, you will be prompted to enter the SSO Start URL and Region. The selected profile is then set as the default and authenticated. | -| `awsctl sso init` | Starts SSO authentication by allowing you to select from existing AWS SSO profiles (created via `awsctl sso setup`). Useful for switching between multiple configured SSO profiles. | -| `awsctl bastion` | Manages SSH/SSM connections, SOCKS proxy, or port forwarding to bastion hosts or EC2 instances. | -| `awsctl rds` | Connects to RDS databases directly or via SSH/SSM tunnels. | -| `awsctl eks` | Updates kubeconfig for accessing Amazon EKS clusters. | -| `awsctl ecr` | Authenticates to Amazon ECR for container image operations. | +| Command | Description | +| ------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `awsctl sso setup` | Creates/updates AWS SSO profiles. Supports flags: `--name`, `--start-url`, `--region` for non-interactive setup. Uses `~/.config/awsctl/config.yml` if available; otherwise, you will be prompted to enter the SSO Start URL, Region and SSO Name. The selected profile is then set as the default and authenticated. | +| `awsctl sso init` | Starts SSO authentication by allowing you to select from existing AWS SSO profiles (created via `awsctl sso setup`). Useful for switching between multiple configured SSO profiles. | +| `awsctl bastion` | Manages SSH/SSM connections, SOCKS proxy, or port forwarding to bastion hosts or EC2 instances. | +| `awsctl rds` | Connects to RDS databases directly or via SSH/SSM tunnels. | +| `awsctl eks` | Updates kubeconfig for accessing Amazon EKS clusters. | +| `awsctl ecr` | Authenticates to Amazon ECR for container image operations. | #### For detailed CLI command usage, see [Command Usage Documentation](docs/usage/commands.md). diff --git a/cmd/sso/setup.go b/cmd/sso/setup.go index eb942c2..21554f3 100644 --- a/cmd/sso/setup.go +++ b/cmd/sso/setup.go @@ -3,19 +3,45 @@ package sso import ( "errors" "fmt" + "strings" "github.com/BerryBytes/awsctl/internal/sso" + generalutils "github.com/BerryBytes/awsctl/utils/general" promptUtils "github.com/BerryBytes/awsctl/utils/prompt" "github.com/spf13/cobra" ) func SetupCmd(ssoClient sso.SSOClient) *cobra.Command { - return &cobra.Command{ + var startURL string + var region string + var name string + + cmd := &cobra.Command{ Use: "setup", Short: "Setup AWS SSO configuration", RunE: func(cmd *cobra.Command, args []string) error { - err := ssoClient.SetupSSO() + if startURL != "" && !strings.HasPrefix(startURL, "https://") { + return fmt.Errorf("invalid start URL: must begin with https://") + } + + if region != "" { + if !generalutils.IsRegionValid(region) { + return fmt.Errorf("invalid AWS region: %s", region) + } + } + + if name != "" && !generalutils.IsValidSessionName(name) { + return fmt.Errorf("invalid session name: must only contain letters, numbers, dashes, or underscores, and cannot start or end with a dash/underscore") + } + + opts := sso.SSOFlagOptions{ + StartURL: startURL, + Region: region, + Name: name, + } + + err := ssoClient.SetupSSO(opts) if err != nil { if errors.Is(err, promptUtils.ErrInterrupted) { return nil @@ -25,4 +51,10 @@ func SetupCmd(ssoClient sso.SSOClient) *cobra.Command { return nil }, } + + cmd.Flags().StringVar(&name, "name", "", "SSO session name") + cmd.Flags().StringVar(&startURL, "start-url", "", "AWS SSO Start URL") + cmd.Flags().StringVar(®ion, "region", "", "AWS SSO Region") + + return cmd } diff --git a/cmd/sso/setup_test.go b/cmd/sso/setup_test.go index 8bc39b3..203acb1 100644 --- a/cmd/sso/setup_test.go +++ b/cmd/sso/setup_test.go @@ -4,9 +4,9 @@ import ( "errors" "testing" + "github.com/BerryBytes/awsctl/internal/sso" mock_sso "github.com/BerryBytes/awsctl/tests/mock/sso" promptUtils "github.com/BerryBytes/awsctl/utils/prompt" - "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) @@ -18,30 +18,86 @@ func TestSetupCmd(t *testing.T) { mockSSOClient := mock_sso.NewMockSSOClient(ctrl) tests := []struct { - name string - mockSetup func() - expectedError string + name string + args []string + mockSetup func() + expectedError string + expectedOutput string }{ { - name: "successful setup", + name: "successful setup with no flags", + args: []string{}, + mockSetup: func() { + mockSSOClient.EXPECT().SetupSSO(sso.SSOFlagOptions{}).Return(nil) + }, + }, + { + name: "successful setup with all flags", + args: []string{"--name=test-session", "--start-url=https://test.awsapps.com/start", "--region=us-east-1"}, mockSetup: func() { - mockSSOClient.EXPECT().SetupSSO().Return(nil) + mockSSOClient.EXPECT().SetupSSO(sso.SSOFlagOptions{ + Name: "test-session", + StartURL: "https://test.awsapps.com/start", + Region: "us-east-1", + }).Return(nil) }, - expectedError: "", }, { name: "error during setup", + args: []string{}, mockSetup: func() { - mockSSOClient.EXPECT().SetupSSO().Return(errors.New("setup error")) + mockSSOClient.EXPECT().SetupSSO(sso.SSOFlagOptions{}).Return(errors.New("setup error")) }, - expectedError: "setup error", + expectedError: "SSO initialization failed: setup error", }, { name: "interrupted by user", + args: []string{}, mockSetup: func() { - mockSSOClient.EXPECT().SetupSSO().Return(promptUtils.ErrInterrupted) + mockSSOClient.EXPECT().SetupSSO(sso.SSOFlagOptions{}).Return(promptUtils.ErrInterrupted) + }, + }, + { + name: "invalid start URL format", + args: []string{"--start-url=invalid-url"}, + mockSetup: func() { + + }, + expectedError: "invalid start URL: must begin with https://", + }, + { + name: "invalid region", + args: []string{"--region=invalid-region"}, + mockSetup: func() { + + }, + expectedError: "invalid AWS region: invalid-region", + }, + { + name: "invalid session name", + args: []string{"--name=invalid-name-"}, + mockSetup: func() { + + }, + expectedError: "invalid session name: must only contain letters, numbers, dashes, or underscores, and cannot start or end with a dash/underscore", + }, + { + name: "partial flags - only name provided", + args: []string{"--name=valid-name"}, + mockSetup: func() { + mockSSOClient.EXPECT().SetupSSO(sso.SSOFlagOptions{ + Name: "valid-name", + }).Return(nil) + }, + }, + { + name: "partial flags - only region provided", + args: []string{"--region=us-west-2"}, + mockSetup: func() { + mockSSOClient.EXPECT().SetupSSO(sso.SSOFlagOptions{ + Region: "us-west-2", + }).Return(nil) }, - expectedError: "", }, } @@ -50,15 +106,91 @@ func TestSetupCmd(t *testing.T) { tt.mockSetup() cmd := SetupCmd(mockSSOClient) - cmd.SetArgs([]string{}) + cmd.SetArgs(tt.args) err := cmd.Execute() - if tt.expectedError == "" { - assert.NoError(t, err) + if tt.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) } else { + assert.NoError(t, err) + } + }) + } +} + +func TestSetupCmd_FlagValidation(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockSSOClient := mock_sso.NewMockSSOClient(ctrl) + + tests := []struct { + name string + args []string + expectCall bool + expectedError string + }{ + { + name: "valid start URL", + args: []string{"--start-url=https://valid.awsapps.com/start"}, + expectCall: true, + }, + { + name: "invalid start URL missing https", + args: []string{"--start-url=http://invalid.awsapps.com/start"}, + expectCall: false, + expectedError: "invalid start URL: must begin with https://", + }, + { + name: "invalid start URL format", + args: []string{"--start-url=invalid-format"}, + expectCall: false, + expectedError: "invalid start URL: must begin with https://", + }, + { + name: "valid region", + args: []string{"--region=eu-west-1"}, + expectCall: true, + }, + { + name: "invalid region", + args: []string{"--region=invalid-region"}, + expectCall: false, + expectedError: "invalid AWS region: invalid-region", + }, + { + name: "valid session name", + args: []string{"--name=valid_name-123"}, + expectCall: true, + }, + { + name: "invalid session name starts with dash", + args: []string{"--name=-invalid"}, + expectCall: false, + expectedError: "invalid session name: must only contain letters, numbers, dashes, or underscores, and cannot start or end with a dash/underscore", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.expectCall { + mockSSOClient.EXPECT().SetupSSO(gomock.Any()).Return(nil) + } + + cmd := SetupCmd(mockSSOClient) + cmd.SetArgs(tt.args) + cmd.SilenceErrors = true + cmd.SilenceUsage = true + + err := cmd.Execute() + + if tt.expectedError != "" { assert.Error(t, err) assert.Contains(t, err.Error(), tt.expectedError) + } else { + assert.NoError(t, err) } }) } diff --git a/docs/usage/commands.md b/docs/usage/commands.md index 08d651b..c851a8b 100644 --- a/docs/usage/commands.md +++ b/docs/usage/commands.md @@ -6,11 +6,61 @@ Creates or updates AWS SSO profiles. -- Prompts for: - - SSO Start URL - - AWS Region -- Uses defaults from `~/.config/awsctl/config.yml` if available. -- Automatically sets the profile as default and authenticates it after setup. +#### Basic Usage + +```bash +awsctl sso setup [flags] +``` + +#### Flags + +| Flag | Description | Required | Example | +| ------------- | ---------------------------------------------- | -------- | ---------------------------------------------- | +| `--name` | SSO session name | No | `--name my-sso-session` | +| `--start-url` | AWS SSO start URL (must begin with `https://`) | No | `--start-url https://my-sso.awsapps.com/start` | +| `--region` | AWS region for the SSO session | No | `--region us-east-1` | + +#### Behavior + +- **Interactive Mode** (default when no flags): + + - Prompts for: + - SSO Start URL + - AWS Region + - Session name (default: "default-sso") + - Uses defaults from `~/.config/awsctl/config.yml` if available + - Validates all inputs before creating session + +- **Non-interactive Mode** (when all flags provided): + + - Creates session immediately without prompts + - Validates: + - Start URL format (`https://`) + - Valid AWS region + - Proper session name format + +#### Examples + +1. Fully interactive: + +```bash +awsctl sso setup +``` + +2. Fully non-interactive: + +```bash +awsctl sso setup --name dev-session --start-url https://dev.awsapps.com/start --region us-east-1 +``` + +#### Validation Rules + +- `--start-url`: Must begin with `https://` +- `--region`: Must be valid AWS region code +- `--name`: + - Alphanumeric with dashes/underscores + - Cannot start/end with special chars + - 3-64 characters --- diff --git a/internal/sso/client.go b/internal/sso/client.go index a4fa885..c80984a 100644 --- a/internal/sso/client.go +++ b/internal/sso/client.go @@ -23,6 +23,11 @@ type RealSSOClient struct { Prompter Prompter Executor common.CommandExecutor } +type SSOFlagOptions struct { + StartURL string + Region string + Name string +} func NewSSOClient(prompter Prompter, executor common.CommandExecutor) (SSOClient, error) { if prompter == nil { @@ -42,7 +47,7 @@ func NewSSOClient(prompter Prompter, executor common.CommandExecutor) (SSOClient }, nil } -func (c *RealSSOClient) getSsoAccessTokenFromCache(profile string) (*models.SSOCache, time.Time, error) { +func (c *RealSSOClient) GetSsoAccessTokenFromCache(profile string) (*models.SSOCache, time.Time, error) { startURL, err := c.ConfigureGet("sso_start_url", profile) if err != nil { return nil, time.Time{}, fmt.Errorf("failed to get sso_start_url for profile %s: %v", profile, err) @@ -112,7 +117,7 @@ func (c *RealSSOClient) getSsoAccessTokenFromCache(profile string) (*models.SSOC if err != nil { return nil, time.Time{}, fmt.Errorf("SSO login failed: %v", err) } - selectedCache, expiryTime, err = c.getSsoAccessTokenFromCache(profile) + selectedCache, expiryTime, err = c.GetSsoAccessTokenFromCache(profile) if err != nil { return nil, time.Time{}, fmt.Errorf("failed to get token after re-login: %v", err) } @@ -130,7 +135,7 @@ func (c *RealSSOClient) GetCachedSsoAccessToken(profile string) (string, time.Ti return c.TokenCache.AccessToken, c.TokenCache.Expiry, nil } - cachedSSO, expiry, err := c.getSsoAccessTokenFromCache(profile) + cachedSSO, expiry, err := c.GetSsoAccessTokenFromCache(profile) if err != nil { return "", time.Time{}, err } diff --git a/internal/sso/client_test.go b/internal/sso/client_test.go index f600e63..eb7cbfc 100644 --- a/internal/sso/client_test.go +++ b/internal/sso/client_test.go @@ -1,4 +1,4 @@ -package sso +package sso_test import ( "context" @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/BerryBytes/awsctl/internal/sso" "github.com/BerryBytes/awsctl/models" mock_awsctl "github.com/BerryBytes/awsctl/tests/mock" mock_sso "github.com/BerryBytes/awsctl/tests/mock/sso" @@ -27,13 +28,13 @@ func TestNewSSOClient(t *testing.T) { mockPrompter := mock_sso.NewMockPrompter(ctrl) mockExecutor := mock_awsctl.NewMockCommandExecutor(ctrl) - client, err := NewSSOClient(mockPrompter, mockExecutor) + client, err := sso.NewSSOClient(mockPrompter, mockExecutor) assert.NoError(t, err) assert.NotNil(t, client) }) t.Run("nil prompter returns error", func(t *testing.T) { - client, err := NewSSOClient(nil, nil) + client, err := sso.NewSSOClient(nil, nil) assert.Error(t, err) assert.Nil(t, client) assert.Equal(t, "prompter cannot be nil", err.Error()) @@ -46,10 +47,10 @@ func TestNewSSOClient_NilExecutor(t *testing.T) { mockPrompter := mock_sso.NewMockPrompter(ctrl) - client, err := NewSSOClient(mockPrompter, nil) + client, err := sso.NewSSOClient(mockPrompter, nil) assert.NoError(t, err) assert.NotNil(t, client) - assert.IsType(t, &common.RealCommandExecutor{}, client.(*RealSSOClient).Executor) + assert.IsType(t, &common.RealCommandExecutor{}, client.(*sso.RealSSOClient).Executor) }) } @@ -61,7 +62,7 @@ func TestGetCachedSsoAccessToken(t *testing.T) { mockExecutor := mock_awsctl.NewMockCommandExecutor(ctrl) t.Run("returns cached token if valid", func(t *testing.T) { - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, TokenCache: models.TokenCache{ @@ -96,7 +97,7 @@ func TestGetCachedSsoAccessToken(t *testing.T) { mockExecutor.EXPECT().RunCommand("aws", "configure", "get", "sso_start_url", "--profile", "test-profile"). Return([]byte("https://example.awsapps.com/start"), nil) - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, TokenCache: models.TokenCache{ @@ -127,7 +128,7 @@ func TestAwsSTSGetCallerIdentity(t *testing.T) { mockExecutor := mock_awsctl.NewMockCommandExecutor(ctrl) t.Run("successful identity retrieval", func(t *testing.T) { - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } @@ -147,7 +148,7 @@ func TestAwsSTSGetCallerIdentity(t *testing.T) { }) t.Run("command failure", func(t *testing.T) { - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } @@ -218,7 +219,7 @@ func TestSSOLogin(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.setup() - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } @@ -242,7 +243,7 @@ func TestGetSSOAccountName(t *testing.T) { tests := []struct { name string - setup func(client *RealSSOClient) + setup func(client *sso.RealSSOClient) accountID string profile string expectedName string @@ -251,7 +252,7 @@ func TestGetSSOAccountName(t *testing.T) { }{ { name: "successful account name retrieval with cache hit", - setup: func(client *RealSSOClient) { + setup: func(client *sso.RealSSOClient) { client.TokenCache.AccessToken = "test-token" client.TokenCache.Expiry = time.Now().Add(1 * time.Hour) @@ -272,7 +273,7 @@ func TestGetSSOAccountName(t *testing.T) { }, { name: "successful account name retrieval with cache miss", - setup: func(client *RealSSOClient) { + setup: func(client *sso.RealSSOClient) { tempDir := t.TempDir() cacheDir := filepath.Join(tempDir, ".aws", "sso", "cache") require.NoError(t, os.MkdirAll(cacheDir, 0755)) @@ -315,7 +316,7 @@ func TestGetSSOAccountName(t *testing.T) { }, { name: "account not found", - setup: func(client *RealSSOClient) { + setup: func(client *sso.RealSSOClient) { tempDir := t.TempDir() cacheDir := filepath.Join(tempDir, ".aws", "sso", "cache") require.NoError(t, os.MkdirAll(cacheDir, 0755)) @@ -358,7 +359,7 @@ func TestGetSSOAccountName(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Executor: mockExecutor, Prompter: mockPrompter, } @@ -387,7 +388,7 @@ func TestGetRoleCredentials(t *testing.T) { mockExecutor := mock_awsctl.NewMockCommandExecutor(ctrl) t.Run("successful credentials retrieval", func(t *testing.T) { - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } @@ -423,7 +424,7 @@ func TestGetRoleCredentials(t *testing.T) { }) t.Run("command failure", func(t *testing.T) { - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } @@ -440,7 +441,7 @@ func TestGetRoleCredentials(t *testing.T) { }) t.Run("invalid JSON response", func(t *testing.T) { - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } @@ -469,7 +470,7 @@ func TestGetCachedSsoAccessToken_ErrorCases(t *testing.T) { RunCommand("aws", "configure", "get", "sso_start_url", "--profile", "bad-profile"). Return(nil, errors.New("config error")) - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } @@ -505,7 +506,7 @@ func TestGetCachedSsoAccessToken_ErrorCases(t *testing.T) { _ = os.Setenv("HOME", oldHome) }) - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } @@ -524,7 +525,7 @@ func TestGetSSOAccountName_ErrorCases(t *testing.T) { mockExecutor := mock_awsctl.NewMockCommandExecutor(ctrl) t.Run("error getting access token", func(t *testing.T) { - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } @@ -539,7 +540,7 @@ func TestGetSSOAccountName_ErrorCases(t *testing.T) { }) t.Run("error listing accounts", func(t *testing.T) { - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, TokenCache: models.TokenCache{ @@ -558,7 +559,7 @@ func TestGetSSOAccountName_ErrorCases(t *testing.T) { }) t.Run("invalid accounts JSON", func(t *testing.T) { - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, TokenCache: models.TokenCache{ @@ -586,14 +587,14 @@ func TestGetSsoAccessTokenFromCache_ErrorCases(t *testing.T) { tests := []struct { name string - setup func(*RealSSOClient) + setup func(*sso.RealSSOClient) profile string expectedError string expectError bool }{ { name: "error getting start URL", - setup: func(client *RealSSOClient) { + setup: func(client *sso.RealSSOClient) { mockExecutor.EXPECT(). RunCommand("aws", "configure", "get", "sso_start_url", "--profile", "bad-profile"). Return(nil, errors.New("config error")) @@ -604,7 +605,7 @@ func TestGetSsoAccessTokenFromCache_ErrorCases(t *testing.T) { }, { name: "error getting home directory", - setup: func(client *RealSSOClient) { + setup: func(client *sso.RealSSOClient) { mockExecutor.EXPECT(). RunCommand("aws", "configure", "get", "sso_start_url", "--profile", "test-profile"). Return([]byte("https://example.com"), nil) @@ -623,7 +624,7 @@ func TestGetSsoAccessTokenFromCache_ErrorCases(t *testing.T) { }, { name: "cache directory not exists", - setup: func(client *RealSSOClient) { + setup: func(client *sso.RealSSOClient) { mockExecutor.EXPECT(). RunCommand("aws", "configure", "get", "sso_start_url", "--profile", "test-profile"). Return([]byte("https://example.com"), nil) @@ -643,7 +644,7 @@ func TestGetSsoAccessTokenFromCache_ErrorCases(t *testing.T) { }, { name: "no matching cache file", - setup: func(client *RealSSOClient) { + setup: func(client *sso.RealSSOClient) { mockExecutor.EXPECT(). RunCommand("aws", "configure", "get", "sso_start_url", "--profile", "test-profile"). Return([]byte("https://example.com"), nil) @@ -677,13 +678,13 @@ func TestGetSsoAccessTokenFromCache_ErrorCases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } tt.setup(client) - cache, _, err := client.getSsoAccessTokenFromCache(tt.profile) + cache, _, err := client.GetSsoAccessTokenFromCache(tt.profile) if tt.expectError { require.Error(t, err) diff --git a/internal/sso/interface.go b/internal/sso/interface.go index 8e031de..2684a72 100644 --- a/internal/sso/interface.go +++ b/internal/sso/interface.go @@ -7,7 +7,7 @@ import ( ) type SSOClient interface { - SetupSSO() error + SetupSSO(opts SSOFlagOptions) error InitSSO(refresh, noBrowser bool) error ConfigureSet(key, value, profile string) error ConfigureGet(key, profile string) (string, error) diff --git a/internal/sso/profile.go b/internal/sso/profile.go index bbd2be5..2e763f9 100644 --- a/internal/sso/profile.go +++ b/internal/sso/profile.go @@ -28,12 +28,12 @@ func (c *RealSSOClient) ConfigureSSOProfile(profile, region, accountID, role, ss return nil } -func (c *RealSSOClient) configureAWSProfile(profileName, sessionName, ssoRegion, ssoStartURL, accountID, roleName, region string) error { +func (c *RealSSOClient) ConfigureAWSProfile(profileName, sessionName, ssoRegion, ssoStartURL, accountID, roleName, region string) error { ssoStartURL = strings.TrimSuffix(ssoStartURL, "#") - if err := validateStartURL(ssoStartURL); err != nil { + if err := ValidateStartURL(ssoStartURL); err != nil { return fmt.Errorf("invalid start URL: %w", err) } - if err := validateAccountID(accountID); err != nil { + if err := ValidateAccountID(accountID); err != nil { return fmt.Errorf("invalid account ID: %w", err) } @@ -150,7 +150,7 @@ func (c *RealSSOClient) configureAWSProfile(profileName, sessionName, ssoRegion, return nil } -func (c *RealSSOClient) promptProfileDetails(ssoRegion string) (string, string, error) { +func (c *RealSSOClient) PromptProfileDetails(ssoRegion string) (string, string, error) { profileName, err := c.Prompter.PromptWithDefault("Enter profile name to configure", "sso-profile") if err != nil { return "", "", fmt.Errorf("failed to prompt for profile name: %w", err) @@ -193,7 +193,7 @@ func (c *RealSSOClient) setProfileAsDefault(profile string) error { region = ssoRegion } - if err := c.configureAWSProfile("default", sessionName, ssoRegion, ssoStartURL, accountID, roleName, region); err != nil { + if err := c.ConfigureAWSProfile("default", sessionName, ssoRegion, ssoStartURL, accountID, roleName, region); err != nil { return fmt.Errorf("failed to configure AWS default profile: %w", err) } fmt.Println("Successfully set this profile as default!") @@ -247,6 +247,6 @@ func (c *RealSSOClient) printProfileSummary(profile string) error { expiration = expiry.Format(time.RFC3339) } - printSummary(profile, sessionName, ssoStartURL, ssoRegion, accountID, roleName, accountName, roleARN, expiration) + PrintSummary(profile, sessionName, ssoStartURL, ssoRegion, accountID, roleName, accountName, roleARN, expiration) return nil } diff --git a/internal/sso/profile_test.go b/internal/sso/profile_test.go index cdd8fd4..c528a39 100644 --- a/internal/sso/profile_test.go +++ b/internal/sso/profile_test.go @@ -1,4 +1,4 @@ -package sso +package sso_test import ( "errors" @@ -6,6 +6,7 @@ import ( "path/filepath" "testing" + "github.com/BerryBytes/awsctl/internal/sso" mock_awsctl "github.com/BerryBytes/awsctl/tests/mock" mock_sso "github.com/BerryBytes/awsctl/tests/mock/sso" "github.com/golang/mock/gomock" @@ -29,7 +30,7 @@ func TestConfigureSSOProfile(t *testing.T) { mockExecutor.EXPECT().RunCommand("aws", "configure", "set", "region", "us-west-2", "--profile", "test-profile").Return(nil, nil) mockExecutor.EXPECT().RunCommand("aws", "configure", "set", "output", "json", "--profile", "test-profile").Return(nil, nil) - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } @@ -45,7 +46,7 @@ func TestConfigureSSOProfile(t *testing.T) { mockExecutor.EXPECT().RunCommand(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } @@ -180,12 +181,12 @@ func TestConfigureAWSProfile(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.setup() - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Executor: mockExecutor, } - err := client.configureAWSProfile(tt.profileName, tt.sessionName, tt.ssoRegion, tt.ssoStartURL, tt.accountID, tt.roleName, tt.region) + err := client.ConfigureAWSProfile(tt.profileName, tt.sessionName, tt.ssoRegion, tt.ssoStartURL, tt.accountID, tt.roleName, tt.region) if tt.expectError { require.Error(t, err) @@ -257,11 +258,11 @@ func TestPromptProfileDetails(t *testing.T) { Return(tt.mockResponses[2], tt.mockResponses[3]) } - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, } - name, region, err := client.promptProfileDetails(tt.ssoRegion) + name, region, err := client.PromptProfileDetails(tt.ssoRegion) if tt.expectError { assert.Error(t, err) diff --git a/internal/sso/prompter.go b/internal/sso/prompter.go index a8c9e58..3eb5ec7 100644 --- a/internal/sso/prompter.go +++ b/internal/sso/prompter.go @@ -30,7 +30,7 @@ func (r *RealPromptRunner) RunSelect(label string, items []string) (string, erro return result, err } -var validateStartURLFunc = func(input string) error { +var ValidateStartURLFunc = func(input string) error { if !strings.HasPrefix(input, "http://") && !strings.HasPrefix(input, "https://") { return fmt.Errorf("invalid URL format") } @@ -39,7 +39,7 @@ var validateStartURLFunc = func(input string) error { type PromptUI struct { Prompt promptUtils.Prompter - runner PromptRunner + Runner PromptRunner } func handlePromptError(err error) error { @@ -60,7 +60,7 @@ func (p *PromptUI) PromptWithDefault(label, defaultValue string) (string, error) } return nil } - result, err := p.runner.RunPrompt(label, defaultValue, validate) + result, err := p.Runner.RunPrompt(label, defaultValue, validate) err = handlePromptError(err) if err != nil { return "", err @@ -89,9 +89,9 @@ func (p *PromptUI) PromptRequired(label string) (string, error) { if strings.TrimSpace(input) == "" { return fmt.Errorf("input is required") } - return validateStartURLFunc(input) + return ValidateStartURLFunc(input) } - result, err := p.runner.RunPrompt(label, "", validate) + result, err := p.Runner.RunPrompt(label, "", validate) err = handlePromptError(err) if err != nil { return "", err @@ -100,7 +100,7 @@ func (p *PromptUI) PromptRequired(label string) (string, error) { } func (p *PromptUI) SelectFromList(label string, items []string) (string, error) { - result, err := p.runner.RunSelect(label, items) + result, err := p.Runner.RunSelect(label, items) err = handlePromptError(err) if err != nil { return "", err @@ -120,7 +120,7 @@ func (p *PromptUI) PromptYesNo(label string, defaultValue bool) (bool, error) { } return nil } - result, err := p.runner.RunPrompt(label, defaultStr, validate) + result, err := p.Runner.RunPrompt(label, defaultStr, validate) err = handlePromptError(err) if err != nil { return false, err @@ -133,5 +133,5 @@ func (p *PromptUI) PromptYesNo(label string, defaultValue bool) (bool, error) { } func NewPrompter() Prompter { - return &PromptUI{runner: &RealPromptRunner{}, Prompt: promptUtils.NewPrompt()} + return &PromptUI{Runner: &RealPromptRunner{}, Prompt: promptUtils.NewPrompt()} } diff --git a/internal/sso/prompter_test.go b/internal/sso/prompter_test.go index 12ea4df..5ccb549 100644 --- a/internal/sso/prompter_test.go +++ b/internal/sso/prompter_test.go @@ -1,4 +1,4 @@ -package sso +package sso_test import ( "errors" @@ -6,6 +6,7 @@ import ( "strings" "testing" + "github.com/BerryBytes/awsctl/internal/sso" mock_awsctl "github.com/BerryBytes/awsctl/tests/mock" mock_sso "github.com/BerryBytes/awsctl/tests/mock/sso" promptUtils "github.com/BerryBytes/awsctl/utils/prompt" @@ -101,7 +102,7 @@ func TestPromptUI_PromptWithDefault(t *testing.T) { RunPrompt(tt.label, tt.defaultValue, gomock.Any()). Return(tt.input, tt.inputErr) - p := &PromptUI{runner: mockRunner} + p := &sso.PromptUI{Runner: mockRunner} result, err := p.PromptWithDefault(tt.label, tt.defaultValue) if tt.wantErr { @@ -173,7 +174,7 @@ func TestPromptUI_SelectFromList(t *testing.T) { RunSelect(tt.label, tt.items). Return(tt.input, tt.inputErr) - p := &PromptUI{runner: mockRunner} + p := &sso.PromptUI{Runner: mockRunner} result, err := p.SelectFromList(tt.label, tt.items) if tt.wantErr { @@ -282,7 +283,7 @@ func TestPromptUI_PromptYesNo(t *testing.T) { RunPrompt(tt.label, defaultStr, gomock.Any()). Return(tt.input, tt.inputErr) - p := &PromptUI{runner: mockRunner} + p := &sso.PromptUI{Runner: mockRunner} result, err := p.PromptYesNo(tt.label, tt.defaultValue) if tt.wantErr { @@ -300,9 +301,9 @@ func TestPromptUI_PromptYesNo(t *testing.T) { } func TestPromptUI_PromptRequired(t *testing.T) { - originalValidateStartURL := validateStartURLFunc - validateStartURLFunc = mockValidateStartURL - defer func() { validateStartURLFunc = originalValidateStartURL }() + originalValidateStartURL := sso.ValidateStartURLFunc + sso.ValidateStartURLFunc = mockValidateStartURL + defer func() { sso.ValidateStartURLFunc = originalValidateStartURL }() tests := []struct { name string @@ -369,7 +370,7 @@ func TestPromptUI_PromptRequired(t *testing.T) { RunPrompt(tt.label, "", gomock.Any()). Return(tt.input, tt.inputErr) - p := &PromptUI{runner: mockRunner} + p := &sso.PromptUI{Runner: mockRunner} result, err := p.PromptRequired(tt.label) if tt.wantErr { @@ -466,7 +467,7 @@ func TestPromptUI_PromptForRegion(t *testing.T) { Return(tt.input, tt.inputErr) } - p := &PromptUI{Prompt: mockPrompter} + p := &sso.PromptUI{Prompt: mockPrompter} result, err := p.PromptForRegion(tt.defaultRegion) if tt.wantErr { @@ -522,7 +523,7 @@ func TestValidateStartURLFunc(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validateStartURLFunc(tt.input) + err := sso.ValidateStartURLFunc(tt.input) if tt.wantErr { assert.Error(t, err) if tt.errContains != "" { @@ -536,8 +537,8 @@ func TestValidateStartURLFunc(t *testing.T) { } func TestNewPrompter(t *testing.T) { - prompter := NewPrompter() + prompter := sso.NewPrompter() assert.NotNil(t, prompter) - _, ok := prompter.(*PromptUI) - assert.True(t, ok, "NewPrompter should return a *PromptUI") + _, ok := prompter.(*sso.PromptUI) + assert.True(t, ok, "NewPrompter should return a *sso.PromptUI") } diff --git a/internal/sso/session.go b/internal/sso/session.go index 4c472b8..b309e3a 100644 --- a/internal/sso/session.go +++ b/internal/sso/session.go @@ -15,73 +15,79 @@ import ( promptUtils "github.com/BerryBytes/awsctl/utils/prompt" ) -func (c *RealSSOClient) loadOrCreateSession() (string, *models.SSOSession, error) { +func (c *RealSSOClient) LoadOrCreateSession(name, startURL, region string) (string, *models.SSOSession, error) { configPath, err := config.FindConfigFile(&c.Config) if err != nil && !errors.Is(err, config.ErrNoConfigFile) { return "", nil, fmt.Errorf("failed to check config file: %w", err) } - var ssoSession *models.SSOSession - if configPath != "" { - fileInfo, err := os.Stat(configPath) - if err != nil { - return "", nil, fmt.Errorf("failed to stat config file: %w", err) + if name != "" && startURL != "" && region != "" { + ssoSession := &models.SSOSession{ + Name: name, + StartURL: strings.TrimSuffix(startURL, "#"), + Region: region, + Scopes: "sso:account:access", } - if fileInfo.Size() > 0 && len(c.Config.RawCustomConfig.SSOSessions) > 0 { - fmt.Printf("Loaded existing configuration from '%s'\n", configPath) - if len(c.Config.RawCustomConfig.SSOSessions) == 1 { - ssoSession = &c.Config.RawCustomConfig.SSOSessions[0] - ssoSession.StartURL = strings.TrimSuffix(ssoSession.StartURL, "#") - if ssoSession.Scopes == "" { - ssoSession.Scopes = "sso:account:access" - } - fmt.Printf("Using SSO session: %s (Start URL: %s, Region: %s)\n", - ssoSession.Name, ssoSession.StartURL, ssoSession.Region) - } else { - ssoSession, err = c.selectSSOSession() - - if err != nil { - if errors.Is(err, promptUtils.ErrInterrupted) { - return "", nil, promptUtils.ErrInterrupted - } - return "", nil, fmt.Errorf("failed to select SSO session: %w", err) - } + + c.Config.RawCustomConfig.SSOSessions = append(c.Config.RawCustomConfig.SSOSessions, *ssoSession) + return configPath, ssoSession, nil + } + + if configPath != "" && len(c.Config.RawCustomConfig.SSOSessions) > 0 { + fmt.Printf("Loaded existing configuration from '%s'\n", configPath) + + if len(c.Config.RawCustomConfig.SSOSessions) == 1 && name == "" && startURL == "" && region == "" { + ssoSession := &c.Config.RawCustomConfig.SSOSessions[0] + ssoSession.StartURL = strings.TrimSuffix(ssoSession.StartURL, "#") + if ssoSession.Scopes == "" { + ssoSession.Scopes = "sso:account:access" } + fmt.Printf("Using SSO session: %s (Start URL: %s, Region: %s)\n", + ssoSession.Name, ssoSession.StartURL, ssoSession.Region) + return configPath, ssoSession, nil } - } - if ssoSession == nil || ssoSession.Name == "" || ssoSession.StartURL == "" || ssoSession.Region == "" { - fmt.Println("Setting up a new AWS SSO configuration...") - name, err := c.Prompter.PromptWithDefault("SSO session name", "default-sso") + ssoSession, err := c.SelectSSOSession() if err != nil { - return "", nil, fmt.Errorf("failed to prompt for SSO session name: %w", err) + if errors.Is(err, promptUtils.ErrInterrupted) { + return "", nil, promptUtils.ErrInterrupted + } + return "", nil, fmt.Errorf("failed to select SSO session: %w", err) } - startURL, err := c.Prompter.PromptRequired("SSO start URL (e.g., https://my-sso-portal.awsapps.com/start)") - if err != nil { - return "", nil, fmt.Errorf("failed to prompt for SSO start URL: %w", err) + if ssoSession != nil { + return configPath, ssoSession, nil } + } - region, err := c.Prompter.PromptForRegion("us-east-1") - if err != nil { - return "", nil, fmt.Errorf("failed to prompt for SSO region: %w", err) - } + fmt.Println("Setting up a new AWS SSO configuration...") - scopes := "sso:account:access" + name, err = c.Prompter.PromptWithDefault("SSO session name", "default-sso") + if err != nil { + return "", nil, fmt.Errorf("failed to prompt for SSO session name: %w", err) + } - ssoSession = &models.SSOSession{ - Name: name, - StartURL: strings.TrimSuffix(startURL, "#"), - Region: region, - Scopes: scopes, - } + startURL, err = c.Prompter.PromptRequired("SSO start URL (e.g., https://my-sso-portal.awsapps.com/start)") + if err != nil { + return "", nil, fmt.Errorf("failed to prompt for SSO start URL: %w", err) + } - c.Config.RawCustomConfig.SSOSessions = append(c.Config.RawCustomConfig.SSOSessions, *ssoSession) + region, err = c.Prompter.PromptForRegion("us-east-1") + if err != nil { + return "", nil, fmt.Errorf("failed to prompt for SSO region: %w", err) + } + + ssoSession := &models.SSOSession{ + Name: name, + StartURL: strings.TrimSuffix(startURL, "#"), + Region: region, + Scopes: "sso:account:access", } + c.Config.RawCustomConfig.SSOSessions = append(c.Config.RawCustomConfig.SSOSessions, *ssoSession) return configPath, ssoSession, nil } -func (c *RealSSOClient) selectSSOSession() (*models.SSOSession, error) { +func (c *RealSSOClient) SelectSSOSession() (*models.SSOSession, error) { options := make([]string, 0, len(c.Config.RawCustomConfig.SSOSessions)+1) sessionMap := make(map[string]*models.SSOSession) for i := range c.Config.RawCustomConfig.SSOSessions { @@ -123,7 +129,7 @@ func (c *RealSSOClient) selectSSOSession() (*models.SSOSession, error) { return ssoSession, nil } -func (c *RealSSOClient) configureSSOSession(sessionName, startURL, region, scopes string) error { +func (c *RealSSOClient) ConfigureSSOSession(sessionName, startURL, region, scopes string) error { fmt.Println("\nConfiguring AWS SSO session in ~/.aws/config...") startURL = strings.TrimSuffix(startURL, "#") @@ -225,7 +231,7 @@ func (c *RealSSOClient) configureSSOSession(sessionName, startURL, region, scope return nil } -func (c *RealSSOClient) runSSOLogin(sessionName string) error { +func (c *RealSSOClient) RunSSOLogin(sessionName string) error { if err := c.validateAWSConfig(sessionName); err != nil { return fmt.Errorf("invalid SSO configuration: %w", err) } @@ -259,7 +265,7 @@ func (c *RealSSOClient) validateAWSConfig(sessionName string) error { return nil } -func (c *RealSSOClient) getAccessToken(startURL string) (string, error) { +func (c *RealSSOClient) GetAccessToken(startURL string) (string, error) { startURL = strings.TrimSuffix(startURL, "#") homeDir, err := os.UserHomeDir() if err != nil { diff --git a/internal/sso/session_test.go b/internal/sso/session_test.go index 95cbc7c..c97bc2a 100644 --- a/internal/sso/session_test.go +++ b/internal/sso/session_test.go @@ -1,4 +1,4 @@ -package sso +package sso_test import ( "errors" @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/BerryBytes/awsctl/internal/sso" "github.com/BerryBytes/awsctl/internal/sso/config" "github.com/BerryBytes/awsctl/models" mock_awsctl "github.com/BerryBytes/awsctl/tests/mock" @@ -29,6 +30,9 @@ func TestLoadOrCreateSession(t *testing.T) { tests := []struct { name string initialConfig *models.Config + nameParam string + startURLParam string + regionParam string mockPrompts []mockPrompt wantSession *models.SSOSession wantConfigPath string @@ -36,7 +40,22 @@ func TestLoadOrCreateSession(t *testing.T) { errContains string }{ { - name: "Create new session successfully", + name: "Create new session with parameters", + initialConfig: &models.Config{ + SSOSessions: []models.SSOSession{}, + }, + nameParam: "test-session", + startURLParam: "https://test.awsapps.com/start", + regionParam: "us-west-2", + wantSession: &models.SSOSession{ + Name: "test-session", + StartURL: "https://test.awsapps.com/start", + Region: "us-west-2", + Scopes: "sso:account:access", + }, + }, + { + name: "Create new session interactively", initialConfig: &models.Config{ SSOSessions: []models.SSOSession{}, }, @@ -51,9 +70,7 @@ func TestLoadOrCreateSession(t *testing.T) { Region: "us-west-2", Scopes: "sso:account:access", }, - wantConfigPath: "", }, - { name: "Region prompt error", initialConfig: &models.Config{ @@ -76,7 +93,6 @@ func TestLoadOrCreateSession(t *testing.T) { mockPrompter := mock_sso.NewMockPrompter(ctrl) - // Set up expected mock calls for _, mp := range tt.mockPrompts { switch mp.method { case "PromptWithDefault": @@ -94,14 +110,14 @@ func TestLoadOrCreateSession(t *testing.T) { } } - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Config: config.Config{ RawCustomConfig: tt.initialConfig, }, } - configPath, session, err := client.loadOrCreateSession() + configPath, session, err := client.LoadOrCreateSession(tt.nameParam, tt.startURLParam, tt.regionParam) if tt.wantErr { assert.Error(t, err) @@ -124,7 +140,6 @@ func TestLoadOrCreateSession(t *testing.T) { }) } } - func TestSelectSSOSession(t *testing.T) { tests := []struct { name string @@ -223,14 +238,14 @@ func TestSelectSSOSession(t *testing.T) { } } - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Prompter: mockPrompter, Config: config.Config{ RawCustomConfig: tt.initialConfig, }, } - session, err := client.selectSSOSession() + session, err := client.SelectSSOSession() if tt.wantErr { assert.Error(t, err) @@ -358,11 +373,11 @@ sso_region = us-west-2 tt.mockExec(mockExecutor) } - client := &RealSSOClient{ + client := &sso.RealSSOClient{ Executor: mockExecutor, } - err := client.runSSOLogin(tt.sessionName) + err := client.RunSSOLogin(tt.sessionName) if tt.wantErr { assert.Error(t, err) @@ -460,9 +475,9 @@ func TestGetAccessToken(t *testing.T) { } }() - client := &RealSSOClient{} + client := &sso.RealSSOClient{} - token, err := client.getAccessToken(tt.startURL) + token, err := client.GetAccessToken(tt.startURL) if tt.wantErr { assert.Error(t, err) @@ -567,9 +582,9 @@ sso_registration_scopes = sso:account:access } }() - client := &RealSSOClient{} + client := &sso.RealSSOClient{} - err := client.configureSSOSession(tt.sessionName, tt.startURL, tt.region, tt.scopes) + err := client.ConfigureSSOSession(tt.sessionName, tt.startURL, tt.region, tt.scopes) if tt.wantErr { assert.Error(t, err) diff --git a/internal/sso/sso.go b/internal/sso/sso.go index 63c1def..f1b708b 100644 --- a/internal/sso/sso.go +++ b/internal/sso/sso.go @@ -8,11 +8,11 @@ import ( promptUtils "github.com/BerryBytes/awsctl/utils/prompt" ) -func (c *RealSSOClient) SetupSSO() error { +func (c *RealSSOClient) SetupSSO(opts SSOFlagOptions) error { fmt.Println("AWS SSO Configuration Tool") fmt.Println("-------------------------") - _, ssoSession, err := c.loadOrCreateSession() + _, ssoSession, err := c.LoadOrCreateSession(opts.Name, opts.StartURL, opts.Region) if err != nil { if errors.Is(err, promptUtils.ErrInterrupted) { return nil @@ -20,11 +20,11 @@ func (c *RealSSOClient) SetupSSO() error { return fmt.Errorf("failed to load or create session: %w", err) } - if err := c.configureSSOSession(ssoSession.Name, ssoSession.StartURL, ssoSession.Region, ssoSession.Scopes); err != nil { + if err := c.ConfigureSSOSession(ssoSession.Name, ssoSession.StartURL, ssoSession.Region, ssoSession.Scopes); err != nil { return fmt.Errorf("failed to configure SSO session: %w", err) } - if err := c.runSSOLogin(ssoSession.Name); err != nil { + if err := c.RunSSOLogin(ssoSession.Name); err != nil { return fmt.Errorf("failed to run SSO login: %w", err) } @@ -46,20 +46,20 @@ func (c *RealSSOClient) SetupSSO() error { profileName := ssoSession.Name + "-profile" - if err := c.configureAWSProfile(profileName, ssoSession.Name, ssoSession.Region, ssoSession.StartURL, accountID, role, ssoSession.Region); err != nil { + if err := c.ConfigureAWSProfile(profileName, ssoSession.Name, ssoSession.Region, ssoSession.StartURL, accountID, role, ssoSession.Region); err != nil { return fmt.Errorf("failed to configure AWS profile: %w", err) } defaultConfigured := profileName == "default" if !defaultConfigured { - if err := c.configureAWSProfile("default", ssoSession.Name, ssoSession.Region, ssoSession.StartURL, accountID, role, ssoSession.Region); err != nil { + if err := c.ConfigureAWSProfile("default", ssoSession.Name, ssoSession.Region, ssoSession.StartURL, accountID, role, ssoSession.Region); err != nil { return fmt.Errorf("failed to configure AWS default profile: %w", err) } defaultConfigured = true } - printSummary(profileName, ssoSession.Name, ssoSession.StartURL, ssoSession.Region, accountID, role, "", "", "") + PrintSummary(profileName, ssoSession.Name, ssoSession.StartURL, ssoSession.Region, accountID, role, "", "", "") fmt.Printf("\nSuccessfully configured AWS profile '%s'!\n", profileName) if defaultConfigured { @@ -82,14 +82,11 @@ func (c *RealSSOClient) InitSSO(refresh, noBrowser bool) error { awsProfile := c.Config.AWSProfile if awsProfile == "" { if len(profiles) == 0 { - fmt.Println("No profiles found. Configuring SSO...") - if err := c.SetupSSO(); err != nil { - if errors.Is(err, promptUtils.ErrInterrupted) { - return promptUtils.ErrInterrupted - } - return fmt.Errorf("failed to set up SSO: %w", err) - } - return nil + fmt.Println("No AWS SSO profiles found.") + fmt.Println("Run `awsctl sso setup` to create a new profile.") + fmt.Println("Run `awsctl sso setup -h` for help.") + var ErrNoProfiles = errors.New("no AWS SSO profiles found") + return ErrNoProfiles } awsProfile, err = c.Prompter.SelectFromList("Select AWS profile", profiles) diff --git a/internal/sso/utils.go b/internal/sso/utils.go index 98a0337..96b5a45 100644 --- a/internal/sso/utils.go +++ b/internal/sso/utils.go @@ -7,21 +7,21 @@ import ( "strings" ) -func validateAccountID(accountID string) error { +func ValidateAccountID(accountID string) error { if len(accountID) != 12 || !regexp.MustCompile(`^\d{12}$`).MatchString(accountID) { return fmt.Errorf("invalid account ID: %s (must be 12 digits)", accountID) } return nil } -func validateStartURL(startURL string) error { +func ValidateStartURL(startURL string) error { if !strings.HasPrefix(startURL, "https://") { return fmt.Errorf("invalid start URL: %s (must start with https://)", startURL) } return nil } -func printSummary(profileName, sessionName, ssoStartURL, ssoRegion, accountID, roleName, accountName, roleARN, expiration string) { +func PrintSummary(profileName, sessionName, ssoStartURL, ssoRegion, accountID, roleName, accountName, roleARN, expiration string) { fmt.Println("\nAWS SSO Configuration Summary:") fmt.Printf("Profile Name: %s\n", profileName) fmt.Printf("SSO Session: %s\n", sessionName) @@ -41,7 +41,7 @@ func printSummary(profileName, sessionName, ssoStartURL, ssoRegion, accountID, r } func (c *RealSSOClient) listSSOAccounts(region, startURL string) ([]string, error) { - accessToken, err := c.getAccessToken(startURL) + accessToken, err := c.GetAccessToken(startURL) if err != nil { return nil, fmt.Errorf("failed to get access token: %w", err) } @@ -91,7 +91,7 @@ func (c *RealSSOClient) listSSOAccounts(region, startURL string) ([]string, erro } func (c *RealSSOClient) listSSORoles(region, startURL, accountID string) ([]string, error) { - accessToken, err := c.getAccessToken(startURL) + accessToken, err := c.GetAccessToken(startURL) if err != nil { return nil, fmt.Errorf("failed to get access token: %w", err) } @@ -146,7 +146,7 @@ func (c *RealSSOClient) selectAccount(region, startURL string) (string, error) { } accountID := strings.SplitN(selectedAccount, " ", 2)[0] - if err := validateAccountID(accountID); err != nil { + if err := ValidateAccountID(accountID); err != nil { return "", err } return accountID, nil diff --git a/internal/sso/utils_test.go b/internal/sso/utils_test.go index 3d39def..dee6e92 100644 --- a/internal/sso/utils_test.go +++ b/internal/sso/utils_test.go @@ -1,8 +1,9 @@ -package sso +package sso_test import ( "testing" + "github.com/BerryBytes/awsctl/internal/sso" "github.com/stretchr/testify/assert" ) @@ -40,7 +41,7 @@ func TestValidateAccountID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validateAccountID(tt.accountID) + err := sso.ValidateAccountID(tt.accountID) if tt.expectError { assert.Error(t, err) if tt.errorContains != "" { @@ -87,7 +88,7 @@ func TestValidateStartURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validateStartURL(tt.startURL) + err := sso.ValidateStartURL(tt.startURL) if tt.expectError { assert.Error(t, err) if tt.errorContains != "" { @@ -102,7 +103,7 @@ func TestValidateStartURL(t *testing.T) { func TestPrintSummary(t *testing.T) { t.Run("prints complete summary", func(t *testing.T) { - printSummary( + sso.PrintSummary( "my-profile", "my-session", "https://example.awsapps.com/start", @@ -116,7 +117,7 @@ func TestPrintSummary(t *testing.T) { }) t.Run("prints minimal summary", func(t *testing.T) { - printSummary( + sso.PrintSummary( "my-profile", "my-session", "https://example.awsapps.com/start", diff --git a/tests/mock/sso/sso.go b/tests/mock/sso/sso.go index bf48387..3be720d 100644 --- a/tests/mock/sso/sso.go +++ b/tests/mock/sso/sso.go @@ -8,6 +8,7 @@ import ( reflect "reflect" time "time" + sso "github.com/BerryBytes/awsctl/internal/sso" models "github.com/BerryBytes/awsctl/models" gomock "github.com/golang/mock/gomock" ) @@ -212,17 +213,17 @@ func (mr *MockSSOClientMockRecorder) SSOLogin(awsProfile, refresh, noBrowser int } // SetupSSO mocks base method. -func (m *MockSSOClient) SetupSSO() error { +func (m *MockSSOClient) SetupSSO(opts sso.SSOFlagOptions) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetupSSO") + ret := m.ctrl.Call(m, "SetupSSO", opts) ret0, _ := ret[0].(error) return ret0 } // SetupSSO indicates an expected call of SetupSSO. -func (mr *MockSSOClientMockRecorder) SetupSSO() *gomock.Call { +func (mr *MockSSOClientMockRecorder) SetupSSO(opts interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetupSSO", reflect.TypeOf((*MockSSOClient)(nil).SetupSSO)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetupSSO", reflect.TypeOf((*MockSSOClient)(nil).SetupSSO), opts) } // ValidProfiles mocks base method. diff --git a/utils/general/general.go b/utils/general/general.go index 30192f9..29be5dd 100644 --- a/utils/general/general.go +++ b/utils/general/general.go @@ -115,3 +115,9 @@ func IsRegionValid(region string) bool { return isValidRegionFormat(region) } + +var validNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9-_]{0,126}[a-zA-Z0-9]$`) + +func IsValidSessionName(name string) bool { + return validNameRegex.MatchString(name) +}