diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go index b5ccf0b4c..cce469506 100644 --- a/auth/jwt/middleware.go +++ b/auth/jwt/middleware.go @@ -44,17 +44,14 @@ var ( ErrUnexpectedSigningMethod = errors.New("unexpected signing method") ) -// Claims is a map of arbitrary claim data. -type Claims map[string]interface{} - // NewSigner creates a new JWT token generating middleware, specifying key ID, // signing string, signing method and the claims you would like it to contain. // Tokens are signed with a Key ID header (kid) which is useful for determining // the key to use for parsing. Particularly useful for clients. -func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims) endpoint.Middleware { +func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { - token := jwt.NewWithClaims(method, jwt.MapClaims(claims)) + token := jwt.NewWithClaims(method, claims) token.Header["kid"] = kid // Sign and get the complete encoded token as a string using the secret @@ -70,10 +67,10 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims) } // NewParser creates a new JWT token parsing middleware, specifying a -// jwt.Keyfunc interface and the signing method. NewParser adds the resulting -// claims to endpoint context or returns error on invalid token. Particularly -// useful for servers. -func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middleware { +// jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser +// adds the resulting claims to endpoint context or returns error on invalid token. +// Particularly useful for servers. +func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { // tokenString is stored in the context from the transport handlers. @@ -88,7 +85,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middlewar // of the token to identify which key to use, but the parsed token // (head and claims) is provided to the callback, providing // flexibility. - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { // Don't forget to validate the alg is what you expect: if token.Method != method { return nil, ErrUnexpectedSigningMethod @@ -119,9 +116,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middlewar return nil, ErrTokenInvalid } - if claims, ok := token.Claims.(jwt.MapClaims); ok { - ctx = context.WithValue(ctx, JWTClaimsContextKey, Claims(claims)) - } + ctx = context.WithValue(ctx, JWTClaimsContextKey, token.Claims) return next(ctx, request) } diff --git a/auth/jwt/middleware_test.go b/auth/jwt/middleware_test.go index 99b943c59..76889d6f4 100644 --- a/auth/jwt/middleware_test.go +++ b/auth/jwt/middleware_test.go @@ -4,24 +4,38 @@ import ( "context" "testing" + "crypto/subtle" + jwt "github.com/dgrijalva/jwt-go" + "github.com/go-kit/kit/endpoint" ) +type customClaims struct { + MyProperty string `json:"my_property"` + jwt.StandardClaims +} + +func (c customClaims) VerifyMyProperty(p string) bool { + return subtle.ConstantTimeCompare([]byte(c.MyProperty), []byte(p)) != 0 +} + var ( - kid = "kid" - key = []byte("test_signing_key") - method = jwt.SigningMethodHS256 - invalidMethod = jwt.SigningMethodRS256 - claims = Claims{"user": "go-kit"} + kid = "kid" + key = []byte("test_signing_key") + myProperty = "some value" + method = jwt.SigningMethodHS256 + invalidMethod = jwt.SigningMethodRS256 + mapClaims = jwt.MapClaims{"user": "go-kit"} + standardClaims = jwt.StandardClaims{Audience: "go-kit"} + myCustomClaims = customClaims{MyProperty: myProperty, StandardClaims: standardClaims} // Signed tokens generated at https://jwt.io/ - signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" - invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA" + signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" + standardSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJnby1raXQifQ.L5ypIJjCOOv3jJ8G5SelaHvR04UJuxmcBN5QW3m_aoY" + customSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJteV9wcm9wZXJ0eSI6InNvbWUgdmFsdWUiLCJhdWQiOiJnby1raXQifQ.s8F-IDrV4WPJUsqr7qfDi-3GRlcKR0SRnkTeUT_U-i0" + invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA" ) -func TestSigner(t *testing.T) { - e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } - - signer := NewSigner(kid, key, method, claims)(e) +func signingValidator(t *testing.T, signer endpoint.Endpoint, expectedKey string) { ctx, err := signer(context.Background(), struct{}{}) if err != nil { t.Fatalf("Signer returned error: %s", err) @@ -32,11 +46,24 @@ func TestSigner(t *testing.T) { t.Fatal("Token did not exist in context") } - if token != signedKey { - t.Fatalf("JWT tokens did not match: expecting %s got %s", signedKey, token) + if token != expectedKey { + t.Fatalf("JWT tokens did not match: expecting %s got %s", expectedKey, token) } } +func TestNewSigner(t *testing.T) { + e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } + + signer := NewSigner(kid, key, method, mapClaims)(e) + signingValidator(t, signer, signedKey) + + signer = NewSigner(kid, key, method, standardClaims)(e) + signingValidator(t, signer, standardSignedKey) + + signer = NewSigner(kid, key, method, myCustomClaims)(e) + signingValidator(t, signer, customSignedKey) +} + func TestJWTParser(t *testing.T) { e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } @@ -44,7 +71,7 @@ func TestJWTParser(t *testing.T) { return key, nil } - parser := NewParser(keys, method)(e) + parser := NewParser(keys, method, jwt.MapClaims{})(e) // No Token is passed into the parser _, err := parser(context.Background(), struct{}{}) @@ -64,7 +91,7 @@ func TestJWTParser(t *testing.T) { } // Invalid Method is used in the parser - badParser := NewParser(keys, invalidMethod)(e) + badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -80,7 +107,7 @@ func TestJWTParser(t *testing.T) { return []byte("bad"), nil } - badParser = NewParser(invalidKeys, method)(e) + badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e) ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { @@ -94,12 +121,43 @@ func TestJWTParser(t *testing.T) { t.Fatalf("Parser returned error: %s", err) } - cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(Claims) + cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(jwt.MapClaims) + if !ok { + t.Fatal("Claims were not passed into context correctly") + } + + if cl["user"] != mapClaims["user"] { + t.Fatalf("JWT Claims.user did not match: expecting %s got %s", mapClaims["user"], cl["user"]) + } + + parser = NewParser(keys, method, &jwt.StandardClaims{})(e) + ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey) + ctx1, err = parser(ctx, struct{}{}) + if err != nil { + t.Fatalf("Parser returned error: %s", err) + } + stdCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*jwt.StandardClaims) if !ok { t.Fatal("Claims were not passed into context correctly") } + if !stdCl.VerifyAudience("go-kit", true) { + t.Fatalf("JWT jwt.StandardClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, stdCl.Audience) + } - if cl["user"] != claims["user"] { - t.Fatalf("JWT Claims.user did not match: expecting %s got %s", claims["user"], cl["user"]) + parser = NewParser(keys, method, &customClaims{})(e) + ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey) + ctx1, err = parser(ctx, struct{}{}) + if err != nil { + t.Fatalf("Parser returned error: %s", err) + } + custCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*customClaims) + if !ok { + t.Fatal("Claims were not passed into context correctly") + } + if !custCl.VerifyAudience("go-kit", true) { + t.Fatalf("JWT customClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, custCl.Audience) + } + if !custCl.VerifyMyProperty(myProperty) { + t.Fatalf("JWT customClaims.MyProperty did not match: expecting %s got %s", myProperty, custCl.MyProperty) } }