From ad3e1648ae05314a6c9a5b6cd84e31016b03a71c Mon Sep 17 00:00:00 2001 From: Cosmos Nicolaou Date: Thu, 19 Nov 2020 18:14:33 -0800 Subject: [PATCH 1/3] x/ref/runtime/internal: use aws imds v1 first and then v2 --- x/ref/runtime/internal/cloudvm.go | 5 +- x/ref/runtime/internal/cloudvm/aws.go | 60 +++++++++++++------ x/ref/runtime/internal/cloudvm/aws_test.go | 19 ++++-- .../internal/cloudvm/cloudvmtest/aws_mock.go | 24 +++++--- 4 files changed, 77 insertions(+), 31 deletions(-) diff --git a/x/ref/runtime/internal/cloudvm.go b/x/ref/runtime/internal/cloudvm.go index f25f3ab9d..aac4d3592 100644 --- a/x/ref/runtime/internal/cloudvm.go +++ b/x/ref/runtime/internal/cloudvm.go @@ -60,6 +60,9 @@ type asyncChooser struct { func (ac *asyncChooser) ChooseAddresses(protocol string, candidates []net.Addr) ([]net.Addr, error) { select { case <-ac.ch: + if cvmErr != nil { + return nil, cvmErr + } return cvm.ChooseAddresses(protocol, candidates) case <-ac.ctx.Done(): return nil, ac.ctx.Err() @@ -115,7 +118,7 @@ func newCloudVM(ctx context.Context, logger logging.Logger, fl *flags.Virtualize switch fl.VirtualizationProvider.Get().(flags.VirtualizationProvider) { case flags.AWS: - if !cloudvm.OnAWS(ctx, time.Second) { + if !cloudvm.OnAWS(ctx, cvm.logger, time.Second) { if fl.DissallowNativeFallback { return nil, fmt.Errorf("this process is not running on AWS even though its command line says it is") } diff --git a/x/ref/runtime/internal/cloudvm/aws.go b/x/ref/runtime/internal/cloudvm/aws.go index e44346f89..81febac2c 100644 --- a/x/ref/runtime/internal/cloudvm/aws.go +++ b/x/ref/runtime/internal/cloudvm/aws.go @@ -17,6 +17,7 @@ import ( "sync" "time" + "v.io/v23/logging" "v.io/x/ref/lib/stats" "v.io/x/ref/runtime/internal/cloudvm/cloudpaths" ) @@ -58,46 +59,56 @@ const ( var ( onceAWS sync.Once onAWS bool + imdsv2 bool ) // OnAWS returns true if this process is running on Amazon Web Services. // If true, the the stats variables AWSAccountIDStatName and GCPRegionStatName // are set. -func OnAWS(ctx context.Context, timeout time.Duration) bool { +func OnAWS(ctx context.Context, logger logging.Logger, timeout time.Duration) bool { onceAWS.Do(func() { - onAWS = awsInit(ctx, timeout) + onAWS, imdsv2 = awsInit(ctx, logger, timeout) + logger.VI(1).Infof("OnAWS: onAWS: %v, imdsv2: %v", onAWS, imdsv2) }) return onAWS } // AWSPublicAddrs returns the current public IP of this AWS instance. +// Must be called after OnAWS. func AWSPublicAddrs(ctx context.Context, timeout time.Duration) ([]net.Addr, error) { - return awsGetAddr(ctx, awsExternalURL(), timeout) + return awsGetAddr(ctx, imdsv2, awsExternalURL(), timeout) } // AWSPrivateAddrs returns the current private Addrs of this AWS instance. +// Must be called after OnAWS. func AWSPrivateAddrs(ctx context.Context, timeout time.Duration) ([]net.Addr, error) { - return awsGetAddr(ctx, awsInternalURL(), timeout) + return awsGetAddr(ctx, imdsv2, awsInternalURL(), timeout) } -func awsGet(ctx context.Context, url string, timeout time.Duration) ([]byte, error) { +func awsGet(ctx context.Context, imdsv2 bool, url string, timeout time.Duration) ([]byte, error) { client := &http.Client{Timeout: timeout} - token, err := awsSetIMDSv2Token(ctx, awsTokenURL(), timeout) - if err != nil { - return nil, err + var token string + var err error + if imdsv2 { + token, err = awsSetIMDSv2Token(ctx, awsTokenURL(), timeout) + if err != nil { + return nil, err + } } req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - req.Header.Add("X-aws-ec2-metadata-token", token) if err != nil { return nil, err } + if len(token) > 0 { + req.Header.Add("X-aws-ec2-metadata-token", token) + } resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != 200 { - return nil, err + return nil, fmt.Errorf("HTTP Error: %v %v", url, resp.StatusCode) } if server := resp.Header["Server"]; len(server) != 1 || server[0] != "EC2ws" { return nil, fmt.Errorf("wrong headers") @@ -105,16 +116,29 @@ func awsGet(ctx context.Context, url string, timeout time.Duration) ([]byte, err return ioutil.ReadAll(resp.Body) } -// awsInit returns true if it can access AWS project metadata. It also +// awsInit returns true if it can access AWS project metadata and the version +// of the metadata service it was able to access. It also // creates two stats variables with the account ID and zone. -func awsInit(ctx context.Context, timeout time.Duration) bool { - body, err := awsGet(ctx, awsIdentityDocURL(), timeout) +func awsInit(ctx context.Context, logger logging.Logger, timeout time.Duration) (bool, bool) { + v2 := false + // Try the v1 service first since it should always work unless v2 + // is specifically configured (and hence v1 is disabled), in which + // case the expectation is that it fails fast with a 4xx HTTP error. + body, err := awsGet(ctx, false, awsIdentityDocURL(), timeout) if err != nil { - return false + logger.VI(1).Infof("failed to access v1 metadata service: %v", err) + // can't access v1, try v2. + body, err = awsGet(ctx, true, awsIdentityDocURL(), timeout) + if err != nil { + logger.VI(1).Infof("failed to access v2 metadata service: %v", err) + return false, false + } + v2 = true } doc := map[string]interface{}{} if err := json.Unmarshal(body, &doc); err != nil { - return false + logger.VI(1).Infof("failed to unmarshal metadata service response: %s: %v", body, err) + return false, false } found := 0 for _, v := range []struct { @@ -130,11 +154,11 @@ func awsInit(ctx context.Context, timeout time.Duration) bool { } } } - return found == 2 + return found == 2, v2 } -func awsGetAddr(ctx context.Context, url string, timeout time.Duration) ([]net.Addr, error) { - body, err := awsGet(ctx, url, timeout) +func awsGetAddr(ctx context.Context, imdsv2 bool, url string, timeout time.Duration) ([]net.Addr, error) { + body, err := awsGet(ctx, imdsv2, url, timeout) if err != nil { return nil, err } diff --git a/x/ref/runtime/internal/cloudvm/aws_test.go b/x/ref/runtime/internal/cloudvm/aws_test.go index c2541dfc4..014f37644 100644 --- a/x/ref/runtime/internal/cloudvm/aws_test.go +++ b/x/ref/runtime/internal/cloudvm/aws_test.go @@ -9,22 +9,30 @@ import ( "testing" "time" + "v.io/x/ref/internal/logger" "v.io/x/ref/runtime/internal/cloudvm/cloudpaths" "v.io/x/ref/runtime/internal/cloudvm/cloudvmtest" ) -func startAWSMetadataServer(t *testing.T) (string, func()) { - host, close := cloudvmtest.StartAWSMetadataServer(t) +func startAWSMetadataServer(t *testing.T, imdsv2Only bool) (string, func()) { + host, close := cloudvmtest.StartAWSMetadataServer(t, imdsv2Only) SetAWSMetadataHost(host) return host, close } func TestAWS(t *testing.T) { + testAWSIDMSVersion(t, false) + testAWSIDMSVersion(t, true) +} + +func testAWSIDMSVersion(t *testing.T, imdsv2Only bool) { ctx := context.Background() - host, stop := startAWSMetadataServer(t) + host, stop := startAWSMetadataServer(t, false) defer stop() - if got, want := OnAWS(ctx, time.Second), true; got != want { + logger := logger.NewLogger("test") + + if got, want := OnAWS(ctx, logger, time.Second), true; got != want { t.Errorf("got %v, want %v", got, want) } @@ -45,8 +53,9 @@ func TestAWS(t *testing.T) { if got, want := pub[0].String(), cloudvmtest.WellKnownPublicIP; got != want { t.Errorf("got %v, want %v", got, want) } + externalURL := host + cloudpaths.AWSPublicIPPath + "/noip" - noip, err := awsGetAddr(ctx, externalURL, time.Second) + noip, err := awsGetAddr(ctx, false, externalURL, time.Second) if err != nil { t.Fatal(err) } diff --git a/x/ref/runtime/internal/cloudvm/cloudvmtest/aws_mock.go b/x/ref/runtime/internal/cloudvm/cloudvmtest/aws_mock.go index fb8666edb..92fde87c4 100644 --- a/x/ref/runtime/internal/cloudvm/cloudvmtest/aws_mock.go +++ b/x/ref/runtime/internal/cloudvm/cloudvmtest/aws_mock.go @@ -15,13 +15,14 @@ import ( "v.io/x/ref/runtime/internal/cloudvm/cloudpaths" ) -func StartAWSMetadataServer(t *testing.T) (string, func()) { +func StartAWSMetadataServer(t *testing.T, imdsv2Only bool) (string, func()) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } var token string - http.HandleFunc(cloudpaths.AWSTokenPath, func(w http.ResponseWriter, req *http.Request) { + mux := &http.ServeMux{} + mux.HandleFunc(cloudpaths.AWSTokenPath, func(w http.ResponseWriter, req *http.Request) { token = time.Now().String() w.Header().Add("Server", "EC2ws") fmt.Fprint(w, token) @@ -32,7 +33,13 @@ func StartAWSMetadataServer(t *testing.T) (string, func()) { return requestToken == token } - http.HandleFunc(cloudpaths.AWSIdentityDocPath, func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc(cloudpaths.AWSIdentityDocPath, func(w http.ResponseWriter, r *http.Request) { + if imdsv2Only { + if len(r.Header.Get("X-aws-ec2-metadata-token")) == 0 { + w.WriteHeader(http.StatusUnauthorized) + return + } + } if !validSession(r) { w.WriteHeader(http.StatusForbidden) return @@ -58,19 +65,22 @@ func StartAWSMetadataServer(t *testing.T) (string, func()) { fmt.Fprintf(w, format, args...) } - http.HandleFunc(cloudpaths.AWSPrivateIPPath, + mux.HandleFunc(cloudpaths.AWSPrivateIPPath, func(w http.ResponseWriter, r *http.Request) { respond(w, r, WellKnownPrivateIP) }) - http.HandleFunc(cloudpaths.AWSPublicIPPath, + mux.HandleFunc(cloudpaths.AWSPublicIPPath, func(w http.ResponseWriter, r *http.Request) { respond(w, r, WellKnownPublicIP) }) - http.HandleFunc(cloudpaths.AWSPublicIPPath+"/noip", + mux.HandleFunc(cloudpaths.AWSPublicIPPath+"/noip", func(w http.ResponseWriter, r *http.Request) { respond(w, r, "") }) - go http.Serve(l, nil) + srv := http.Server{ + Handler: mux, + } + go srv.Serve(l) return "http://" + l.Addr().String(), func() { l.Close() } } From 91706665a05ed09f8f6441a14256dff59d161013 Mon Sep 17 00:00:00 2001 From: Cosmos Nicolaou Date: Thu, 19 Nov 2020 18:18:16 -0800 Subject: [PATCH 2/3] fix test --- x/ref/runtime/internal/cloudvm_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/ref/runtime/internal/cloudvm_test.go b/x/ref/runtime/internal/cloudvm_test.go index 8174665b7..fcecd7d36 100644 --- a/x/ref/runtime/internal/cloudvm_test.go +++ b/x/ref/runtime/internal/cloudvm_test.go @@ -40,7 +40,7 @@ func hasAddr(addrs []net.Addr, host string) bool { } func TestCloudVMProviders(t *testing.T) { - awsHost, awsClose := cloudvmtest.StartAWSMetadataServer(t) + awsHost, awsClose := cloudvmtest.StartAWSMetadataServer(t, true) defer awsClose() cloudvm.SetAWSMetadataHost(awsHost) From cbf10a299049648452da625c483d7adfaf3311bc Mon Sep 17 00:00:00 2001 From: Cosmos Nicolaou Date: Fri, 20 Nov 2020 12:42:34 -0800 Subject: [PATCH 3/3] address review comments --- x/ref/runtime/internal/cloudvm/aws.go | 14 +++++++------- x/ref/runtime/internal/cloudvm/aws_test.go | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/x/ref/runtime/internal/cloudvm/aws.go b/x/ref/runtime/internal/cloudvm/aws.go index 81febac2c..3cbd9c649 100644 --- a/x/ref/runtime/internal/cloudvm/aws.go +++ b/x/ref/runtime/internal/cloudvm/aws.go @@ -57,9 +57,9 @@ const ( ) var ( - onceAWS sync.Once - onAWS bool - imdsv2 bool + onceAWS sync.Once + onAWS bool + onIMDSv2 bool ) // OnAWS returns true if this process is running on Amazon Web Services. @@ -67,8 +67,8 @@ var ( // are set. func OnAWS(ctx context.Context, logger logging.Logger, timeout time.Duration) bool { onceAWS.Do(func() { - onAWS, imdsv2 = awsInit(ctx, logger, timeout) - logger.VI(1).Infof("OnAWS: onAWS: %v, imdsv2: %v", onAWS, imdsv2) + onAWS, onIMDSv2 = awsInit(ctx, logger, timeout) + logger.VI(1).Infof("OnAWS: onAWS: %v, onIMDSv2: %v", onAWS, onIMDSv2) }) return onAWS } @@ -76,13 +76,13 @@ func OnAWS(ctx context.Context, logger logging.Logger, timeout time.Duration) bo // AWSPublicAddrs returns the current public IP of this AWS instance. // Must be called after OnAWS. func AWSPublicAddrs(ctx context.Context, timeout time.Duration) ([]net.Addr, error) { - return awsGetAddr(ctx, imdsv2, awsExternalURL(), timeout) + return awsGetAddr(ctx, onIMDSv2, awsExternalURL(), timeout) } // AWSPrivateAddrs returns the current private Addrs of this AWS instance. // Must be called after OnAWS. func AWSPrivateAddrs(ctx context.Context, timeout time.Duration) ([]net.Addr, error) { - return awsGetAddr(ctx, imdsv2, awsInternalURL(), timeout) + return awsGetAddr(ctx, onIMDSv2, awsInternalURL(), timeout) } func awsGet(ctx context.Context, imdsv2 bool, url string, timeout time.Duration) ([]byte, error) { diff --git a/x/ref/runtime/internal/cloudvm/aws_test.go b/x/ref/runtime/internal/cloudvm/aws_test.go index 014f37644..6793ecaee 100644 --- a/x/ref/runtime/internal/cloudvm/aws_test.go +++ b/x/ref/runtime/internal/cloudvm/aws_test.go @@ -27,7 +27,7 @@ func TestAWS(t *testing.T) { func testAWSIDMSVersion(t *testing.T, imdsv2Only bool) { ctx := context.Background() - host, stop := startAWSMetadataServer(t, false) + host, stop := startAWSMetadataServer(t, imdsv2Only) defer stop() logger := logger.NewLogger("test") @@ -55,7 +55,7 @@ func testAWSIDMSVersion(t *testing.T, imdsv2Only bool) { } externalURL := host + cloudpaths.AWSPublicIPPath + "/noip" - noip, err := awsGetAddr(ctx, false, externalURL, time.Second) + noip, err := awsGetAddr(ctx, imdsv2Only, externalURL, time.Second) if err != nil { t.Fatal(err) }