diff --git a/jose/jws.py b/jose/jws.py index 56f94a55..dc6ad5c5 100644 --- a/jose/jws.py +++ b/jose/jws.py @@ -222,8 +222,8 @@ def _get_keys(key): elif (isinstance(key, Iterable) and not (isinstance(key, six.string_types) or isinstance(key, Mapping))): return key - else: # Scalar value, wrap in list. - return [key] + else: # Scalar value, wrap in tuple. + return (key,) def _verify_signature(signing_input, header, signature, key='', algorithms=None): diff --git a/jose/jwt.py b/jose/jwt.py index 4479901a..2128c851 100644 --- a/jose/jwt.py +++ b/jose/jwt.py @@ -72,9 +72,9 @@ def decode(token, key, algorithms=None, options=None, audience=None, audience (str): The intended audience of the token. If the "aud" claim is included in the claim set, then the audience must be included and must equal the provided claim. - issuer (str): The issuer of the token. If the "iss" claim is - included in the claim set, then the issuer must be included and must equal - the provided claim. + issuer (str or iterable): Acceptable value(s) for the issuer of the token. + If the "iss" claim is included in the claim set, then the issuer must be + given and the claim in the token must be among the acceptable values. subject (str): The subject of the token. If the "sub" claim is included in the claim set, then the subject must be included and must equal the provided claim. @@ -345,11 +345,14 @@ def _validate_iss(claims, issuer=None): Args: claims (dict): The claims dictionary to validate. - issuer (str): The issuer that sent the token. + issuer (str or iterable): Acceptable value(s) for the issuer that + signed the token. """ if issuer is not None: - if claims.get('iss') != issuer: + if isinstance(issuer, string_types): + issuer = (issuer,) + if claims.get('iss') not in issuer: raise JWTClaimsError('Invalid issuer') diff --git a/tests/test_jws.py b/tests/test_jws.py index 2f914c41..a57c7f4d 100644 --- a/tests/test_jws.py +++ b/tests/test_jws.py @@ -201,13 +201,13 @@ def jwk_set(): class TestGetKeys(object): def test_dict(self): - assert [{}] == jws._get_keys({}) + assert ({},) == jws._get_keys({}) def test_custom_object(self): class MyDict(dict): pass mydict = MyDict() - assert [mydict] == jws._get_keys(mydict) + assert (mydict,) == jws._get_keys(mydict) def test_RFC7517_string(self): key = '{"keys": [{}, {}]}' @@ -218,7 +218,7 @@ def test_RFC7517_mapping(self): assert [{}, {}] == jws._get_keys(key) def test_string(self): - assert ['test'] == jws._get_keys('test') + assert ('test',) == jws._get_keys('test') def test_tuple(self): assert ('test', 'key') == jws._get_keys(('test', 'key')) diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 710351c0..485fff52 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -347,6 +347,28 @@ def test_iss_string(self, key): token = jwt.encode(claims, key) jwt.decode(token, key, issuer=iss) + def test_iss_list(self, key): + + iss = 'issuer' + + claims = { + 'iss': iss + } + + token = jwt.encode(claims, key) + jwt.decode(token, key, issuer=['https://issuer', 'issuer']) + + def test_iss_tuple(self, key): + + iss = 'issuer' + + claims = { + 'iss': iss + } + + token = jwt.encode(claims, key) + jwt.decode(token, key, issuer=('https://issuer', 'issuer')) + def test_iss_invalid(self, key): iss = 'issuer'