diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go index e7dcb9d6a..c07d6e97e 100644 --- a/auth/jwt/middleware.go +++ b/auth/jwt/middleware.go @@ -66,11 +66,27 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai } } +// ClaimsFactory is a factory for jwt.Claims. +// Useful in NewParser middleware. +type ClaimsFactory func() jwt.Claims + +// MapClaimsFactory is a ClaimsFactory that returns +// an empty jwt.MapClaims. +func MapClaimsFactory() jwt.Claims { + return jwt.MapClaims{} +} + +// StandardClaimsFactory is a ClaimsFactory that returns +// an empty jwt.StandardClaims. +func StandardClaimsFactory() jwt.Claims { + return &jwt.StandardClaims{} +} + // NewParser creates a new JWT token parsing middleware, specifying a // jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser // adds the resulting claims to endpoint context or returns error on invalid token. // Particularly useful for servers. -func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { +func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims ClaimsFactory) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { // tokenString is stored in the context from the transport handlers. @@ -85,7 +101,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) // of the token to identify which key to use, but the parsed token // (head and claims) is provided to the callback, providing // flexibility. - token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, newClaims(), func(token *jwt.Token) (interface{}, error) { // Don't forget to validate the alg is what you expect: if token.Method != method { return nil, ErrUnexpectedSigningMethod diff --git a/auth/jwt/middleware_test.go b/auth/jwt/middleware_test.go index 8a4520154..3278e13a7 100644 --- a/auth/jwt/middleware_test.go +++ b/auth/jwt/middleware_test.go @@ -2,6 +2,7 @@ package jwt import ( "context" + "sync" "testing" "time" @@ -73,7 +74,7 @@ func TestJWTParser(t *testing.T) { return key, nil } - parser := NewParser(keys, method, jwt.MapClaims{})(e) + parser := NewParser(keys, method, MapClaimsFactory)(e) // No Token is passed into the parser _, err := parser(context.Background(), struct{}{}) @@ -93,7 +94,7 @@ func TestJWTParser(t *testing.T) { } // Invalid Method is used in the parser - badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e) + badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -109,7 +110,7 @@ func TestJWTParser(t *testing.T) { return []byte("bad"), nil } - badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e) + badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -133,7 +134,7 @@ func TestJWTParser(t *testing.T) { } // Test for malformed token error response - parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey) ctx1, err = parser(ctx, struct{}{}) if want, have := ErrTokenMalformed, err; want != have { @@ -141,7 +142,7 @@ func TestJWTParser(t *testing.T) { } // Test for expired token error response - parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100}) token, err := expired.SignedString(key) if err != nil { @@ -154,7 +155,7 @@ func TestJWTParser(t *testing.T) { } // Test for not activated token error response - parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100}) token, err = notactive.SignedString(key) if err != nil { @@ -167,7 +168,7 @@ func TestJWTParser(t *testing.T) { } // test valid standard claims token - parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + parser = NewParser(keys, method, StandardClaimsFactory)(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey) ctx1, err = parser(ctx, struct{}{}) if err != nil { @@ -182,7 +183,7 @@ func TestJWTParser(t *testing.T) { } // test valid customized claims token - parser = NewParser(keys, method, &customClaims{})(e) + parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey) ctx1, err = parser(ctx, struct{}{}) if err != nil { @@ -199,3 +200,22 @@ func TestJWTParser(t *testing.T) { t.Fatalf("JWT customClaims.MyProperty did not match: expecting %s got %s", myProperty, custCl.MyProperty) } } + +func TestIssue562(t *testing.T) { + var ( + kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil } + e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop) + key = JWTTokenContextKey + val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" + ctx = context.WithValue(context.Background(), key, val) + ) + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + e(ctx, struct{}{}) // fatal error: concurrent map read and map write + }() + } + wg.Wait() +}