diff --git a/github/github.go b/github/github.go index 424ed1ddc90..c8ec97c6e09 100644 --- a/github/github.go +++ b/github/github.go @@ -86,8 +86,9 @@ type Client struct { // User agent used when communicating with the GitHub API. UserAgent string - rateMu sync.Mutex - rate Rate // Rate limit for the client as determined by the most recent API call. + rateMu sync.Mutex + rateLimits [categories]Rate // Rate limits for the client as determined by the most recent API calls. + mostRecent rateLimitCategory // Services used for talking to different parts of the GitHub API. Activity *ActivityService @@ -319,11 +320,13 @@ func parseRate(r *http.Response) Rate { // Rate specifies the current rate limit for the client as determined by the // most recent API call. If the client is used in a multi-user application, -// this rate may not always be up-to-date. Call RateLimits() to check the -// current rate. +// this rate may not always be up-to-date. +// +// Deprecated: Use the Response.Rate returned from most recent API call instead. +// Call RateLimits() to check the current rate. func (c *Client) Rate() Rate { c.rateMu.Lock() - rate := c.rate + rate := c.rateLimits[c.mostRecent] c.rateMu.Unlock() return rate } @@ -332,8 +335,16 @@ func (c *Client) Rate() Rate { // JSON decoded and stored in the value pointed to by v, or returned as an // error if an API error has occurred. If v implements the io.Writer // interface, the raw response body will be written to v, without attempting to -// first decode it. +// first decode it. If rate limit is exceeded and reset time is in the future, +// Do returns *RateLimitError immediately without making a network API call. func (c *Client) Do(req *http.Request, v interface{}) (*Response, error) { + rateLimitCategory := category(req.URL.Path) + + // If we've hit rate limit, don't make further requests before Reset time. + if err := c.checkRateLimitBeforeDo(req, rateLimitCategory); err != nil { + return nil, err + } + resp, err := c.client.Do(req) if err != nil { return nil, err @@ -348,7 +359,8 @@ func (c *Client) Do(req *http.Request, v interface{}) (*Response, error) { response := newResponse(resp) c.rateMu.Lock() - c.rate = response.Rate + c.rateLimits[rateLimitCategory] = response.Rate + c.mostRecent = rateLimitCategory c.rateMu.Unlock() err = CheckResponse(resp) @@ -372,6 +384,33 @@ func (c *Client) Do(req *http.Request, v interface{}) (*Response, error) { return response, err } +// checkRateLimitBeforeDo does not make any network calls, but uses existing knowledge from +// current client state in order to quickly check if *RateLimitError can be immediately returned +// from Client.Do, and if so, returns it so that Client.Do can skip making a network API call unneccessarily. +// Otherwise it returns nil, and Client.Do should proceed normally. +func (c *Client) checkRateLimitBeforeDo(req *http.Request, rateLimitCategory rateLimitCategory) error { + c.rateMu.Lock() + rate := c.rateLimits[rateLimitCategory] + c.rateMu.Unlock() + if !rate.Reset.Time.IsZero() && rate.Remaining == 0 && time.Now().Before(rate.Reset.Time) { + // Create a fake response. + resp := &http.Response{ + Status: http.StatusText(http.StatusForbidden), + StatusCode: http.StatusForbidden, + Request: req, + Header: make(http.Header), + Body: ioutil.NopCloser(strings.NewReader("")), + } + return &RateLimitError{ + Rate: rate, + Response: resp, + Message: fmt.Sprintf("API rate limit of %v still exceeded until %v, not making remote request.", rate.Limit, rate.Reset.Time), + } + } + + return nil +} + /* An ErrorResponse reports one or more errors caused by an API request. @@ -528,6 +567,8 @@ type RateLimits struct { // The rate limit for non-search API requests. Unauthenticated // requests are limited to 60 per hour. Authenticated requests are // limited to 5,000 per hour. + // + // GitHub API docs: https://developer.github.com/v3/#rate-limiting Core *Rate `json:"core"` // The rate limit for search API requests. Unauthenticated requests @@ -542,6 +583,25 @@ func (r RateLimits) String() string { return Stringify(r) } +type rateLimitCategory uint8 + +const ( + coreCategory rateLimitCategory = iota + searchCategory + + categories // An array of this length will be able to contain all rate limit categories. +) + +// category returns the rate limit category of the endpoint, determined by Request.URL.Path. +func category(path string) rateLimitCategory { + switch { + default: + return coreCategory + case strings.HasPrefix(path, "/search/"): + return searchCategory + } +} + // Deprecated: RateLimit is deprecated, use RateLimits instead. func (c *Client) RateLimit() (*Rate, *Response, error) { limits, resp, err := c.RateLimits() @@ -567,6 +627,17 @@ func (c *Client) RateLimits() (*RateLimits, *Response, error) { return nil, nil, err } + if response.Resources != nil { + c.rateMu.Lock() + if response.Resources.Core != nil { + c.rateLimits[coreCategory] = *response.Resources.Core + } + if response.Resources.Search != nil { + c.rateLimits[searchCategory] = *response.Resources.Search + } + c.rateMu.Unlock() + } + return response.Resources, resp, err } diff --git a/github/github_test.go b/github/github_test.go index 06abcc5d29d..521f12a126b 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -163,6 +163,13 @@ func TestNewClient(t *testing.T) { } } +// Ensure that length of Client.rateLimits is the same as number of fields in RateLimits struct. +func TestClient_rateLimits(t *testing.T) { + if got, want := len(Client{}.rateLimits), reflect.TypeOf(RateLimits{}).NumField(); got != want { + t.Errorf("len(Client{}.rateLimits) is %v, want %v", got, want) + } +} + func TestNewRequest(t *testing.T) { c := NewClient(nil) @@ -478,6 +485,60 @@ func TestDo_rateLimit_rateLimitError(t *testing.T) { } } +// Ensure a network call is not made when it's known that API rate limit is still exceeded. +func TestDo_rateLimit_noNetworkCall(t *testing.T) { + setup() + defer teardown() + + reset := time.Now().UTC().Round(time.Second).Add(time.Minute) // Rate reset is a minute from now, with 1 second precision. + + mux.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) { + w.Header().Add(headerRateLimit, "60") + w.Header().Add(headerRateRemaining, "0") + w.Header().Add(headerRateReset, fmt.Sprint(reset.Unix())) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusForbidden) + fmt.Fprintln(w, `{ + "message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)", + "documentation_url": "https://developer.github.com/v3/#rate-limiting" +}`) + }) + + madeNetworkCall := false + mux.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) { + madeNetworkCall = true + }) + + // First request is made, and it makes the client aware of rate reset time being in the future. + req, _ := client.NewRequest("GET", "/first", nil) + client.Do(req, nil) + + // Second request should not cause a network call to be made, since client can predict a rate limit error. + req, _ = client.NewRequest("GET", "/second", nil) + _, err := client.Do(req, nil) + + if madeNetworkCall { + t.Fatal("Network call was made, even though rate limit is known to still be exceeded.") + } + + if err == nil { + t.Error("Expected error to be returned.") + } + rateLimitErr, ok := err.(*RateLimitError) + if !ok { + t.Fatalf("Expected a *RateLimitError error; got %#v.", err) + } + if got, want := rateLimitErr.Rate.Limit, 60; got != want { + t.Errorf("rateLimitErr rate limit = %v, want %v", got, want) + } + if got, want := rateLimitErr.Rate.Remaining, 0; got != want { + t.Errorf("rateLimitErr rate remaining = %v, want %v", got, want) + } + if rateLimitErr.Rate.Reset.UTC() != reset { + t.Errorf("rateLimitErr rate reset = %v, want %v", rateLimitErr.Rate.Reset.UTC(), reset) + } +} + func TestDo_noContent(t *testing.T) { setup() defer teardown() @@ -628,7 +689,6 @@ func TestRateLimit(t *testing.T) { if m := "GET"; m != r.Method { t.Errorf("Request method = %v, want %v", r.Method, m) } - //fmt.Fprint(w, `{"resources":{"core": {"limit":2,"remaining":1,"reset":1372700873}}}`) fmt.Fprint(w, `{"resources":{ "core": {"limit":2,"remaining":1,"reset":1372700873}, "search": {"limit":3,"remaining":2,"reset":1372700874} @@ -684,6 +744,13 @@ func TestRateLimits(t *testing.T) { if !reflect.DeepEqual(rate, want) { t.Errorf("RateLimits returned %+v, want %+v", rate, want) } + + if got, want := client.rateLimits[coreCategory], *want.Core; got != want { + t.Errorf("client.rateLimits[coreCategory] is %+v, want %+v", got, want) + } + if got, want := client.rateLimits[searchCategory], *want.Search; got != want { + t.Errorf("client.rateLimits[searchCategory] is %+v, want %+v", got, want) + } } func TestUnauthenticatedRateLimitedTransport(t *testing.T) {