Skip to content

Commit 2f24f44

Browse files
committed
oauthdevice: add RefreshToken field and Refresh method
1 parent 2ac282a commit 2f24f44

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed

internal/oauthdevice/device_flow.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ type Client interface {
6868
Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error)
6969
Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error)
7070
Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error)
71+
Refresh(ctx context.Context, endpoint, refreshToken string) (*TokenResponse, error)
7172
}
7273

7374
type httpClient struct {
@@ -307,3 +308,55 @@ func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode str
307308

308309
return &tokenResp, nil
309310
}
311+
312+
// Refresh exchanges a refresh token for a new access token.
313+
func (c *httpClient) Refresh(ctx context.Context, endpoint, refreshToken string) (*TokenResponse, error) {
314+
endpoint = strings.TrimRight(endpoint, "/")
315+
316+
config, err := c.Discover(ctx, endpoint)
317+
if err != nil {
318+
return nil, errors.Wrap(err, "OIDC discovery failed")
319+
}
320+
321+
if config.TokenEndpoint == "" {
322+
return nil, errors.New("token endpoint not found in OIDC configuration")
323+
}
324+
325+
data := url.Values{}
326+
data.Set("client_id", c.clientID)
327+
data.Set("grant_type", "refresh_token")
328+
data.Set("refresh_token", refreshToken)
329+
330+
req, err := http.NewRequestWithContext(ctx, "POST", config.TokenEndpoint, strings.NewReader(data.Encode()))
331+
if err != nil {
332+
return nil, errors.Wrap(err, "creating refresh token request")
333+
}
334+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
335+
req.Header.Set("Accept", "application/json")
336+
337+
resp, err := c.client.Do(req)
338+
if err != nil {
339+
return nil, errors.Wrap(err, "refresh token request failed")
340+
}
341+
defer resp.Body.Close()
342+
343+
body, err := io.ReadAll(resp.Body)
344+
if err != nil {
345+
return nil, errors.Wrap(err, "reading refresh token response")
346+
}
347+
348+
if resp.StatusCode != http.StatusOK {
349+
var errResp ErrorResponse
350+
if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" {
351+
return nil, errors.Newf("refresh token failed: %s: %s", errResp.Error, errResp.ErrorDescription)
352+
}
353+
return nil, errors.Newf("refresh token failed with status %d: %s", resp.StatusCode, string(body))
354+
}
355+
356+
var tokenResp TokenResponse
357+
if err := json.Unmarshal(body, &tokenResp); err != nil {
358+
return nil, errors.Wrap(err, "parsing refresh token response")
359+
}
360+
361+
return &tokenResp, nil
362+
}

internal/oauthdevice/device_flow_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,3 +507,44 @@ func TestPoll_ContextCancellation(t *testing.T) {
507507
t.Errorf("error = %v, want context.Canceled or wrapped context canceled error", err)
508508
}
509509
}
510+
511+
func TestRefresh_Success(t *testing.T) {
512+
server := newTestServer(t, testServerOptions{
513+
handlers: map[string]http.HandlerFunc{
514+
testTokenPath: func(w http.ResponseWriter, r *http.Request) {
515+
if err := r.ParseForm(); err != nil {
516+
http.Error(w, "bad request", http.StatusBadRequest)
517+
return
518+
}
519+
if got := r.FormValue("grant_type"); got != "refresh_token" {
520+
t.Errorf("grant_type = %q, want %q", got, "refresh_token")
521+
}
522+
if got := r.FormValue("refresh_token"); got != "test-refresh-token" {
523+
t.Errorf("refresh_token = %q, want %q", got, "test-refresh-token")
524+
}
525+
526+
w.Header().Set("Content-Type", "application/json")
527+
json.NewEncoder(w).Encode(TokenResponse{
528+
AccessToken: "new-access-token",
529+
RefreshToken: "new-refresh-token",
530+
TokenType: "Bearer",
531+
ExpiresIn: 3600,
532+
})
533+
},
534+
},
535+
})
536+
defer server.Close()
537+
538+
client := NewClient(DefaultClientID)
539+
resp, err := client.Refresh(context.Background(), server.URL, "test-refresh-token")
540+
if err != nil {
541+
t.Fatalf("Refresh() error = %v", err)
542+
}
543+
544+
if resp.AccessToken != "new-access-token" {
545+
t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "new-access-token")
546+
}
547+
if resp.RefreshToken != "new-refresh-token" {
548+
t.Errorf("RefreshToken = %q, want %q", resp.RefreshToken, "new-refresh-token")
549+
}
550+
}

0 commit comments

Comments
 (0)