diff --git a/jose/jwt.py b/jose/jwt.py index bffed256..9a0396dc 100644 --- a/jose/jwt.py +++ b/jose/jwt.py @@ -12,7 +12,7 @@ from .utils import timedelta_total_seconds -def encode(claims, key, algorithm=None): +def encode(claims, key, algorithm=None, headers=None): """Encodes a claims set and returns a JWT string. JWTs are JWS signed objects with a few reserved claims. @@ -20,11 +20,11 @@ def encode(claims, key, algorithm=None): Args: claims (dict): A claims set to sign key (str): The key to use for signing the claim set + algorithm (str, optional): The algorithm to use for signing the + the claims. Defaults to HS256. headers (dict, optional): A set of headers that will be added to the default headers. Any headers that are added as additional headers will override the default headers. - algorithm (str, optional): The algorithm to use for signing the - the claims. Defaults to HS256. Returns: str: The string representation of the header, claims, and signature. @@ -46,9 +46,9 @@ def encode(claims, key, algorithm=None): claims[time_claim] = timegm(claims[time_claim].utctimetuple()) if algorithm: - return jws.sign(claims, key, algorithm=algorithm) + return jws.sign(claims, key, headers=headers, algorithm=algorithm) - return jws.sign(claims, key) + return jws.sign(claims, key, headers=headers) def decode(token, key, algorithms=None, options=None, audience=None, issuer=None): diff --git a/tests/test_jwt.py b/tests/test_jwt.py index acc64210..05926b58 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -20,6 +20,12 @@ def claims(): def key(): return 'secret' +@pytest.fixture +def headers(): + headers = { + 'kid': 'my-key-id', + } + return headers class TestJWT: @@ -28,6 +34,19 @@ def test_non_default_alg(self, claims, key): decoded = jwt.decode(encoded, key, algorithms='HS384') assert claims == decoded + def test_non_default_alg_positional_bwcompat(self, claims, key): + encoded = jwt.encode(claims, key, 'HS384') + decoded = jwt.decode(encoded, key, 'HS384') + assert claims == decoded + + def test_non_default_headers(self, claims, key, headers): + encoded = jwt.encode(claims, key, headers=headers) + decoded = jwt.decode(encoded, key) + assert claims == decoded + all_headers = jwt.get_unverified_headers(encoded) + for k, v in headers.items(): + assert all_headers[k] == v + def test_encode(self, claims, key): expected = (