Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 59 additions & 7 deletions internal/github/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ type ListIssuesOptions struct {
}

type RequestError struct {
Method string
URL string
Status int
Body string
Method string
URL string
Status int
Body string
Headers http.Header
}

func (e *RequestError) Error() string {
Expand All @@ -63,13 +64,12 @@ func New(options Options) *Client {
if userAgent == "" {
userAgent = "gitcrawl"
}
pageDelay := options.PageDelay
return &Client{
httpClient: httpClient,
baseURL: baseURL,
token: options.Token,
userAgent: userAgent,
pageDelay: pageDelay,
pageDelay: options.PageDelay,
}
}

Expand Down Expand Up @@ -185,6 +185,26 @@ func (c *Client) doJSON(ctx context.Context, method, path string, body io.Reader
}

func (c *Client) do(ctx context.Context, method, path string, body io.Reader, reporter Reporter) (*http.Response, error) {
resp, err := c.doOnce(ctx, method, path, body, reporter)
if err == nil {
return resp, nil
}
wait, ok := rateLimitWait(err)
if !ok {
return nil, err
}
reporter.Printf("[github] rate-limit retry wait=%s", wait.Round(time.Second))
timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
}
return c.doOnce(ctx, method, path, body, reporter)
}

func (c *Client) doOnce(ctx context.Context, method, path string, body io.Reader, reporter Reporter) (*http.Response, error) {
fullURL := c.baseURL + path
req, err := http.NewRequestWithContext(ctx, method, fullURL, body)
if err != nil {
Expand All @@ -206,7 +226,39 @@ func (c *Client) do(ctx context.Context, method, path string, body io.Reader, re
}
defer resp.Body.Close()
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, &RequestError{Method: method, URL: path, Status: resp.StatusCode, Body: strings.TrimSpace(string(data))}
return nil, &RequestError{
Method: method,
URL: path,
Status: resp.StatusCode,
Body: strings.TrimSpace(string(data)),
Headers: resp.Header,
}
}

func rateLimitWait(err error) (time.Duration, bool) {
reqErr, ok := err.(*RequestError)
if !ok {
return 0, false
}
if reqErr.Status != http.StatusForbidden && reqErr.Status != http.StatusTooManyRequests {
return 0, false
}
if v := strings.TrimSpace(reqErr.Headers.Get("Retry-After")); v != "" {
if secs, err := strconv.Atoi(v); err == nil && secs > 0 {
return time.Duration(secs) * time.Second, true
}
}
if reqErr.Headers.Get("X-RateLimit-Remaining") != "0" {
return 0, false
}
secs, err := strconv.ParseInt(strings.TrimSpace(reqErr.Headers.Get("X-RateLimit-Reset")), 10, 64)
if err != nil {
return 0, false
}
if wait := time.Until(time.Unix(secs, 0)); wait > 0 {
return wait, true
}
return time.Second, true
}

func nextPage(linkHeader string) string {
Expand Down
89 changes: 89 additions & 0 deletions internal/github/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@ package github
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
)

func TestListRepositoryIssuesPaginatesAndLimits(t *testing.T) {
Expand Down Expand Up @@ -173,6 +177,91 @@ func TestClientErrorAndHelperBranches(t *testing.T) {
}
}

func TestRateLimitRetriesOn403WithRemainingZero(t *testing.T) {
var calls int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if atomic.AddInt32(&calls, 1) == 1 {
w.Header().Set("X-RateLimit-Remaining", "0")
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Unix()))
http.Error(w, "rate limited", http.StatusForbidden)
return
}
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
}))
defer server.Close()

client := New(Options{BaseURL: server.URL, PageDelay: -1})
row, err := client.GetRepo(context.Background(), "openclaw", "gitcrawl", nil)
if err != nil {
t.Fatalf("get repo: %v", err)
}
if intValue(row["id"]) != 1 {
t.Fatalf("row = %#v", row)
}
if got := atomic.LoadInt32(&calls); got != 2 {
t.Fatalf("calls = %d want 2", got)
}
}

func TestRateLimitRetriesOn429WithRetryAfter(t *testing.T) {
var calls int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if atomic.AddInt32(&calls, 1) == 1 {
w.Header().Set("Retry-After", "1")
http.Error(w, "slow down", http.StatusTooManyRequests)
return
}
_ = json.NewEncoder(w).Encode(map[string]any{"id": 2})
}))
defer server.Close()

client := New(Options{BaseURL: server.URL, PageDelay: -1})
row, err := client.GetRepo(context.Background(), "openclaw", "gitcrawl", nil)
if err != nil {
t.Fatalf("get repo: %v", err)
}
if intValue(row["id"]) != 2 {
t.Fatalf("row = %#v", row)
}
}

func TestRateLimitRespectsContextCancellation(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-RateLimit-Remaining", "0")
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Add(time.Hour).Unix()))
http.Error(w, "rate limited", http.StatusForbidden)
}))
defer server.Close()

client := New(Options{BaseURL: server.URL, PageDelay: -1})
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, err := client.GetRepo(ctx, "openclaw", "gitcrawl", nil)
if err == nil {
t.Fatal("expected error")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("err = %v", err)
}
}

func TestNonRateLimit403IsNotRetried(t *testing.T) {
var calls int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&calls, 1)
http.Error(w, "forbidden", http.StatusForbidden)
}))
defer server.Close()

client := New(Options{BaseURL: server.URL, PageDelay: -1})
if _, err := client.GetRepo(context.Background(), "openclaw", "gitcrawl", nil); err == nil {
t.Fatal("expected error")
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("calls = %d want 1", got)
}
}

func serverURL(r *http.Request) string {
scheme := "http"
if r.TLS != nil {
Expand Down