diff --git a/jose/jws.py b/jose/jws.py index 8d1ab472..56f94a55 100644 --- a/jose/jws.py +++ b/jose/jws.py @@ -3,7 +3,7 @@ import json import six -from collections import Mapping +from collections import Mapping, Iterable from jose import jwk from jose.constants import ALGORITHMS @@ -205,6 +205,27 @@ def _load(jwt): return (header, payload, signing_input, signature) +def _sig_matches_keys(keys, signing_input, signature, alg): + for key in keys: + key = jwk.construct(key, alg) + if key.verify(signing_input, signature): + return True + return False + + +def _get_keys(key): + if 'keys' in key: # JWK Set per RFC 7517 + if not isinstance(key, Mapping): # Caller didn't JSON-decode + key = json.loads(key) + return key['keys'] + # Iterable but not text or mapping => list- or tuple-like + elif (isinstance(key, Iterable) and + not (isinstance(key, six.string_types) or isinstance(key, Mapping))): + return key + else: # Scalar value, wrap in list. + return [key] + + def _verify_signature(signing_input, header, signature, key='', algorithms=None): alg = header.get('alg') @@ -214,12 +235,10 @@ def _verify_signature(signing_input, header, signature, key='', algorithms=None) if algorithms is not None and alg not in algorithms: raise JWSError('The specified alg value is not allowed') + keys = _get_keys(key) try: - key = jwk.construct(key, alg) - - if not key.verify(signing_input, signature): + if not _sig_matches_keys(keys, signing_input, signature, alg): raise JWSSignatureError() - except JWSSignatureError: raise JWSError('Signature verification failed.') except JWSError: diff --git a/tests/test_jws.py b/tests/test_jws.py index ad77dc5a..2f914c41 100644 --- a/tests/test_jws.py +++ b/tests/test_jws.py @@ -1,3 +1,4 @@ +import json from jose import jws from jose.constants import ALGORITHMS @@ -12,7 +13,7 @@ def payload(): return payload -class TestJWS: +class TestJWS(object): def test_unicode_token(self): token = u'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8' @@ -48,7 +49,7 @@ def test_invalid_key(self, payload): jws.sign(payload, 'secret', algorithm='RS256') -class TestHMAC: +class TestHMAC(object): def testHMAC256(self, payload): token = jws.sign(payload, 'secret', algorithm=ALGORITHMS.HS256) @@ -160,8 +161,85 @@ def test_add_headers(self, payload): Ks3IHH7tVltM6NsRk3jNdVMCAwEAAQ== -----END PUBLIC KEY-----""" - -class TestRSA: +@pytest.fixture +def jwk_set(): + return {u'keys': [{u'alg': u'RS256', + u'e': u'AQAB', + u'kid': u'40aa42edac0614d7ca3f57f97ee866cdfba3b61a', + u'kty': u'RSA', + u'n': u'6lm9AEGLPFpVqnfeVFuTIZsj7vz_kxla6uW1WWtosM_MtIjXkyyiSolxiSOs3bzG66iVm71023QyOzKYFbio0hI-yZauG3g9nH-zb_AHScsjAKagHtrHmTdtq0JcNkQnAaaUwxVbjwMlYAcOh87W5jWj_MAcPvc-qjy8-WJ81UgoOUZNiKByuF4-9igxKZeskGRXuTPX64kWGBmKl-tM7VnCGMKoK3m92NPrktfBoNN_EGGthNfQsKFUdQFJFtpMuiXp9Gib7dcMGabxcG2GUl-PU086kPUyUdUYiMN2auKSOxSUZgDjT7DcI8Sn8kdQ0-tImaHi54JNa1PNNdKRpw', + u'use': u'sig'}, + {u'alg': u'RS256', + u'e': u'AQAB', + u'kid': u'8fbbeea40332d2c0d27e37e1904af29b64594e57', + u'kty': u'RSA', + u'n': u'z7h6_rt35-j6NV2iQvYIuR3xvsxmEImgMl8dc8CFl4SzEWrry3QILajKxQZA9YYYfXIcZUG_6R6AghVMJetNIl2AhCoEr3RQjjNsm9PE6h5p2kQ-zIveFeb__4oIkVihYtxtoYBSdVj69nXLUAJP2bxPfU8RDp5X7hT62pKR05H8QLxH8siIQ5qR2LGFw_dJcitAVRRQofuaj_9u0CLZBfinqyRkBc7a0zi7pBxtEiIbn9sRr8Kkb_Boap6BHbnLS-YFBVarcgFBbifRf7NlK5dqE9z4OUb-dx8wCMRIPVAx_hV4Qx2anTgp1sDA6V4vd4NaCOZX-mSctNZqQmKtNw', + u'use': u'sig'}, + {u'alg': u'RS256', + u'e': u'AQAB', + u'kid': u'6758b0b8eb341e90454860432d6a1648bf4de03b', + u'kty': u'RSA', + u'n': u'5K0rYaA7xtqSe1nFn_nCA10uUXY81NcohMeFsYLbBlx_NdpsmbpgtXJ6ektYR7rUdtMMLu2IONlNhkWlx-lge91okyacUrWHP88PycilUE-RnyVjbPEm3seR0VefgALfN4y_e77ljq2F7W2_kbUkTvDzriDIWvQT0WwVF5FIOBydfDDs92S-queaKgLBwt50SXJCZryLew5ODrwVsFGI4Et6MLqjS-cgWpCNwzcRqjBRsse6DXnex_zSRII4ODzKIfX4qdFBKZHO_BkTsK9DNkUayrr9cz8rFRK6TEH6XTVabgsyd6LP6PTxhpiII_pTYRSWk7CGMnm2nO0dKxzaFQ', + u'use': u'sig'}]} + +google_id_token = ( + 'eyJhbGciOiJSUzI1NiIsImtpZCI6IjhmYmJlZWE0MDMzMmQyYzBkMjdlMzdlMTkwN' + 'GFmMjliNjQ1OTRlNTcifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5' + 'jb20iLCJhdF9oYXNoIjoiUUY5RnRjcHlmbUFBanJuMHVyeUQ5dyIsImF1ZCI6IjQw' + 'NzQwODcxODE5Mi5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbSIsInN1YiI6IjEwN' + 'zkzMjQxNjk2NTIwMzIzNDA3NiIsImF6cCI6IjQwNzQwODcxODE5Mi5hcHBzLmdvb2' + 'dsZXVzZXJjb250ZW50LmNvbSIsImlhdCI6MTQ2ODYyMjQ4MCwiZXhwIjoxNDY4NjI' + '2MDgwfQ.Nz6VREh7smvfVRWNHlbKZ6W_DX57akRUGrDTcns06ndAwrslwUlBeFsWY' + 'RLon_tDw0QCeQCGvw7l1AT440UQBRP-mtqK_2Yny2JmIQ7Ll6UAIHRhXOD1uj9w5v' + 'X0jyI1MbjDtODeDWWn_9EDJRBd4xmwKhAONuWodTgSi7qGe1UVmzseFNNkKdoo54d' + 'XhCJiyiRAMnWB_FQDveRJghche131pd9O_E4Wj6hf_zCcMTaDaLDOmElcQe-WsKWA' + 'A3YwHFEWOLO_7x6u4uGmhItPGH7zsOTzYxPYhZMSZusgVg9fbE1kSlHVSyQrcp_rR' + 'WNz7vOIbvIlBR9Jrq5MIqbkkg' +) + + +class TestGetKeys(object): + + def test_dict(self): + assert [{}] == jws._get_keys({}) + + def test_custom_object(self): + class MyDict(dict): + pass + mydict = MyDict() + assert [mydict] == jws._get_keys(mydict) + + def test_RFC7517_string(self): + key = '{"keys": [{}, {}]}' + assert [{}, {}] == jws._get_keys(key) + + def test_RFC7517_mapping(self): + key = {"keys": [{}, {}]} + assert [{}, {}] == jws._get_keys(key) + + def test_string(self): + assert ['test'] == jws._get_keys('test') + + def test_tuple(self): + assert ('test', 'key') == jws._get_keys(('test', 'key')) + + def test_list(self): + assert ['test', 'key'] == jws._get_keys(['test', 'key']) + + +class TestRSA(object): + + def test_jwk_set(self, jwk_set): + # Would raise a JWSError if validation failed. + payload = jws.verify(google_id_token, jwk_set, ALGORITHMS.RS256) + iss = json.loads(payload.decode('utf-8'))['iss'] + assert iss == "https://accounts.google.com" + + def test_jwk_set_failure(self, jwk_set): + # Remove the key that was used to sign this token. + del jwk_set['keys'][1] + with pytest.raises(JWSError): + payload = jws.verify(google_id_token, jwk_set, ALGORITHMS.RS256) def test_RSA256(self, payload): token = jws.sign(payload, rsa_private_key, algorithm=ALGORITHMS.RS256) @@ -201,7 +279,7 @@ def test_wrong_key(self, payload): -----END PUBLIC KEY-----""" -class TestEC: +class TestEC(object): def test_EC256(self, payload): token = jws.sign(payload, ec_private_key, algorithm=ALGORITHMS.ES256) @@ -221,7 +299,7 @@ def test_wrong_alg(self, payload): jws.verify(token, rsa_public_key, ALGORITHMS.ES384) -class TestLoad: +class TestLoad(object): def test_header_not_mapping(self): token = 'WyJ0ZXN0Il0.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8'