diff --git a/github/github.go b/github/github.go index 91af5aa665f..8355f395985 100644 --- a/github/github.go +++ b/github/github.go @@ -373,6 +373,25 @@ func (c *Client) WithAuthToken(token string) *Client { return transport.RoundTrip(req) }, ) + // Prevent the bearer token from being forwarded to a different host on + // redirect. The Transport above re-injects the Authorization header on + // every RoundTrip call, including the intermediate calls that http.Client + // makes when following redirects. Go's http.Client strips Authorization + // for cross-host redirects, but the Transport immediately adds it back, + // so the token would be sent to the redirect destination. Returning + // http.ErrUseLastResponse surfaces the 3xx to the caller instead. + callerRedirect := c2.client.CheckRedirect + c2.client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if callerRedirect != nil { + if err := callerRedirect(req, via); err != nil { + return err + } + } + if len(via) > 0 && req.URL.Host != via[0].URL.Host { + return http.ErrUseLastResponse + } + return nil + } return c2 } diff --git a/github/github_test.go b/github/github_test.go index b41ebd8f849..280cc601104 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -539,6 +539,67 @@ func TestWithAuthToken(t *testing.T) { t.Fatal("WithAuthToken reset Marketplace.Stubbed; want true") } }) + + t.Run("cross-host redirect does not leak token", func(t *testing.T) { + t.Parallel() + + tokenReceived := false + victim := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "" { + tokenReceived = true + } + w.WriteHeader(http.StatusOK) + })) + defer victim.Close() + + redirector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, victim.URL+"/stolen", http.StatusFound) + })) + defer redirector.Close() + + c := new(Client).WithAuthToken("SECRET_TOKEN") + resp, err := c.Client().Get(redirector.URL + "/api") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusFound { + t.Errorf("got status %d; want %d (cross-host redirect should not be followed)", resp.StatusCode, http.StatusFound) + } + if tokenReceived { + t.Error("token was forwarded to redirect target; cross-host redirect should have been stopped") + } + }) + + t.Run("same-host redirect is followed", func(t *testing.T) { + t.Parallel() + + called := false + mux := http.NewServeMux() + mux.HandleFunc("/redirectme", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/final", http.StatusFound) + }) + mux.HandleFunc("/final", func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + server := httptest.NewServer(mux) + defer server.Close() + + c := new(Client).WithAuthToken("MY_TOKEN") + resp, err := c.Client().Get(server.URL + "/redirectme") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + if !called { + t.Error("same-host redirect was not followed") + } + if resp.StatusCode != http.StatusOK { + t.Errorf("got status %d; want 200", resp.StatusCode) + } + }) } func TestWithEnterpriseURLs(t *testing.T) {