diff --git a/client.go b/client.go index 3a22b99..86bb294 100644 --- a/client.go +++ b/client.go @@ -552,6 +552,9 @@ func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) if resp != nil { if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { if sleep, ok := parseRetryAfterHeader(resp.Header["Retry-After"]); ok { + if sleep > max { + return max + } return sleep } } diff --git a/client_test.go b/client_test.go index 0e8ca60..94e62be 100644 --- a/client_test.go +++ b/client_test.go @@ -896,6 +896,51 @@ func TestClient_DefaultBackoff(t *testing.T) { } } +// TestClient_DefaultBackoff_RetryAfterExceedsMax verifies that DefaultBackoff +// caps the sleep duration from a Retry-After header at RetryWaitMax so that +// a server cannot force an arbitrarily long wait. +func TestClient_DefaultBackoff_RetryAfterExceedsMax(t *testing.T) { + testStaticTime(t) + const retryWaitMax = 10 * time.Second + + tests := []struct { + name string + code int + retryHeader string + }{ + {"http_429_seconds_exceeds_max", http.StatusTooManyRequests, "3600"}, + {"http_503_seconds_exceeds_max", http.StatusServiceUnavailable, "3600"}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", test.retryHeader) + http.Error(w, fmt.Sprintf("test_%d_body", test.code), test.code) + })) + defer ts.Close() + + client := NewClient() + client.RetryWaitMax = retryWaitMax + + var got time.Duration + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + got = DefaultBackoff(client.RetryWaitMin, client.RetryWaitMax, 1, resp) + return false, nil + } + + _, err := client.Get(ts.URL) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if got != retryWaitMax { + t.Fatalf("Retry-After of 3600s should be capped to RetryWaitMax=%s, got %s", retryWaitMax, got) + } + }) + } +} + func TestClient_DefaultRetryPolicy_TLS(t *testing.T) { ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200)