Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions auth/aws/implementation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
Copyright 2025 The Flux authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package aws

import (
"context"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ecr"
"github.com/aws/aws-sdk-go-v2/service/sts"
)

// Implementation provides the required methods of the AWS libraries.
type Implementation interface {
LoadDefaultConfig(ctx context.Context, optFns ...func(*config.LoadOptions) error) (aws.Config, error)
AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, options sts.Options) (*sts.AssumeRoleWithWebIdentityOutput, error)
GetAuthorizationToken(ctx context.Context, cfg aws.Config) (*ecr.GetAuthorizationTokenOutput, error)
}

type implementation struct{}

func (implementation) LoadDefaultConfig(ctx context.Context, optFns ...func(*config.LoadOptions) error) (aws.Config, error) {
return config.LoadDefaultConfig(ctx, optFns...)
}

func (implementation) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, options sts.Options) (*sts.AssumeRoleWithWebIdentityOutput, error) {
return sts.New(options).AssumeRoleWithWebIdentity(ctx, params)
}

func (implementation) GetAuthorizationToken(ctx context.Context, cfg aws.Config) (*ecr.GetAuthorizationTokenOutput, error) {
return ecr.NewFromConfig(cfg).GetAuthorizationToken(ctx, &ecr.GetAuthorizationTokenInput{})
}
116 changes: 116 additions & 0 deletions auth/aws/implementation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
Copyright 2025 The Flux authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package aws_test

import (
"context"
"net/http"
"net/url"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ecr"
ecrtypes "github.com/aws/aws-sdk-go-v2/service/ecr/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
. "github.com/onsi/gomega"
)

type mockImplementation struct {
t *testing.T

argRoleARN string
argRoleSessionName string
argOIDCToken string
argRegion string
argSTSEndpoint string
argProxyURL *url.URL
argCredsProvider aws.CredentialsProvider
}

type mockCredentialsProvider struct{}

func (m *mockImplementation) LoadDefaultConfig(ctx context.Context, optFns ...func(*config.LoadOptions) error) (aws.Config, error) {
m.t.Helper()
g := NewWithT(m.t)
var o config.LoadOptions
for _, optFn := range optFns {
optFn(&o)
}
g.Expect(o.Region).To(Equal(m.argRegion))
g.Expect(o.BaseEndpoint).To(Equal(m.argSTSEndpoint))
g.Expect(o.HTTPClient).NotTo(BeNil())
g.Expect(o.HTTPClient.(*http.Client)).NotTo(BeNil())
g.Expect(o.HTTPClient.(*http.Client).Transport).NotTo(BeNil())
g.Expect(o.HTTPClient.(*http.Client).Transport.(*http.Transport)).NotTo(BeNil())
g.Expect(o.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy).NotTo(BeNil())
proxyURL, err := o.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy(nil)
g.Expect(err).NotTo(HaveOccurred())
g.Expect(proxyURL).To(Equal(m.argProxyURL))
return aws.Config{Credentials: mockCredentialsProvider{}}, nil
}

func (m *mockImplementation) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, options sts.Options) (*sts.AssumeRoleWithWebIdentityOutput, error) {
m.t.Helper()
g := NewWithT(m.t)
g.Expect(params).NotTo(BeNil())
g.Expect(params.RoleArn).NotTo(BeNil())
g.Expect(*params.RoleArn).To(Equal(m.argRoleARN))
g.Expect(params.RoleSessionName).NotTo(BeNil())
g.Expect(*params.RoleSessionName).To(Equal(m.argRoleSessionName))
g.Expect(params.WebIdentityToken).NotTo(BeNil())
g.Expect(*params.WebIdentityToken).To(Equal(m.argOIDCToken))
g.Expect(options.Region).To(Equal(m.argRegion))
g.Expect(options.BaseEndpoint).NotTo(BeNil())
g.Expect(*options.BaseEndpoint).To(Equal(m.argSTSEndpoint))
g.Expect(options.HTTPClient).NotTo(BeNil())
g.Expect(options.HTTPClient.(*http.Client)).NotTo(BeNil())
g.Expect(options.HTTPClient.(*http.Client).Transport).NotTo(BeNil())
g.Expect(options.HTTPClient.(*http.Client).Transport.(*http.Transport)).NotTo(BeNil())
g.Expect(options.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy).NotTo(BeNil())
proxyURL, err := options.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy(nil)
g.Expect(err).NotTo(HaveOccurred())
g.Expect(proxyURL).To(Equal(m.argProxyURL))
return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &ststypes.Credentials{},
}, nil
}

