From 0acb22d1d96235ec5088fc49973eb558eb615917 Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Mon, 18 Jul 2016 16:00:20 +0100 Subject: [PATCH 1/3] Adds support for at_hash verification --- jose/constants.py | 14 +++++++++- jose/jwt.py | 69 +++++++++++++++++++++++++++++++++++++++++------ jose/utils.py | 23 ++++++++++++++++ 3 files changed, 97 insertions(+), 9 deletions(-) diff --git a/jose/constants.py b/jose/constants.py index 2defafef..01a2fe8c 100644 --- a/jose/constants.py +++ b/jose/constants.py @@ -1,4 +1,4 @@ - +import hashlib class ALGORITHMS(object): NONE = 'none' @@ -19,3 +19,15 @@ class ALGORITHMS(object): SUPPORTED = HMAC + RSA + EC ALL = SUPPORTED + (NONE, ) + + HASHES = { + HS256: hashlib.sha256, + HS384: hashlib.sha384, + HS512: hashlib.sha512, + RS256: hashlib.sha256, + RS384: hashlib.sha384, + RS512: hashlib.sha512, + ES256: hashlib.sha256, + ES384: hashlib.sha384, + ES512: hashlib.sha512, + } diff --git a/jose/jwt.py b/jose/jwt.py index 156063bf..5c5b5fd4 100644 --- a/jose/jwt.py +++ b/jose/jwt.py @@ -14,10 +14,11 @@ from .exceptions import JWTClaimsError from .exceptions import JWTError from .exceptions import ExpiredSignatureError -from .utils import timedelta_total_seconds +from .constants import ALGORITHMS +from .utils import timedelta_total_seconds, calculate_at_hash -def encode(claims, key, algorithm=None, headers=None): +def encode(claims, key, algorithm=ALGORITHMS.HS256, headers=None, access_token=None): """Encodes a claims set and returns a JWT string. JWTs are JWS signed objects with a few reserved claims. @@ -30,6 +31,9 @@ def encode(claims, key, algorithm=None, headers=None): headers (dict, optional): A set of headers that will be added to the default headers. Any headers that are added as additional headers will override the default headers. + access_token (str, optional): If present, the 'at_hash' claim will + be calculated and added to the claims present in the 'claims' + parameter. Returns: str: The string representation of the header, claims, and signature. @@ -50,13 +54,15 @@ def encode(claims, key, algorithm=None, headers=None): if isinstance(claims.get(time_claim), datetime): claims[time_claim] = timegm(claims[time_claim].utctimetuple()) - if algorithm: - return jws.sign(claims, key, headers=headers, algorithm=algorithm) + if access_token: + claims['at_hash'] = calculate_at_hash(access_token, + ALGORITHMS.HASHES[algorithm]) - return jws.sign(claims, key, headers=headers) + return jws.sign(claims, key, headers=headers, algorithm=algorithm) -def decode(token, key, algorithms=None, options=None, audience=None, issuer=None, subject=None): +def decode(token, key, algorithms=None, options=None, audience=None, + issuer=None, subject=None, access_token=None): """Verifies a JWT string's signature and validates reserved claims. Args: @@ -72,6 +78,10 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None subject (str): The subject of the token. If the "sub" claim is included in the claim set, then the subject must be included and must equal the provided claim. + access_token (str): An access token returned alongside the id_token during + the authorization grant flow. If the "at_hash" claim is included in the + claim set, then the access_token must be included, and it must match + the "at_hash" claim. options (dict): A dictionary of options for skipping validation steps. defaults = { @@ -109,6 +119,7 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None 'verify_iss': True, 'verify_sub': True, 'verify_jti': True, + 'verify_at_hash': True, 'leeway': 0, } @@ -122,6 +133,9 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None except JWSError as e: raise JWTError(e) + # Needed for at_hash verification + algorithm = jws.get_unverified_header(token)['alg'] + try: claims = json.loads(payload.decode('utf-8')) except ValueError as e: @@ -130,7 +144,10 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None if not isinstance(claims, Mapping): raise JWTError('Invalid payload string: must be a json object') - _validate_claims(claims, audience=audience, issuer=issuer, subject=subject, options=defaults) + _validate_claims(claims, audience=audience, issuer=issuer, + subject=subject, algorithm=algorithm, + access_token=access_token, + options=defaults) return claims @@ -384,7 +401,40 @@ def _validate_jti(claims): raise JWTClaimsError('JWT ID must be a string.') -def _validate_claims(claims, audience=None, issuer=None, subject=None, options=None): +def _validate_at_hash(claims, access_token, algorithm): + """ + Validates that the 'at_hash' parameter included in the claims matches + with the access_token returned alongside the id token as part of + the authorization_code flow. + + Args: + claims (dict): The claims dictionary to validate. + access_token (str): The access token returned by the OpenID Provider. + algorithm (str): The algorithm used to sign the JWT, as specified by + the token headers. + """ + if 'at_hash' not in claims and not access_token: + return + elif 'at_hash' in claims and not access_token: + msg = 'No access_token provided to compare against at_hash claim.' + raise JWTClaimsError(msg) + elif access_token and 'at_hash' not in claims: + msg = 'at_hash claim missing from token.' + raise JWTClaimsError(msg) + + try: + expected_hash = calculate_at_hash(access_token, + ALGORITHMS.HASHES[algorithm]) + except TypeError: + msg = 'Unable to calculate at_hash to verify against token claims.' + raise JWTClaimsError(msg) + + if claims['at_hash'] != expected_hash: + raise JWTClaimsError('at_hash claim does not match access_token.') + + +def _validate_claims(claims, audience=None, issuer=None, subject=None, + algorithm=None, access_token=None, options=None): leeway = options.get('leeway', 0) @@ -414,3 +464,6 @@ def _validate_claims(claims, audience=None, issuer=None, subject=None, options=N if options.get('verify_jti'): _validate_jti(claims) + + if options.get('verify_at_hash'): + _validate_at_hash(claims, access_token, algorithm) diff --git a/jose/utils.py b/jose/utils.py index c0d8f954..ce5b9657 100644 --- a/jose/utils.py +++ b/jose/utils.py @@ -2,6 +2,29 @@ import base64 +def calculate_at_hash(access_token, hash_alg): + """Helper method for calculating an access token + hash, as described in http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken + + Its value is the base64url encoding of the left-most half of the hash of the octets + of the ASCII representation of the access_token value, where the hash algorithm + used is the hash algorithm used in the alg Header Parameter of the ID Token's JOSE + Header. For instance, if the alg is RS256, hash the access_token value with SHA-256, + then take the left-most 128 bits and base64url encode them. The at_hash value is a + case sensitive string. + + Args: + access_token (str): An access token string. + hash_alg (callable): A callable returning a hash object, e.g. hashlib.sha256 + + """ + hash_digest = hash_alg(access_token.encode('utf-8')).digest() + cut_at = int(len(hash_digest) / 2) + truncated = hash_digest[:cut_at] + at_hash = base64url_encode(truncated) + return at_hash.decode('utf-8') + + def base64url_decode(input): """Helper method to base64url_decode a string. From 11ccb5a8aca8606cb449a1183c7a00738aa4f91a Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Mon, 18 Jul 2016 16:24:57 +0100 Subject: [PATCH 2/3] Tests for at_hash claim --- tests/test_jwt.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 298fe7d9..513391ea 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -428,6 +428,27 @@ def test_jti_invalid(self, key): with pytest.raises(JWTError): jwt.decode(token, key) + def test_at_hash(self, claims, key): + access_token = '' + token = jwt.encode(claims, key, access_token=access_token) + payload = jwt.decode(token, key, access_token=access_token) + assert 'at_hash' in payload + + def test_at_hash_invalid(self, claims, key): + token = jwt.encode(claims, key, access_token='') + with pytest.raises(JWTError): + jwt.decode(token, key, access_token='') + + def test_at_hash_missing_access_token(self, claims, key): + token = jwt.encode(claims, key, access_token='') + with pytest.raises(JWTError): + jwt.decode(token, key) + + def test_at_hash_missing_claim(self, claims, key): + token = jwt.encode(claims, key) + with pytest.raises(JWTError): + jwt.decode(token, key, access_token='') + def test_unverified_claims_string(self): token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.aW52YWxpZCBjbGFpbQ.iOJ5SiNfaNO_pa2J4Umtb3b3zmk5C18-mhTCVNsjnck' with pytest.raises(JWTError): From 95fb84a72c47fd045dafaf9e26bea21ab748c44e Mon Sep 17 00:00:00 2001 From: Brendan McCollam Date: Mon, 18 Jul 2016 16:59:08 +0100 Subject: [PATCH 3/3] Improve test coverage --- jose/jwt.py | 2 +- tests/test_jwt.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/jose/jwt.py b/jose/jwt.py index 5c5b5fd4..4479901a 100644 --- a/jose/jwt.py +++ b/jose/jwt.py @@ -425,7 +425,7 @@ def _validate_at_hash(claims, access_token, algorithm): try: expected_hash = calculate_at_hash(access_token, ALGORITHMS.HASHES[algorithm]) - except TypeError: + except (TypeError, ValueError): msg = 'Unable to calculate at_hash to verify against token claims.' raise JWTClaimsError(msg) diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 513391ea..710351c0 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -449,6 +449,11 @@ def test_at_hash_missing_claim(self, claims, key): with pytest.raises(JWTError): jwt.decode(token, key, access_token='') + def test_at_hash_unable_to_calculate(self, claims, key): + token = jwt.encode(claims, key, access_token='') + with pytest.raises(JWTError): + jwt.decode(token, key, access_token='\xe2') + def test_unverified_claims_string(self): token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.aW52YWxpZCBjbGFpbQ.iOJ5SiNfaNO_pa2J4Umtb3b3zmk5C18-mhTCVNsjnck' with pytest.raises(JWTError):