Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 64 additions & 80 deletions api/auth/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@ package aws

import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"io"
"net/http"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
)
Expand All @@ -32,7 +36,7 @@ type AWSAuth struct {
signatureType string
region string
iamServerIDHeaderValue string
creds *credentials.Credentials
creds aws.CredentialsProvider
nonce string
}

Expand Down Expand Up @@ -95,102 +99,82 @@ func (a *AWSAuth) Login(ctx context.Context, client *api.Client) (*api.Secret, e
loginData := make(map[string]interface{})
switch a.authType {
case ec2Type:
sess, err := session.NewSession()
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(a.region))
if err != nil {
return nil, fmt.Errorf("error creating session to probe EC2 metadata: %w", err)
return nil, fmt.Errorf("error loading AWS config: %w", err)
}
metadataSvc := ec2metadata.New(sess)
if !metadataSvc.Available() {
return nil, fmt.Errorf("metadata service not available")
metadataSvc := imds.NewFromConfig(cfg)

var path string
switch a.signatureType {
case pkcs7Type:
path = "/instance-identity/pkcs7"
case identityType:
path = "/instance-identity/document"
case rsa2048Type:
path = "/instance-identity/rsa2048"
default:
return nil, fmt.Errorf("unknown signature type: %s", a.signatureType)
}

if a.signatureType == pkcs7Type {
// fetch PKCS #7 signature
resp, err := metadataSvc.GetDynamicData("/instance-identity/pkcs7")
if err != nil {
return nil, fmt.Errorf("unable to get PKCS 7 data from metadata service: %w", err)
}
pkcs7 := strings.TrimSpace(resp)
loginData["pkcs7"] = pkcs7
} else if a.signatureType == identityType {
// fetch signature from identity document
doc, err := metadataSvc.GetDynamicData("/instance-identity/document")
if err != nil {
return nil, fmt.Errorf("error requesting instance identity doc: %w", err)
}
loginData["identity"] = base64.StdEncoding.EncodeToString([]byte(doc))

signature, err := metadataSvc.GetDynamicData("/instance-identity/signature")
if err != nil {
return nil, fmt.Errorf("error requesting signature: %w", err)
}
loginData["signature"] = signature
} else if a.signatureType == rsa2048Type {
// fetch RSA 2048 signature, which is also a PKCS#7 signature
resp, err := metadataSvc.GetDynamicData("/instance-identity/rsa2048")
if err != nil {
return nil, fmt.Errorf("unable to get PKCS 7 data from metadata service: %w", err)
}
pkcs7 := strings.TrimSpace(resp)
loginData["pkcs7"] = pkcs7
} else {
return nil, fmt.Errorf("unknown signature type: %s", a.signatureType)
resp, err := metadataSvc.GetDynamicData(ctx, &imds.GetDynamicDataInput{Path: path})
if err != nil {
return nil, fmt.Errorf("unable to get identity data: %w", err)
}
defer resp.Content.Close()
body, err := io.ReadAll(resp.Content)
if err != nil {
return nil, fmt.Errorf("error reading identity data: %w", err)
}
pkcs7 := strings.TrimSpace(string(body))
loginData["pkcs7"] = pkcs7

// Add the reauthentication value, if we have one
if a.nonce == "" {
uid, err := uuid.GenerateUUID()
uuid, err := uuid.GenerateUUID()
if err != nil {
return nil, fmt.Errorf("error generating uuid for reauthentication value: %w", err)
return nil, fmt.Errorf("error generating uuid: %w", err)
}
a.nonce = uid
a.nonce = uuid
}
loginData["nonce"] = a.nonce
case iamType:
logger := hclog.Default()
if a.creds == nil {
credsConfig := awsutil.CredentialsConfig{
AccessKey: os.Getenv("AWS_ACCESS_KEY_ID"),
SecretKey: os.Getenv("AWS_SECRET_ACCESS_KEY"),
SessionToken: os.Getenv("AWS_SESSION_TOKEN"),
Logger: logger,
}

// the env vars above will take precedence if they are set, as
// they will be added to the ChainProvider stack first
var hasCredsFile bool
credsFilePath := os.Getenv("AWS_SHARED_CREDENTIALS_FILE")
if credsFilePath != "" {
hasCredsFile = true
credsConfig.Filename = credsFilePath
}

creds, err := credsConfig.GenerateCredentialChain(awsutil.WithSharedCredentials(hasCredsFile))
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(a.region))
if err != nil {
return nil, err
}
if creds == nil {
return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata")
return nil, fmt.Errorf("unable to load AWS config: %w", err)
}
a.creds = cfg.Credentials
}

_, err = creds.Get()
if err != nil {
return nil, fmt.Errorf("failed to retrieve credentials from credential chain: %w", err)
}
credsVal, err := a.creds.Retrieve(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve credentials: %w", err)
}

a.creds = creds
const iamBody = "Action=GetCallerIdentity&Version=2011-06-15"
req, err := http.NewRequest("POST", "https://sts.amazonaws.com/", strings.NewReader(iamBody))
if err != nil {
return nil, fmt.Errorf("failed to construct STS request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

data, err := awsutil.GenerateLoginData(a.creds, a.iamServerIDHeaderValue, a.region, logger)
hash := sha256.Sum256([]byte(iamBody))
payloadHash := hex.EncodeToString(hash[:])

signer := v4.NewSigner()
err = signer.SignHTTP(ctx, credsVal, req, payloadHash, "sts", a.region, time.Now().UTC())
if err != nil {
return nil, fmt.Errorf("unable to generate login data for AWS auth endpoint: %w", err)
return nil, fmt.Errorf("failed to sign STS request: %w", err)
}
loginData = data

headersData, _ := json.Marshal(req.Header)

loginData["iam_http_request_method"] = "POST"
loginData["iam_request_url"] = base64.StdEncoding.EncodeToString([]byte(req.URL.String()))
loginData["iam_request_body"] = base64.StdEncoding.EncodeToString([]byte(iamBody))
loginData["iam_request_headers"] = base64.StdEncoding.EncodeToString(headersData)
}

// Add role if we have one. If not, Vault will infer the role name based
// on the IAM friendly name (iam auth type) or EC2 instance's
// AMI ID (ec2 auth type).
if a.roleName != "" {
loginData["role"] = a.roleName
}
Expand Down
22 changes: 13 additions & 9 deletions api/auth/aws/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,25 @@ go 1.23.0
toolchain go1.23.7

require (
github.com/aws/aws-sdk-go v1.49.22
github.com/hashicorp/go-hclog v1.6.3
github.com/hashicorp/go-secure-stdlib/awsutil v0.1.6
github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2/config v1.26.1
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10
github.com/hashicorp/go-uuid v1.0.2
github.com/hashicorp/vault/api v1.20.0
)

require (
github.com/aws/aws-sdk-go-v2/credentials v1.16.12 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.9 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.9 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.18.5 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.26.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/fatih/color v1.16.0 // indirect
github.com/go-jose/go-jose/v4 v4.0.5 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
Expand All @@ -25,16 +34,11 @@ require (
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect
github.com/hashicorp/go-sockaddr v1.0.2 // indirect
github.com/hashicorp/hcl v1.0.1-vault-7 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.37.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 // indirect
)
Loading
Loading