Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 60 additions & 25 deletions builtin/logical/aws/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,11 @@ func getAccountID() (string, error) {
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
svc := sts.New(session.New(awsConfig))
sess, err := session.NewSession(awsConfig)
if err != nil {
return "", err
}
svc := sts.New(sess)

params := &sts.GetCallerIdentityInput{}
res, err := svc.GetCallerIdentity(params)
Expand Down Expand Up @@ -240,7 +244,11 @@ func createRole(t *testing.T, roleName, awsAccountID string, policyARNs []string
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
svc := iam.New(session.New(awsConfig))
sess, err := session.NewSession(awsConfig)
if err != nil {
t.Fatal(err)
}
svc := iam.New(sess)
trustPolicy := fmt.Sprintf(testRoleAssumePolicy, awsAccountID)

params := &iam.CreateRoleInput{
Expand All @@ -250,9 +258,7 @@ func createRole(t *testing.T, roleName, awsAccountID string, policyARNs []string
}

log.Printf("[INFO] AWS CreateRole: %s", roleName)
_, err := svc.CreateRole(params)

if err != nil {
if _, err := svc.CreateRole(params); err != nil {
t.Fatalf("AWS CreateRole failed: %v", err)
}

Expand Down Expand Up @@ -303,14 +309,16 @@ func createUser(t *testing.T, userName string, accessKey *awsAccessKey) {
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
svc := iam.New(session.New(awsConfig))

sess, err := session.NewSession(awsConfig)
if err != nil {
t.Fatal(err)
}
svc := iam.New(sess)
createUserInput := &iam.CreateUserInput{
UserName: aws.String(userName),
}
log.Printf("[INFO] AWS CreateUser: %s", userName)
_, err := svc.CreateUser(createUserInput)
if err != nil {
if _, err := svc.CreateUser(createUserInput); err != nil {
t.Fatalf("AWS CreateUser failed: %v", err)
}

Expand Down Expand Up @@ -354,8 +362,11 @@ func deleteTestRole(roleName string) error {
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
svc := iam.New(session.New(awsConfig))

sess, err := session.NewSession(awsConfig)
if err != nil {
return err
}
svc := iam.New(sess)
listAttachmentsInput := &iam.ListAttachedRolePoliciesInput{
RoleName: aws.String(roleName),
}
Expand All @@ -372,8 +383,7 @@ func deleteTestRole(roleName string) error {
}
return true
}
err := svc.ListAttachedRolePoliciesPages(listAttachmentsInput, detacher)
if err != nil {
if err := svc.ListAttachedRolePoliciesPages(listAttachmentsInput, detacher); err != nil {
log.Printf("[WARN] AWS DetachRolePolicy failed: %v", err)
}

Expand All @@ -396,14 +406,16 @@ func deleteTestUser(accessKey *awsAccessKey, userName string) error {
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
svc := iam.New(session.New(awsConfig))

sess, err := session.NewSession(awsConfig)
if err != nil {
return err
}
svc := iam.New(sess)
userDetachment := &iam.DetachUserPolicyInput{
PolicyArn: aws.String("arn:aws:iam::aws:policy/AdministratorAccess"),
UserName: aws.String(userName),
}
_, err := svc.DetachUserPolicy(userDetachment)
if err != nil {
if _, err := svc.DetachUserPolicy(userDetachment); err != nil {
log.Printf("[WARN] AWS DetachUserPolicy failed: %v", err)
return err
}
Expand Down Expand Up @@ -490,10 +502,13 @@ func testAccStepRotateRoot(oldAccessKey *awsAccessKey) logicaltest.TestStep {
oldAccessKey.AccessKeyID = newAccessKeyID
log.Println("[WARN] Sleeping for 10 seconds waiting for AWS...")
time.Sleep(10 * time.Second)
svc := sts.New(session.New(awsConfig))
sess, err := session.NewSession(awsConfig)
if err != nil {
return err
}
svc := sts.New(sess)
params := &sts.GetCallerIdentityInput{}
_, err := svc.GetCallerIdentity(params)
if err == nil {
if _, err := svc.GetCallerIdentity(params); err == nil {
return fmt.Errorf("bad: old credentials succeeded after rotate")
}
if aerr, ok := err.(awserr.Error); ok {
Expand Down Expand Up @@ -556,7 +571,11 @@ func describeInstancesTest(accessKey, secretKey, token string) error {
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
client := ec2.New(session.New(awsConfig))
sess, err := session.NewSession(awsConfig)
if err != nil {
return err
}
client := ec2.New(sess)
log.Printf("[WARN] Verifying that the generated credentials work with ec2:DescribeInstances...")
return retryUntilSuccess(func() error {
_, err := client.DescribeInstances(&ec2.DescribeInstancesInput{})
Expand All @@ -571,7 +590,11 @@ func describeAzsTestUnauthorized(accessKey, secretKey, token string) error {
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
client := ec2.New(session.New(awsConfig))
sess, err := session.NewSession(awsConfig)
if err != nil {
return err
}
client := ec2.New(sess)
log.Printf("[WARN] Verifying that the generated credentials don't work with ec2:DescribeAvailabilityZones...")
return retryUntilSuccess(func() error {
_, err := client.DescribeAvailabilityZones(&ec2.DescribeAvailabilityZonesInput{})
Expand All @@ -595,7 +618,11 @@ func assertCreatedIAMUser(accessKey, secretKey, token string) error {
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
client := iam.New(session.New(awsConfig))
sess, err := session.NewSession(awsConfig)
if err != nil {
return err
}
client := iam.New(sess)
log.Printf("[WARN] Checking if IAM User is created properly...")
userOutput, err := client.GetUser(&iam.GetUserInput{})
if err != nil {
Expand All @@ -616,7 +643,11 @@ func listIamUsersTest(accessKey, secretKey, token string) error {
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
client := iam.New(session.New(awsConfig))
sess, err := session.NewSession(awsConfig)
if err != nil {
return err
}
client := iam.New(sess)
log.Printf("[WARN] Verifying that the generated credentials work with iam:ListUsers...")
return retryUntilSuccess(func() error {
_, err := client.ListUsers(&iam.ListUsersInput{})
Expand All @@ -631,7 +662,11 @@ func listDynamoTablesTest(accessKey, secretKey, token string) error {
Region: aws.String("us-east-1"),
HTTPClient: cleanhttp.DefaultClient(),
}
client := dynamodb.New(session.New(awsConfig))
sess, err := session.NewSession(awsConfig)
if err != nil {
return err
}
client := dynamodb.New(sess)
log.Printf("[WARN] Verifying that the generated credentials work with dynamodb:ListTables...")
return retryUntilSuccess(func() error {
_, err := client.ListTables(&dynamodb.ListTablesInput{})
Expand Down
15 changes: 10 additions & 5 deletions builtin/logical/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ func nonCachedClientIAM(ctx context.Context, s logical.Storage) (*iam.IAM, error
if err != nil {
return nil, err
}

client := iam.New(session.New(awsConfig))

sess, err := session.NewSession(awsConfig)
if err != nil {
return nil, err
}
client := iam.New(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain iam client")
}
Expand All @@ -88,8 +90,11 @@ func nonCachedClientSTS(ctx context.Context, s logical.Storage) (*sts.STS, error
if err != nil {
return nil, err
}
client := sts.New(session.New(awsConfig))

sess, err := session.NewSession(awsConfig)
if err != nil {
return nil, err
}
client := sts.New(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain sts client")
}
Expand Down
8 changes: 6 additions & 2 deletions physical/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func NewS3Backend(conf map[string]string, logger log.Logger) (physical.Backend,
pooledTransport := cleanhttp.DefaultPooledTransport()
pooledTransport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount

s3conn := s3.New(session.New(&aws.Config{
sess, err := session.NewSession(&aws.Config{
Credentials: creds,
HTTPClient: &http.Client{
Transport: pooledTransport,
Expand All @@ -120,7 +120,11 @@ func NewS3Backend(conf map[string]string, logger log.Logger) (physical.Backend,
Region: aws.String(region),
S3ForcePathStyle: aws.Bool(s3ForcePathStyleBool),
DisableSSL: aws.Bool(disableSSLBool),
}))
})
if err != nil {
return nil, err
}
s3conn := s3.New(sess)

_, err = s3conn.ListObjects(&s3.ListObjectsInput{Bucket: &bucket})
if err != nil {
Expand Down
15 changes: 9 additions & 6 deletions physical/s3/s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/awsutil"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/physical"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
)

func TestDefaultS3Backend(t *testing.T) {
Expand Down Expand Up @@ -51,11 +50,15 @@ func DoS3BackendTest(t *testing.T, kmsKeyId string) {
region = "us-east-1"
}

s3conn := s3.New(session.New(&aws.Config{
sess, err := session.NewSession(&aws.Config{
Credentials: credsChain,
Endpoint: aws.String(endpoint),
Region: aws.String(region),
}))
})
if err != nil {
t.Fatal(err)
}
s3conn := s3.New(sess)

var randInt = rand.New(rand.NewSource(time.Now().UnixNano())).Int()
bucket := fmt.Sprintf("vault-s3-testacc-%d", randInt)
Expand Down