Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions auth/jwt/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,27 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai
}
}

// ClaimsFactory is a factory for jwt.Claims.
// Useful in NewParser middleware.
type ClaimsFactory func() jwt.Claims

// MapClaimsFactory is a ClaimsFactory that returns
// an empty jwt.MapClaims.
func MapClaimsFactory() jwt.Claims {
return jwt.MapClaims{}
}

// StandardClaimsFactory is a ClaimsFactory that returns
// an empty jwt.StandardClaims.
func StandardClaimsFactory() jwt.Claims {
return &jwt.StandardClaims{}
}

// NewParser creates a new JWT token parsing middleware, specifying a
// jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser
// adds the resulting claims to endpoint context or returns error on invalid token.
// Particularly useful for servers.
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware {
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims ClaimsFactory) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
// tokenString is stored in the context from the transport handlers.
Expand All @@ -85,7 +101,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims)
// of the token to identify which key to use, but the parsed token
// (head and claims) is provided to the callback, providing
// flexibility.
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, newClaims(), func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if token.Method != method {
return nil, ErrUnexpectedSigningMethod
Expand Down
36 changes: 28 additions & 8 deletions auth/jwt/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jwt

import (
"context"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -73,7 +74,7 @@ func TestJWTParser(t *testing.T) {
return key, nil
}

parser := NewParser(keys, method, jwt.MapClaims{})(e)
parser := NewParser(keys, method, MapClaimsFactory)(e)

// No Token is passed into the parser
_, err := parser(context.Background(), struct{}{})
Expand All @@ -93,7 +94,7 @@ func TestJWTParser(t *testing.T) {
}

// Invalid Method is used in the parser
badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e)
badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -109,7 +110,7 @@ func TestJWTParser(t *testing.T) {
return []byte("bad"), nil
}

badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e)
badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -133,15 +134,15 @@ func TestJWTParser(t *testing.T) {
}

// Test for malformed token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey)
ctx1, err = parser(ctx, struct{}{})
if want, have := ErrTokenMalformed, err; want != have {
t.Fatalf("Expected %+v, got %+v", want, have)
}

// Test for expired token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100})
token, err := expired.SignedString(key)
if err != nil {
Expand All @@ -154,7 +155,7 @@ func TestJWTParser(t *testing.T) {
}

// Test for not activated token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100})
token, err = notactive.SignedString(key)
if err != nil {
Expand All @@ -167,7 +168,7 @@ func TestJWTParser(t *testing.T) {
}

// test valid standard claims token
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
Expand All @@ -182,7 +183,7 @@ func TestJWTParser(t *testing.T) {
}

// test valid customized claims token
parser = NewParser(keys, method, &customClaims{})(e)
parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
Expand All @@ -199,3 +200,22 @@ func TestJWTParser(t *testing.T) {
t.Fatalf("JWT customClaims.MyProperty did not match: expecting %s got %s", myProperty, custCl.MyProperty)
}
}

func TestIssue562(t *testing.T) {
var (
kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil }
e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop)
key = JWTTokenContextKey
val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
ctx = context.WithValue(context.Background(), key, val)
)
wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
e(ctx, struct{}{}) // fatal error: concurrent map read and map write
}()
}
wg.Wait()
}