Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions awsutil/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# AWSUTIL - Go library for generating aws credentials

*NOTE*: This is version 2 of the library. The `v0` branch contains version 0,
which may be needed for legacy applications or while transitioning to version 2.

## Usage

Following is an example usage of generating AWS credentials with static user credentials

```go

// AWS access keys for an IAM user can be used as your AWS credentials.
// This is an example of an access key and secret key
var accessKey = "AKIAIOSFODNN7EXAMPLE"
var secretKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"

// Access key IDs beginning with AKIA are long-term access keys. A long-term
// access key should be supplied when generating static credentials.
config, err := awsutil.NewCredentialsConfig(
awsutil.WithAccessKey(accessKey),
awsutil.WithSecretKey(secretKey),
)
if err != nil {
return err
}

s3Client := s3.NewFromConfig(config)

```

## Contributing to v0
Comment thread
ddebko marked this conversation as resolved.

To push a bug fix or feature for awsutil `v0`, branch out from the [awsutil/v0](https://github.com/hashicorp/go-secure-stdlib/tree/awsutil/v0) branch.
Commit the code changes you want to this new branch and open a PR. Make sure the PR
is configured so that the base branch is set to `awsutil/v0` and not `main`. Once the PR
is reviewed, feel free to merge it into the `awsutil/v0` branch. When creating a new
release, validate that the `Target` branch is `awsutil/v0` and the tag is `awsutil/v0.x.x`.
80 changes: 44 additions & 36 deletions awsutil/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,90 +4,98 @@
package awsutil

import (
"errors"
"context"
"fmt"

"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/iam"
"github.com/aws/aws-sdk-go-v2/service/sts"
)

// IAMAPIFunc is a factory function for returning an IAM interface,
// useful for supplying mock interfaces for testing IAM. The session
// is passed into the function in the same way as done with the
// standard iam.New() constructor.
type IAMAPIFunc func(sess *session.Session) (iamiface.IAMAPI, error)
// useful for supplying mock interfaces for testing IAM.
type IAMAPIFunc func(awsConfig *aws.Config) (IAMClient, error)

// IAMClient represents an iam.Client
type IAMClient interface {
CreateAccessKey(context.Context, *iam.CreateAccessKeyInput, ...func(*iam.Options)) (*iam.CreateAccessKeyOutput, error)
DeleteAccessKey(context.Context, *iam.DeleteAccessKeyInput, ...func(*iam.Options)) (*iam.DeleteAccessKeyOutput, error)
ListAccessKeys(context.Context, *iam.ListAccessKeysInput, ...func(*iam.Options)) (*iam.ListAccessKeysOutput, error)
GetUser(context.Context, *iam.GetUserInput, ...func(*iam.Options)) (*iam.GetUserOutput, error)
}

// STSAPIFunc is a factory function for returning a STS interface,
// useful for supplying mock interfaces for testing STS. The session
// is passed into the function in the same way as done with the
// standard sts.New() constructor.
type STSAPIFunc func(sess *session.Session) (stsiface.STSAPI, error)
// useful for supplying mock interfaces for testing STS.
type STSAPIFunc func(awsConfig *aws.Config) (STSClient, error)

// STSClient represents an sts.Client
type STSClient interface {
AssumeRole(context.Context, *sts.AssumeRoleInput, ...func(*sts.Options)) (*sts.AssumeRoleOutput, error)
GetCallerIdentity(context.Context, *sts.GetCallerIdentityInput, ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error)
}

// IAMClient returns an IAM client.
//
// Supported options: WithSession, WithIAMAPIFunc.
// Supported options: WithAwsConfig, WithIAMAPIFunc, WithIamEndpointResolver.
//
// If WithIAMAPIFunc is supplied, the included function is used as
// the IAM client constructor instead. This can be used for Mocking
// the IAM API.
func (c *CredentialsConfig) IAMClient(opt ...Option) (iamiface.IAMAPI, error) {
func (c *CredentialsConfig) IAMClient(ctx context.Context, opt ...Option) (IAMClient, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("error reading options: %w", err)
}

sess := opts.withAwsSession
if sess == nil {
sess, err = c.GetSession(opt...)
cfg := opts.withAwsConfig
if cfg == nil {
cfg, err = c.GenerateCredentialChain(ctx, opt...)
if err != nil {
return nil, fmt.Errorf("error calling GetSession: %w", err)
return nil, fmt.Errorf("error calling GenerateCredentialChain: %w", err)
}
}

if opts.withIAMAPIFunc != nil {
return opts.withIAMAPIFunc(sess)
return opts.withIAMAPIFunc(cfg)
}

client := iam.New(sess)
if client == nil {
return nil, errors.New("could not obtain iam client from session")
var iamOpts []func(*iam.Options)
if c.IAMEndpointResolver != nil {
iamOpts = append(iamOpts, iam.WithEndpointResolverV2(c.IAMEndpointResolver))
}

return client, nil
return iam.NewFromConfig(*cfg, iamOpts...), nil
}

// STSClient returns a STS client.
//
// Supported options: WithSession, WithSTSAPIFunc.
// Supported options: WithAwsConfig, WithSTSAPIFunc, WithStsEndpointResolver.
//
// If WithSTSAPIFunc is supplied, the included function is used as
// the STS client constructor instead. This can be used for Mocking
// the STS API.
func (c *CredentialsConfig) STSClient(opt ...Option) (stsiface.STSAPI, error) {
func (c *CredentialsConfig) STSClient(ctx context.Context, opt ...Option) (STSClient, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("error reading options: %w", err)
}

sess := opts.withAwsSession
if sess == nil {
sess, err = c.GetSession(opt...)
cfg := opts.withAwsConfig
if cfg == nil {
cfg, err = c.GenerateCredentialChain(ctx, opt...)
if err != nil {
return nil, fmt.Errorf("error calling GetSession: %w", err)
return nil, fmt.Errorf("error calling GenerateCredentialChain: %w", err)
}
}

if opts.withSTSAPIFunc != nil {
return opts.withSTSAPIFunc(sess)
return opts.withSTSAPIFunc(cfg)
}

client := sts.New(sess)
if client == nil {
return nil, errors.New("could not obtain sts client from session")
var stsOpts []func(*sts.Options)
if c.STSEndpointResolver != nil {
stsOpts = append(stsOpts, sts.WithEndpointResolverV2(c.STSEndpointResolver))
}

return client, nil
return sts.NewFromConfig(*cfg, stsOpts...), nil
}
45 changes: 13 additions & 32 deletions awsutil/clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,24 @@
package awsutil

import (
"context"
"errors"
"fmt"
"testing"

"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/aws/aws-sdk-go-v2/service/iam"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/stretchr/testify/require"
)

const testOptionErr = "test option error"
const testBadClientType = "badclienttype"

func testWithBadClientType(o *options) error {
o.withClientType = testBadClientType
return nil
}

func TestCredentialsConfigIAMClient(t *testing.T) {
cases := []struct {
name string
credentialsConfig *CredentialsConfig
opts []Option
require func(t *testing.T, actual iamiface.IAMAPI)
require func(t *testing.T, actual IAMClient)
requireErr string
}{
{
Expand All @@ -37,17 +30,11 @@ func TestCredentialsConfigIAMClient(t *testing.T) {
opts: []Option{MockOptionErr(errors.New(testOptionErr))},
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr),
},
{
name: "session error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{testWithBadClientType},
requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType),
},
{
name: "with mock IAM session",
credentialsConfig: &CredentialsConfig{},
opts: []Option{WithIAMAPIFunc(NewMockIAM())},
require: func(t *testing.T, actual iamiface.IAMAPI) {
require: func(t *testing.T, actual IAMClient) {
t.Helper()
require := require.New(t)
require.Equal(&MockIAM{}, actual)
Expand All @@ -57,10 +44,10 @@ func TestCredentialsConfigIAMClient(t *testing.T) {
name: "no mock client",
credentialsConfig: &CredentialsConfig{},
opts: []Option{},
require: func(t *testing.T, actual iamiface.IAMAPI) {
require: func(t *testing.T, actual IAMClient) {
t.Helper()
require := require.New(t)
require.IsType(&iam.IAM{}, actual)
require.IsType(&iam.Client{}, actual)
},
},
}
Expand All @@ -69,7 +56,7 @@ func TestCredentialsConfigIAMClient(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
actual, err := tc.credentialsConfig.IAMClient(tc.opts...)
actual, err := tc.credentialsConfig.IAMClient(context.TODO(), tc.opts...)
if tc.requireErr != "" {
require.EqualError(err, tc.requireErr)
return
Expand All @@ -86,7 +73,7 @@ func TestCredentialsConfigSTSClient(t *testing.T) {
name string
credentialsConfig *CredentialsConfig
opts []Option
require func(t *testing.T, actual stsiface.STSAPI)
require func(t *testing.T, actual STSClient)
requireErr string
}{
{
Expand All @@ -95,17 +82,11 @@ func TestCredentialsConfigSTSClient(t *testing.T) {
opts: []Option{MockOptionErr(errors.New(testOptionErr))},
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr),
},
{
name: "session error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{testWithBadClientType},
requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType),
},
{
name: "with mock STS session",
credentialsConfig: &CredentialsConfig{},
opts: []Option{WithSTSAPIFunc(NewMockSTS())},
require: func(t *testing.T, actual stsiface.STSAPI) {
require: func(t *testing.T, actual STSClient) {
t.Helper()
require := require.New(t)
require.Equal(&MockSTS{}, actual)
Expand All @@ -115,10 +96,10 @@ func TestCredentialsConfigSTSClient(t *testing.T) {
name: "no mock client",
credentialsConfig: &CredentialsConfig{},
opts: []Option{},
require: func(t *testing.T, actual stsiface.STSAPI) {
require: func(t *testing.T, actual STSClient) {
t.Helper()
require := require.New(t)
require.IsType(&sts.STS{}, actual)
require.IsType(&sts.Client{}, actual)
},
},
}
Expand All @@ -127,7 +108,7 @@ func TestCredentialsConfigSTSClient(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
actual, err := tc.credentialsConfig.STSClient(tc.opts...)
actual, err := tc.credentialsConfig.STSClient(context.TODO(), tc.opts...)
if tc.requireErr != "" {
require.EqualError(err, tc.requireErr)
return
Expand Down
10 changes: 5 additions & 5 deletions awsutil/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package awsutil
import (
"errors"

awsRequest "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go-v2/aws/retry"
multierror "github.com/hashicorp/go-multierror"
)

Expand All @@ -15,10 +15,10 @@ var ErrUpstreamRateLimited = errors.New("upstream rate limited")
// CheckAWSError will examine an error and convert to a logical error if
// appropriate. If no appropriate error is found, return nil
func CheckAWSError(err error) error {
// IsErrorThrottle will check if the error returned is one that matches
// known request limiting errors:
// https://github.com/aws/aws-sdk-go/blob/488d634b5a699b9118ac2befb5135922b4a77210/aws/request/retryer.go#L35
if awsRequest.IsErrorThrottle(err) {
retryErr := retry.ThrottleErrorCode{
Codes: retry.DefaultThrottleErrorCodes,
}
if retryErr.IsErrorThrottle(err).Bool() {
return ErrUpstreamRateLimited
}
return nil
Expand Down
18 changes: 11 additions & 7 deletions awsutil/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"fmt"
"testing"

"github.com/aws/aws-sdk-go/aws/awserr"
awserr "github.com/aws/smithy-go"
multierror "github.com/hashicorp/go-multierror"
)

Expand All @@ -23,12 +23,16 @@ func Test_CheckAWSError(t *testing.T) {
},
{
Name: "Upstream throttle error",
Err: awserr.New("Throttling", "", nil),
Err: MockAWSThrottleErr(),
Expected: ErrUpstreamRateLimited,
},
{
Name: "Upstream RequestLimitExceeded",
Err: awserr.New("RequestLimitExceeded", "Request rate limited", nil),
Name: "Upstream RequestLimitExceeded",
Err: &MockAWSErr{
Code: "RequestLimitExceeded",
Message: "Request rate limited",
Fault: awserr.FaultServer,
},
Expected: ErrUpstreamRateLimited,
},
}
Expand All @@ -50,7 +54,7 @@ func Test_CheckAWSError(t *testing.T) {
}

func Test_AppendRateLimitedError(t *testing.T) {
awsErr := awserr.New("Throttling", "", nil)
throttleErr := MockAWSThrottleErr()
testCases := []struct {
Name string
Err error
Expand All @@ -63,8 +67,8 @@ func Test_AppendRateLimitedError(t *testing.T) {
},
{
Name: "Upstream throttle error",
Err: awsErr,
Expected: multierror.Append(awsErr, ErrUpstreamRateLimited),
Err: throttleErr,
Expected: multierror.Append(throttleErr, ErrUpstreamRateLimited),
},
{
Name: "Nil",
Expand Down
Loading