From 99578766b8ba3d1eaa1053d6f1998ef9f0627be6 Mon Sep 17 00:00:00 2001 From: sawnjordan Date: Thu, 3 Jul 2025 15:03:14 +0545 Subject: [PATCH] feat: removed prompt to set profile as default on sso init --- internal/sso/profile.go | 90 +++++++++++++++++++++++++++++++++++++++ internal/sso/sso.go | 94 +++-------------------------------------- 2 files changed, 95 insertions(+), 89 deletions(-) diff --git a/internal/sso/profile.go b/internal/sso/profile.go index 8be414f..bbd2be5 100644 --- a/internal/sso/profile.go +++ b/internal/sso/profile.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "strings" + "time" ) func (c *RealSSOClient) ConfigureSSOProfile(profile, region, accountID, role, ssoStartUrl, ssoSession string) error { @@ -160,3 +161,92 @@ func (c *RealSSOClient) promptProfileDetails(ssoRegion string) (string, string, } return profileName, region, nil } + +func (c *RealSSOClient) setProfileAsDefault(profile string) error { + sessionName, err := c.ConfigureGet("sso_session", profile) + if err != nil { + return fmt.Errorf("failed to get sso_session: %w", err) + } + + ssoStartURL, err := c.ConfigureGet("sso_start_url", profile) + if err != nil { + return fmt.Errorf("failed to get sso_start_url: %w", err) + } + + ssoRegion, err := c.ConfigureGet("sso_region", profile) + if err != nil { + return fmt.Errorf("failed to get sso_region: %w", err) + } + + accountID, err := c.ConfigureGet("sso_account_id", profile) + if err != nil { + return fmt.Errorf("failed to get account ID: %w", err) + } + + roleName, err := c.ConfigureGet("sso_role_name", profile) + if err != nil { + return fmt.Errorf("failed to get role name: %w", err) + } + + region, err := c.ConfigureGet("region", profile) + if err != nil { + region = ssoRegion + } + + if err := c.configureAWSProfile("default", sessionName, ssoRegion, ssoStartURL, accountID, roleName, region); err != nil { + return fmt.Errorf("failed to configure AWS default profile: %w", err) + } + fmt.Println("Successfully set this profile as default!") + return nil +} + +func (c *RealSSOClient) printProfileSummary(profile string) error { + sessionName, err := c.ConfigureGet("sso_session", profile) + if err != nil { + return fmt.Errorf("failed to get sso_session: %w", err) + } + + ssoStartURL, err := c.ConfigureGet("sso_start_url", profile) + if err != nil { + return fmt.Errorf("failed to get sso_start_url: %w", err) + } + + ssoRegion, err := c.ConfigureGet("sso_region", profile) + if err != nil { + return fmt.Errorf("failed to get sso_region: %w", err) + } + + accountID, err := c.ConfigureGet("sso_account_id", profile) + if err != nil { + return fmt.Errorf("failed to get account ID: %w", err) + } + + roleName, err := c.ConfigureGet("sso_role_name", profile) + if err != nil { + return fmt.Errorf("failed to get role name: %w", err) + } + + roleARN, err := c.AwsSTSGetCallerIdentity(profile) + if err != nil { + return fmt.Errorf("failed to get role ARN: %w", err) + } + + accountName, err := c.GetSSOAccountName(accountID, profile) + if err != nil { + accountName = "Unknown" + fmt.Printf("Warning: Failed to get account name: %v\n", err) + } + + _, expiry, err := c.GetCachedSsoAccessToken(profile) + if err != nil { + return fmt.Errorf("failed to get token expiry: %w", err) + } + + var expiration string + if !expiry.IsZero() { + expiration = expiry.Format(time.RFC3339) + } + + printSummary(profile, sessionName, ssoStartURL, ssoRegion, accountID, roleName, accountName, roleARN, expiration) + return nil +} diff --git a/internal/sso/sso.go b/internal/sso/sso.go index 2809a66..63c1def 100644 --- a/internal/sso/sso.go +++ b/internal/sso/sso.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "slices" - "time" promptUtils "github.com/BerryBytes/awsctl/utils/prompt" ) @@ -99,48 +98,8 @@ func (c *RealSSOClient) InitSSO(refresh, noBrowser bool) error { } if awsProfile != "default" { - setDefault, err := c.Prompter.PromptYesNo("Set this as the default profile? [Y/n]", true) - if err != nil { - if errors.Is(err, promptUtils.ErrInterrupted) { - return nil - } - return fmt.Errorf("failed to prompt for default profile: %w", err) - } - if setDefault { - sessionName, err := c.ConfigureGet("sso_session", awsProfile) - if err != nil { - return fmt.Errorf("failed to get sso_session: %w", err) - } - - ssoStartURL, err := c.ConfigureGet("sso_start_url", awsProfile) - if err != nil { - return fmt.Errorf("failed to get sso_start_url: %w", err) - } - - ssoRegion, err := c.ConfigureGet("sso_region", awsProfile) - if err != nil { - return fmt.Errorf("failed to get sso_region: %w", err) - } - - accountID, err := c.ConfigureGet("sso_account_id", awsProfile) - if err != nil { - return fmt.Errorf("failed to get account ID: %w", err) - } - - roleName, err := c.ConfigureGet("sso_role_name", awsProfile) - if err != nil { - return fmt.Errorf("failed to get role name: %w", err) - } - - region, err := c.ConfigureGet("region", awsProfile) - if err != nil { - region = ssoRegion - } - - if err := c.configureAWSProfile("default", sessionName, ssoRegion, ssoStartURL, accountID, roleName, region); err != nil { - return fmt.Errorf("failed to configure AWS default profile: %w", err) - } - fmt.Println("Successfully set this profile as default!") + if err := c.setProfileAsDefault(awsProfile); err != nil { + return err } } } @@ -149,59 +108,16 @@ func (c *RealSSOClient) InitSSO(refresh, noBrowser bool) error { return fmt.Errorf("invalid profile: %s", awsProfile) } - var expiration string - _, expiry, err := c.GetCachedSsoAccessToken(awsProfile) - if err != nil { + if _, _, err := c.GetCachedSsoAccessToken(awsProfile); err != nil { fmt.Printf("SSO token expired or missing for profile %s. Logging in...\n", awsProfile) if err := c.SSOLogin(awsProfile, refresh, noBrowser); err != nil { return fmt.Errorf("failed to login: %w", err) } - _, expiry, err = c.GetCachedSsoAccessToken(awsProfile) - if err != nil { + if _, _, err = c.GetCachedSsoAccessToken(awsProfile); err != nil { return fmt.Errorf("failed to get SSO token after login: %w", err) } } - if !expiry.IsZero() { - expiration = expiry.Format(time.RFC3339) - } fmt.Printf("SSO token validated for profile %s\n", awsProfile) - sessionName, err := c.ConfigureGet("sso_session", awsProfile) - if err != nil { - return fmt.Errorf("failed to get sso_session: %w", err) - } - - ssoStartURL, err := c.ConfigureGet("sso_start_url", awsProfile) - if err != nil { - return fmt.Errorf("failed to get sso_start_url: %w", err) - } - - ssoRegion, err := c.ConfigureGet("sso_region", awsProfile) - if err != nil { - return fmt.Errorf("failed to get sso_region: %w", err) - } - - accountID, err := c.ConfigureGet("sso_account_id", awsProfile) - if err != nil { - return fmt.Errorf("failed to get account ID: %w", err) - } - - roleName, err := c.ConfigureGet("sso_role_name", awsProfile) - if err != nil { - return fmt.Errorf("failed to get role name: %w", err) - } - - roleARN, err := c.AwsSTSGetCallerIdentity(awsProfile) - if err != nil { - return fmt.Errorf("failed to get role ARN: %w", err) - } - - accountName, err := c.GetSSOAccountName(accountID, awsProfile) - if err != nil { - accountName = "Unknown" - fmt.Printf("Warning: Failed to get account name: %v\n", err) - } - - printSummary(awsProfile, sessionName, ssoStartURL, ssoRegion, accountID, roleName, accountName, roleARN, expiration) - return nil + return c.printProfileSummary(awsProfile) }