diff --git a/README.md b/README.md index 94e1cc544ee..386f581968b 100644 --- a/README.md +++ b/README.md @@ -180,6 +180,12 @@ if _, ok := err.(*github.AbuseRateLimitError); ok { } ``` +Alternatively, you can block until the rate limit is reset by using the `context.WithValue` method: + +````go +repos, _, err := client.Repositories.List(context.WithValue(ctx, github.SleepUntilPrimaryRateLimitResetWhenRateLimited, true), "", nil) +``` + You can use [go-github-ratelimit](https://github.com/gofri/go-github-ratelimit) to handle secondary rate limit sleep-and-retry for you. diff --git a/github/github.go b/github/github.go index 03e2e6273cf..ad4269595bf 100644 --- a/github/github.go +++ b/github/github.go @@ -804,6 +804,7 @@ type requestContext uint8 const ( bypassRateLimitCheck requestContext = iota + SleepUntilPrimaryRateLimitResetWhenRateLimited ) // BareDo sends an API request and lets you handle the api response. If an error @@ -889,6 +890,15 @@ func (c *Client) BareDo(ctx context.Context, req *http.Request) (*Response, erro err = aerr } + rateLimitError, ok := err.(*RateLimitError) + if ok && req.Context().Value(SleepUntilPrimaryRateLimitResetWhenRateLimited) != nil { + if err := sleepUntilResetWithBuffer(req.Context(), rateLimitError.Rate.Reset.Time); err != nil { + return response, err + } + // retry the request once when the rate limit has reset + return c.BareDo(context.WithValue(req.Context(), SleepUntilPrimaryRateLimitResetWhenRateLimited, nil), req) + } + // Update the secondary rate limit if we hit it. rerr, ok := err.(*AbuseRateLimitError) if ok && rerr.RetryAfter != nil { @@ -950,6 +960,18 @@ func (c *Client) checkRateLimitBeforeDo(req *http.Request, rateLimitCategory Rat Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), } + + if req.Context().Value(SleepUntilPrimaryRateLimitResetWhenRateLimited) != nil { + if err := sleepUntilResetWithBuffer(req.Context(), rate.Reset.Time); err == nil { + return nil + } + return &RateLimitError{ + Rate: rate, + Response: resp, + Message: fmt.Sprintf("Context cancelled while waiting for rate limit to reset until %v, not making remote request.", rate.Reset.Time), + } + } + return &RateLimitError{ Rate: rate, Response: resp, @@ -1514,6 +1536,20 @@ func formatRateReset(d time.Duration) string { return fmt.Sprintf("[rate reset in %v]", timeString) } +func sleepUntilResetWithBuffer(ctx context.Context, reset time.Time) error { + buffer := time.Second + timer := time.NewTimer(time.Until(reset) + buffer) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return ctx.Err() + case <-timer.C: + } + return nil +} + // When using roundTripWithOptionalFollowRedirect, note that it // is the responsibility of the caller to close the response body. func (c *Client) roundTripWithOptionalFollowRedirect(ctx context.Context, u string, maxRedirects int, opts ...RequestOption) (*http.Response, error) { diff --git a/github/github_test.go b/github/github_test.go index efc42d432fb..8c6c38a84f0 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -1381,6 +1381,176 @@ func TestDo_rateLimit_ignoredFromCache(t *testing.T) { } } +// Ensure sleeps until the rate limit is reset when the client is rate limited. +func TestDo_rateLimit_sleepUntilResponseResetLimit(t *testing.T) { + client, mux, _, teardown := setup() + defer teardown() + + reset := time.Now().UTC().Add(time.Second) + + var firstRequest = true + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if firstRequest { + firstRequest = false + w.Header().Set(headerRateLimit, "60") + w.Header().Set(headerRateRemaining, "0") + w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusForbidden) + fmt.Fprintln(w, `{ + "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", + "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" +}`) + return + } + w.Header().Set(headerRateLimit, "5000") + w.Header().Set(headerRateRemaining, "5000") + w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix())) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, `{}`) + }) + + req, _ := client.NewRequest("GET", ".", nil) + ctx := context.Background() + resp, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) + if err != nil { + t.Errorf("Do returned unexpected error: %v", err) + } + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("Response status code = %v, want %v", got, want) + } +} + +// Ensure tries to sleep until the rate limit is reset when the client is rate limited, but only once. +func TestDo_rateLimit_sleepUntilResponseResetLimitRetryOnce(t *testing.T) { + client, mux, _, teardown := setup() + defer teardown() + + reset := time.Now().UTC().Add(time.Second) + + requestCount := 0 + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.Header().Set(headerRateLimit, "60") + w.Header().Set(headerRateRemaining, "0") + w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusForbidden) + fmt.Fprintln(w, `{ + "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", + "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" +}`) + }) + + req, _ := client.NewRequest("GET", ".", nil) + ctx := context.Background() + _, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) + if err == nil { + t.Error("Expected error to be returned.") + } + if got, want := requestCount, 2; got != want { + t.Errorf("Expected 2 requests, got %d", got) + } +} + +// Ensure a network call is not made when it's known that API rate limit is still exceeded. +func TestDo_rateLimit_sleepUntilClientResetLimit(t *testing.T) { + client, mux, _, teardown := setup() + defer teardown() + + reset := time.Now().UTC().Add(time.Second) + client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}} + requestCount := 0 + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.Header().Set(headerRateLimit, "5000") + w.Header().Set(headerRateRemaining, "5000") + w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix())) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, `{}`) + }) + req, _ := client.NewRequest("GET", ".", nil) + ctx := context.Background() + resp, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) + if err != nil { + t.Errorf("Do returned unexpected error: %v", err) + } + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("Response status code = %v, want %v", got, want) + } + if got, want := requestCount, 1; got != want { + t.Errorf("Expected 1 request, got %d", got) + } +} + +// Ensure sleep is aborted when the context is cancelled. +func TestDo_rateLimit_abortSleepContextCancelled(t *testing.T) { + client, mux, _, teardown := setup() + defer teardown() + + // We use a 1 minute reset time to ensure the sleep is not completed. + reset := time.Now().UTC().Add(time.Minute) + requestCount := 0 + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.Header().Set(headerRateLimit, "60") + w.Header().Set(headerRateRemaining, "0") + w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix())) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusForbidden) + fmt.Fprintln(w, `{ + "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", + "documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits" +}`) + }) + + req, _ := client.NewRequest("GET", ".", nil) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + _, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) + if !errors.Is(err, context.DeadlineExceeded) { + t.Error("Expected context deadline exceeded error.") + } + if got, want := requestCount, 1; got != want { + t.Errorf("Expected 1 requests, got %d", got) + } +} + +// Ensure sleep is aborted when the context is cancelled on initial request. +func TestDo_rateLimit_abortSleepContextCancelledClientLimit(t *testing.T) { + client, mux, _, teardown := setup() + defer teardown() + + reset := time.Now().UTC().Add(time.Minute) + client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}} + requestCount := 0 + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.Header().Set(headerRateLimit, "5000") + w.Header().Set(headerRateRemaining, "5000") + w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix())) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, `{}`) + }) + req, _ := client.NewRequest("GET", ".", nil) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + _, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil) + rateLimitError, ok := err.(*RateLimitError) + if !ok { + t.Fatalf("Expected a *rateLimitError error; got %#v.", err) + } + if got, wantSuffix := rateLimitError.Message, "Context cancelled while waiting for rate limit to reset until"; !strings.HasPrefix(got, wantSuffix) { + t.Errorf("Expected request to be prevented because context cancellation, got: %v.", got) + } + if got, want := requestCount, 0; got != want { + t.Errorf("Expected 1 requests, got %d", got) + } +} + // Ensure *AbuseRateLimitError is returned when the response indicates that // the client has triggered an abuse detection mechanism. func TestDo_rateLimit_abuseRateLimitError(t *testing.T) {