Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions internal/sso/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"strings"
"time"
)

func (c *RealSSOClient) ConfigureSSOProfile(profile, region, accountID, role, ssoStartUrl, ssoSession string) error {
Expand Down Expand Up @@ -160,3 +161,92 @@ func (c *RealSSOClient) promptProfileDetails(ssoRegion string) (string, string,
}
return profileName, region, nil
}

func (c *RealSSOClient) setProfileAsDefault(profile string) error {
sessionName, err := c.ConfigureGet("sso_session", profile)
if err != nil {
return fmt.Errorf("failed to get sso_session: %w", err)
}

ssoStartURL, err := c.ConfigureGet("sso_start_url", profile)
if err != nil {
return fmt.Errorf("failed to get sso_start_url: %w", err)
}

ssoRegion, err := c.ConfigureGet("sso_region", profile)
if err != nil {
return fmt.Errorf("failed to get sso_region: %w", err)
}

accountID, err := c.ConfigureGet("sso_account_id", profile)
if err != nil {
return fmt.Errorf("failed to get account ID: %w", err)
}

roleName, err := c.ConfigureGet("sso_role_name", profile)
if err != nil {
return fmt.Errorf("failed to get role name: %w", err)
}

region, err := c.ConfigureGet("region", profile)
if err != nil {
region = ssoRegion
}

if err := c.configureAWSProfile("default", sessionName, ssoRegion, ssoStartURL, accountID, roleName, region); err != nil {
return fmt.Errorf("failed to configure AWS default profile: %w", err)
}
fmt.Println("Successfully set this profile as default!")
return nil
}

func (c *RealSSOClient) printProfileSummary(profile string) error {
sessionName, err := c.ConfigureGet("sso_session", profile)
if err != nil {
return fmt.Errorf("failed to get sso_session: %w", err)
}

ssoStartURL, err := c.ConfigureGet("sso_start_url", profile)
if err != nil {
return fmt.Errorf("failed to get sso_start_url: %w", err)
}

ssoRegion, err := c.ConfigureGet("sso_region", profile)
if err != nil {
return fmt.Errorf("failed to get sso_region: %w", err)
}

accountID, err := c.ConfigureGet("sso_account_id", profile)
if err != nil {
return fmt.Errorf("failed to get account ID: %w", err)
}

roleName, err := c.ConfigureGet("sso_role_name", profile)
if err != nil {
return fmt.Errorf("failed to get role name: %w", err)
}

roleARN, err := c.AwsSTSGetCallerIdentity(profile)
if err != nil {
return fmt.Errorf("failed to get role ARN: %w", err)
}

accountName, err := c.GetSSOAccountName(accountID, profile)
if err != nil {
accountName = "Unknown"
fmt.Printf("Warning: Failed to get account name: %v\n", err)
}

_, expiry, err := c.GetCachedSsoAccessToken(profile)
if err != nil {
return fmt.Errorf("failed to get token expiry: %w", err)
}

var expiration string
if !expiry.IsZero() {
expiration = expiry.Format(time.RFC3339)
}

printSummary(profile, sessionName, ssoStartURL, ssoRegion, accountID, roleName, accountName, roleARN, expiration)
return nil
}
94 changes: 5 additions & 89 deletions internal/sso/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"slices"
"time"

promptUtils "github.com/BerryBytes/awsctl/utils/prompt"
)
Expand Down Expand Up @@ -99,48 +98,8 @@ func (c *RealSSOClient) InitSSO(refresh, noBrowser bool) error {
}

if awsProfile != "default" {
setDefault, err := c.Prompter.PromptYesNo("Set this as the default profile? [Y/n]", true)
if err != nil {
if errors.Is(err, promptUtils.ErrInterrupted) {
return nil
}
return fmt.Errorf("failed to prompt for default profile: %w", err)
}
if setDefault {
sessionName, err := c.ConfigureGet("sso_session", awsProfile)
if err != nil {
return fmt.Errorf("failed to get sso_session: %w", err)
}

ssoStartURL, err := c.ConfigureGet("sso_start_url", awsProfile)
if err != nil {
return fmt.Errorf("failed to get sso_start_url: %w", err)
}

ssoRegion, err := c.ConfigureGet("sso_region", awsProfile)
if err != nil {
return fmt.Errorf("failed to get sso_region: %w", err)
}

accountID, err := c.ConfigureGet("sso_account_id", awsProfile)
if err != nil {
return fmt.Errorf("failed to get account ID: %w", err)
}

roleName, err := c.ConfigureGet("sso_role_name", awsProfile)
if err != nil {
return fmt.Errorf("failed to get role name: %w", err)
}

region, err := c.ConfigureGet("region", awsProfile)
if err != nil {
region = ssoRegion
}

if err := c.configureAWSProfile("default", sessionName, ssoRegion, ssoStartURL, accountID, roleName, region); err != nil {
return fmt.Errorf("failed to configure AWS default profile: %w", err)
}
fmt.Println("Successfully set this profile as default!")
if err := c.setProfileAsDefault(awsProfile); err != nil {
return err
}
}
}
Expand All @@ -149,59 +108,16 @@ func (c *RealSSOClient) InitSSO(refresh, noBrowser bool) error {
return fmt.Errorf("invalid profile: %s", awsProfile)
}

var expiration string
_, expiry, err := c.GetCachedSsoAccessToken(awsProfile)
if err != nil {
if _, _, err := c.GetCachedSsoAccessToken(awsProfile); err != nil {
fmt.Printf("SSO token expired or missing for profile %s. Logging in...\n", awsProfile)
if err := c.SSOLogin(awsProfile, refresh, noBrowser); err != nil {
return fmt.Errorf("failed to login: %w", err)
}
_, expiry, err = c.GetCachedSsoAccessToken(awsProfile)
if err != nil {
if _, _, err = c.GetCachedSsoAccessToken(awsProfile); err != nil {
return fmt.Errorf("failed to get SSO token after login: %w", err)
}
}
if !expiry.IsZero() {
expiration = expiry.Format(time.RFC3339)
}
fmt.Printf("SSO token validated for profile %s\n", awsProfile)

sessionName, err := c.ConfigureGet("sso_session", awsProfile)
if err != nil {
return fmt.Errorf("failed to get sso_session: %w", err)
}

ssoStartURL, err := c.ConfigureGet("sso_start_url", awsProfile)
if err != nil {
return fmt.Errorf("failed to get sso_start_url: %w", err)
}

ssoRegion, err := c.ConfigureGet("sso_region", awsProfile)
if err != nil {
return fmt.Errorf("failed to get sso_region: %w", err)
}

accountID, err := c.ConfigureGet("sso_account_id", awsProfile)
if err != nil {
return fmt.Errorf("failed to get account ID: %w", err)
}

roleName, err := c.ConfigureGet("sso_role_name", awsProfile)
if err != nil {
return fmt.Errorf("failed to get role name: %w", err)
}

roleARN, err := c.AwsSTSGetCallerIdentity(awsProfile)
if err != nil {
return fmt.Errorf("failed to get role ARN: %w", err)
}

accountName, err := c.GetSSOAccountName(accountID, awsProfile)
if err != nil {
accountName = "Unknown"
fmt.Printf("Warning: Failed to get account name: %v\n", err)
}

printSummary(awsProfile, sessionName, ssoStartURL, ssoRegion, accountID, roleName, accountName, roleARN, expiration)
return nil
return c.printProfileSummary(awsProfile)
}
Loading