From 3976c50177be24d486ab0812faf87f48bfdd7acc Mon Sep 17 00:00:00 2001 From: Matheus Pimenta Date: Fri, 2 May 2025 13:50:47 +0100 Subject: [PATCH] [RFC-0010] Add tests for auth providers Signed-off-by: Matheus Pimenta --- auth/aws/implementation.go | 47 ++++++++ auth/aws/implementation_test.go | 116 ++++++++++++++++++++ auth/aws/options.go | 35 +++++- auth/aws/options_test.go | 129 ++++++++++++++++++++++ auth/aws/provider.go | 57 ++++++---- auth/aws/provider_test.go | 171 ++++++++++++++++++++++++++++++ auth/azure/implementation.go | 46 ++++++++ auth/azure/implementation_test.go | 102 ++++++++++++++++++ auth/azure/provider.go | 22 ++-- auth/azure/provider_test.go | 121 +++++++++++++++++++++ auth/gcp/implementation.go | 41 +++++++ auth/gcp/implementation_test.go | 70 ++++++++++++ auth/gcp/provider.go | 20 ++-- auth/gcp/provider_test.go | 158 +++++++++++++++++++++++++++ auth/gcp/token_supplier.go | 5 +- auth/get_token_test.go | 50 +++++++-- auth/go.mod | 2 +- git/go.mod | 4 +- git/gogit/go.mod | 6 +- git/internal/e2e/go.mod | 6 +- oci/tests/integration/go.mod | 6 +- 21 files changed, 1152 insertions(+), 62 deletions(-) create mode 100644 auth/aws/implementation.go create mode 100644 auth/aws/implementation_test.go create mode 100644 auth/aws/provider_test.go create mode 100644 auth/azure/implementation.go create mode 100644 auth/azure/implementation_test.go create mode 100644 auth/azure/provider_test.go create mode 100644 auth/gcp/implementation.go create mode 100644 auth/gcp/implementation_test.go create mode 100644 auth/gcp/provider_test.go diff --git a/auth/aws/implementation.go b/auth/aws/implementation.go new file mode 100644 index 000000000..933a83ef1 --- /dev/null +++ b/auth/aws/implementation.go @@ -0,0 +1,47 @@ +/* +Copyright 2025 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ecr" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +// Implementation provides the required methods of the AWS libraries. +type Implementation interface { + LoadDefaultConfig(ctx context.Context, optFns ...func(*config.LoadOptions) error) (aws.Config, error) + AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, options sts.Options) (*sts.AssumeRoleWithWebIdentityOutput, error) + GetAuthorizationToken(ctx context.Context, cfg aws.Config) (*ecr.GetAuthorizationTokenOutput, error) +} + +type implementation struct{} + +func (implementation) LoadDefaultConfig(ctx context.Context, optFns ...func(*config.LoadOptions) error) (aws.Config, error) { + return config.LoadDefaultConfig(ctx, optFns...) +} + +func (implementation) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, options sts.Options) (*sts.AssumeRoleWithWebIdentityOutput, error) { + return sts.New(options).AssumeRoleWithWebIdentity(ctx, params) +} + +func (implementation) GetAuthorizationToken(ctx context.Context, cfg aws.Config) (*ecr.GetAuthorizationTokenOutput, error) { + return ecr.NewFromConfig(cfg).GetAuthorizationToken(ctx, &ecr.GetAuthorizationTokenInput{}) +} diff --git a/auth/aws/implementation_test.go b/auth/aws/implementation_test.go new file mode 100644 index 000000000..e5e6c5741 --- /dev/null +++ b/auth/aws/implementation_test.go @@ -0,0 +1,116 @@ +/* +Copyright 2025 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws_test + +import ( + "context" + "net/http" + "net/url" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ecr" + ecrtypes "github.com/aws/aws-sdk-go-v2/service/ecr/types" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + . "github.com/onsi/gomega" +) + +type mockImplementation struct { + t *testing.T + + argRoleARN string + argRoleSessionName string + argOIDCToken string + argRegion string + argSTSEndpoint string + argProxyURL *url.URL + argCredsProvider aws.CredentialsProvider +} + +type mockCredentialsProvider struct{} + +func (m *mockImplementation) LoadDefaultConfig(ctx context.Context, optFns ...func(*config.LoadOptions) error) (aws.Config, error) { + m.t.Helper() + g := NewWithT(m.t) + var o config.LoadOptions + for _, optFn := range optFns { + optFn(&o) + } + g.Expect(o.Region).To(Equal(m.argRegion)) + g.Expect(o.BaseEndpoint).To(Equal(m.argSTSEndpoint)) + g.Expect(o.HTTPClient).NotTo(BeNil()) + g.Expect(o.HTTPClient.(*http.Client)).NotTo(BeNil()) + g.Expect(o.HTTPClient.(*http.Client).Transport).NotTo(BeNil()) + g.Expect(o.HTTPClient.(*http.Client).Transport.(*http.Transport)).NotTo(BeNil()) + g.Expect(o.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy).NotTo(BeNil()) + proxyURL, err := o.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy(nil) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(proxyURL).To(Equal(m.argProxyURL)) + return aws.Config{Credentials: mockCredentialsProvider{}}, nil +} + +func (m *mockImplementation) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, options sts.Options) (*sts.AssumeRoleWithWebIdentityOutput, error) { + m.t.Helper() + g := NewWithT(m.t) + g.Expect(params).NotTo(BeNil()) + g.Expect(params.RoleArn).NotTo(BeNil()) + g.Expect(*params.RoleArn).To(Equal(m.argRoleARN)) + g.Expect(params.RoleSessionName).NotTo(BeNil()) + g.Expect(*params.RoleSessionName).To(Equal(m.argRoleSessionName)) + g.Expect(params.WebIdentityToken).NotTo(BeNil()) + g.Expect(*params.WebIdentityToken).To(Equal(m.argOIDCToken)) + g.Expect(options.Region).To(Equal(m.argRegion)) + g.Expect(options.BaseEndpoint).NotTo(BeNil()) + g.Expect(*options.BaseEndpoint).To(Equal(m.argSTSEndpoint)) + g.Expect(options.HTTPClient).NotTo(BeNil()) + g.Expect(options.HTTPClient.(*http.Client)).NotTo(BeNil()) + g.Expect(options.HTTPClient.(*http.Client).Transport).NotTo(BeNil()) + g.Expect(options.HTTPClient.(*http.Client).Transport.(*http.Transport)).NotTo(BeNil()) + g.Expect(options.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy).NotTo(BeNil()) + proxyURL, err := options.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy(nil) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(proxyURL).To(Equal(m.argProxyURL)) + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &ststypes.Credentials{}, + }, nil +} + +func (m *mockImplementation) GetAuthorizationToken(ctx context.Context, cfg aws.Config) (*ecr.GetAuthorizationTokenOutput, error) { + m.t.Helper() + g := NewWithT(m.t) + g.Expect(cfg.Region).To(Equal(m.argRegion)) + g.Expect(cfg.Credentials).To(Equal(m.argCredsProvider)) + g.Expect(cfg.HTTPClient).NotTo(BeNil()) + g.Expect(cfg.HTTPClient.(*http.Client)).NotTo(BeNil()) + g.Expect(cfg.HTTPClient.(*http.Client).Transport).NotTo(BeNil()) + g.Expect(cfg.HTTPClient.(*http.Client).Transport.(*http.Transport)).NotTo(BeNil()) + g.Expect(cfg.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy).NotTo(BeNil()) + proxyURL, err := cfg.HTTPClient.(*http.Client).Transport.(*http.Transport).Proxy(nil) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(proxyURL).To(Equal(m.argProxyURL)) + return &ecr.GetAuthorizationTokenOutput{ + AuthorizationData: []ecrtypes.AuthorizationData{{ + AuthorizationToken: aws.String("dXNlcm5hbWU6cGFzc3dvcmQ="), + }}, + }, nil +} + +func (mockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + return aws.Credentials{}, nil +} diff --git a/auth/aws/options.go b/auth/aws/options.go index 1be1143cc..09ff65f7e 100644 --- a/auth/aws/options.go +++ b/auth/aws/options.go @@ -24,10 +24,36 @@ import ( corev1 "k8s.io/api/core/v1" ) -func getRegion() string { +func getSTSRegion() (string, error) { // The AWS_REGION is usually automatically set in EKS clusters. - // If not set users can set it manually (e.g. Fargate). - return os.Getenv("AWS_REGION") + // If not, users can set it manually (e.g. Fargate). + region := os.Getenv("AWS_REGION") + if region == "" { + return "", fmt.Errorf("AWS_REGION environment variable is not set in the Flux controller") + } + return region, nil +} + +const stsEndpointPattern = `^https://(.+\.)?sts(-fips)?(\.[^.]+)?(\.vpce)?\.amazonaws\.com$` + +var stsEndpointRegex = regexp.MustCompile(stsEndpointPattern) + +// ValidateSTSEndpoint checks if the provided STS endpoint is valid. +// +// Global and regional endpoints: +// +// https://docs.aws.amazon.com/general/latest/gr/sts.html +// +// VPC endpoint examples: +// +// https://vpce-002b7cc8966426bc6-njisq19r.sts.us-east-1.vpce.amazonaws.com +// https://vpce-002b7cc8966426bc6-njisq19r-us-east-1a.sts.us-east-1.vpce.amazonaws.com +func ValidateSTSEndpoint(endpoint string) error { + if !stsEndpointRegex.MatchString(endpoint) { + return fmt.Errorf("invalid STS endpoint: '%s'. must match %s", + endpoint, stsEndpointPattern) + } + return nil } const roleARNPattern = `^arn:aws:iam::[0-9]{1,30}:role/.{1,200}$` @@ -43,10 +69,9 @@ func getRoleARN(serviceAccount corev1.ServiceAccount) (string, error) { return arn, nil } -func getRoleSessionName(serviceAccount corev1.ServiceAccount) string { +func getRoleSessionName(serviceAccount corev1.ServiceAccount, region string) string { name := serviceAccount.Name namespace := serviceAccount.Namespace - region := getRegion() return fmt.Sprintf("%s.%s.%s.fluxcd.io", name, namespace, region) } diff --git a/auth/aws/options_test.go b/auth/aws/options_test.go index a1bfa56a2..67f8fa8b7 100644 --- a/auth/aws/options_test.go +++ b/auth/aws/options_test.go @@ -24,6 +24,135 @@ import ( "github.com/fluxcd/pkg/auth/aws" ) +func TestValidateSTSEndpoint(t *testing.T) { + for _, tt := range []struct { + name string + stsEndpoint string + valid bool + }{ + // valid endpoints + { + name: "global endpoint", + stsEndpoint: "https://sts.amazonaws.com", + valid: true, + }, + { + name: "sts.us-east-2.amazonaws.com", + stsEndpoint: "https://sts.us-east-2.amazonaws.com", + valid: true, + }, + { + name: "sts-fips.us-east-2.amazonaws.com", + stsEndpoint: "https://sts-fips.us-east-2.amazonaws.com", + valid: true, + }, + { + name: "sts.us-east-1.amazonaws.com", + stsEndpoint: "https://sts.us-east-1.amazonaws.com", + valid: true, + }, + { + name: "sts-fips.us-east-1.amazonaws.com", + stsEndpoint: "https://sts-fips.us-east-1.amazonaws.com", + valid: true, + }, + { + name: "sts.us-west-1.amazonaws.com", + stsEndpoint: "https://sts.us-west-1.amazonaws.com", + valid: true, + }, + { + name: "sts-fips.us-west-1.amazonaws.com", + stsEndpoint: "https://sts-fips.us-west-1.amazonaws.com", + valid: true, + }, + { + name: "sts.us-west-2.amazonaws.com", + stsEndpoint: "https://sts.us-west-2.amazonaws.com", + valid: true, + }, + { + name: "sts-fips.us-west-2.amazonaws.com", + stsEndpoint: "https://sts-fips.us-west-2.amazonaws.com", + valid: true, + }, + { + name: "sts.il-central-1.amazonaws.com", + stsEndpoint: "https://sts.il-central-1.amazonaws.com", + valid: true, + }, + { + name: "sts.mx-central-1.amazonaws.com", + stsEndpoint: "https://sts.mx-central-1.amazonaws.com", + valid: true, + }, + { + name: "sts.me-south-1.amazonaws.com", + stsEndpoint: "https://sts.me-south-1.amazonaws.com", + valid: true, + }, + { + name: "sts.me-central-1.amazonaws.com", + stsEndpoint: "https://sts.me-central-1.amazonaws.com", + valid: true, + }, + { + name: "sts.sa-east-1.amazonaws.com", + stsEndpoint: "https://sts.sa-east-1.amazonaws.com", + valid: true, + }, + { + name: "sts.us-gov-east-1.amazonaws.com", + stsEndpoint: "https://sts.us-gov-east-1.amazonaws.com", + valid: true, + }, + { + name: "sts.us-gov-west-1.amazonaws.com", + stsEndpoint: "https://sts.us-gov-west-1.amazonaws.com", + valid: true, + }, + { + name: "vpce-002b7cc8966426bc6-njisq19r.sts.us-east-1.vpce.amazonaws.com", + stsEndpoint: "https://vpce-002b7cc8966426bc6-njisq19r.sts.us-east-1.vpce.amazonaws.com", + valid: true, + }, + { + name: "vpce-002b7cc8966426bc6-njisq19r-us-east-1a.sts.us-east-1.vpce.amazonaws.com", + stsEndpoint: "https://vpce-002b7cc8966426bc6-njisq19r-us-east-1a.sts.us-east-1.vpce.amazonaws.com", + valid: true, + }, + // invalid endpoints + { + name: "non sts endpoint", + stsEndpoint: "https://stss.amazonaws.com", + valid: false, + }, + { + name: "non aws endpoint", + stsEndpoint: "https://sts.amazonaws.example.com", + valid: false, + }, + { + name: "http endpoint", + stsEndpoint: "http://sts.amazonaws.com", + valid: false, + }, + { + name: "no scheme", + stsEndpoint: "sts.amazonaws.com", + valid: false, + }, + } { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + err := aws.ValidateSTSEndpoint(tt.stsEndpoint) + + g.Expect(err == nil).To(Equal(tt.valid)) + }) + } +} + func TestParseRegistry(t *testing.T) { tests := []struct { registry string diff --git a/auth/aws/provider.go b/auth/aws/provider.go index feda04550..c3a154b72 100644 --- a/auth/aws/provider.go +++ b/auth/aws/provider.go @@ -26,7 +26,6 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/service/ecr" "github.com/aws/aws-sdk-go-v2/service/sts" corev1 "k8s.io/api/core/v1" @@ -37,7 +36,7 @@ import ( const ProviderName = "aws" // Provider implements the auth.Provider interface for AWS authentication. -type Provider struct{} +type Provider struct{ Implementation } // GetName implements auth.Provider. func (Provider) GetName() string { @@ -45,16 +44,22 @@ func (Provider) GetName() string { } // NewDefaultToken implements auth.Provider. -func (Provider) NewDefaultToken(ctx context.Context, opts ...auth.Option) (auth.Token, error) { +func (p Provider) NewDefaultToken(ctx context.Context, opts ...auth.Option) (auth.Token, error) { var o auth.Options o.Apply(opts...) var awsOpts []func(*config.LoadOptions) error - region := getRegion() - awsOpts = append(awsOpts, config.WithRegion(region)) + stsRegion, err := getSTSRegion() + if err != nil { + return nil, err + } + awsOpts = append(awsOpts, config.WithRegion(stsRegion)) if e := o.STSEndpoint; e != "" { + if err := ValidateSTSEndpoint(e); err != nil { + return nil, err + } awsOpts = append(awsOpts, config.WithBaseEndpoint(e)) } @@ -62,7 +67,7 @@ func (Provider) NewDefaultToken(ctx context.Context, opts ...auth.Option) (auth. awsOpts = append(awsOpts, config.WithHTTPClient(hc)) } - conf, err := config.LoadDefaultConfig(ctx, awsOpts...) + conf, err := p.impl().LoadDefaultConfig(ctx, awsOpts...) if err != nil { return nil, err } @@ -89,25 +94,32 @@ func (Provider) GetIdentity(serviceAccount corev1.ServiceAccount) (string, error } // NewTokenForServiceAccount implements auth.Provider. -func (Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken string, +func (p Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken string, serviceAccount corev1.ServiceAccount, opts ...auth.Option) (auth.Token, error) { var o auth.Options o.Apply(opts...) - roleARN, err := getRoleARN(serviceAccount) + stsRegion, err := getSTSRegion() if err != nil { return nil, err } - roleSessionName := getRoleSessionName(serviceAccount) + roleARN, err := getRoleARN(serviceAccount) + if err != nil { + return nil, err + } - var awsOpts sts.Options + roleSessionName := getRoleSessionName(serviceAccount, stsRegion) - region := getRegion() - awsOpts.Region = region + awsOpts := sts.Options{ + Region: stsRegion, + } if e := o.STSEndpoint; e != "" { + if err := ValidateSTSEndpoint(e); err != nil { + return nil, err + } awsOpts.BaseEndpoint = &e } @@ -123,7 +135,7 @@ func (Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken string, RoleSessionName: &roleSessionName, WebIdentityToken: &oidcToken, } - resp, err := sts.New(awsOpts).AssumeRoleWithWebIdentity(ctx, req) + resp, err := p.impl().AssumeRoleWithWebIdentity(ctx, req, awsOpts) if err != nil { return nil, err } @@ -141,20 +153,20 @@ func (Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken string, // GetArtifactCacheKey implements auth.Provider. func (Provider) GetArtifactCacheKey(artifactRepository string) string { - if _, region, ok := ParseRegistry(artifactRepository); ok { - return region + if _, ecrRegion, ok := ParseRegistry(artifactRepository); ok { + return ecrRegion } return "" } // NewArtifactRegistryToken implements auth.Provider. -func (Provider) NewArtifactRegistryToken(ctx context.Context, artifactRepository string, +func (p Provider) NewArtifactRegistryToken(ctx context.Context, artifactRepository string, accessToken auth.Token, opts ...auth.Option) (auth.Token, error) { var o auth.Options o.Apply(opts...) - _, region, ok := ParseRegistry(artifactRepository) + _, ecrRegion, ok := ParseRegistry(artifactRepository) if !ok { return nil, fmt.Errorf("invalid ecr repository: '%s'", artifactRepository) } @@ -162,7 +174,7 @@ func (Provider) NewArtifactRegistryToken(ctx context.Context, artifactRepository credsProvider := accessToken.(*Token).CredentialsProvider() conf := aws.Config{ - Region: region, + Region: ecrRegion, Credentials: credsProvider, } @@ -170,7 +182,7 @@ func (Provider) NewArtifactRegistryToken(ctx context.Context, artifactRepository conf.HTTPClient = hc } - resp, err := ecr.NewFromConfig(conf).GetAuthorizationToken(ctx, nil) + resp, err := p.impl().GetAuthorizationToken(ctx, conf) if err != nil { return nil, err } @@ -202,3 +214,10 @@ func (Provider) NewArtifactRegistryToken(ctx context.Context, artifactRepository ExpiresAt: expiresAt, }, nil } + +func (p Provider) impl() Implementation { + if p.Implementation == nil { + return implementation{} + } + return p.Implementation +} diff --git a/auth/aws/provider_test.go b/auth/aws/provider_test.go new file mode 100644 index 000000000..735cff970 --- /dev/null +++ b/auth/aws/provider_test.go @@ -0,0 +1,171 @@ +/* +Copyright 2025 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws_test + +import ( + "context" + "net/url" + "testing" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts/types" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/fluxcd/pkg/auth" + "github.com/fluxcd/pkg/auth/aws" +) + +func TestProvider_NewDefaultToken_Options(t *testing.T) { + t.Setenv("AWS_REGION", "us-east-1") + + impl := &mockImplementation{ + t: t, + argRegion: "us-east-1", + argProxyURL: &url.URL{Scheme: "http", Host: "proxy.example.com"}, + argSTSEndpoint: "https://sts.amazonaws.com", + } + + for _, tt := range []struct { + name string + stsEndpoint string + err string + }{ + { + name: "valid", + stsEndpoint: "https://sts.amazonaws.com", + }, + { + name: "invalid sts endpoint", + stsEndpoint: "https://something.amazonaws.com", + err: `invalid STS endpoint: 'https://something.amazonaws.com'. must match ^https://(.+\.)?sts(-fips)?(\.[^.]+)?(\.vpce)?\.amazonaws\.com$`, + }, + } { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + opts := []auth.Option{ + auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.example.com"}), + auth.WithSTSEndpoint(tt.stsEndpoint), + } + + provider := aws.Provider{Implementation: impl} + token, err := provider.NewDefaultToken(context.Background(), opts...) + + if tt.err == "" { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).NotTo(BeNil()) + } else { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(Equal(tt.err)) + g.Expect(token).To(BeNil()) + } + }) + } +} + +func TestProvider_NewTokenForServiceAccount_Options(t *testing.T) { + t.Setenv("AWS_REGION", "us-east-1") + + impl := &mockImplementation{ + t: t, + argRegion: "us-east-1", + argRoleARN: "arn:aws:iam::1234567890:role/some-role", + argRoleSessionName: "test-sa.test-ns.us-east-1.fluxcd.io", + argOIDCToken: "oidc-token", + argProxyURL: &url.URL{Scheme: "http", Host: "proxy.example.com"}, + argSTSEndpoint: "https://sts.amazonaws.com", + } + + oidcToken := "oidc-token" + serviceAccount := corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-sa", + Namespace: "test-ns", + Annotations: map[string]string{ + "eks.amazonaws.com/role-arn": "arn:aws:iam::1234567890:role/some-role", + }, + }, + } + + for _, tt := range []struct { + name string + stsEndpoint string + err string + }{ + { + name: "valid", + stsEndpoint: "https://sts.amazonaws.com", + }, + { + name: "invalid sts endpoint", + stsEndpoint: "https://something.amazonaws.com", + err: `invalid STS endpoint: 'https://something.amazonaws.com'. must match ^https://(.+\.)?sts(-fips)?(\.[^.]+)?(\.vpce)?\.amazonaws\.com$`, + }, + } { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + opts := []auth.Option{ + auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.example.com"}), + auth.WithSTSEndpoint(tt.stsEndpoint), + } + + provider := aws.Provider{Implementation: impl} + token, err := provider.NewTokenForServiceAccount(context.Background(), oidcToken, serviceAccount, opts...) + + if tt.err == "" { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).NotTo(BeNil()) + } else { + g.Expect(err).To(HaveOccurred()) + g.Expect(err.Error()).To(Equal(tt.err)) + g.Expect(token).To(BeNil()) + } + }) + } +} + +func TestProvider_NewArtifactRegistryToken_Options(t *testing.T) { + g := NewWithT(t) + + impl := &mockImplementation{ + t: t, + argRegion: "us-east-1", + argProxyURL: &url.URL{Scheme: "http", Host: "proxy.example.com"}, + argCredsProvider: credentials.NewStaticCredentialsProvider("access-key-id", "secret-access-key", "session-token"), + } + + artifactRepository := "012345678901.dkr.ecr.us-east-1.amazonaws.com/foo" + accessToken := &aws.Token{ + Credentials: types.Credentials{ + AccessKeyId: awssdk.String("access-key-id"), + SecretAccessKey: awssdk.String("secret-access-key"), + SessionToken: awssdk.String("session-token"), + }, + } + opts := []auth.Option{ + auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.example.com"}), + } + + provider := aws.Provider{Implementation: impl} + token, err := provider.NewArtifactRegistryToken(context.Background(), artifactRepository, accessToken, opts...) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).NotTo(BeNil()) +} diff --git a/auth/azure/implementation.go b/auth/azure/implementation.go new file mode 100644 index 000000000..7ef8f3b2e --- /dev/null +++ b/auth/azure/implementation.go @@ -0,0 +1,46 @@ +/* +Copyright 2025 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "context" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" +) + +// Implementation provides the required methods of the Azure libraries. +type Implementation interface { + NewDefaultAzureCredential(options azidentity.DefaultAzureCredentialOptions) (azcore.TokenCredential, error) + NewClientAssertionCredential(tenantID string, clientID string, getAssertion func(context.Context) (string, error), options *azidentity.ClientAssertionCredentialOptions) (azcore.TokenCredential, error) + SendRequest(req *http.Request, client *http.Client) (*http.Response, error) +} + +type implementation struct{} + +func (implementation) NewDefaultAzureCredential(options azidentity.DefaultAzureCredentialOptions) (azcore.TokenCredential, error) { + return newDefaultAzureCredential(options) +} + +func (implementation) NewClientAssertionCredential(tenantID string, clientID string, getAssertion func(context.Context) (string, error), options *azidentity.ClientAssertionCredentialOptions) (azcore.TokenCredential, error) { + return azidentity.NewClientAssertionCredential(tenantID, clientID, getAssertion, options) +} + +func (implementation) SendRequest(req *http.Request, client *http.Client) (*http.Response, error) { + return client.Do(req) +} diff --git a/auth/azure/implementation_test.go b/auth/azure/implementation_test.go new file mode 100644 index 000000000..5258741b6 --- /dev/null +++ b/auth/azure/implementation_test.go @@ -0,0 +1,102 @@ +/* +Copyright 2025 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure_test + +import ( + "context" + "net/http" + "net/url" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + . "github.com/onsi/gomega" +) + +type mockImplementation struct { + t *testing.T + + argTenantID string + argClientID string + argOIDCToken string + argProxyURL *url.URL + argScopes []string + + returnResp *http.Response +} + +type mockTokenCredential struct { + t *testing.T + + argScopes []string +} + +func (m *mockImplementation) NewDefaultAzureCredential(options azidentity.DefaultAzureCredentialOptions) (azcore.TokenCredential, error) { + m.t.Helper() + g := NewWithT(m.t) + g.Expect(options.Transport).NotTo(BeNil()) + g.Expect(options.Transport.(*http.Client)).NotTo(BeNil()) + g.Expect(options.Transport.(*http.Client).Transport).NotTo(BeNil()) + g.Expect(options.Transport.(*http.Client).Transport.(*http.Transport)).NotTo(BeNil()) + g.Expect(options.Transport.(*http.Client).Transport.(*http.Transport).Proxy).NotTo(BeNil()) + proxyURL, err := options.Transport.(*http.Client).Transport.(*http.Transport).Proxy(nil) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(proxyURL).To(Equal(m.argProxyURL)) + return &mockTokenCredential{t: m.t, argScopes: m.argScopes}, nil +} + +func (m *mockImplementation) NewClientAssertionCredential(tenantID string, clientID string, getAssertion func(context.Context) (string, error), options *azidentity.ClientAssertionCredentialOptions) (azcore.TokenCredential, error) { + m.t.Helper() + g := NewWithT(m.t) + g.Expect(tenantID).To(Equal(m.argTenantID)) + g.Expect(clientID).To(Equal(m.argClientID)) + g.Expect(getAssertion).NotTo(BeNil()) + oidcToken, err := getAssertion(context.Background()) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(oidcToken).To(Equal(m.argOIDCToken)) + g.Expect(options).NotTo(BeNil()) + g.Expect(options.Transport).NotTo(BeNil()) + g.Expect(options.Transport.(*http.Client)).NotTo(BeNil()) + g.Expect(options.Transport.(*http.Client).Transport).NotTo(BeNil()) + g.Expect(options.Transport.(*http.Client).Transport.(*http.Transport)).NotTo(BeNil()) + g.Expect(options.Transport.(*http.Client).Transport.(*http.Transport).Proxy).NotTo(BeNil()) + proxyURL, err := options.Transport.(*http.Client).Transport.(*http.Transport).Proxy(nil) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(proxyURL).To(Equal(m.argProxyURL)) + return &mockTokenCredential{t: m.t, argScopes: m.argScopes}, nil +} + +func (m *mockImplementation) SendRequest(req *http.Request, client *http.Client) (*http.Response, error) { + m.t.Helper() + g := NewWithT(m.t) + g.Expect(client).NotTo(BeNil()) + g.Expect(client.Transport).NotTo(BeNil()) + g.Expect(client.Transport.(*http.Transport)).NotTo(BeNil()) + g.Expect(client.Transport.(*http.Transport).Proxy).NotTo(BeNil()) + proxyURL, err := client.Transport.(*http.Transport).Proxy(nil) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(proxyURL).To(Equal(m.argProxyURL)) + return m.returnResp, nil +} + +func (m *mockTokenCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { + m.t.Helper() + g := NewWithT(m.t) + g.Expect(options.Scopes).To(Equal(m.argScopes)) + return azcore.AccessToken{}, nil +} diff --git a/auth/azure/provider.go b/auth/azure/provider.go index 3f0354885..fbe116581 100644 --- a/auth/azure/provider.go +++ b/auth/azure/provider.go @@ -37,7 +37,7 @@ import ( const ProviderName = "azure" // Provider implements the auth.Provider interface for Azure authentication. -type Provider struct{} +type Provider struct{ Implementation } // GetName implements auth.Provider. func (Provider) GetName() string { @@ -45,7 +45,8 @@ func (Provider) GetName() string { } // NewDefaultToken implements auth.Provider. -func (Provider) NewDefaultToken(ctx context.Context, opts ...auth.Option) (auth.Token, error) { +func (p Provider) NewDefaultToken(ctx context.Context, opts ...auth.Option) (auth.Token, error) { + var o auth.Options o.Apply(opts...) @@ -55,7 +56,7 @@ func (Provider) NewDefaultToken(ctx context.Context, opts ...auth.Option) (auth. azOpts.Transport = hc } - cred, err := newDefaultAzureCredential(azOpts) + cred, err := p.impl().NewDefaultAzureCredential(azOpts) if err != nil { return nil, err } @@ -80,7 +81,7 @@ func (Provider) GetIdentity(serviceAccount corev1.ServiceAccount) (string, error } // NewTokenForServiceAccount implements auth.Provider. -func (Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken string, +func (p Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken string, serviceAccount corev1.ServiceAccount, opts ...auth.Option) (auth.Token, error) { var o auth.Options @@ -99,7 +100,7 @@ func (Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken string, azOpts.Transport = hc } - cred, err := azidentity.NewClientAssertionCredential(tenantID, clientID, func(context.Context) (string, error) { + cred, err := p.impl().NewClientAssertionCredential(tenantID, clientID, func(context.Context) (string, error) { return oidcToken, nil }, azOpts) if err != nil { @@ -121,7 +122,7 @@ func (Provider) GetArtifactCacheKey(artifactRepository string) string { } // NewArtifactRegistryToken implements auth.Provider. -func (Provider) NewArtifactRegistryToken(ctx context.Context, artifactRepository string, +func (p Provider) NewArtifactRegistryToken(ctx context.Context, artifactRepository string, accessToken auth.Token, opts ...auth.Option) (auth.Token, error) { t := accessToken.(*Token) @@ -155,7 +156,7 @@ func (Provider) NewArtifactRegistryToken(ctx context.Context, artifactRepository if hc := o.GetHTTPClient(); hc != nil { httpClient = hc } - resp, err := httpClient.Do(req) + resp, err := p.impl().SendRequest(req, httpClient) if err != nil { return nil, err } @@ -187,3 +188,10 @@ func (Provider) NewArtifactRegistryToken(ctx context.Context, artifactRepository ExpiresAt: expiry.Time, }, nil } + +func (p Provider) impl() Implementation { + if p.Implementation == nil { + return implementation{} + } + return p.Implementation +} diff --git a/auth/azure/provider_test.go b/auth/azure/provider_test.go new file mode 100644 index 000000000..9b036eaae --- /dev/null +++ b/auth/azure/provider_test.go @@ -0,0 +1,121 @@ +/* +Copyright 2025 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/fluxcd/pkg/auth" + "github.com/fluxcd/pkg/auth/azure" +) + +func TestProvider_NewDefaultToken_Options(t *testing.T) { + g := NewWithT(t) + + impl := &mockImplementation{ + t: t, + argProxyURL: &url.URL{Scheme: "http", Host: "proxy.example.com"}, + argScopes: []string{"scope1", "scope2"}, + } + + opts := []auth.Option{ + auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.example.com"}), + auth.WithScopes("scope1", "scope2"), + } + + provider := azure.Provider{Implementation: impl} + token, err := provider.NewDefaultToken(context.Background(), opts...) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).NotTo(BeNil()) +} + +func TestProvider_NewTokenForServiceAccount_Options(t *testing.T) { + g := NewWithT(t) + + impl := &mockImplementation{ + t: t, + argTenantID: "tenant-id", + argClientID: "client-id", + argOIDCToken: "oidc-token", + argProxyURL: &url.URL{Scheme: "http", Host: "proxy.example.com"}, + argScopes: []string{"scope1", "scope2"}, + } + + oidcToken := "oidc-token" + serviceAccount := corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "azure.workload.identity/tenant-id": "tenant-id", + "azure.workload.identity/client-id": "client-id", + }, + }, + } + opts := []auth.Option{ + auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.example.com"}), + auth.WithScopes("scope1", "scope2"), + } + + provider := azure.Provider{Implementation: impl} + token, err := provider.NewTokenForServiceAccount(context.Background(), oidcToken, serviceAccount, opts...) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).NotTo(BeNil()) +} + +func TestProvider_NewArtifactRegistryToken_Options(t *testing.T) { + g := NewWithT(t) + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + g.Expect(err).NotTo(HaveOccurred()) + refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "exp": time.Now().Add(time.Hour).Unix(), + }).SignedString(privateKey) + g.Expect(err).NotTo(HaveOccurred()) + + impl := &mockImplementation{ + t: t, + argProxyURL: &url.URL{Scheme: "http", Host: "proxy.example.com"}, + returnResp: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"refresh_token":"%s"}`, refreshToken))), + }, + } + + artifactRepository := "acr-repo" + accessToken := &azure.Token{} + opts := []auth.Option{ + auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.example.com"}), + } + + provider := azure.Provider{Implementation: impl} + token, err := provider.NewArtifactRegistryToken(context.Background(), artifactRepository, accessToken, opts...) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).NotTo(BeNil()) +} diff --git a/auth/gcp/implementation.go b/auth/gcp/implementation.go new file mode 100644 index 000000000..fba765f98 --- /dev/null +++ b/auth/gcp/implementation.go @@ -0,0 +1,41 @@ +/* +Copyright 2025 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gcp + +import ( + "context" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "golang.org/x/oauth2/google/externalaccount" +) + +// Implementation provides the required methods of the GCP libraries. +type Implementation interface { + DefaultTokenSource(ctx context.Context, scope ...string) (oauth2.TokenSource, error) + NewTokenSource(ctx context.Context, conf externalaccount.Config) (oauth2.TokenSource, error) +} + +type implementation struct{} + +func (implementation) DefaultTokenSource(ctx context.Context, scope ...string) (oauth2.TokenSource, error) { + return google.DefaultTokenSource(ctx, scope...) +} + +func (implementation) NewTokenSource(ctx context.Context, conf externalaccount.Config) (oauth2.TokenSource, error) { + return externalaccount.NewTokenSource(ctx, conf) +} diff --git a/auth/gcp/implementation_test.go b/auth/gcp/implementation_test.go new file mode 100644 index 000000000..98019a689 --- /dev/null +++ b/auth/gcp/implementation_test.go @@ -0,0 +1,70 @@ +/* +Copyright 2025 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gcp_test + +import ( + "context" + "net/http" + "net/url" + "testing" + + . "github.com/onsi/gomega" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google/externalaccount" +) + +type mockImplementation struct { + t *testing.T + + argConfig externalaccount.Config + argProxyURL *url.URL +} + +func (m *mockImplementation) DefaultTokenSource(ctx context.Context, scope ...string) (oauth2.TokenSource, error) { + m.t.Helper() + g := NewWithT(m.t) + g.Expect(ctx).NotTo(BeNil()) + g.Expect(ctx.Value(oauth2.HTTPClient)).NotTo(BeNil()) + g.Expect(ctx.Value(oauth2.HTTPClient).(*http.Client)).NotTo(BeNil()) + g.Expect(ctx.Value(oauth2.HTTPClient).(*http.Client).Transport).NotTo(BeNil()) + g.Expect(ctx.Value(oauth2.HTTPClient).(*http.Client).Transport.(*http.Transport)).NotTo(BeNil()) + g.Expect(ctx.Value(oauth2.HTTPClient).(*http.Client).Transport.(*http.Transport).Proxy).NotTo(BeNil()) + proxyURL, err := ctx.Value(oauth2.HTTPClient).(*http.Client).Transport.(*http.Transport).Proxy(nil) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(proxyURL).To(Equal(m.argProxyURL)) + g.Expect(scope).To(Equal([]string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + })) + return oauth2.StaticTokenSource(&oauth2.Token{}), nil +} + +func (m *mockImplementation) NewTokenSource(ctx context.Context, conf externalaccount.Config) (oauth2.TokenSource, error) { + m.t.Helper() + g := NewWithT(m.t) + g.Expect(ctx).NotTo(BeNil()) + g.Expect(ctx.Value(oauth2.HTTPClient)).NotTo(BeNil()) + g.Expect(ctx.Value(oauth2.HTTPClient).(*http.Client)).NotTo(BeNil()) + g.Expect(ctx.Value(oauth2.HTTPClient).(*http.Client).Transport).NotTo(BeNil()) + g.Expect(ctx.Value(oauth2.HTTPClient).(*http.Client).Transport.(*http.Transport)).NotTo(BeNil()) + g.Expect(ctx.Value(oauth2.HTTPClient).(*http.Client).Transport.(*http.Transport).Proxy).NotTo(BeNil()) + proxyURL, err := ctx.Value(oauth2.HTTPClient).(*http.Client).Transport.(*http.Transport).Proxy(nil) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(proxyURL).To(Equal(m.argProxyURL)) + g.Expect(conf).To(Equal(m.argConfig)) + return oauth2.StaticTokenSource(&oauth2.Token{}), nil +} diff --git a/auth/gcp/provider.go b/auth/gcp/provider.go index 53989ec15..f51b88923 100644 --- a/auth/gcp/provider.go +++ b/auth/gcp/provider.go @@ -21,7 +21,6 @@ import ( "fmt" "golang.org/x/oauth2" - "golang.org/x/oauth2/google" "golang.org/x/oauth2/google/externalaccount" corev1 "k8s.io/api/core/v1" @@ -37,7 +36,7 @@ var scopes = []string{ } // Provider implements the auth.Provider interface for GCP authentication. -type Provider struct{} +type Provider struct{ Implementation } // GetName implements auth.Provider. func (Provider) GetName() string { @@ -45,7 +44,7 @@ func (Provider) GetName() string { } // NewDefaultToken implements auth.Provider. -func (Provider) NewDefaultToken(ctx context.Context, opts ...auth.Option) (auth.Token, error) { +func (p Provider) NewDefaultToken(ctx context.Context, opts ...auth.Option) (auth.Token, error) { var o auth.Options o.Apply(opts...) @@ -53,7 +52,7 @@ func (Provider) NewDefaultToken(ctx context.Context, opts ...auth.Option) (auth. ctx = context.WithValue(ctx, oauth2.HTTPClient, hc) } - src, err := google.DefaultTokenSource(ctx, scopes...) + src, err := p.impl().DefaultTokenSource(ctx, scopes...) if err != nil { return nil, err } @@ -80,7 +79,7 @@ func (Provider) GetIdentity(serviceAccount corev1.ServiceAccount) (string, error } // NewTokenForServiceAccount implements auth.Provider. -func (Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken string, +func (p Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken string, serviceAccount corev1.ServiceAccount, opts ...auth.Option) (auth.Token, error) { var o auth.Options @@ -96,7 +95,7 @@ func (Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken string, Audience: audience, SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", TokenURL: "https://sts.googleapis.com/v1/token", - SubjectTokenSupplier: tokenSupplier(oidcToken), + SubjectTokenSupplier: TokenSupplier(oidcToken), Scopes: scopes, } @@ -117,7 +116,7 @@ func (Provider) NewTokenForServiceAccount(ctx context.Context, oidcToken string, ctx = context.WithValue(ctx, oauth2.HTTPClient, hc) } - src, err := externalaccount.NewTokenSource(ctx, conf) + src, err := p.impl().NewTokenSource(ctx, conf) if err != nil { return nil, err } @@ -148,3 +147,10 @@ func (Provider) NewArtifactRegistryToken(ctx context.Context, artifactRepository ExpiresAt: t.Expiry, }, nil } + +func (p Provider) impl() Implementation { + if p.Implementation == nil { + return implementation{} + } + return p.Implementation +} diff --git a/auth/gcp/provider_test.go b/auth/gcp/provider_test.go new file mode 100644 index 000000000..3f9082539 --- /dev/null +++ b/auth/gcp/provider_test.go @@ -0,0 +1,158 @@ +/* +Copyright 2025 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gcp_test + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "testing" + "time" + + . "github.com/onsi/gomega" + "golang.org/x/oauth2/google/externalaccount" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/fluxcd/pkg/auth" + "github.com/fluxcd/pkg/auth/gcp" +) + +func TestProvider_NewDefaultToken_Options(t *testing.T) { + g := NewWithT(t) + + impl := &mockImplementation{ + t: t, + argProxyURL: &url.URL{Scheme: "http", Host: "proxy.example.com"}, + } + + opts := []auth.Option{ + auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.example.com"}), + } + + provider := gcp.Provider{Implementation: impl} + token, err := provider.NewDefaultToken(context.Background(), opts...) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).NotTo(BeNil()) +} + +func TestProvider_NewTokenForServiceAccount_Options(t *testing.T) { + g := NewWithT(t) + + // Start GKE metadata server. + lis, err := net.Listen("tcp", ":0") + g.Expect(err).NotTo(HaveOccurred()) + gkeMetadataServer := &http.Server{ + Addr: lis.Addr().String(), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/computeMetadata/v1/project/project-id": + fmt.Fprintf(w, "%s", "project-id") + case "/computeMetadata/v1/instance/attributes/cluster-location": + fmt.Fprintf(w, "%s", "cluster-location") + case "/computeMetadata/v1/instance/attributes/cluster-name": + fmt.Fprintf(w, "%s", "cluster-name") + } + }), + } + go func() { + err := gkeMetadataServer.Serve(lis) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + g.Expect(err).NotTo(HaveOccurred()) + } + }() + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := gkeMetadataServer.Shutdown(ctx) + g.Expect(err).NotTo(HaveOccurred()) + }) + gceMetadataHost := strings.TrimPrefix(lis.Addr().String(), "http://") + t.Setenv("GCE_METADATA_HOST", gceMetadataHost) + + for _, tt := range []struct { + name string + conf externalaccount.Config + saAnnotations map[string]string + }{ + { + name: "direct access", + conf: externalaccount.Config{ + Audience: "identitynamespace:project-id.svc.id.goog:https://container.googleapis.com/v1/projects/project-id/locations/cluster-location/clusters/cluster-name", + SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", + TokenURL: "https://sts.googleapis.com/v1/token", + TokenInfoURL: "https://sts.googleapis.com/v1/introspect", + Scopes: []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + }, + SubjectTokenSupplier: gcp.TokenSupplier("oidc-token"), + UniverseDomain: "googleapis.com", + }, + }, + { + name: "impersonation", + conf: externalaccount.Config{ + Audience: "identitynamespace:project-id.svc.id.goog:https://container.googleapis.com/v1/projects/project-id/locations/cluster-location/clusters/cluster-name", + SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", + TokenURL: "https://sts.googleapis.com/v1/token", + ServiceAccountImpersonationURL: "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/test-sa@project-id.iam.gserviceaccount.com:generateAccessToken", + Scopes: []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + }, + SubjectTokenSupplier: gcp.TokenSupplier("oidc-token"), + UniverseDomain: "googleapis.com", + }, + saAnnotations: map[string]string{ + "iam.gke.io/gcp-service-account": "test-sa@project-id.iam.gserviceaccount.com", + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + impl := &mockImplementation{ + t: t, + argConfig: tt.conf, + argProxyURL: &url.URL{Scheme: "http", Host: "proxy.example.com"}, + } + + oidcToken := "oidc-token" + serviceAccount := corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-sa", + Namespace: "test-ns", + Annotations: tt.saAnnotations, + }, + } + opts := []auth.Option{ + auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.example.com"}), + auth.WithSTSEndpoint("https://sts.example.com"), + } + + provider := gcp.Provider{Implementation: impl} + token, err := provider.NewTokenForServiceAccount(context.Background(), oidcToken, serviceAccount, opts...) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).NotTo(BeNil()) + }) + } +} diff --git a/auth/gcp/token_supplier.go b/auth/gcp/token_supplier.go index d427f4bac..5b91940e4 100644 --- a/auth/gcp/token_supplier.go +++ b/auth/gcp/token_supplier.go @@ -22,9 +22,10 @@ import ( "golang.org/x/oauth2/google/externalaccount" ) -type tokenSupplier string +// TokenSupplier provides a static OIDC token. +type TokenSupplier string // SubjectToken implements externalaccount.SubjectTokenSupplier. -func (s tokenSupplier) SubjectToken(context.Context, externalaccount.SupplierOptions) (string, error) { +func (s TokenSupplier) SubjectToken(context.Context, externalaccount.SupplierOptions) (string, error) { return string(s), nil } diff --git a/auth/get_token_test.go b/auth/get_token_test.go index d145f60b0..9a8008639 100644 --- a/auth/get_token_test.go +++ b/auth/get_token_test.go @@ -48,7 +48,7 @@ func (m *mockToken) GetDuration() time.Duration { } type mockProvider struct { - *testing.T + t *testing.T returnName string returnAudience string @@ -69,6 +69,7 @@ func (m *mockProvider) GetName() string { } func (m *mockProvider) NewDefaultToken(ctx context.Context, opts ...auth.Option) (auth.Token, error) { + checkOptions(m.t, opts...) return m.returnDefaultToken, nil } @@ -77,8 +78,8 @@ func (m *mockProvider) GetAudience(ctx context.Context) (string, error) { } func (m *mockProvider) GetIdentity(serviceAccount corev1.ServiceAccount) (string, error) { - m.Helper() - g := NewWithT(m) + m.t.Helper() + g := NewWithT(m.t) g.Expect(serviceAccount).To(Equal(m.paramServiceAccount)) if m.returnIdentityErr != "" { return "", errors.New(m.returnIdentityErr) @@ -89,8 +90,8 @@ func (m *mockProvider) GetIdentity(serviceAccount corev1.ServiceAccount) (string func (m *mockProvider) NewTokenForServiceAccount(ctx context.Context, oidcToken string, serviceAccount corev1.ServiceAccount, opts ...auth.Option) (auth.Token, error) { - m.Helper() - g := NewWithT(m) + m.t.Helper() + g := NewWithT(m.t) // Verify the OIDC token. g.Expect(m.returnAudience).NotTo(BeEmpty()) @@ -108,25 +109,40 @@ func (m *mockProvider) NewTokenForServiceAccount(ctx context.Context, oidcToken g.Expect(serviceAccount).To(Equal(m.paramServiceAccount)) + checkOptions(m.t, opts...) + return m.returnAccessToken, nil } func (m *mockProvider) GetArtifactCacheKey(artifactRepository string) string { - m.Helper() - g := NewWithT(m) + m.t.Helper() + g := NewWithT(m.t) g.Expect(artifactRepository).To(Equal(m.paramArtifactRepository)) return m.returnArtifactCacheKey } func (m *mockProvider) NewArtifactRegistryToken(ctx context.Context, artifactRepository string, accessToken auth.Token, opts ...auth.Option) (auth.Token, error) { - m.Helper() - g := NewWithT(m) + m.t.Helper() + g := NewWithT(m.t) g.Expect(artifactRepository).To(Equal(m.paramArtifactRepository)) g.Expect(accessToken).To(Equal(m.paramAccessToken)) + checkOptions(m.t, opts...) return m.returnRegistryToken, nil } +func checkOptions(t *testing.T, opts ...auth.Option) { + t.Helper() + g := NewWithT(t) + + var o auth.Options + o.Apply(opts...) + + g.Expect(o.Scopes).To(Equal([]string{"scope1", "scope2"})) + g.Expect(o.STSEndpoint).To(Equal("https://sts.some-cloud.io")) + g.Expect(o.ProxyURL).To(Equal(&url.URL{Scheme: "http", Host: "proxy.io:8080"})) +} + func TestGetToken(t *testing.T) { g := NewWithT(t) @@ -197,6 +213,11 @@ func TestGetToken(t *testing.T) { provider: &mockProvider{ returnDefaultToken: &mockToken{token: "mock-default-token"}, }, + opts: []auth.Option{ + auth.WithScopes("scope1", "scope2"), + auth.WithSTSEndpoint("https://sts.some-cloud.io"), + auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.io:8080"}), + }, expectedToken: &mockToken{token: "mock-default-token"}, }, { @@ -209,6 +230,9 @@ func TestGetToken(t *testing.T) { }, opts: []auth.Option{ auth.WithArtifactRepository("some-registry.io/some/artifact"), + auth.WithScopes("scope1", "scope2"), + auth.WithSTSEndpoint("https://sts.some-cloud.io"), + auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.io:8080"}), }, expectedToken: &mockToken{token: "mock-registry-token"}, }, @@ -223,6 +247,9 @@ func TestGetToken(t *testing.T) { }, opts: []auth.Option{ auth.WithServiceAccount(saRef, kubeClient), + auth.WithScopes("scope1", "scope2"), + auth.WithSTSEndpoint("https://sts.some-cloud.io"), + auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.io:8080"}), // Exercise the code path where a cache is set but no token is // available in the cache. func(o *auth.Options) { @@ -248,6 +275,9 @@ func TestGetToken(t *testing.T) { opts: []auth.Option{ auth.WithServiceAccount(saRef, kubeClient), auth.WithArtifactRepository("some-registry.io/some/artifact"), + auth.WithScopes("scope1", "scope2"), + auth.WithSTSEndpoint("https://sts.some-cloud.io"), + auth.WithProxyURL(url.URL{Scheme: "http", Host: "proxy.io:8080"}), }, expectedToken: &mockToken{token: "mock-registry-token"}, }, @@ -315,7 +345,7 @@ func TestGetToken(t *testing.T) { t.Run(tt.name, func(t *testing.T) { g := NewWithT(t) - tt.provider.T = t + tt.provider.t = t token, err := auth.GetToken(ctx, tt.provider, tt.opts...) diff --git a/auth/go.mod b/auth/go.mod index be4e1518e..73c59a84d 100644 --- a/auth/go.mod +++ b/auth/go.mod @@ -17,7 +17,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/ecr v1.43.3 github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 github.com/coreos/go-oidc/v3 v3.14.1 - github.com/fluxcd/pkg/cache v0.8.0 + github.com/fluxcd/pkg/cache v0.9.0 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/onsi/gomega v1.37.0 golang.org/x/oauth2 v0.28.0 diff --git a/git/go.mod b/git/go.mod index 1fcd355b0..efbe4c344 100644 --- a/git/go.mod +++ b/git/go.mod @@ -13,8 +13,8 @@ require ( github.com/ProtonMail/go-crypto v1.2.0 github.com/bradleyfalzon/ghinstallation/v2 v2.15.0 github.com/cyphar/filepath-securejoin v0.4.1 - github.com/fluxcd/pkg/auth v0.10.0 - github.com/fluxcd/pkg/cache v0.8.0 + github.com/fluxcd/pkg/auth v0.11.0 + github.com/fluxcd/pkg/cache v0.9.0 github.com/fluxcd/pkg/ssh v0.18.0 github.com/onsi/gomega v1.37.0 golang.org/x/net v0.39.0 diff --git a/git/gogit/go.mod b/git/gogit/go.mod index 0e0c10b22..d3aa66836 100644 --- a/git/gogit/go.mod +++ b/git/gogit/go.mod @@ -17,9 +17,9 @@ require ( github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 github.com/elazarl/goproxy v1.7.2 github.com/fluxcd/gitkit v0.6.0 - github.com/fluxcd/pkg/auth v0.10.0 - github.com/fluxcd/pkg/cache v0.8.0 - github.com/fluxcd/pkg/git v0.27.0 + github.com/fluxcd/pkg/auth v0.11.0 + github.com/fluxcd/pkg/cache v0.9.0 + github.com/fluxcd/pkg/git v0.28.0 github.com/fluxcd/pkg/gittestserver v0.17.0 github.com/fluxcd/pkg/ssh v0.18.0 github.com/fluxcd/pkg/version v0.7.0 diff --git a/git/internal/e2e/go.mod b/git/internal/e2e/go.mod index 9dd4cd786..8645c5f70 100644 --- a/git/internal/e2e/go.mod +++ b/git/internal/e2e/go.mod @@ -14,7 +14,7 @@ replace ( require ( github.com/fluxcd/go-git-providers v0.22.0 - github.com/fluxcd/pkg/git v0.27.0 + github.com/fluxcd/pkg/git v0.28.0 github.com/fluxcd/pkg/git/gogit v0.23.0 github.com/fluxcd/pkg/gittestserver v0.17.0 github.com/fluxcd/pkg/ssh v0.18.0 @@ -44,8 +44,8 @@ require ( github.com/emirpasic/gods v1.18.1 // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/fluxcd/gitkit v0.6.0 // indirect - github.com/fluxcd/pkg/auth v0.10.0 // indirect - github.com/fluxcd/pkg/cache v0.8.0 // indirect + github.com/fluxcd/pkg/auth v0.11.0 // indirect + github.com/fluxcd/pkg/cache v0.9.0 // indirect github.com/fluxcd/pkg/version v0.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect diff --git a/oci/tests/integration/go.mod b/oci/tests/integration/go.mod index 49a0117d8..aae149427 100644 --- a/oci/tests/integration/go.mod +++ b/oci/tests/integration/go.mod @@ -11,8 +11,8 @@ replace ( ) require ( - github.com/fluxcd/pkg/auth v0.10.0 - github.com/fluxcd/pkg/git v0.27.0 + github.com/fluxcd/pkg/auth v0.11.0 + github.com/fluxcd/pkg/git v0.28.0 github.com/fluxcd/pkg/git/gogit v0.23.0 github.com/fluxcd/pkg/oci v0.43.1 github.com/fluxcd/test-infra/tftestenv v0.0.0-20240903092121-c783b14801d1 @@ -66,7 +66,7 @@ require ( github.com/emirpasic/gods v1.18.1 // indirect github.com/evanphx/json-patch v5.7.0+incompatible // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect - github.com/fluxcd/pkg/cache v0.8.0 // indirect + github.com/fluxcd/pkg/cache v0.9.0 // indirect github.com/fluxcd/pkg/ssh v0.18.0 // indirect github.com/fluxcd/pkg/version v0.7.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect