Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions auth/jwt/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,14 @@ var (
ErrUnexpectedSigningMethod = errors.New("unexpected signing method")
)

// Claims is a map of arbitrary claim data.
type Claims map[string]interface{}

// NewSigner creates a new JWT token generating middleware, specifying key ID,
// signing string, signing method and the claims you would like it to contain.
// Tokens are signed with a Key ID header (kid) which is useful for determining
// the key to use for parsing. Particularly useful for clients.
func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims) endpoint.Middleware {
func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
token := jwt.NewWithClaims(method, jwt.MapClaims(claims))
token := jwt.NewWithClaims(method, claims)
token.Header["kid"] = kid

// Sign and get the complete encoded token as a string using the secret
Expand All @@ -70,10 +67,10 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims)
}

// NewParser creates a new JWT token parsing middleware, specifying a
// jwt.Keyfunc interface and the signing method. NewParser adds the resulting
// claims to endpoint context or returns error on invalid token. Particularly
// useful for servers.
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middleware {
// jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser
// adds the resulting claims to endpoint context or returns error on invalid token.
// Particularly useful for servers.
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
// tokenString is stored in the context from the transport handlers.
Expand All @@ -88,7 +85,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middlewar
// of the token to identify which key to use, but the parsed token
// (head and claims) is provided to the callback, providing
// flexibility.
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if token.Method != method {
return nil, ErrUnexpectedSigningMethod
Expand Down Expand Up @@ -119,9 +116,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middlewar
return nil, ErrTokenInvalid
}

if claims, ok := token.Claims.(jwt.MapClaims); ok {
ctx = context.WithValue(ctx, JWTClaimsContextKey, Claims(claims))
}
ctx = context.WithValue(ctx, JWTClaimsContextKey, token.Claims)

return next(ctx, request)
}
Expand Down
96 changes: 77 additions & 19 deletions auth/jwt/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,38 @@ import (
"context"
"testing"

"crypto/subtle"

jwt "github.com/dgrijalva/jwt-go"
"github.com/go-kit/kit/endpoint"
)

type customClaims struct {
MyProperty string `json:"my_property"`
jwt.StandardClaims
}

func (c customClaims) VerifyMyProperty(p string) bool {
return subtle.ConstantTimeCompare([]byte(c.MyProperty), []byte(p)) != 0
}

var (
kid = "kid"
key = []byte("test_signing_key")
method = jwt.SigningMethodHS256
invalidMethod = jwt.SigningMethodRS256
claims = Claims{"user": "go-kit"}
kid = "kid"
key = []byte("test_signing_key")
myProperty = "some value"
method = jwt.SigningMethodHS256
invalidMethod = jwt.SigningMethodRS256
mapClaims = jwt.MapClaims{"user": "go-kit"}
standardClaims = jwt.StandardClaims{Audience: "go-kit"}
myCustomClaims = customClaims{MyProperty: myProperty, StandardClaims: standardClaims}
// Signed tokens generated at https://jwt.io/
signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA"
signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
standardSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJnby1raXQifQ.L5ypIJjCOOv3jJ8G5SelaHvR04UJuxmcBN5QW3m_aoY"
customSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJteV9wcm9wZXJ0eSI6InNvbWUgdmFsdWUiLCJhdWQiOiJnby1raXQifQ.s8F-IDrV4WPJUsqr7qfDi-3GRlcKR0SRnkTeUT_U-i0"
invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA"
)

func TestSigner(t *testing.T) {
e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }

signer := NewSigner(kid, key, method, claims)(e)
func signingValidator(t *testing.T, signer endpoint.Endpoint, expectedKey string) {
ctx, err := signer(context.Background(), struct{}{})
if err != nil {
t.Fatalf("Signer returned error: %s", err)
Expand All @@ -32,19 +46,32 @@ func TestSigner(t *testing.T) {
t.Fatal("Token did not exist in context")
}

if token != signedKey {
t.Fatalf("JWT tokens did not match: expecting %s got %s", signedKey, token)
if token != expectedKey {
t.Fatalf("JWT tokens did not match: expecting %s got %s", expectedKey, token)
}
}

func TestNewSigner(t *testing.T) {
e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }

signer := NewSigner(kid, key, method, mapClaims)(e)
signingValidator(t, signer, signedKey)

signer = NewSigner(kid, key, method, standardClaims)(e)
signingValidator(t, signer, standardSignedKey)

signer = NewSigner(kid, key, method, myCustomClaims)(e)
signingValidator(t, signer, customSignedKey)
}

func TestJWTParser(t *testing.T) {
e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }

keys := func(token *jwt.Token) (interface{}, error) {
return key, nil
}

parser := NewParser(keys, method)(e)
parser := NewParser(keys, method, jwt.MapClaims{})(e)

// No Token is passed into the parser
_, err := parser(context.Background(), struct{}{})
Expand All @@ -64,7 +91,7 @@ func TestJWTParser(t *testing.T) {
}

// Invalid Method is used in the parser
badParser := NewParser(keys, invalidMethod)(e)
badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -80,7 +107,7 @@ func TestJWTParser(t *testing.T) {
return []byte("bad"), nil
}

badParser = NewParser(invalidKeys, method)(e)
badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -94,12 +121,43 @@ func TestJWTParser(t *testing.T) {
t.Fatalf("Parser returned error: %s", err)
}

cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(Claims)
cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(jwt.MapClaims)
if !ok {
t.Fatal("Claims were not passed into context correctly")
}

if cl["user"] != mapClaims["user"] {
t.Fatalf("JWT Claims.user did not match: expecting %s got %s", mapClaims["user"], cl["user"])
}

parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
t.Fatalf("Parser returned error: %s", err)
}
stdCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*jwt.StandardClaims)
if !ok {
t.Fatal("Claims were not passed into context correctly")
}
if !stdCl.VerifyAudience("go-kit", true) {
t.Fatalf("JWT jwt.StandardClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, stdCl.Audience)
}

if cl["user"] != claims["user"] {
t.Fatalf("JWT Claims.user did not match: expecting %s got %s", claims["user"], cl["user"])
parser = NewParser(keys, method, &customClaims{})(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
t.Fatalf("Parser returned error: %s", err)
}
custCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*customClaims)
if !ok {
t.Fatal("Claims were not passed into context correctly")
}
if !custCl.VerifyAudience("go-kit", true) {
t.Fatalf("JWT customClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, custCl.Audience)
}
if !custCl.VerifyMyProperty(myProperty) {
t.Fatalf("JWT customClaims.MyProperty did not match: expecting %s got %s", myProperty, custCl.MyProperty)
}
}