diff --git a/builtin/logical/aws/backend_test.go b/builtin/logical/aws/backend_test.go index 5c372e513f2..19dad90889b 100644 --- a/builtin/logical/aws/backend_test.go +++ b/builtin/logical/aws/backend_test.go @@ -207,7 +207,11 @@ func getAccountID() (string, error) { Region: aws.String("us-east-1"), HTTPClient: cleanhttp.DefaultClient(), } - svc := sts.New(session.New(awsConfig)) + sess, err := session.NewSession(awsConfig) + if err != nil { + return "", err + } + svc := sts.New(sess) params := &sts.GetCallerIdentityInput{} res, err := svc.GetCallerIdentity(params) @@ -240,7 +244,11 @@ func createRole(t *testing.T, roleName, awsAccountID string, policyARNs []string Region: aws.String("us-east-1"), HTTPClient: cleanhttp.DefaultClient(), } - svc := iam.New(session.New(awsConfig)) + sess, err := session.NewSession(awsConfig) + if err != nil { + t.Fatal(err) + } + svc := iam.New(sess) trustPolicy := fmt.Sprintf(testRoleAssumePolicy, awsAccountID) params := &iam.CreateRoleInput{ @@ -250,9 +258,7 @@ func createRole(t *testing.T, roleName, awsAccountID string, policyARNs []string } log.Printf("[INFO] AWS CreateRole: %s", roleName) - _, err := svc.CreateRole(params) - - if err != nil { + if _, err := svc.CreateRole(params); err != nil { t.Fatalf("AWS CreateRole failed: %v", err) } @@ -303,14 +309,16 @@ func createUser(t *testing.T, userName string, accessKey *awsAccessKey) { Region: aws.String("us-east-1"), HTTPClient: cleanhttp.DefaultClient(), } - svc := iam.New(session.New(awsConfig)) - + sess, err := session.NewSession(awsConfig) + if err != nil { + t.Fatal(err) + } + svc := iam.New(sess) createUserInput := &iam.CreateUserInput{ UserName: aws.String(userName), } log.Printf("[INFO] AWS CreateUser: %s", userName) - _, err := svc.CreateUser(createUserInput) - if err != nil { + if _, err := svc.CreateUser(createUserInput); err != nil { t.Fatalf("AWS CreateUser failed: %v", err) } @@ -354,8 +362,11 @@ func deleteTestRole(roleName string) error { Region: aws.String("us-east-1"), HTTPClient: cleanhttp.DefaultClient(), } - svc := iam.New(session.New(awsConfig)) - + sess, err := session.NewSession(awsConfig) + if err != nil { + return err + } + svc := iam.New(sess) listAttachmentsInput := &iam.ListAttachedRolePoliciesInput{ RoleName: aws.String(roleName), } @@ -372,8 +383,7 @@ func deleteTestRole(roleName string) error { } return true } - err := svc.ListAttachedRolePoliciesPages(listAttachmentsInput, detacher) - if err != nil { + if err := svc.ListAttachedRolePoliciesPages(listAttachmentsInput, detacher); err != nil { log.Printf("[WARN] AWS DetachRolePolicy failed: %v", err) } @@ -396,14 +406,16 @@ func deleteTestUser(accessKey *awsAccessKey, userName string) error { Region: aws.String("us-east-1"), HTTPClient: cleanhttp.DefaultClient(), } - svc := iam.New(session.New(awsConfig)) - + sess, err := session.NewSession(awsConfig) + if err != nil { + return err + } + svc := iam.New(sess) userDetachment := &iam.DetachUserPolicyInput{ PolicyArn: aws.String("arn:aws:iam::aws:policy/AdministratorAccess"), UserName: aws.String(userName), } - _, err := svc.DetachUserPolicy(userDetachment) - if err != nil { + if _, err := svc.DetachUserPolicy(userDetachment); err != nil { log.Printf("[WARN] AWS DetachUserPolicy failed: %v", err) return err } @@ -490,10 +502,13 @@ func testAccStepRotateRoot(oldAccessKey *awsAccessKey) logicaltest.TestStep { oldAccessKey.AccessKeyID = newAccessKeyID log.Println("[WARN] Sleeping for 10 seconds waiting for AWS...") time.Sleep(10 * time.Second) - svc := sts.New(session.New(awsConfig)) + sess, err := session.NewSession(awsConfig) + if err != nil { + return err + } + svc := sts.New(sess) params := &sts.GetCallerIdentityInput{} - _, err := svc.GetCallerIdentity(params) - if err == nil { + if _, err := svc.GetCallerIdentity(params); err == nil { return fmt.Errorf("bad: old credentials succeeded after rotate") } if aerr, ok := err.(awserr.Error); ok { @@ -556,7 +571,11 @@ func describeInstancesTest(accessKey, secretKey, token string) error { Region: aws.String("us-east-1"), HTTPClient: cleanhttp.DefaultClient(), } - client := ec2.New(session.New(awsConfig)) + sess, err := session.NewSession(awsConfig) + if err != nil { + return err + } + client := ec2.New(sess) log.Printf("[WARN] Verifying that the generated credentials work with ec2:DescribeInstances...") return retryUntilSuccess(func() error { _, err := client.DescribeInstances(&ec2.DescribeInstancesInput{}) @@ -571,7 +590,11 @@ func describeAzsTestUnauthorized(accessKey, secretKey, token string) error { Region: aws.String("us-east-1"), HTTPClient: cleanhttp.DefaultClient(), } - client := ec2.New(session.New(awsConfig)) + sess, err := session.NewSession(awsConfig) + if err != nil { + return err + } + client := ec2.New(sess) log.Printf("[WARN] Verifying that the generated credentials don't work with ec2:DescribeAvailabilityZones...") return retryUntilSuccess(func() error { _, err := client.DescribeAvailabilityZones(&ec2.DescribeAvailabilityZonesInput{}) @@ -595,7 +618,11 @@ func assertCreatedIAMUser(accessKey, secretKey, token string) error { Region: aws.String("us-east-1"), HTTPClient: cleanhttp.DefaultClient(), } - client := iam.New(session.New(awsConfig)) + sess, err := session.NewSession(awsConfig) + if err != nil { + return err + } + client := iam.New(sess) log.Printf("[WARN] Checking if IAM User is created properly...") userOutput, err := client.GetUser(&iam.GetUserInput{}) if err != nil { @@ -616,7 +643,11 @@ func listIamUsersTest(accessKey, secretKey, token string) error { Region: aws.String("us-east-1"), HTTPClient: cleanhttp.DefaultClient(), } - client := iam.New(session.New(awsConfig)) + sess, err := session.NewSession(awsConfig) + if err != nil { + return err + } + client := iam.New(sess) log.Printf("[WARN] Verifying that the generated credentials work with iam:ListUsers...") return retryUntilSuccess(func() error { _, err := client.ListUsers(&iam.ListUsersInput{}) @@ -631,7 +662,11 @@ func listDynamoTablesTest(accessKey, secretKey, token string) error { Region: aws.String("us-east-1"), HTTPClient: cleanhttp.DefaultClient(), } - client := dynamodb.New(session.New(awsConfig)) + sess, err := session.NewSession(awsConfig) + if err != nil { + return err + } + client := dynamodb.New(sess) log.Printf("[WARN] Verifying that the generated credentials work with dynamodb:ListTables...") return retryUntilSuccess(func() error { _, err := client.ListTables(&dynamodb.ListTablesInput{}) diff --git a/builtin/logical/aws/client.go b/builtin/logical/aws/client.go index 2575289653f..f37ce100823 100644 --- a/builtin/logical/aws/client.go +++ b/builtin/logical/aws/client.go @@ -74,9 +74,11 @@ func nonCachedClientIAM(ctx context.Context, s logical.Storage) (*iam.IAM, error if err != nil { return nil, err } - - client := iam.New(session.New(awsConfig)) - + sess, err := session.NewSession(awsConfig) + if err != nil { + return nil, err + } + client := iam.New(sess) if client == nil { return nil, fmt.Errorf("could not obtain iam client") } @@ -88,8 +90,11 @@ func nonCachedClientSTS(ctx context.Context, s logical.Storage) (*sts.STS, error if err != nil { return nil, err } - client := sts.New(session.New(awsConfig)) - + sess, err := session.NewSession(awsConfig) + if err != nil { + return nil, err + } + client := sts.New(sess) if client == nil { return nil, fmt.Errorf("could not obtain sts client") } diff --git a/physical/s3/s3.go b/physical/s3/s3.go index b131678d47b..c7a6058dbb5 100644 --- a/physical/s3/s3.go +++ b/physical/s3/s3.go @@ -111,7 +111,7 @@ func NewS3Backend(conf map[string]string, logger log.Logger) (physical.Backend, pooledTransport := cleanhttp.DefaultPooledTransport() pooledTransport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount - s3conn := s3.New(session.New(&aws.Config{ + sess, err := session.NewSession(&aws.Config{ Credentials: creds, HTTPClient: &http.Client{ Transport: pooledTransport, @@ -120,7 +120,11 @@ func NewS3Backend(conf map[string]string, logger log.Logger) (physical.Backend, Region: aws.String(region), S3ForcePathStyle: aws.Bool(s3ForcePathStyleBool), DisableSSL: aws.Bool(disableSSLBool), - })) + }) + if err != nil { + return nil, err + } + s3conn := s3.New(sess) _, err = s3conn.ListObjects(&s3.ListObjectsInput{Bucket: &bucket}) if err != nil { diff --git a/physical/s3/s3_test.go b/physical/s3/s3_test.go index a7a72455419..f29a6ea2891 100644 --- a/physical/s3/s3_test.go +++ b/physical/s3/s3_test.go @@ -7,14 +7,13 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/awsutil" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/physical" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" ) func TestDefaultS3Backend(t *testing.T) { @@ -51,11 +50,15 @@ func DoS3BackendTest(t *testing.T, kmsKeyId string) { region = "us-east-1" } - s3conn := s3.New(session.New(&aws.Config{ + sess, err := session.NewSession(&aws.Config{ Credentials: credsChain, Endpoint: aws.String(endpoint), Region: aws.String(region), - })) + }) + if err != nil { + t.Fatal(err) + } + s3conn := s3.New(sess) var randInt = rand.New(rand.NewSource(time.Now().UnixNano())).Int() bucket := fmt.Sprintf("vault-s3-testacc-%d", randInt)