From dadaad6e0b9629adaa768b2a5766f2c5018cdc31 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Mon, 30 Jun 2025 14:30:34 -0700 Subject: [PATCH] [misc] add aws auth for accesskey, ip, assumerole, wi, etc --- go.mod | 12 + go.sum | 27 ++ pkg/auth/aws/config.go | 100 +++++ pkg/auth/aws/credentials.go | 175 +++++++++ pkg/auth/aws/credentials_test.go | 611 +++++++++++++++++++++++++++++++ pkg/auth/aws/factory.go | 386 +++++++++++++++++++ pkg/auth/aws/factory_test.go | 449 +++++++++++++++++++++++ pkg/auth/interfaces.go | 2 + 8 files changed, 1762 insertions(+) create mode 100644 pkg/auth/aws/config.go create mode 100644 pkg/auth/aws/credentials.go create mode 100644 pkg/auth/aws/credentials_test.go create mode 100644 pkg/auth/aws/factory.go create mode 100644 pkg/auth/aws/factory_test.go diff --git a/go.mod b/go.mod index 1a008533..963ca5a0 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,10 @@ go 1.24.1 require ( fortio.org/progressbar v1.1.0 + github.com/aws/aws-sdk-go-v2 v1.21.0 + github.com/aws/aws-sdk-go-v2/config v1.18.42 + github.com/aws/aws-sdk-go-v2/credentials v1.13.40 + github.com/aws/aws-sdk-go-v2/service/sts v1.22.0 github.com/fsnotify/fsnotify v1.9.0 github.com/gin-gonic/gin v1.10.0 github.com/go-logr/logr v1.4.2 @@ -64,6 +68,14 @@ require ( github.com/NYTimes/gziphandler v1.1.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/antonmedv/expr v1.15.3 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.3.43 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.35 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.14.1 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.17.1 // indirect + github.com/aws/smithy-go v1.14.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/blendle/zapdriver v1.3.1 // indirect diff --git a/go.sum b/go.sum index cab24b93..b0fc920b 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,30 @@ github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8 github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/antonmedv/expr v1.15.3 h1:q3hOJZNvLvhqE8OHBs1cFRdbXFNKuA+bHmRaI+AmRmI= github.com/antonmedv/expr v1.15.3/go.mod h1:0E/6TxnOlRNp81GMzX9QfDPAmHo2Phg00y4JUv1ihsE= +github.com/aws/aws-sdk-go-v2 v1.21.0 h1:gMT0IW+03wtYJhRqTVYn0wLzwdnK9sRMcxmtfGzRdJc= +github.com/aws/aws-sdk-go-v2 v1.21.0/go.mod h1:/RfNgGmRxI+iFOB1OeJUyxiU+9s88k3pfHvDagGEp0M= +github.com/aws/aws-sdk-go-v2/config v1.18.42 h1:28jHROB27xZwU0CB88giDSjz7M1Sba3olb5JBGwina8= +github.com/aws/aws-sdk-go-v2/config v1.18.42/go.mod h1:4AZM3nMMxwlG+eZlxvBKqwVbkDLlnN2a4UGTL6HjaZI= +github.com/aws/aws-sdk-go-v2/credentials v1.13.40 h1:s8yOkDh+5b1jUDhMBtngF6zKWLDs84chUk2Vk0c38Og= +github.com/aws/aws-sdk-go-v2/credentials v1.13.40/go.mod h1:VtEHVAAqDWASwdOqj/1huyT6uHbs5s8FUHfDQdky/Rs= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.11 h1:uDZJF1hu0EVT/4bogChk8DyjSF6fof6uL/0Y26Ma7Fg= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.11/go.mod h1:TEPP4tENqBGO99KwVpV9MlOX4NSrSLP8u3KRy2CDwA8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41 h1:22dGT7PneFMx4+b3pz7lMTRyN8ZKH7M2cW4GP9yUS2g= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41/go.mod h1:CrObHAuPneJBlfEJ5T3szXOUkLEThaGfvnhTf33buas= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35 h1:SijA0mgjV8E+8G45ltVHs0fvKpTj8xmZJ3VwhGKtUSI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35/go.mod h1:SJC1nEVVva1g3pHAIdCp7QsRIkMmLAgoDquQ9Rr8kYw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.43 h1:g+qlObJH4Kn4n21g69DjspU0hKTjWtq7naZ9OLCv0ew= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.43/go.mod h1:rzfdUlfA+jdgLDmPKjd3Chq9V7LVLYo1Nz++Wb91aRo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.35 h1:CdzPW9kKitgIiLV1+MHobfR5Xg25iYnyzWZhyQuSlDI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.35/go.mod h1:QGF2Rs33W5MaN9gYdEQOBBFPLwTZkEhRwI33f7KIG0o= +github.com/aws/aws-sdk-go-v2/service/sso v1.14.1 h1:YkNzx1RLS0F5qdf9v1Q8Cuv9NXCL2TkosOxhzlUPV64= +github.com/aws/aws-sdk-go-v2/service/sso v1.14.1/go.mod h1:fIAwKQKBFu90pBxx07BFOMJLpRUGu8VOzLJakeY+0K4= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.17.1 h1:8lKOidPkmSmfUtiTgtdXWgaKItCZ/g75/jEk6Ql6GsA= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.17.1/go.mod h1:yygr8ACQRY2PrEcy3xsUI357stq2AxnFM6DIsR9lij4= +github.com/aws/aws-sdk-go-v2/service/sts v1.22.0 h1:s4bioTgjSFRwOoyEFzAVCmFmoowBgjTR8gkrF/sQ4wk= +github.com/aws/aws-sdk-go-v2/service/sts v1.22.0/go.mod h1:VC7JDqsqiwXukYEDjoHh9U0fOJtNWh04FPQz4ct4GGU= +github.com/aws/smithy-go v1.14.2 h1:MJU9hqBGbvWZdApzpvoF2WAIJDbtjK2NDJSiJP7HblQ= +github.com/aws/smithy-go v1.14.2/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -289,6 +313,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jarcoal/httpmock v1.2.0 h1:gSvTxxFR/MEMfsGrvRbdfpRUMBStovlSRLw0Ep1bwwc= github.com/jarcoal/httpmock v1.2.0/go.mod h1:oCoTsnAz4+UoOUIf5lJOWV2QQIW5UoeUI6aM2YnWAZk= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4= github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= @@ -895,6 +921,7 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/pkg/auth/aws/config.go b/pkg/auth/aws/config.go new file mode 100644 index 00000000..d161e954 --- /dev/null +++ b/pkg/auth/aws/config.go @@ -0,0 +1,100 @@ +package aws + +import ( + "fmt" + "time" +) + +// AccessKeyConfig represents AWS access key configuration +type AccessKeyConfig struct { + AccessKeyID string `mapstructure:"access_key_id" json:"access_key_id"` + SecretAccessKey string `mapstructure:"secret_access_key" json:"secret_access_key"` + SessionToken string `mapstructure:"session_token" json:"session_token,omitempty"` +} + +// Validate validates the access key configuration +func (c *AccessKeyConfig) Validate() error { + if c.AccessKeyID == "" { + return fmt.Errorf("access_key_id is required") + } + if c.SecretAccessKey == "" { + return fmt.Errorf("secret_access_key is required") + } + return nil +} + +// AssumeRoleConfig represents AWS assume role configuration +type AssumeRoleConfig struct { + RoleARN string `mapstructure:"role_arn" json:"role_arn"` + RoleSessionName string `mapstructure:"role_session_name" json:"role_session_name,omitempty"` + ExternalID string `mapstructure:"external_id" json:"external_id,omitempty"` + Duration time.Duration `mapstructure:"duration" json:"duration,omitempty"` + Tags map[string]string `mapstructure:"tags" json:"tags,omitempty"` +} + +// Validate validates the assume role configuration +func (c *AssumeRoleConfig) Validate() error { + if c.RoleARN == "" { + return fmt.Errorf("role_arn is required") + } + return nil +} + +// WebIdentityConfig represents AWS web identity configuration +type WebIdentityConfig struct { + RoleARN string `mapstructure:"role_arn" json:"role_arn"` + TokenFile string `mapstructure:"token_file" json:"token_file"` + RoleSessionName string `mapstructure:"role_session_name" json:"role_session_name,omitempty"` +} + +// Validate validates the web identity configuration +func (c *WebIdentityConfig) Validate() error { + if c.RoleARN == "" { + return fmt.Errorf("role_arn is required for web identity") + } + if c.TokenFile == "" { + return fmt.Errorf("token_file is required for web identity") + } + return nil +} + +// ECSTaskRoleConfig represents ECS task role configuration +type ECSTaskRoleConfig struct { + // RelativeURI is the relative URI to the ECS credentials endpoint + // If not specified, it will be read from AWS_CONTAINER_CREDENTIALS_RELATIVE_URI + RelativeURI string `mapstructure:"relative_uri" json:"relative_uri,omitempty"` + + // FullURI is the full URI to the ECS credentials endpoint + // If not specified, it will be read from AWS_CONTAINER_CREDENTIALS_FULL_URI + FullURI string `mapstructure:"full_uri" json:"full_uri,omitempty"` + + // AuthorizationToken is used for authentication with the ECS credentials endpoint + // If not specified, it will be read from AWS_CONTAINER_AUTHORIZATION_TOKEN + AuthorizationToken string `mapstructure:"authorization_token" json:"authorization_token,omitempty"` +} + +// Validate validates the ECS task role configuration +func (c *ECSTaskRoleConfig) Validate() error { + // Either RelativeURI or FullURI must be specified + if c.RelativeURI == "" && c.FullURI == "" { + return fmt.Errorf("either relative_uri or full_uri must be specified for ECS task role") + } + return nil +} + +// ProcessConfig represents process credentials provider configuration +type ProcessConfig struct { + // Command is the command to execute to retrieve credentials + Command string `mapstructure:"command" json:"command"` + + // Timeout is the maximum time to wait for the process to complete + Timeout time.Duration `mapstructure:"timeout" json:"timeout,omitempty"` +} + +// Validate validates the process configuration +func (c *ProcessConfig) Validate() error { + if c.Command == "" { + return fmt.Errorf("command is required for process credentials provider") + } + return nil +} diff --git a/pkg/auth/aws/credentials.go b/pkg/auth/aws/credentials.go new file mode 100644 index 00000000..33856565 --- /dev/null +++ b/pkg/auth/aws/credentials.go @@ -0,0 +1,175 @@ +package aws + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/sgl-project/ome/pkg/auth" + "github.com/sgl-project/ome/pkg/logging" +) + +// AWSCredentials implements auth.Credentials for AWS +type AWSCredentials struct { + credProvider aws.CredentialsProvider + authType auth.AuthType + region string + logger logging.Interface + + // Mutex protects cached credentials + mu sync.RWMutex + cachedCreds *aws.Credentials + cacheExpiry time.Time +} + +// Provider returns the provider type +func (c *AWSCredentials) Provider() auth.Provider { + return auth.ProviderAWS +} + +// Type returns the authentication type +func (c *AWSCredentials) Type() auth.AuthType { + return c.authType +} + +// Token retrieves the AWS credentials as a token string +func (c *AWSCredentials) Token(ctx context.Context) (string, error) { + creds, err := c.getCredentials(ctx) + if err != nil { + return "", err + } + + // Return formatted token (access key for identification) + return creds.AccessKeyID, nil +} + +// SignRequest signs an HTTP request with AWS v4 signature +func (c *AWSCredentials) SignRequest(ctx context.Context, req *http.Request) error { + creds, err := c.getCredentials(ctx) + if err != nil { + return fmt.Errorf("failed to get credentials: %w", err) + } + + // Create a signer + signer := v4.NewSigner() + + // Determine service from host + service := extractServiceFromHost(req.Host) + + // Calculate payload hash (empty for GET requests, unsigned for others) + payloadHash := "UNSIGNED-PAYLOAD" + if req.Method == http.MethodGet || req.Method == http.MethodHead { + payloadHash = "" + } + + // Sign the request + err = signer.SignHTTP(ctx, *creds, req, payloadHash, service, c.region, time.Now()) + if err != nil { + return fmt.Errorf("failed to sign request: %w", err) + } + + return nil +} + +// Refresh refreshes the credentials +func (c *AWSCredentials) Refresh(ctx context.Context) error { + // Clear cache to force refresh + c.mu.Lock() + c.cachedCreds = nil + c.cacheExpiry = time.Time{} + c.mu.Unlock() + + // Try to get new credentials + _, err := c.getCredentials(ctx) + return err +} + +// IsExpired checks if the credentials are expired +func (c *AWSCredentials) IsExpired() bool { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.cachedCreds == nil { + return true + } + return time.Now().After(c.cacheExpiry) +} + +// GetRegion returns the AWS region +func (c *AWSCredentials) GetRegion() string { + return c.region +} + +// GetCredentialsProvider returns the underlying AWS credentials provider +func (c *AWSCredentials) GetCredentialsProvider() aws.CredentialsProvider { + return c.credProvider +} + +// getCredentials retrieves and caches AWS credentials +func (c *AWSCredentials) getCredentials(ctx context.Context) (*aws.Credentials, error) { + // Check cache with read lock + c.mu.RLock() + if c.cachedCreds != nil && time.Now().Before(c.cacheExpiry) { + creds := *c.cachedCreds + c.mu.RUnlock() + return &creds, nil + } + c.mu.RUnlock() + + // Need to refresh - acquire write lock + c.mu.Lock() + defer c.mu.Unlock() + + // Double-check after acquiring write lock + if c.cachedCreds != nil && time.Now().Before(c.cacheExpiry) { + return c.cachedCreds, nil + } + + // Retrieve new credentials + creds, err := c.credProvider.Retrieve(ctx) + if err != nil { + return nil, fmt.Errorf("failed to retrieve credentials: %w", err) + } + + // Cache credentials + c.cachedCreds = &creds + if creds.Expires.IsZero() { + // If no expiry, cache for 1 hour + c.cacheExpiry = time.Now().Add(1 * time.Hour) + } else { + // Cache until 5 minutes before expiry + c.cacheExpiry = creds.Expires.Add(-5 * time.Minute) + } + + return &creds, nil +} + +// extractServiceFromHost extracts the AWS service name from the host +func extractServiceFromHost(host string) string { + // Remove port if present + if idx := strings.LastIndex(host, ":"); idx != -1 { + host = host[:idx] + } + + // Extract service from standard AWS domain pattern + // Examples: s3.amazonaws.com, dynamodb.us-east-1.amazonaws.com + parts := strings.Split(host, ".") + if len(parts) >= 2 { + // Check for service.region.amazonaws.com pattern + if len(parts) >= 3 && parts[len(parts)-2] == "amazonaws" { + return parts[0] + } + // Check for service.amazonaws.com pattern + if parts[1] == "amazonaws" { + return parts[0] + } + } + + // Default to s3 for unknown patterns + return "s3" +} diff --git a/pkg/auth/aws/credentials_test.go b/pkg/auth/aws/credentials_test.go new file mode 100644 index 00000000..ee78c478 --- /dev/null +++ b/pkg/auth/aws/credentials_test.go @@ -0,0 +1,611 @@ +package aws + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/sgl-project/ome/pkg/auth" + "github.com/sgl-project/ome/pkg/logging" + "go.uber.org/zap/zaptest" +) + +// mockCredentialsProvider implements aws.CredentialsProvider for testing +type mockCredentialsProvider struct { + creds aws.Credentials + retrieveErr error +} + +func (m *mockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + if m.retrieveErr != nil { + return aws.Credentials{}, m.retrieveErr + } + return m.creds, nil +} + +// createStaticCredentialsProvider creates a static credentials provider for testing +func createStaticCredentialsProvider(config AccessKeyConfig) aws.CredentialsProvider { + return credentials.NewStaticCredentialsProvider( + config.AccessKeyID, + config.SecretAccessKey, + config.SessionToken, + ) +} + +func TestAWSCredentials_Provider(t *testing.T) { + creds := &AWSCredentials{ + authType: auth.AWSAccessKey, + region: "us-east-1", + logger: logging.ForZap(zaptest.NewLogger(t)), + } + + if provider := creds.Provider(); provider != auth.ProviderAWS { + t.Errorf("Expected provider %s, got %s", auth.ProviderAWS, provider) + } +} + +func TestAWSCredentials_Type(t *testing.T) { + tests := []struct { + name string + authType auth.AuthType + }{ + { + name: "Access Key", + authType: auth.AWSAccessKey, + }, + { + name: "Assume Role", + authType: auth.AWSAssumeRole, + }, + { + name: "Instance Profile", + authType: auth.AWSInstanceProfile, + }, + { + name: "Web Identity", + authType: auth.AWSWebIdentity, + }, + { + name: "Default", + authType: auth.AWSDefault, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + creds := &AWSCredentials{ + authType: tt.authType, + } + if typ := creds.Type(); typ != tt.authType { + t.Errorf("Expected type %s, got %s", tt.authType, typ) + } + }) + } +} + +func TestAWSCredentials_GetRegion(t *testing.T) { + creds := &AWSCredentials{ + region: "us-west-2", + } + + if region := creds.GetRegion(); region != "us-west-2" { + t.Errorf("Expected region us-west-2, got %s", region) + } +} + +func TestAWSCredentials_IsExpired(t *testing.T) { + tests := []struct { + name string + creds *AWSCredentials + expected bool + }{ + { + name: "No cached credentials", + creds: &AWSCredentials{}, + expected: true, + }, + { + name: "Valid credentials", + creds: &AWSCredentials{ + cachedCreds: &aws.Credentials{ + AccessKeyID: "test", + SecretAccessKey: "test", + }, + cacheExpiry: time.Now().Add(1 * time.Hour), + }, + expected: false, + }, + { + name: "Expired credentials", + creds: &AWSCredentials{ + cachedCreds: &aws.Credentials{ + AccessKeyID: "test", + SecretAccessKey: "test", + }, + cacheExpiry: time.Now().Add(-1 * time.Hour), + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.creds.IsExpired() + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestAWSCredentials_SignRequest(t *testing.T) { + creds := &AWSCredentials{ + credProvider: createStaticCredentialsProvider(AccessKeyConfig{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + }), + region: "us-east-1", + logger: logging.ForZap(zaptest.NewLogger(t)), + } + + req, _ := http.NewRequest("GET", "https://s3.amazonaws.com/test-bucket/test-key", nil) + ctx := context.Background() + + err := creds.SignRequest(ctx, req) + if err != nil { + t.Errorf("Failed to sign request: %v", err) + } + + // Check that Authorization header was added + if req.Header.Get("Authorization") == "" { + t.Error("Expected Authorization header to be set") + } +} + +func TestAWSCredentials_SignRequest_Error(t *testing.T) { + ctx := context.Background() + mockProvider := &mockCredentialsProvider{ + retrieveErr: errors.New("test error"), + } + + creds := &AWSCredentials{ + credProvider: mockProvider, + region: "us-west-2", + logger: logging.ForZap(zaptest.NewLogger(t)), + } + + req, _ := http.NewRequest("GET", "https://s3.amazonaws.com/test-bucket/test-object", nil) + + err := creds.SignRequest(ctx, req) + if err == nil { + t.Error("Expected error but got none") + } +} + +func TestAWSCredentials_Token(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + provider *mockCredentialsProvider + wantToken string + wantError bool + }{ + { + name: "Valid credentials", + provider: &mockCredentialsProvider{ + creds: aws.Credentials{ + AccessKeyID: "test-key", + SecretAccessKey: "test-secret", + }, + }, + wantToken: "test-key", + wantError: false, + }, + { + name: "Error retrieving credentials", + provider: &mockCredentialsProvider{ + retrieveErr: errors.New("retrieve error"), + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + creds := &AWSCredentials{ + credProvider: tt.provider, + logger: logging.ForZap(zaptest.NewLogger(t)), + } + + token, err := creds.Token(ctx) + + if tt.wantError { + if err == nil { + t.Error("Expected error but got none") + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if token != tt.wantToken { + t.Errorf("Expected token %s, got %s", tt.wantToken, token) + } + } + }) + } +} + +func TestAWSCredentials_Refresh(t *testing.T) { + ctx := context.Background() + + mockProvider := &mockCredentialsProvider{ + creds: aws.Credentials{ + AccessKeyID: "new-key", + SecretAccessKey: "new-secret", + }, + } + + creds := &AWSCredentials{ + credProvider: mockProvider, + logger: logging.ForZap(zaptest.NewLogger(t)), + cachedCreds: &aws.Credentials{ + AccessKeyID: "old-key", + }, + cacheExpiry: time.Now().Add(1 * time.Hour), + } + + err := creds.Refresh(ctx) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // After refresh, getCredentials is called which sets new cache + // So we should check that the cached creds are the new ones + if creds.cachedCreds == nil { + t.Error("Expected cached credentials to be set after refresh") + } else if creds.cachedCreds.AccessKeyID != "new-key" { + t.Errorf("Expected new credentials to be cached, got %s", creds.cachedCreds.AccessKeyID) + } +} + +func TestAWSCredentials_Refresh_Error(t *testing.T) { + ctx := context.Background() + + mockProvider := &mockCredentialsProvider{ + retrieveErr: errors.New("refresh error"), + } + + creds := &AWSCredentials{ + credProvider: mockProvider, + logger: logging.ForZap(zaptest.NewLogger(t)), + } + + err := creds.Refresh(ctx) + if err == nil { + t.Error("Expected error but got none") + } +} + +func TestAWSCredentials_GetCredentialsProvider(t *testing.T) { + mockProvider := &mockCredentialsProvider{} + creds := &AWSCredentials{ + credProvider: mockProvider, + } + + provider := creds.GetCredentialsProvider() + if provider != mockProvider { + t.Error("Expected to get the same credentials provider") + } +} + +func TestAWSCredentials_getCredentials_Caching(t *testing.T) { + ctx := context.Background() + callCount := 0 + + // Create a provider that counts calls + countingProvider := aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) { + callCount++ + return aws.Credentials{ + AccessKeyID: "test-key", + SecretAccessKey: "test-secret", + Expires: time.Now().Add(2 * time.Hour), + }, nil + }) + + creds := &AWSCredentials{ + credProvider: countingProvider, + logger: logging.ForZap(zaptest.NewLogger(t)), + } + + // First call should retrieve credentials + _, err := creds.getCredentials(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if callCount != 1 { + t.Errorf("Expected 1 call to provider, got %d", callCount) + } + + // Second call should use cache + _, err = creds.getCredentials(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if callCount != 1 { + t.Errorf("Expected 1 call to provider (cached), got %d", callCount) + } +} + +func TestAWSCredentials_getCredentials_NoExpiry(t *testing.T) { + ctx := context.Background() + + provider := &mockCredentialsProvider{ + creds: aws.Credentials{ + AccessKeyID: "test-key", + SecretAccessKey: "test-secret", + // No Expires field - should cache for 1 hour + }, + } + + creds := &AWSCredentials{ + credProvider: provider, + logger: logging.ForZap(zaptest.NewLogger(t)), + } + + _, err := creds.getCredentials(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Check that cache expiry was set to ~1 hour from now + expectedExpiry := time.Now().Add(1 * time.Hour) + if creds.cacheExpiry.Before(expectedExpiry.Add(-5*time.Minute)) || + creds.cacheExpiry.After(expectedExpiry.Add(5*time.Minute)) { + t.Errorf("Expected cache expiry around %v, got %v", expectedExpiry, creds.cacheExpiry) + } +} + +func TestAccessKeyConfig_Validate(t *testing.T) { + tests := []struct { + name string + config AccessKeyConfig + wantError bool + }{ + { + name: "Valid config", + config: AccessKeyConfig{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + }, + wantError: false, + }, + { + name: "Valid config with session token", + config: AccessKeyConfig{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + SessionToken: "FwoGZXIvYXdzEBYaDExampleSessionToken", + }, + wantError: false, + }, + { + name: "Missing access key ID", + config: AccessKeyConfig{ + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + }, + wantError: true, + }, + { + name: "Missing secret access key", + config: AccessKeyConfig{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + }, + wantError: true, + }, + { + name: "Empty config", + config: AccessKeyConfig{}, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantError { + t.Errorf("Validate() error = %v, wantError %v", err, tt.wantError) + } + }) + } +} + +func TestAssumeRoleConfig_Validate(t *testing.T) { + tests := []struct { + name string + config AssumeRoleConfig + wantError bool + }{ + { + name: "Valid minimal config", + config: AssumeRoleConfig{ + RoleARN: "arn:aws:iam::123456789012:role/MyRole", + }, + wantError: false, + }, + { + name: "Valid full config", + config: AssumeRoleConfig{ + RoleARN: "arn:aws:iam::123456789012:role/MyRole", + RoleSessionName: "my-session", + ExternalID: "external-123", + Duration: 30 * time.Minute, + }, + wantError: false, + }, + { + name: "Missing role ARN", + config: AssumeRoleConfig{ + RoleSessionName: "my-session", + }, + wantError: true, + }, + { + name: "Empty config", + config: AssumeRoleConfig{}, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantError { + t.Errorf("Validate() error = %v, wantError %v", err, tt.wantError) + } + }) + } +} + +func TestWebIdentityConfig_Validate(t *testing.T) { + tests := []struct { + name string + config WebIdentityConfig + wantError bool + }{ + { + name: "Valid config", + config: WebIdentityConfig{ + RoleARN: "arn:aws:iam::123456789012:role/TestRole", + TokenFile: "/tmp/token", + }, + wantError: false, + }, + { + name: "Valid config with session name", + config: WebIdentityConfig{ + RoleARN: "arn:aws:iam::123456789012:role/TestRole", + TokenFile: "/tmp/token", + RoleSessionName: "test-session", + }, + wantError: false, + }, + { + name: "Missing role ARN", + config: WebIdentityConfig{ + TokenFile: "/tmp/token", + }, + wantError: true, + }, + { + name: "Missing token file", + config: WebIdentityConfig{ + RoleARN: "arn:aws:iam::123456789012:role/TestRole", + }, + wantError: true, + }, + { + name: "Empty config", + config: WebIdentityConfig{}, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantError { + t.Errorf("Validate() error = %v, wantError %v", err, tt.wantError) + } + }) + } +} + +func TestCreateStaticCredentialsProvider(t *testing.T) { + config := AccessKeyConfig{ + AccessKeyID: "test-key", + SecretAccessKey: "test-secret", + SessionToken: "test-token", + } + + provider := createStaticCredentialsProvider(config) + + // Retrieve credentials + ctx := context.Background() + creds, err := provider.Retrieve(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if creds.AccessKeyID != config.AccessKeyID { + t.Errorf("Expected access key ID %s, got %s", config.AccessKeyID, creds.AccessKeyID) + } + if creds.SecretAccessKey != config.SecretAccessKey { + t.Errorf("Expected secret access key %s, got %s", config.SecretAccessKey, creds.SecretAccessKey) + } + if creds.SessionToken != config.SessionToken { + t.Errorf("Expected session token %s, got %s", config.SessionToken, creds.SessionToken) + } +} + +func TestExtractServiceFromHost(t *testing.T) { + tests := []struct { + name string + host string + expected string + }{ + { + name: "S3 global endpoint", + host: "s3.amazonaws.com", + expected: "s3", + }, + { + name: "S3 regional endpoint", + host: "s3.us-east-1.amazonaws.com", + expected: "s3", + }, + { + name: "DynamoDB regional endpoint", + host: "dynamodb.us-west-2.amazonaws.com", + expected: "dynamodb", + }, + { + name: "EC2 endpoint", + host: "ec2.eu-west-1.amazonaws.com", + expected: "ec2", + }, + { + name: "STS endpoint", + host: "sts.amazonaws.com", + expected: "sts", + }, + { + name: "Host with port", + host: "s3.amazonaws.com:443", + expected: "s3", + }, + { + name: "Custom domain", + host: "my-bucket.example.com", + expected: "s3", + }, + { + name: "Localhost", + host: "localhost:9000", + expected: "s3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractServiceFromHost(tt.host) + if result != tt.expected { + t.Errorf("Expected service %s for host %s, got %s", tt.expected, tt.host, result) + } + }) + } +} diff --git a/pkg/auth/aws/factory.go b/pkg/auth/aws/factory.go new file mode 100644 index 00000000..4b4b3cae --- /dev/null +++ b/pkg/auth/aws/factory.go @@ -0,0 +1,386 @@ +package aws + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds" + "github.com/aws/aws-sdk-go-v2/credentials/endpointcreds" + "github.com/aws/aws-sdk-go-v2/credentials/processcreds" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/sgl-project/ome/pkg/auth" + "github.com/sgl-project/ome/pkg/logging" +) + +// Factory creates AWS credentials +type Factory struct { + logger logging.Interface +} + +// NewFactory creates a new AWS auth factory +func NewFactory(logger logging.Interface) *Factory { + return &Factory{ + logger: logger, + } +} + +// Create creates AWS credentials based on config +func (f *Factory) Create(ctx context.Context, config auth.Config) (auth.Credentials, error) { + if config.Provider != auth.ProviderAWS { + return nil, fmt.Errorf("invalid provider: expected %s, got %s", auth.ProviderAWS, config.Provider) + } + + var credProvider aws.CredentialsProvider + var err error + + switch config.AuthType { + case auth.AWSAccessKey: + credProvider, err = f.createAccessKeyProvider(config) + case auth.AWSAssumeRole: + credProvider, err = f.createAssumeRoleProvider(ctx, config) + case auth.AWSInstanceProfile: + credProvider, err = f.createInstanceProfileProvider(ctx, config) + case auth.AWSWebIdentity: + credProvider, err = f.createWebIdentityProvider(ctx, config) + case auth.AWSECSTaskRole: + credProvider, err = f.createECSTaskRoleProvider(ctx, config) + case auth.AWSProcess: + credProvider, err = f.createProcessProvider(config) + case auth.AWSDefault: + credProvider, err = f.createDefaultProvider(ctx, config) + default: + return nil, fmt.Errorf("unsupported AWS auth type: %s", config.AuthType) + } + + if err != nil { + return nil, fmt.Errorf("failed to create AWS credentials provider: %w", err) + } + + return &AWSCredentials{ + credProvider: credProvider, + authType: config.AuthType, + region: config.Region, + logger: f.logger, + }, nil +} + +// SupportedAuthTypes returns supported AWS auth types +func (f *Factory) SupportedAuthTypes() []auth.AuthType { + return []auth.AuthType{ + auth.AWSAccessKey, + auth.AWSAssumeRole, + auth.AWSInstanceProfile, + auth.AWSWebIdentity, + auth.AWSECSTaskRole, + auth.AWSProcess, + auth.AWSDefault, + } +} + +// createAccessKeyProvider creates an access key credentials provider +func (f *Factory) createAccessKeyProvider(config auth.Config) (aws.CredentialsProvider, error) { + // Extract access key config + akConfig := AccessKeyConfig{} + + if config.Extra != nil { + if ak, ok := config.Extra["access_key"].(map[string]interface{}); ok { + if accessKeyID, ok := ak["access_key_id"].(string); ok { + akConfig.AccessKeyID = accessKeyID + } + if secretAccessKey, ok := ak["secret_access_key"].(string); ok { + akConfig.SecretAccessKey = secretAccessKey + } + if sessionToken, ok := ak["session_token"].(string); ok { + akConfig.SessionToken = sessionToken + } + } + } + + // Check environment variables + if akConfig.AccessKeyID == "" { + akConfig.AccessKeyID = os.Getenv("AWS_ACCESS_KEY_ID") + } + if akConfig.SecretAccessKey == "" { + akConfig.SecretAccessKey = os.Getenv("AWS_SECRET_ACCESS_KEY") + } + if akConfig.SessionToken == "" { + akConfig.SessionToken = os.Getenv("AWS_SESSION_TOKEN") + } + + // Validate + if err := akConfig.Validate(); err != nil { + return nil, err + } + + return credentials.NewStaticCredentialsProvider( + akConfig.AccessKeyID, + akConfig.SecretAccessKey, + akConfig.SessionToken, + ), nil +} + +// createAssumeRoleProvider creates an assume role credentials provider +func (f *Factory) createAssumeRoleProvider(ctx context.Context, config auth.Config) (aws.CredentialsProvider, error) { + // Extract assume role config + arConfig := AssumeRoleConfig{} + + if config.Extra != nil { + if ar, ok := config.Extra["assume_role"].(map[string]interface{}); ok { + if roleARN, ok := ar["role_arn"].(string); ok { + arConfig.RoleARN = roleARN + } + if roleSessionName, ok := ar["role_session_name"].(string); ok { + arConfig.RoleSessionName = roleSessionName + } + if externalID, ok := ar["external_id"].(string); ok { + arConfig.ExternalID = externalID + } + } + } + + // Check environment variables + if arConfig.RoleARN == "" { + arConfig.RoleARN = os.Getenv("AWS_ROLE_ARN") + } + if arConfig.RoleSessionName == "" { + arConfig.RoleSessionName = os.Getenv("AWS_ROLE_SESSION_NAME") + if arConfig.RoleSessionName == "" { + arConfig.RoleSessionName = "ome-storage-session" + } + } + + // Validate + if err := arConfig.Validate(); err != nil { + return nil, err + } + + // Load base config with region if specified + configOpts := []func(*awsconfig.LoadOptions) error{} + if config.Region != "" { + configOpts = append(configOpts, awsconfig.WithRegion(config.Region)) + } + + cfg, err := awsconfig.LoadDefaultConfig(ctx, configOpts...) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + + // Create STS client + stsClient := sts.NewFromConfig(cfg) + + // Create assume role provider + provider := stscreds.NewAssumeRoleProvider(stsClient, arConfig.RoleARN, func(o *stscreds.AssumeRoleOptions) { + if arConfig.RoleSessionName != "" { + o.RoleSessionName = arConfig.RoleSessionName + } + if arConfig.ExternalID != "" { + o.ExternalID = &arConfig.ExternalID + } + if arConfig.Duration > 0 { + o.Duration = arConfig.Duration + } + }) + + return provider, nil +} + +// createInstanceProfileProvider creates an EC2 instance profile credentials provider +func (f *Factory) createInstanceProfileProvider(ctx context.Context, config auth.Config) (aws.CredentialsProvider, error) { + // Create EC2 role credentials provider + provider := ec2rolecreds.New() + + return provider, nil +} + +// createDefaultProvider creates a default credentials provider chain +func (f *Factory) createDefaultProvider(ctx context.Context, config auth.Config) (aws.CredentialsProvider, error) { + // Load default config with region if specified + configOpts := []func(*awsconfig.LoadOptions) error{} + if config.Region != "" { + configOpts = append(configOpts, awsconfig.WithRegion(config.Region)) + } + + cfg, err := awsconfig.LoadDefaultConfig(ctx, configOpts...) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + + return cfg.Credentials, nil +} + +// createWebIdentityProvider creates a web identity credentials provider +func (f *Factory) createWebIdentityProvider(ctx context.Context, config auth.Config) (aws.CredentialsProvider, error) { + // Extract web identity config + wiConfig := WebIdentityConfig{} + + if config.Extra != nil { + if wi, ok := config.Extra["web_identity"].(map[string]interface{}); ok { + if roleArn, ok := wi["role_arn"].(string); ok { + wiConfig.RoleARN = roleArn + } + if tokenFile, ok := wi["token_file"].(string); ok { + wiConfig.TokenFile = tokenFile + } + if sessionName, ok := wi["role_session_name"].(string); ok { + wiConfig.RoleSessionName = sessionName + } + } + } + + // Check environment variables as fallback + if wiConfig.RoleARN == "" { + wiConfig.RoleARN = os.Getenv("AWS_ROLE_ARN") + } + if wiConfig.TokenFile == "" { + wiConfig.TokenFile = os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE") + } + if wiConfig.RoleSessionName == "" { + wiConfig.RoleSessionName = os.Getenv("AWS_ROLE_SESSION_NAME") + if wiConfig.RoleSessionName == "" { + wiConfig.RoleSessionName = fmt.Sprintf("aws-web-identity-%d", time.Now().Unix()) + } + } + + // Validate + if err := wiConfig.Validate(); err != nil { + return nil, err + } + + // Load base config with region if specified + configOpts := []func(*awsconfig.LoadOptions) error{} + if config.Region != "" { + configOpts = append(configOpts, awsconfig.WithRegion(config.Region)) + } + + cfg, err := awsconfig.LoadDefaultConfig(ctx, configOpts...) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + + // Create STS client + stsClient := sts.NewFromConfig(cfg) + + // Create web identity role provider + provider := stscreds.NewWebIdentityRoleProvider( + stsClient, + wiConfig.RoleARN, + stscreds.IdentityTokenFile(wiConfig.TokenFile), + func(o *stscreds.WebIdentityRoleOptions) { + o.RoleSessionName = wiConfig.RoleSessionName + }, + ) + + return provider, nil +} + +// createECSTaskRoleProvider creates an ECS task role credentials provider +func (f *Factory) createECSTaskRoleProvider(ctx context.Context, config auth.Config) (aws.CredentialsProvider, error) { + // Extract ECS task role config + ecsConfig := ECSTaskRoleConfig{} + + if config.Extra != nil { + if ecs, ok := config.Extra["ecs_task_role"].(map[string]interface{}); ok { + if relativeURI, ok := ecs["relative_uri"].(string); ok { + ecsConfig.RelativeURI = relativeURI + } + if fullURI, ok := ecs["full_uri"].(string); ok { + ecsConfig.FullURI = fullURI + } + if authToken, ok := ecs["authorization_token"].(string); ok { + ecsConfig.AuthorizationToken = authToken + } + } + } + + // Check environment variables as fallback + if ecsConfig.RelativeURI == "" { + ecsConfig.RelativeURI = os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") + } + if ecsConfig.FullURI == "" { + ecsConfig.FullURI = os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") + } + if ecsConfig.AuthorizationToken == "" { + ecsConfig.AuthorizationToken = os.Getenv("AWS_CONTAINER_AUTHORIZATION_TOKEN") + } + + // Validate + if err := ecsConfig.Validate(); err != nil { + return nil, err + } + + // Create endpoint credentials provider based on whether we have relative or full URI + if ecsConfig.FullURI != "" { + // Use full URI endpoint + var options []func(*endpointcreds.Options) + if ecsConfig.AuthorizationToken != "" { + options = append(options, func(o *endpointcreds.Options) { + o.AuthorizationToken = ecsConfig.AuthorizationToken + }) + } + return endpointcreds.New(ecsConfig.FullURI, options...), nil + } + + // Use relative URI with the ECS credentials endpoint base + ecsEndpoint := fmt.Sprintf("http://169.254.170.2%s", ecsConfig.RelativeURI) + return endpointcreds.New(ecsEndpoint), nil +} + +// createProcessProvider creates a process credentials provider +func (f *Factory) createProcessProvider(config auth.Config) (aws.CredentialsProvider, error) { + // Extract process config + procConfig := ProcessConfig{} + + if config.Extra != nil { + if proc, ok := config.Extra["process"].(map[string]interface{}); ok { + if command, ok := proc["command"].(string); ok { + procConfig.Command = command + } + // Handle timeout - could be string, int64, or float64 from JSON/YAML + if timeoutVal, ok := proc["timeout"]; ok { + switch v := timeoutVal.(type) { + case string: + if duration, err := time.ParseDuration(v); err == nil { + procConfig.Timeout = duration + } + case float64: + procConfig.Timeout = time.Duration(v) * time.Second + case int64: + procConfig.Timeout = time.Duration(v) * time.Second + case int: + procConfig.Timeout = time.Duration(v) * time.Second + } + } + } + } + + // Check environment variable as fallback + if procConfig.Command == "" { + procConfig.Command = os.Getenv("AWS_CREDENTIAL_PROCESS") + } + + // Set default timeout if not specified + if procConfig.Timeout == 0 { + procConfig.Timeout = 1 * time.Minute + } + + // Validate + if err := procConfig.Validate(); err != nil { + return nil, err + } + + // Create process credentials provider with timeout if specified + var options []func(*processcreds.Options) + if procConfig.Timeout > 0 { + options = append(options, func(o *processcreds.Options) { + o.Timeout = procConfig.Timeout + }) + } + + return processcreds.NewProvider(procConfig.Command, options...), nil +} diff --git a/pkg/auth/aws/factory_test.go b/pkg/auth/aws/factory_test.go new file mode 100644 index 00000000..449d9469 --- /dev/null +++ b/pkg/auth/aws/factory_test.go @@ -0,0 +1,449 @@ +package aws + +import ( + "context" + "os" + "testing" + + "github.com/sgl-project/ome/pkg/auth" + "github.com/sgl-project/ome/pkg/logging" + "go.uber.org/zap/zaptest" +) + +func TestFactory_SupportedAuthTypes(t *testing.T) { + logger := logging.ForZap(zaptest.NewLogger(t)) + factory := NewFactory(logger) + + authTypes := factory.SupportedAuthTypes() + expected := []auth.AuthType{ + auth.AWSAccessKey, + auth.AWSAssumeRole, + auth.AWSInstanceProfile, + auth.AWSWebIdentity, + auth.AWSECSTaskRole, + auth.AWSProcess, + auth.AWSDefault, + } + + if len(authTypes) != len(expected) { + t.Errorf("Expected %d auth types, got %d", len(expected), len(authTypes)) + } + + typeMap := make(map[auth.AuthType]bool) + for _, at := range authTypes { + typeMap[at] = true + } + + for _, e := range expected { + if !typeMap[e] { + t.Errorf("Missing expected auth type: %s", e) + } + } +} + +func TestFactory_Create_InvalidProvider(t *testing.T) { + logger := logging.ForZap(zaptest.NewLogger(t)) + factory := NewFactory(logger) + ctx := context.Background() + + config := auth.Config{ + Provider: auth.ProviderOCI, // Wrong provider + AuthType: auth.AWSAccessKey, + } + + _, err := factory.Create(ctx, config) + if err == nil { + t.Error("Expected error for invalid provider") + } +} + +func TestFactory_Create_UnsupportedAuthType(t *testing.T) { + logger := logging.ForZap(zaptest.NewLogger(t)) + factory := NewFactory(logger) + ctx := context.Background() + + config := auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.OCIInstancePrincipal, // Wrong auth type for AWS + } + + _, err := factory.Create(ctx, config) + if err == nil { + t.Error("Expected error for unsupported auth type") + } +} + +func TestFactory_AccessKeyConfig_Validate(t *testing.T) { + tests := []struct { + name string + config AccessKeyConfig + wantError bool + }{ + { + name: "Valid config", + config: AccessKeyConfig{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + }, + wantError: false, + }, + { + name: "Missing access key ID", + config: AccessKeyConfig{ + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + }, + wantError: true, + }, + { + name: "Missing secret access key", + config: AccessKeyConfig{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantError { + t.Errorf("Validate() error = %v, wantError %v", err, tt.wantError) + } + }) + } +} + +func TestFactory_AssumeRoleConfig_Validate(t *testing.T) { + tests := []struct { + name string + config AssumeRoleConfig + wantError bool + }{ + { + name: "Valid config", + config: AssumeRoleConfig{ + RoleARN: "arn:aws:iam::123456789012:role/MyRole", + }, + wantError: false, + }, + { + name: "Missing role ARN", + config: AssumeRoleConfig{ + RoleSessionName: "my-session", + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantError { + t.Errorf("Validate() error = %v, wantError %v", err, tt.wantError) + } + }) + } +} + +func TestFactory_Create_AccessKey(t *testing.T) { + logger := logging.ForZap(zaptest.NewLogger(t)) + factory := NewFactory(logger) + ctx := context.Background() + + // Test with config values + config := auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSAccessKey, + Extra: map[string]interface{}{ + "access_key": map[string]interface{}{ + "access_key_id": "test-key", + "secret_access_key": "test-secret", + "session_token": "test-token", + }, + }, + } + + creds, err := factory.Create(ctx, config) + if err != nil { + t.Fatalf("Failed to create credentials: %v", err) + } + + if creds.Provider() != auth.ProviderAWS { + t.Errorf("Expected provider %s, got %s", auth.ProviderAWS, creds.Provider()) + } + if creds.Type() != auth.AWSAccessKey { + t.Errorf("Expected auth type %s, got %s", auth.AWSAccessKey, creds.Type()) + } +} + +func TestFactory_Create_AccessKey_FromEnvironment(t *testing.T) { + // Set environment variables + os.Setenv("AWS_ACCESS_KEY_ID", "env-key") + os.Setenv("AWS_SECRET_ACCESS_KEY", "env-secret") + defer os.Unsetenv("AWS_ACCESS_KEY_ID") + defer os.Unsetenv("AWS_SECRET_ACCESS_KEY") + + logger := logging.ForZap(zaptest.NewLogger(t)) + factory := NewFactory(logger) + ctx := context.Background() + + config := auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSAccessKey, + Extra: map[string]interface{}{}, + } + + creds, err := factory.Create(ctx, config) + if err != nil { + t.Fatalf("Failed to create credentials from environment: %v", err) + } + + if creds == nil { + t.Fatal("Expected credentials but got nil") + } +} + +func TestFactory_Create_AccessKey_MissingRequired(t *testing.T) { + logger := logging.ForZap(zaptest.NewLogger(t)) + factory := NewFactory(logger) + ctx := context.Background() + + config := auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSAccessKey, + Extra: map[string]interface{}{ + "access_key": map[string]interface{}{ + "access_key_id": "test-key", + // Missing secret_access_key + }, + }, + } + + _, err := factory.Create(ctx, config) + if err == nil { + t.Error("Expected error for missing secret_access_key") + } +} + +func TestFactory_Create_AssumeRole(t *testing.T) { + logger := logging.ForZap(zaptest.NewLogger(t)) + factory := NewFactory(logger) + ctx := context.Background() + + config := auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSAssumeRole, + Extra: map[string]interface{}{ + "assume_role": map[string]interface{}{ + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "role_session_name": "test-session", + "external_id": "external-123", + }, + }, + } + + // This will likely fail in unit tests due to missing AWS credentials + // but we can test that it attempts to create the provider + _, err := factory.Create(ctx, config) + // We expect an error here because we don't have real AWS creds + if err == nil { + t.Log("Unexpected success - normally would fail without real AWS credentials") + } +} + +func TestFactory_Create_WebIdentity(t *testing.T) { + logger := logging.ForZap(zaptest.NewLogger(t)) + factory := NewFactory(logger) + ctx := context.Background() + + config := auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSWebIdentity, + Extra: map[string]interface{}{ + "web_identity": map[string]interface{}{ + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "token_file": "/tmp/token", + "role_session_name": "test-session", + }, + }, + } + + // This will likely fail in unit tests due to missing AWS credentials + // but we can test that it attempts to create the provider + _, err := factory.Create(ctx, config) + // We expect an error here because we don't have real AWS creds + if err == nil { + t.Log("Unexpected success - normally would fail without real AWS credentials") + } +} + +func TestFactory_Create_InstanceProfile(t *testing.T) { + logger := logging.ForZap(zaptest.NewLogger(t)) + factory := NewFactory(logger) + ctx := context.Background() + + config := auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSInstanceProfile, + } + + // This will create an EC2 role credentials provider + creds, err := factory.Create(ctx, config) + if err != nil { + t.Fatalf("Failed to create instance profile credentials: %v", err) + } + + if creds.Provider() != auth.ProviderAWS { + t.Errorf("Expected provider %s, got %s", auth.ProviderAWS, creds.Provider()) + } + if creds.Type() != auth.AWSInstanceProfile { + t.Errorf("Expected auth type %s, got %s", auth.AWSInstanceProfile, creds.Type()) + } +} + +func TestFactory_Create_Default(t *testing.T) { + logger := logging.ForZap(zaptest.NewLogger(t)) + factory := NewFactory(logger) + ctx := context.Background() + + config := auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSDefault, + Region: "us-west-2", + } + + // This will likely fail in unit tests due to missing AWS credentials + // but we can test that it attempts to create the provider + _, err := factory.Create(ctx, config) + // We might get an error here because we don't have real AWS creds + if err == nil { + t.Log("Unexpected success - normally would fail without real AWS credentials") + } +} + +func TestFactory_Create_Process(t *testing.T) { + logger := logging.ForZap(zaptest.NewLogger(t)) + factory := NewFactory(logger) + ctx := context.Background() + + tests := []struct { + name string + config auth.Config + wantErr bool + }{ + { + name: "Valid process config with string timeout", + config: auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSProcess, + Extra: map[string]interface{}{ + "process": map[string]interface{}{ + "command": "/usr/local/bin/aws-credential-process", + "timeout": "30s", + }, + }, + }, + wantErr: false, + }, + { + name: "Valid process config with numeric timeout", + config: auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSProcess, + Extra: map[string]interface{}{ + "process": map[string]interface{}{ + "command": "/usr/local/bin/aws-credential-process", + "timeout": 30.0, // seconds as float64 + }, + }, + }, + wantErr: false, + }, + { + name: "Missing command", + config: auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSProcess, + Extra: map[string]interface{}{ + "process": map[string]interface{}{ + "timeout": "30s", + }, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + creds, err := factory.Create(ctx, tt.config) + if (err != nil) != tt.wantErr { + t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr) + } + if err == nil && creds == nil { + t.Error("Expected credentials to be created") + } + }) + } +} + +func TestFactory_Create_ECSTaskRole(t *testing.T) { + logger := logging.ForZap(zaptest.NewLogger(t)) + factory := NewFactory(logger) + ctx := context.Background() + + tests := []struct { + name string + config auth.Config + wantErr bool + }{ + { + name: "Valid ECS config with relative URI", + config: auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSECSTaskRole, + Extra: map[string]interface{}{ + "ecs_task_role": map[string]interface{}{ + "relative_uri": "/v2/credentials/12345", + }, + }, + }, + wantErr: false, + }, + { + name: "Valid ECS config with full URI", + config: auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSECSTaskRole, + Extra: map[string]interface{}{ + "ecs_task_role": map[string]interface{}{ + "full_uri": "http://localhost:8080/credentials", + "authorization_token": "secret-token", + }, + }, + }, + wantErr: false, + }, + { + name: "Missing both URIs", + config: auth.Config{ + Provider: auth.ProviderAWS, + AuthType: auth.AWSECSTaskRole, + Extra: map[string]interface{}{}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + creds, err := factory.Create(ctx, tt.config) + if (err != nil) != tt.wantErr { + t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr) + } + if err == nil && creds == nil { + t.Error("Expected credentials to be created") + } + }) + } +} diff --git a/pkg/auth/interfaces.go b/pkg/auth/interfaces.go index 6b221e31..f0007a1c 100644 --- a/pkg/auth/interfaces.go +++ b/pkg/auth/interfaces.go @@ -32,6 +32,8 @@ const ( AWSInstanceProfile AuthType = "AWSInstanceProfile" AWSAssumeRole AuthType = "AWSAssumeRole" AWSWebIdentity AuthType = "AWSWebIdentity" + AWSECSTaskRole AuthType = "AWSECSTaskRole" + AWSProcess AuthType = "AWSProcess" AWSDefault AuthType = "AWSDefault" // GCP auth types