Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion jose/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import hashlib

class ALGORITHMS(object):
NONE = 'none'
Expand All @@ -19,3 +19,15 @@ class ALGORITHMS(object):
SUPPORTED = HMAC + RSA + EC

ALL = SUPPORTED + (NONE, )

HASHES = {
HS256: hashlib.sha256,
HS384: hashlib.sha384,
HS512: hashlib.sha512,
RS256: hashlib.sha256,
RS384: hashlib.sha384,
RS512: hashlib.sha512,
ES256: hashlib.sha256,
ES384: hashlib.sha384,
ES512: hashlib.sha512,
}
69 changes: 61 additions & 8 deletions jose/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from .exceptions import JWTClaimsError
from .exceptions import JWTError
from .exceptions import ExpiredSignatureError
from .utils import timedelta_total_seconds
from .constants import ALGORITHMS
from .utils import timedelta_total_seconds, calculate_at_hash


def encode(claims, key, algorithm=None, headers=None):
def encode(claims, key, algorithm=ALGORITHMS.HS256, headers=None, access_token=None):
"""Encodes a claims set and returns a JWT string.

JWTs are JWS signed objects with a few reserved claims.
Expand All @@ -30,6 +31,9 @@ def encode(claims, key, algorithm=None, headers=None):
headers (dict, optional): A set of headers that will be added to
the default headers. Any headers that are added as additional
headers will override the default headers.
access_token (str, optional): If present, the 'at_hash' claim will
be calculated and added to the claims present in the 'claims'
parameter.

Returns:
str: The string representation of the header, claims, and signature.
Expand All @@ -50,13 +54,15 @@ def encode(claims, key, algorithm=None, headers=None):
if isinstance(claims.get(time_claim), datetime):
claims[time_claim] = timegm(claims[time_claim].utctimetuple())

if algorithm:
return jws.sign(claims, key, headers=headers, algorithm=algorithm)
if access_token:
claims['at_hash'] = calculate_at_hash(access_token,
ALGORITHMS.HASHES[algorithm])

return jws.sign(claims, key, headers=headers)
return jws.sign(claims, key, headers=headers, algorithm=algorithm)


def decode(token, key, algorithms=None, options=None, audience=None, issuer=None, subject=None):
def decode(token, key, algorithms=None, options=None, audience=None,
issuer=None, subject=None, access_token=None):
"""Verifies a JWT string's signature and validates reserved claims.

Args:
Expand All @@ -72,6 +78,10 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
subject (str): The subject of the token. If the "sub" claim is
included in the claim set, then the subject must be included and must equal
the provided claim.
access_token (str): An access token returned alongside the id_token during
the authorization grant flow. If the "at_hash" claim is included in the
claim set, then the access_token must be included, and it must match
the "at_hash" claim.
options (dict): A dictionary of options for skipping validation steps.

defaults = {
Expand Down Expand Up @@ -109,6 +119,7 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
'verify_iss': True,
'verify_sub': True,
'verify_jti': True,
'verify_at_hash': True,
'leeway': 0,
}

Expand All @@ -122,6 +133,9 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
except JWSError as e:
raise JWTError(e)

# Needed for at_hash verification
algorithm = jws.get_unverified_header(token)['alg']

try:
claims = json.loads(payload.decode('utf-8'))
except ValueError as e:
Expand All @@ -130,7 +144,10 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
if not isinstance(claims, Mapping):
raise JWTError('Invalid payload string: must be a json object')

_validate_claims(claims, audience=audience, issuer=issuer, subject=subject, options=defaults)
_validate_claims(claims, audience=audience, issuer=issuer,
subject=subject, algorithm=algorithm,
access_token=access_token,
options=defaults)

return claims

Expand Down Expand Up @@ -384,7 +401,40 @@ def _validate_jti(claims):
raise JWTClaimsError('JWT ID must be a string.')


