diff --git a/cmd/init.go b/cmd/init.go index 820925693..9773cfa7f 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -2,7 +2,7 @@ package cmd import ( "github.com/commitdev/zero/internal/config/projectconfig" - "github.com/commitdev/zero/internal/context" + initPrompts "github.com/commitdev/zero/internal/init" "github.com/spf13/cobra" ) @@ -14,7 +14,7 @@ var initCmd = &cobra.Command{ Use: "init", Short: "Create new project with provided name and initialize configuration based on user input.", Run: func(cmd *cobra.Command, args []string) { - projectContext := context.Init(projectconfig.RootDir) + projectContext := initPrompts.Init(projectconfig.RootDir) projectconfig.Init(projectconfig.RootDir, projectContext.Name, projectContext) }, } diff --git a/internal/init/debug.test b/internal/init/debug.test new file mode 100755 index 000000000..bd6f9e665 Binary files /dev/null and b/internal/init/debug.test differ diff --git a/internal/context/init.go b/internal/init/init.go similarity index 85% rename from internal/context/init.go rename to internal/init/init.go index 7fa8beecf..9e72b62b8 100644 --- a/internal/context/init.go +++ b/internal/init/init.go @@ -1,4 +1,4 @@ -package context +package init import ( "fmt" @@ -171,33 +171,78 @@ func getProjectPrompts(projectName string, modules map[string]moduleconfig.Modul return handlers } -func getCredentialPrompts(projectCredentials globalconfig.ProjectCredential, moduleConfigs map[string]moduleconfig.ModuleConfig) map[string][]PromptHandler { +func getCredentialPrompts(projectCredentials globalconfig.ProjectCredential, moduleConfigs map[string]moduleconfig.ModuleConfig) []CredentialPrompts { var uniqueVendors []string for _, module := range moduleConfigs { uniqueVendors = appendToSet(uniqueVendors, module.RequiredCredentials) } + // map is to keep track of which vendor they belong to, to fill them back into the projectConfig - prompts := map[string][]PromptHandler{} - for _, vendor := range uniqueVendors { - prompts[vendor] = mapVendorToPrompts(projectCredentials, vendor) + prompts := []CredentialPrompts{} + for _, vendor := range AvailableVendorOrders { + if itemInSlice(uniqueVendors, vendor) { + vendorPrompts := CredentialPrompts{vendor, mapVendorToPrompts(projectCredentials, vendor)} + prompts = append(prompts, vendorPrompts) + } } return prompts } func mapVendorToPrompts(projectCred globalconfig.ProjectCredential, vendor string) []PromptHandler { var prompts []PromptHandler + profiles, err := project.GetAWSProfiles() + if err != nil { + profiles = []string{} + } + + // if no profiles available, dont prompt use to pick profile + customAwsPickProfileCondition := func(param map[string]string) bool { + if len(profiles) == 0 { + flog.Infof(":warning: No AWS profiles found, please manually input AWS credentials") + return false + } else { + return true + } + } + + // condition for prompting manual AWS credentials input + customAwsMustInputCondition := func(param map[string]string) bool { + toPickProfile := awsPickProfile + if val, ok := param["use_aws_profile"]; ok && val != toPickProfile { + return true + } + return false + } switch vendor { case "aws": awsPrompts := []PromptHandler{ + { + moduleconfig.Parameter{ + Field: "use_aws_profile", + Label: "Use credentials from existing AWS profiles?", + Options: []string{awsPickProfile, awsManualInputCredentials}, + }, + customAwsPickProfileCondition, + NoValidation, + }, + { + moduleconfig.Parameter{ + Field: "aws_profile", + Label: "Select AWS Profile", + Options: profiles, + }, + KeyMatchCondition("use_aws_profile", awsPickProfile), + NoValidation, + }, { moduleconfig.Parameter{ Field: "accessKeyId", Label: "AWS Access Key ID", Default: projectCred.AWSResourceConfig.AccessKeyId, }, - NoCondition, - NoValidation, + CustomCondition(customAwsMustInputCondition), + project.ValidateAKID, }, { moduleconfig.Parameter{ @@ -205,8 +250,8 @@ func mapVendorToPrompts(projectCred globalconfig.ProjectCredential, vendor strin Label: "AWS Secret access key", Default: projectCred.AWSResourceConfig.SecretAccessKey, }, - NoCondition, - NoValidation, + CustomCondition(customAwsMustInputCondition), + project.ValidateSAK, }, } prompts = append(prompts, awsPrompts...) diff --git a/internal/context/prompts.go b/internal/init/prompts.go similarity index 74% rename from internal/context/prompts.go rename to internal/init/prompts.go index 02c890ae9..1a46b7cbe 100644 --- a/internal/context/prompts.go +++ b/internal/init/prompts.go @@ -1,4 +1,4 @@ -package context +package init import ( "fmt" @@ -10,27 +10,48 @@ import ( "github.com/commitdev/zero/internal/config/globalconfig" "github.com/commitdev/zero/internal/config/moduleconfig" + "github.com/commitdev/zero/pkg/credentials" "github.com/commitdev/zero/pkg/util/exit" "github.com/manifoldco/promptui" "gopkg.in/yaml.v2" ) +// Constant to maintain prompt orders so users can have the same flow, +// modules get downloaded asynchronously therefore its easier to just hardcode an order +var AvailableVendorOrders = []string{"aws", "github", "circleci"} + +const awsPickProfile = "Existing AWS Profiles" +const awsManualInputCredentials = "Enter my own AWS credentials" + type PromptHandler struct { moduleconfig.Parameter - Condition func(map[string]string) bool + Condition CustomConditionSignature Validate func(string) error } +type CredentialPrompts struct { + Vendor string + Prompts []PromptHandler +} + +type CustomConditionSignature func(map[string]string) bool + func NoCondition(map[string]string) bool { return true } -func KeyMatchCondition(key string, value string) func(map[string]string) bool { +func KeyMatchCondition(key string, value string) CustomConditionSignature { return func(param map[string]string) bool { return param[key] == value } } +func CustomCondition(fn CustomConditionSignature) CustomConditionSignature { + return func(param map[string]string) bool { + return fn(param) + } +} + func NoValidation(string) error { return nil } @@ -150,15 +171,16 @@ func PromptModuleParams(moduleConfig moduleconfig.ModuleConfig, parameters map[s return parameters, nil } -func promptCredentialsAndFillProjectCreds(credentialPrompts map[string][]PromptHandler, credentials globalconfig.ProjectCredential) globalconfig.ProjectCredential { +func promptCredentialsAndFillProjectCreds(credentialPrompts []CredentialPrompts, creds globalconfig.ProjectCredential) globalconfig.ProjectCredential { promptsValues := map[string]map[string]string{} - for vendor, prompts := range credentialPrompts { + for _, prompts := range credentialPrompts { + vendor := prompts.Vendor vendorPromptValues := map[string]string{} // vendors like AWS have multiple prompts (accessKeyId and secretAccessKey) - for _, prompt := range prompts { - vendorPromptValues[prompt.Field] = prompt.GetParam(map[string]string{}) + for _, prompt := range prompts.Prompts { + vendorPromptValues[prompt.Field] = prompt.GetParam(vendorPromptValues) } promptsValues[vendor] = vendorPromptValues } @@ -166,8 +188,15 @@ func promptCredentialsAndFillProjectCreds(credentialPrompts map[string][]PromptH // FIXME: what is a good way to dynamically modify partial data of a struct // current just marashing to yaml, then unmarshaling into the base struct yamlContent, _ := yaml.Marshal(promptsValues) - yaml.Unmarshal(yamlContent, &credentials) - return credentials + yaml.Unmarshal(yamlContent, &creds) + + // Fill AWS credentials based on profile from ~/.aws/credentials + if val, ok := promptsValues["aws"]; ok { + if val["use_aws_profile"] == awsPickProfile { + creds = credentials.GetAWSProfileProjectCredentials(val["aws_profile"], creds) + } + } + return creds } func appendToSet(set []string, toAppend []string) []string { diff --git a/internal/context/prompts_test.go b/internal/init/prompts_test.go similarity index 72% rename from internal/context/prompts_test.go rename to internal/init/prompts_test.go index e002979bc..434403a09 100644 --- a/internal/context/prompts_test.go +++ b/internal/init/prompts_test.go @@ -1,10 +1,11 @@ -package context_test +package init_test import ( "testing" "github.com/commitdev/zero/internal/config/moduleconfig" - "github.com/commitdev/zero/internal/context" + // init is a reserved word + initPrompts "github.com/commitdev/zero/internal/init" "github.com/stretchr/testify/assert" ) @@ -17,10 +18,10 @@ func TestGetParam(t *testing.T) { Execute: "echo \"my-acconut-id\"", } - prompt := context.PromptHandler{ + prompt := initPrompts.PromptHandler{ param, - context.NoCondition, - context.NoValidation, + initPrompts.NoCondition, + initPrompts.NoValidation, } result := prompt.GetParam(projectParams) @@ -33,10 +34,10 @@ func TestGetParam(t *testing.T) { Execute: "echo $INJECTEDENV", } - prompt := context.PromptHandler{ + prompt := initPrompts.PromptHandler{ param, - context.NoCondition, - context.NoValidation, + initPrompts.NoCondition, + initPrompts.NoValidation, } result := prompt.GetParam(map[string]string{ @@ -51,10 +52,10 @@ func TestGetParam(t *testing.T) { Value: "lorem-ipsum", } - prompt := context.PromptHandler{ + prompt := initPrompts.PromptHandler{ param, - context.NoCondition, - context.NoValidation, + initPrompts.NoCondition, + initPrompts.NoValidation, } result := prompt.GetParam(projectParams) diff --git a/pkg/credentials/credentials.go b/pkg/credentials/credentials.go index 060b4b44a..9389865f0 100644 --- a/pkg/credentials/credentials.go +++ b/pkg/credentials/credentials.go @@ -11,9 +11,9 @@ import ( "regexp" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/commitdev/zero/internal/config/globalconfig" "github.com/commitdev/zero/internal/config/projectconfig" "github.com/manifoldco/promptui" - "gopkg.in/yaml.v2" ) // Secrets - AWS prompted credentials @@ -37,65 +37,27 @@ func MakeAwsEnvars(cfg *projectconfig.ZeroProjectConfig, awsSecrets Secrets) []s return env } -func GetSecrets(baseDir string) Secrets { - - secretsFile := filepath.Join(baseDir, "secrets.yaml") - if fileExists(secretsFile) { - log.Println("secrets.yaml exists ...") - return readSecrets(secretsFile) - } else { - // Get the user's home dir - usr, err := user.Current() - if err != nil { - log.Fatal(err) - } - credsFile := filepath.Join(usr.HomeDir, ".aws/credentials") - - var secrets Secrets - - // Load the credentials file to look for profiles - profiles, err := GetAWSProfiles() - if err == nil { - profilePrompt := promptui.Select{ - Label: "Select AWS Profile", - Items: profiles, - } - - _, profileResult, _ := profilePrompt.Run() - - creds, err := credentials.NewSharedCredentials(credsFile, profileResult).Get() - if err == nil { - secrets = Secrets{ - AWS: AWS{ - AccessKeyID: creds.AccessKeyID, - SecretAccessKey: creds.SecretAccessKey, - }, - } - } - } - - // We couldn't load the credentials file, get the user to just paste them - if secrets.AWS == (AWS{}) { - promptAWSCredentials(&secrets) - } - - if secrets.CircleCIKey == "" || secrets.GithubToken == "" { - ciPrompt := promptui.Select{ - Label: "Which Continuous integration provider do you want to use?", - Items: []string{"CircleCI", "GitHub Actions"}, - } - - _, ciResult, _ := ciPrompt.Run() +func AwsCredsPath() string { + usr, err := user.Current() + if err != nil { + log.Fatal(err) + } + return filepath.Join(usr.HomeDir, ".aws/credentials") +} - if ciResult == "CircleCI" { - promptCircleCICredentials(&secrets) - } else if ciResult == "GitHub Actions" { - promptGitHubCredentials(&secrets) - } - } +func GetAWSProfileProjectCredentials(profileName string, creds globalconfig.ProjectCredential) globalconfig.ProjectCredential { + awsPath := AwsCredsPath() + return GetAWSProfileCredentials(awsPath, profileName, creds) +} - return secrets +func GetAWSProfileCredentials(credsPath string, profileName string, creds globalconfig.ProjectCredential) globalconfig.ProjectCredential { + awsCreds, err := credentials.NewSharedCredentials(credsPath, profileName).Get() + if err != nil { + log.Fatal(err) } + creds.AWSResourceConfig.AccessKeyId = awsCreds.AccessKeyID + creds.AWSResourceConfig.SecretAccessKey = awsCreds.SecretAccessKey + return creds } // GetAWSProfiles returns a list of AWS forprofiles set up on the user's sytem @@ -121,61 +83,28 @@ func GetAWSProfiles() ([]string, error) { return profiles, nil } -func readSecrets(secretsFile string) Secrets { - data, err := ioutil.ReadFile(secretsFile) - if err != nil { - log.Fatalln(err) +func ValidateAKID(input string) error { + // 20 uppercase alphanumeric characters + var awsAccessKeyIDPat = regexp.MustCompile(`^[A-Z0-9]{20}$`) + if !awsAccessKeyIDPat.MatchString(input) { + return errors.New("Invalid aws_access_key_id") } - - awsSecrets := Secrets{} - - err = yaml.Unmarshal(data, &awsSecrets) - if err != nil { - log.Fatalln(err) - } - - return awsSecrets + return nil } -func writeSecrets(secretsFile string, s Secrets) { - secretsYaml, err := yaml.Marshal(&s) - - if err != nil { - log.Fatalf("error: %v", err) - panic(err) - } - - err = ioutil.WriteFile(secretsFile, []byte(secretsYaml), 0644) - - if err != nil { - log.Fatalf("error: %v", err) - panic(err) +func ValidateSAK(input string) error { + // 40 base64 characters + var awsSecretAccessKeyPat = regexp.MustCompile(`^[A-Za-z0-9/+=]{40}$`) + if !awsSecretAccessKeyPat.MatchString(input) { + return errors.New("Invalid aws_secret_access_key") } + return nil } func promptAWSCredentials(secrets *Secrets) { - - validateAKID := func(input string) error { - // 20 uppercase alphanumeric characters - var awsAccessKeyIDPat = regexp.MustCompile(`^[A-Z0-9]{20}$`) - if !awsAccessKeyIDPat.MatchString(input) { - return errors.New("Invalid aws_access_key_id") - } - return nil - } - - validateSAK := func(input string) error { - // 40 base64 characters - var awsSecretAccessKeyPat = regexp.MustCompile(`^[A-Za-z0-9/+=]{40}$`) - if !awsSecretAccessKeyPat.MatchString(input) { - return errors.New("Invalid aws_secret_access_key") - } - return nil - } - accessKeyIDPrompt := promptui.Prompt{ Label: "Aws Access Key ID ", - Validate: validateAKID, + Validate: ValidateAKID, } accessKeyIDResult, err := accessKeyIDPrompt.Run() @@ -187,7 +116,7 @@ func promptAWSCredentials(secrets *Secrets) { secretAccessKeyPrompt := promptui.Prompt{ Label: "Aws Secret Access Key ", - Validate: validateSAK, + Validate: ValidateSAK, Mask: '*', } diff --git a/pkg/credentials/credentials_test.go b/pkg/credentials/credentials_test.go new file mode 100644 index 000000000..9386c993a --- /dev/null +++ b/pkg/credentials/credentials_test.go @@ -0,0 +1,26 @@ +package credentials_test + +import ( + "testing" + + "github.com/commitdev/zero/internal/config/globalconfig" + "github.com/commitdev/zero/pkg/credentials" + "github.com/stretchr/testify/assert" +) + +func TestFillAWSProfileCredentials(t *testing.T) { + mockAwsCredentialFilePath := "../../tests/test_data/aws/mock_credentials.yml" + t.Run("fills project credentials", func(t *testing.T) { + projectCreds := globalconfig.ProjectCredential{} + projectCreds = credentials.GetAWSProfileCredentials(mockAwsCredentialFilePath, "default", projectCreds) + assert.Equal(t, "MOCK1_ACCESS_KEY", projectCreds.AWSResourceConfig.AccessKeyId) + assert.Equal(t, "MOCK1_SECRET_ACCESS_KEY", projectCreds.AWSResourceConfig.SecretAccessKey) + }) + + t.Run("supports non-default profiles", func(t *testing.T) { + projectCreds := globalconfig.ProjectCredential{} + projectCreds = credentials.GetAWSProfileCredentials(mockAwsCredentialFilePath, "foobar", projectCreds) + assert.Equal(t, "MOCK2_ACCESS_KEY", projectCreds.AWSResourceConfig.AccessKeyId) + assert.Equal(t, "MOCK2_SECRET_ACCESS_KEY", projectCreds.AWSResourceConfig.SecretAccessKey) + }) +} diff --git a/tests/test_data/aws/mock_credentials.yml b/tests/test_data/aws/mock_credentials.yml new file mode 100644 index 000000000..b2642dce1 --- /dev/null +++ b/tests/test_data/aws/mock_credentials.yml @@ -0,0 +1,7 @@ +[default] +aws_access_key_id=MOCK1_ACCESS_KEY +aws_secret_access_key=MOCK1_SECRET_ACCESS_KEY + +[foobar] +aws_access_key_id=MOCK2_ACCESS_KEY +aws_secret_access_key=MOCK2_SECRET_ACCESS_KEY