diff --git a/auth/jwt/README.md b/auth/jwt/README.md index 1ed825a95..8cf823c3f 100644 --- a/auth/jwt/README.md +++ b/auth/jwt/README.md @@ -7,7 +7,7 @@ through [JSON Web Tokens](https://jwt.io/). NewParser takes a key function and an expected signing method and returns an `endpoint.Middleware`. The middleware will parse a token passed into the -context via the `jwt.JWTTokenContextKey`. If the token is valid, any claims +context via the `jwt.JWTContextKey`. If the token is valid, any claims will be added to the context via the `jwt.JWTClaimsContextKey`. ```go @@ -30,7 +30,7 @@ func main() { NewSigner takes a JWT key ID header, the signing key, signing method, and a claims object. It returns an `endpoint.Middleware`. The middleware will build -the token string and add it to the context via the `jwt.JWTTokenContextKey`. +the token string and add it to the context via the `jwt.JWTContextKey`. ```go import ( diff --git a/auth/jwt/middleware.go b/auth/jwt/middleware.go index 0e29e6d68..b1c75a6d6 100644 --- a/auth/jwt/middleware.go +++ b/auth/jwt/middleware.go @@ -12,9 +12,13 @@ import ( type contextKey string const ( - // JWTTokenContextKey holds the key used to store a JWT Token in the - // context. - JWTTokenContextKey contextKey = "JWTToken" + // JWTContextKey holds the key used to store a JWT in the context. + JWTContextKey contextKey = "JWTToken" + + // JWTTokenContextKey is an alias for JWTContextKey. + // + // Deprecated: prefer JWTContextKey. + JWTTokenContextKey = JWTContextKey // JWTClaimsContextKey holds the key used to store the JWT Claims in the // context. @@ -27,13 +31,13 @@ var ( ErrTokenContextMissing = errors.New("token up for parsing was not passed through the context") // ErrTokenInvalid denotes a token was not able to be validated. - ErrTokenInvalid = errors.New("JWT Token was invalid") + ErrTokenInvalid = errors.New("JWT was invalid") // ErrTokenExpired denotes a token's expire header (exp) has since passed. - ErrTokenExpired = errors.New("JWT Token is expired") + ErrTokenExpired = errors.New("JWT is expired") - // ErrTokenMalformed denotes a token was not formatted as a JWT token. - ErrTokenMalformed = errors.New("JWT Token is malformed") + // ErrTokenMalformed denotes a token was not formatted as a JWT. + ErrTokenMalformed = errors.New("JWT is malformed") // ErrTokenNotActive denotes a token's not before header (nbf) is in the // future. @@ -44,7 +48,7 @@ var ( ErrUnexpectedSigningMethod = errors.New("unexpected signing method") ) -// NewSigner creates a new JWT token generating middleware, specifying key ID, +// NewSigner creates a new JWT generating middleware, specifying key ID, // signing string, signing method and the claims you would like it to contain. // Tokens are signed with a Key ID header (kid) which is useful for determining // the key to use for parsing. Particularly useful for clients. @@ -59,7 +63,7 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai if err != nil { return nil, err } - ctx = context.WithValue(ctx, JWTTokenContextKey, tokenString) + ctx = context.WithValue(ctx, JWTContextKey, tokenString) return next(ctx, request) } @@ -82,7 +86,7 @@ func StandardClaimsFactory() jwt.Claims { return &jwt.StandardClaims{} } -// NewParser creates a new JWT token parsing middleware, specifying a +// NewParser creates a new JWT parsing middleware, specifying a // jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser // adds the resulting claims to endpoint context or returns error on invalid token. // Particularly useful for servers. @@ -90,7 +94,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims ClaimsFa return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { // tokenString is stored in the context from the transport handlers. - tokenString, ok := ctx.Value(JWTTokenContextKey).(string) + tokenString, ok := ctx.Value(JWTContextKey).(string) if !ok { return nil, ErrTokenContextMissing } diff --git a/auth/jwt/middleware_test.go b/auth/jwt/middleware_test.go index 3278e13a7..fc7032f1e 100644 --- a/auth/jwt/middleware_test.go +++ b/auth/jwt/middleware_test.go @@ -44,13 +44,13 @@ func signingValidator(t *testing.T, signer endpoint.Endpoint, expectedKey string t.Fatalf("Signer returned error: %s", err) } - token, ok := ctx.(context.Context).Value(JWTTokenContextKey).(string) + token, ok := ctx.(context.Context).Value(JWTContextKey).(string) if !ok { t.Fatal("Token did not exist in context") } if token != expectedKey { - t.Fatalf("JWT tokens did not match: expecting %s got %s", expectedKey, token) + t.Fatalf("JWTs did not match: expecting %s got %s", expectedKey, token) } } @@ -87,7 +87,7 @@ func TestJWTParser(t *testing.T) { } // Invalid Token is passed into the parser - ctx := context.WithValue(context.Background(), JWTTokenContextKey, invalidKey) + ctx := context.WithValue(context.Background(), JWTContextKey, invalidKey) _, err = parser(ctx, struct{}{}) if err == nil { t.Error("Parser should have returned an error") @@ -95,7 +95,7 @@ func TestJWTParser(t *testing.T) { // Invalid Method is used in the parser badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e) - ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) + ctx = context.WithValue(context.Background(), JWTContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { t.Error("Parser should have returned an error") @@ -111,14 +111,14 @@ func TestJWTParser(t *testing.T) { } badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e) - ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) + ctx = context.WithValue(context.Background(), JWTContextKey, signedKey) _, err = badParser(ctx, struct{}{}) if err == nil { t.Error("Parser should have returned an error") } // Correct token is passed into the parser - ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) + ctx = context.WithValue(context.Background(), JWTContextKey, signedKey) ctx1, err := parser(ctx, struct{}{}) if err != nil { t.Fatalf("Parser returned error: %s", err) @@ -135,7 +135,7 @@ func TestJWTParser(t *testing.T) { // Test for malformed token error response parser = NewParser(keys, method, StandardClaimsFactory)(e) - ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey) + ctx = context.WithValue(context.Background(), JWTContextKey, malformedKey) ctx1, err = parser(ctx, struct{}{}) if want, have := ErrTokenMalformed, err; want != have { t.Fatalf("Expected %+v, got %+v", want, have) @@ -148,7 +148,7 @@ func TestJWTParser(t *testing.T) { if err != nil { t.Fatalf("Unable to Sign Token: %+v", err) } - ctx = context.WithValue(context.Background(), JWTTokenContextKey, token) + ctx = context.WithValue(context.Background(), JWTContextKey, token) ctx1, err = parser(ctx, struct{}{}) if want, have := ErrTokenExpired, err; want != have { t.Fatalf("Expected %+v, got %+v", want, have) @@ -161,7 +161,7 @@ func TestJWTParser(t *testing.T) { if err != nil { t.Fatalf("Unable to Sign Token: %+v", err) } - ctx = context.WithValue(context.Background(), JWTTokenContextKey, token) + ctx = context.WithValue(context.Background(), JWTContextKey, token) ctx1, err = parser(ctx, struct{}{}) if want, have := ErrTokenNotActive, err; want != have { t.Fatalf("Expected %+v, got %+v", want, have) @@ -169,7 +169,7 @@ func TestJWTParser(t *testing.T) { // test valid standard claims token parser = NewParser(keys, method, StandardClaimsFactory)(e) - ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey) + ctx = context.WithValue(context.Background(), JWTContextKey, standardSignedKey) ctx1, err = parser(ctx, struct{}{}) if err != nil { t.Fatalf("Parser returned error: %s", err) @@ -184,7 +184,7 @@ func TestJWTParser(t *testing.T) { // test valid customized claims token parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e) - ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey) + ctx = context.WithValue(context.Background(), JWTContextKey, customSignedKey) ctx1, err = parser(ctx, struct{}{}) if err != nil { t.Fatalf("Parser returned error: %s", err) @@ -205,7 +205,7 @@ func TestIssue562(t *testing.T) { var ( kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil } e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop) - key = JWTTokenContextKey + key = JWTContextKey val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E" ctx = context.WithValue(context.Background(), key, val) ) diff --git a/auth/jwt/transport.go b/auth/jwt/transport.go index 4b7082a3f..e7d19c180 100644 --- a/auth/jwt/transport.go +++ b/auth/jwt/transport.go @@ -26,7 +26,7 @@ func HTTPToContext() http.RequestFunc { return ctx } - return context.WithValue(ctx, JWTTokenContextKey, token) + return context.WithValue(ctx, JWTContextKey, token) } } @@ -34,7 +34,7 @@ func HTTPToContext() http.RequestFunc { // useful for clients. func ContextToHTTP() http.RequestFunc { return func(ctx context.Context, r *stdhttp.Request) context.Context { - token, ok := ctx.Value(JWTTokenContextKey).(string) + token, ok := ctx.Value(JWTContextKey).(string) if ok { r.Header.Add("Authorization", generateAuthHeaderFromToken(token)) } @@ -54,7 +54,7 @@ func GRPCToContext() grpc.ServerRequestFunc { token, ok := extractTokenFromAuthHeader(authHeader[0]) if ok { - ctx = context.WithValue(ctx, JWTTokenContextKey, token) + ctx = context.WithValue(ctx, JWTContextKey, token) } return ctx @@ -65,7 +65,7 @@ func GRPCToContext() grpc.ServerRequestFunc { // useful for clients. func ContextToGRPC() grpc.ClientRequestFunc { return func(ctx context.Context, md *metadata.MD) context.Context { - token, ok := ctx.Value(JWTTokenContextKey).(string) + token, ok := ctx.Value(JWTContextKey).(string) if ok { // capital "Key" is illegal in HTTP/2. (*md)["authorization"] = []string{generateAuthHeaderFromToken(token)} diff --git a/auth/jwt/transport_test.go b/auth/jwt/transport_test.go index 83d0f17f8..096c6ac31 100644 --- a/auth/jwt/transport_test.go +++ b/auth/jwt/transport_test.go @@ -15,7 +15,7 @@ func TestHTTPToContext(t *testing.T) { // When the header doesn't exist ctx := reqFunc(context.Background(), &http.Request{}) - if ctx.Value(JWTTokenContextKey) != nil { + if ctx.Value(JWTContextKey) != nil { t.Error("Context shouldn't contain the encoded JWT") } @@ -24,7 +24,7 @@ func TestHTTPToContext(t *testing.T) { header.Set("Authorization", "no expected auth header format value") ctx = reqFunc(context.Background(), &http.Request{Header: header}) - if ctx.Value(JWTTokenContextKey) != nil { + if ctx.Value(JWTContextKey) != nil { t.Error("Context shouldn't contain the encoded JWT") } @@ -32,7 +32,7 @@ func TestHTTPToContext(t *testing.T) { header.Set("Authorization", generateAuthHeaderFromToken(signedKey)) ctx = reqFunc(context.Background(), &http.Request{Header: header}) - token := ctx.Value(JWTTokenContextKey).(string) + token := ctx.Value(JWTContextKey).(string) if token != signedKey { t.Errorf("Context doesn't contain the expected encoded token value; expected: %s, got: %s", signedKey, token) } @@ -41,7 +41,7 @@ func TestHTTPToContext(t *testing.T) { func TestContextToHTTP(t *testing.T) { reqFunc := ContextToHTTP() - // No JWT Token is passed in the context + // No JWT is passed in the context ctx := context.Background() r := http.Request{} reqFunc(ctx, &r) @@ -51,8 +51,8 @@ func TestContextToHTTP(t *testing.T) { t.Error("authorization key should not exist in metadata") } - // Correct JWT Token is passed in the context - ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) + // Correct JWT is passed in the context + ctx = context.WithValue(context.Background(), JWTContextKey, signedKey) r = http.Request{Header: http.Header{}} reqFunc(ctx, &r) @@ -60,7 +60,7 @@ func TestContextToHTTP(t *testing.T) { expected := generateAuthHeaderFromToken(signedKey) if token != expected { - t.Errorf("Authorization header does not contain the expected JWT token; expected %s, got %s", expected, token) + t.Errorf("Authorization header does not contain the expected JWT; expected %s, got %s", expected, token) } } @@ -70,36 +70,36 @@ func TestGRPCToContext(t *testing.T) { // No Authorization header is passed ctx := reqFunc(context.Background(), md) - token := ctx.Value(JWTTokenContextKey) + token := ctx.Value(JWTContextKey) if token != nil { - t.Error("Context should not contain a JWT Token") + t.Error("Context should not contain a JWT") } // Invalid Authorization header is passed md["authorization"] = []string{fmt.Sprintf("%s", signedKey)} ctx = reqFunc(context.Background(), md) - token = ctx.Value(JWTTokenContextKey) + token = ctx.Value(JWTContextKey) if token != nil { - t.Error("Context should not contain a JWT Token") + t.Error("Context should not contain a JWT") } // Authorization header is correct md["authorization"] = []string{fmt.Sprintf("Bearer %s", signedKey)} ctx = reqFunc(context.Background(), md) - token, ok := ctx.Value(JWTTokenContextKey).(string) + token, ok := ctx.Value(JWTContextKey).(string) if !ok { - t.Fatal("JWT Token not passed to context correctly") + t.Fatal("JWT not passed to context correctly") } if token != signedKey { - t.Errorf("JWT tokens did not match: expecting %s got %s", signedKey, token) + t.Errorf("JWTs did not match: expecting %s got %s", signedKey, token) } } func TestContextToGRPC(t *testing.T) { reqFunc := ContextToGRPC() - // No JWT Token is passed in the context + // No JWT is passed in the context ctx := context.Background() md := metadata.MD{} reqFunc(ctx, &md) @@ -109,17 +109,17 @@ func TestContextToGRPC(t *testing.T) { t.Error("authorization key should not exist in metadata") } - // Correct JWT Token is passed in the context - ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey) + // Correct JWT is passed in the context + ctx = context.WithValue(context.Background(), JWTContextKey, signedKey) md = metadata.MD{} reqFunc(ctx, &md) token, ok := md["authorization"] if !ok { - t.Fatal("JWT Token not passed to metadata correctly") + t.Fatal("JWT not passed to metadata correctly") } if token[0] != generateAuthHeaderFromToken(signedKey) { - t.Errorf("JWT tokens did not match: expecting %s got %s", signedKey, token[0]) + t.Errorf("JWTs did not match: expecting %s got %s", signedKey, token[0]) } }