def _validate_claims(claims, audience=None, issuer=None, subject=None, options=None):
def _validate_at_hash(claims, access_token, algorithm):
"""
Validates that the 'at_hash' parameter included in the claims matches
with the access_token returned alongside the id token as part of
the authorization_code flow.

Args:
claims (dict): The claims dictionary to validate.
access_token (str): The access token returned by the OpenID Provider.
algorithm (str): The algorithm used to sign the JWT, as specified by
the token headers.
"""
if 'at_hash' not in claims and not access_token:
return
elif 'at_hash' in claims and not access_token:
msg = 'No access_token provided to compare against at_hash claim.'
raise JWTClaimsError(msg)
elif access_token and 'at_hash' not in claims:
msg = 'at_hash claim missing from token.'
raise JWTClaimsError(msg)

try:
expected_hash = calculate_at_hash(access_token,
ALGORITHMS.HASHES[algorithm])
except (TypeError, ValueError):
msg = 'Unable to calculate at_hash to verify against token claims.'
raise JWTClaimsError(msg)

if claims['at_hash'] != expected_hash:
raise JWTClaimsError('at_hash claim does not match access_token.')


def _validate_claims(claims, audience=None, issuer=None, subject=None,
algorithm=None, access_token=None, options=None):

leeway = options.get('leeway', 0)

Expand Down Expand Up @@ -414,3 +464,6 @@ def _validate_claims(claims, audience=None, issuer=None, subject=None, options=N

if options.get('verify_jti'):
_validate_jti(claims)

if options.get('verify_at_hash'):
_validate_at_hash(claims, access_token, algorithm)
23 changes: 23 additions & 0 deletions jose/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@
import base64


def calculate_at_hash(access_token, hash_alg):
"""Helper method for calculating an access token
hash, as described in http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken

Its value is the base64url encoding of the left-most half of the hash of the octets
of the ASCII representation of the access_token value, where the hash algorithm
used is the hash algorithm used in the alg Header Parameter of the ID Token's JOSE
Header. For instance, if the alg is RS256, hash the access_token value with SHA-256,
then take the left-most 128 bits and base64url encode them. The at_hash value is a
case sensitive string.

Args:
access_token (str): An access token string.
hash_alg (callable): A callable returning a hash object, e.g. hashlib.sha256

"""
hash_digest = hash_alg(access_token.encode('utf-8')).digest()
cut_at = int(len(hash_digest) / 2)
truncated = hash_digest[:cut_at]
at_hash = base64url_encode(truncated)
return at_hash.decode('utf-8')


def base64url_decode(input):
"""Helper method to base64url_decode a string.

Expand Down
26 changes: 26 additions & 0 deletions tests/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,32 @@ def test_jti_invalid(self, key):
with pytest.raises(JWTError):
jwt.decode(token, key)

def test_at_hash(self, claims, key):
access_token = '<ACCESS_TOKEN>'
token = jwt.encode(claims, key, access_token=access_token)
payload = jwt.decode(token, key, access_token=access_token)
assert 'at_hash' in payload

def test_at_hash_invalid(self, claims, key):
token = jwt.encode(claims, key, access_token='<ACCESS_TOKEN>')
with pytest.raises(JWTError):
jwt.decode(token, key, access_token='<OTHER_TOKEN>')

def test_at_hash_missing_access_token(self, claims, key):
token = jwt.encode(claims, key, access_token='<ACCESS_TOKEN>')
with pytest.raises(JWTError):
jwt.decode(token, key)

def test_at_hash_missing_claim(self, claims, key):
token = jwt.encode(claims, key)
with pytest.raises(JWTError):
jwt.decode(token, key, access_token='<ACCESS_TOKEN>')

def test_at_hash_unable_to_calculate(self, claims, key):
token = jwt.encode(claims, key, access_token='<ACCESS_TOKEN>')
with pytest.raises(JWTError):
jwt.decode(token, key, access_token='\xe2')

def test_unverified_claims_string(self):
token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.aW52YWxpZCBjbGFpbQ.iOJ5SiNfaNO_pa2J4Umtb3b3zmk5C18-mhTCVNsjnck'
with pytest.raises(JWTError):
Expand Down