func (m *mockImplementation) GetAuthorizationToken(ctx context.Context, cfg aws.Config) (*ecr.GetAuthorizationTokenOutput, error) {
m.t.Helper()
g := NewWithT(m.t)
g.Expect(cfg.Region).To(Equal(m.argRegion))
g.Expect(cfg.Credentials).To(Equal(m.argCredsProvider))
g.Expect(cfg.HTTPClient).NotTo(BeNil())
g.Expect(cfg.HTTPClient.(*http.Client)).NotTo(BeNil())
g.Expect(cfg.HTTPClient.(*http.Client).Transport).NotTo(BeNil())
g.Expect(cfg.HTTPClient.(*http.Client).Transport.(*http.Transport)).NotTo(BeNil())
g.Expect(cfg.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy).NotTo(BeNil())
proxyURL, err := cfg.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy(nil)
g.Expect(err).NotTo(HaveOccurred())
g.Expect(proxyURL).To(Equal(m.argProxyURL))
return &ecr.GetAuthorizationTokenOutput{
AuthorizationData: []ecrtypes.AuthorizationData{{
AuthorizationToken: aws.String("dXNlcm5hbWU6cGFzc3dvcmQ="),
}},
}, nil
}

func (mockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
return aws.Credentials{}, nil
}
35 changes: 30 additions & 5 deletions auth/aws/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,36 @@ import (
corev1 "k8s.io/api/core/v1"
)

