diff --git a/internal/github/client.go b/internal/github/client.go index cacbc5e..5b64107 100644 --- a/internal/github/client.go +++ b/internal/github/client.go @@ -37,10 +37,11 @@ type ListIssuesOptions struct { } type RequestError struct { - Method string - URL string - Status int - Body string + Method string + URL string + Status int + Body string + Headers http.Header } func (e *RequestError) Error() string { @@ -63,13 +64,12 @@ func New(options Options) *Client { if userAgent == "" { userAgent = "gitcrawl" } - pageDelay := options.PageDelay return &Client{ httpClient: httpClient, baseURL: baseURL, token: options.Token, userAgent: userAgent, - pageDelay: pageDelay, + pageDelay: options.PageDelay, } } @@ -185,6 +185,26 @@ func (c *Client) doJSON(ctx context.Context, method, path string, body io.Reader } func (c *Client) do(ctx context.Context, method, path string, body io.Reader, reporter Reporter) (*http.Response, error) { + resp, err := c.doOnce(ctx, method, path, body, reporter) + if err == nil { + return resp, nil + } + wait, ok := rateLimitWait(err) + if !ok { + return nil, err + } + reporter.Printf("[github] rate-limit retry wait=%s", wait.Round(time.Second)) + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + } + return c.doOnce(ctx, method, path, body, reporter) +} + +func (c *Client) doOnce(ctx context.Context, method, path string, body io.Reader, reporter Reporter) (*http.Response, error) { fullURL := c.baseURL + path req, err := http.NewRequestWithContext(ctx, method, fullURL, body) if err != nil { @@ -206,7 +226,39 @@ func (c *Client) do(ctx context.Context, method, path string, body io.Reader, re } defer resp.Body.Close() data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return nil, &RequestError{Method: method, URL: path, Status: resp.StatusCode, Body: strings.TrimSpace(string(data))} + return nil, &RequestError{ + Method: method, + URL: path, + Status: resp.StatusCode, + Body: strings.TrimSpace(string(data)), + Headers: resp.Header, + } +} + +func rateLimitWait(err error) (time.Duration, bool) { + reqErr, ok := err.(*RequestError) + if !ok { + return 0, false + } + if reqErr.Status != http.StatusForbidden && reqErr.Status != http.StatusTooManyRequests { + return 0, false + } + if v := strings.TrimSpace(reqErr.Headers.Get("Retry-After")); v != "" { + if secs, err := strconv.Atoi(v); err == nil && secs > 0 { + return time.Duration(secs) * time.Second, true + } + } + if reqErr.Headers.Get("X-RateLimit-Remaining") != "0" { + return 0, false + } + secs, err := strconv.ParseInt(strings.TrimSpace(reqErr.Headers.Get("X-RateLimit-Reset")), 10, 64) + if err != nil { + return 0, false + } + if wait := time.Until(time.Unix(secs, 0)); wait > 0 { + return wait, true + } + return time.Second, true } func nextPage(linkHeader string) string { diff --git a/internal/github/client_test.go b/internal/github/client_test.go index 5824ae5..75f84ba 100644 --- a/internal/github/client_test.go +++ b/internal/github/client_test.go @@ -3,10 +3,14 @@ package github import ( "context" "encoding/json" + "errors" + "fmt" "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" + "time" ) func TestListRepositoryIssuesPaginatesAndLimits(t *testing.T) { @@ -173,6 +177,91 @@ func TestClientErrorAndHelperBranches(t *testing.T) { } } +func TestRateLimitRetriesOn403WithRemainingZero(t *testing.T) { + var calls int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if atomic.AddInt32(&calls, 1) == 1 { + w.Header().Set("X-RateLimit-Remaining", "0") + w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Unix())) + http.Error(w, "rate limited", http.StatusForbidden) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) + })) + defer server.Close() + + client := New(Options{BaseURL: server.URL, PageDelay: -1}) + row, err := client.GetRepo(context.Background(), "openclaw", "gitcrawl", nil) + if err != nil { + t.Fatalf("get repo: %v", err) + } + if intValue(row["id"]) != 1 { + t.Fatalf("row = %#v", row) + } + if got := atomic.LoadInt32(&calls); got != 2 { + t.Fatalf("calls = %d want 2", got) + } +} + +func TestRateLimitRetriesOn429WithRetryAfter(t *testing.T) { + var calls int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if atomic.AddInt32(&calls, 1) == 1 { + w.Header().Set("Retry-After", "1") + http.Error(w, "slow down", http.StatusTooManyRequests) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{"id": 2}) + })) + defer server.Close() + + client := New(Options{BaseURL: server.URL, PageDelay: -1}) + row, err := client.GetRepo(context.Background(), "openclaw", "gitcrawl", nil) + if err != nil { + t.Fatalf("get repo: %v", err) + } + if intValue(row["id"]) != 2 { + t.Fatalf("row = %#v", row) + } +} + +func TestRateLimitRespectsContextCancellation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-RateLimit-Remaining", "0") + w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Add(time.Hour).Unix())) + http.Error(w, "rate limited", http.StatusForbidden) + })) + defer server.Close() + + client := New(Options{BaseURL: server.URL, PageDelay: -1}) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + _, err := client.GetRepo(ctx, "openclaw", "gitcrawl", nil) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("err = %v", err) + } +} + +func TestNonRateLimit403IsNotRetried(t *testing.T) { + var calls int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + http.Error(w, "forbidden", http.StatusForbidden) + })) + defer server.Close() + + client := New(Options{BaseURL: server.URL, PageDelay: -1}) + if _, err := client.GetRepo(context.Background(), "openclaw", "gitcrawl", nil); err == nil { + t.Fatal("expected error") + } + if got := atomic.LoadInt32(&calls); got != 1 { + t.Fatalf("calls = %d want 1", got) + } +} + func serverURL(r *http.Request) string { scheme := "http" if r.TLS != nil {