func getRegion() string {
func getSTSRegion() (string, error) {
// The AWS_REGION is usually automatically set in EKS clusters.
// If not set users can set it manually (e.g. Fargate).
return os.Getenv("AWS_REGION")
// If not, users can set it manually (e.g. Fargate).
region := os.Getenv("AWS_REGION")
if region == "" {
return "", fmt.Errorf("AWS_REGION environment variable is not set in the Flux controller")
}
return region, nil
}

const stsEndpointPattern = `^https://(.+\.)?sts(-fips)?(\.[^.]+)?(\.vpce)?\.amazonaws\.com$`

var stsEndpointRegex = regexp.MustCompile(stsEndpointPattern)

// ValidateSTSEndpoint checks if the provided STS endpoint is valid.
//
// Global and regional endpoints:
//
// https://docs.aws.amazon.com/general/latest/gr/sts.html
//
// VPC endpoint examples:
//
// https://vpce-002b7cc8966426bc6-njisq19r.sts.us-east-1.vpce.amazonaws.com
// https://vpce-002b7cc8966426bc6-njisq19r-us-east-1a.sts.us-east-1.vpce.amazonaws.com
func ValidateSTSEndpoint(endpoint string) error {
if !stsEndpointRegex.MatchString(endpoint) {
return fmt.Errorf("invalid STS endpoint: '%s'. must match %s",
endpoint, stsEndpointPattern)
}
return nil
}

const roleARNPattern = `^arn:aws:iam::[0-9]{1,30}:role/.{1,200}$`
Expand All @@ -43,10 +69,9 @@ func getRoleARN(serviceAccount corev1.ServiceAccount) (string, error) {
return arn, nil
}

func getRoleSessionName(serviceAccount corev1.ServiceAccount) string {
func getRoleSessionName(serviceAccount corev1.ServiceAccount, region string) string {
name := serviceAccount.Name
namespace := serviceAccount.Namespace
region := getRegion()
return fmt.Sprintf("%s.%s.%s.fluxcd.io", name, namespace, region)
}

Expand Down
129 changes: 129 additions & 0 deletions auth/aws/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,135 @@ import (
"github.com/fluxcd/pkg/auth/aws"
)

func TestValidateSTSEndpoint(t *testing.T) {
for _, tt := range []struct {
name string
stsEndpoint string
valid bool
}{
// valid endpoints
{
name: "global endpoint",
stsEndpoint: "https://sts.amazonaws.com",
valid: true,
},
{
name: "sts.us-east-2.amazonaws.com",
stsEndpoint: "https://sts.us-east-2.amazonaws.com",
valid: true,
},
{
name: "sts-fips.us-east-2.amazonaws.com",
stsEndpoint: "https://sts-fips.us-east-2.amazonaws.com",
valid: true,
},
{
name: "sts.us-east-1.amazonaws.com",
stsEndpoint: "https://sts.us-east-1.amazonaws.com",
valid: true,
},
{
name: "sts-fips.us-east-1.amazonaws.com",
stsEndpoint: "https://sts-fips.us-east-1.amazonaws.com",
valid: true,
},
{
name: "sts.us-west-1.amazonaws.com",
stsEndpoint: "https://sts.us-west-1.amazonaws.com",
valid: true,
},
{
name: "sts-fips.us-west-1.amazonaws.com",
stsEndpoint: "https://sts-fips.us-west-1.amazonaws.com",
valid: true,
},
{
name: "sts.us-west-2.amazonaws.com",
stsEndpoint: "https://sts.us-west-2.amazonaws.com",
valid: true,
},
{
name: "sts-fips.us-west-2.amazonaws.com",
stsEndpoint: "https://sts-fips.us-west-2.amazonaws.com",
valid: true,
},
{
name: "sts.il-central-1.amazonaws.com",
stsEndpoint: "https://sts.il-central-1.amazonaws.com",
valid: true,
},
{
name: "sts.mx-central-1.amazonaws.com",
stsEndpoint: "https://sts.mx-central-1.amazonaws.com",
valid: true,
},
{
name: "sts.me-south-1.amazonaws.com",
stsEndpoint: "https://sts.me-south-1.amazonaws.com",
valid: true,
},
{
name: "sts.me-central-1.amazonaws.com",
stsEndpoint: "https://sts.me-central-1.amazonaws.com",
valid: true,
},
{
name: "sts.sa-east-1.amazonaws.com",
stsEndpoint: "https://sts.sa-east-1.amazonaws.com",
valid: true,
},
{
name: "sts.us-gov-east-1.amazonaws.com",
stsEndpoint: "https://sts.us-gov-east-1.amazonaws.com",
valid: true,
},
{
name: "sts.us-gov-west-1.amazonaws.com",
stsEndpoint: "https://sts.us-gov-west-1.amazonaws.com",
valid: true,
},
{
name: "vpce-002b7cc8966426bc6-njisq19r.sts.us-east-1.vpce.amazonaws.com",
stsEndpoint: "https://vpce-002b7cc8966426bc6-njisq19r.sts.us-east-1.vpce.amazonaws.com",
valid: true,
},
{
name: "vpce-002b7cc8966426bc6-njisq19r-us-east-1a.sts.us-east-1.vpce.amazonaws.com",
stsEndpoint: "https://vpce-002b7cc8966426bc6-njisq19r-us-east-1a.sts.us-east-1.vpce.amazonaws.com",
valid: true,
},
// invalid endpoints
{
name: "non sts endpoint",
stsEndpoint: "https://stss.amazonaws.com",
valid: false,
},
{
name: "non aws endpoint",
stsEndpoint: "https://sts.amazonaws.example.com",
valid: false,
},
{
name: "http endpoint",
stsEndpoint: "http://sts.amazonaws.com",
valid: false,
},
{
name: "no scheme",
stsEndpoint: "sts.amazonaws.com",
valid: false,
},
} {
t.Run(tt.name, func(t *testing.T) {
g := NewWithT(t)

err := aws.ValidateSTSEndpoint(tt.stsEndpoint)

g.Expect(err == nil).To(Equal(tt.valid))
})
}
}

func TestParseRegistry(t *testing.T) {
tests := []struct {
registry string
Expand Down
Loading
Loading