Skip to content

Commit ce9a47b

Browse files
committed
fix: allow jwt.encode passing encoder_cls and jwt.decode passing decoder_cls
1 parent 4bc5bf3 commit ce9a47b

6 files changed

Lines changed: 45 additions & 20 deletions

File tree

src/joserfc/jwt.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
import json
3+
from json import JSONEncoder, JSONDecoder
34
from .rfc7519.claims import convert_claims
45
from .rfc7519.claims import Claims as Claims
56
from .rfc7519.claims import check_sensitive_data as check_sensitive_data
@@ -49,18 +50,20 @@ def encode(
4950
claims: Claims,
5051
key: KeyFlexible,
5152
algorithms: list[str] | None = None,
52-
registry: JWSRegistry | JWERegistry | None = None) -> str:
53+
registry: JWSRegistry | JWERegistry | None = None,
54+
encoder_cls: JSONEncoder | None = None) -> str:
5355
"""Encode a JSON Web Token with the given header, and claims.
5456
5557
:param header: A dict of the JWT header
5658
:param claims: A dict of the JWT claims to be encoded
5759
:param key: key used to sign the signature
5860
:param algorithms: a list of allowed algorithms
5961
:param registry: a ``JWSRegistry`` or ``JWERegistry`` to use
62+
:param encoder_cls: A JSONEncoder subclass to use
6063
"""
6164
# add ``typ`` in header
6265
_header = {"typ": "JWT", **header}
63-
payload = convert_claims(claims)
66+
payload = convert_claims(claims, encoder_cls)
6467
if isinstance(registry, JWERegistry):
6568
return encrypt_compact(_header, payload, key, algorithms, registry)
6669
else:
@@ -71,14 +74,16 @@ def decode(
7174
value: bytes | str,
7275
key: KeyFlexible,
7376
algorithms: list[str] | None = None,
74-
registry: JWSRegistry | JWERegistry | None = None) -> Token:
77+
registry: JWSRegistry | JWERegistry | None = None,
78+
decoder_cls: JSONDecoder | None = None) -> Token:
7579
"""Decode the JSON Web Token string with the given key, and validate
7680
it with the claims requests.
7781
7882
:param value: text of the JWT
7983
:param key: key used to verify the signature
8084
:param algorithms: a list of allowed algorithms
8185
:param registry: a ``JWSRegistry`` or ``JWERegistry`` to use
86+
:param decoder_cls: A JSONDecoder subclass to use
8287
:raise: BadSignatureError
8388
"""
8489
_value = to_bytes(value)
@@ -90,7 +95,7 @@ def decode(
9095
header, payload = _decode_jws(_value, key, algorithms, registry)
9196

9297
try:
93-
claims: Claims = json.loads(payload)
98+
claims: Claims = json.loads(payload, cls=decoder_cls)
9499
except (TypeError, ValueError):
95100
raise InvalidPayloadError()
96101

src/joserfc/rfc7516/message.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def perform_encrypt(obj: EncryptionData, registry: JWERegistry) -> None:
5656
plaintext = obj.plaintext
5757

5858
# Step 13, Compute the Encoded Protected Header value BASE64URL(UTF8(JWE Protected Header)).
59-
aad = json_b64encode(obj.protected, "ascii")
59+
aad = json_b64encode(obj.protected)
6060

6161
# Step 14, Let the Additional Authenticated Data encryption parameter be
6262
# ASCII(Encoded Protected Header). However, if a JWE AAD value is
@@ -117,7 +117,7 @@ def _perform_decrypt(obj: EncryptionData, registry: JWERegistry) -> None:
117117
if len(cek) * 8 != enc.cek_size: # pragma: no cover
118118
raise InvalidCEKLengthError(f"A key of size {enc.cek_size} bits MUST be used")
119119

120-
aad = json_b64encode(obj.protected, "ascii")
120+
aad = json_b64encode(obj.protected)
121121
if isinstance(obj, BaseJSONEncryption) and obj.aad:
122122
aad = aad + b"." + urlsafe_b64encode(obj.aad)
123123

src/joserfc/rfc7519/claims.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from __future__ import annotations
2+
13
import re
4+
import json
25
import datetime
36
import calendar
7+
from json import JSONEncoder
48
from typing import Dict, Any
5-
from ..util import to_bytes, json_dumps
9+
from ..util import to_bytes
610
from ..errors import InsecureClaimError
711

812

@@ -22,13 +26,15 @@
2226
Claims = Dict[str, Any]
2327

2428

25-
def convert_claims(claims: Claims) -> bytes:
29+
def convert_claims(claims: Claims, encoder_cls: JSONEncoder | None = None) -> bytes:
2630
"""Turn claims into bytes payload."""
2731
for k in ["exp", "iat", "nbf"]:
2832
claim = claims.get(k)
2933
if isinstance(claim, datetime.datetime):
3034
claims[k] = calendar.timegm(claim.utctimetuple())
31-
return to_bytes(json_dumps(claims))
35+
36+
content = json.dumps(claims, ensure_ascii=False, separators=(",", ":"), cls=encoder_cls)
37+
return to_bytes(content)
3238

3339

3440
def check_sensitive_data(claims: Claims) -> None:

src/joserfc/rfc7638/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import typing as t
2+
import json
23
import hashlib
34
from collections import OrderedDict
4-
from ..util import to_bytes, json_dumps, urlsafe_b64encode
5+
from ..util import to_bytes, urlsafe_b64encode
56

67

78
def thumbprint(
@@ -15,7 +16,7 @@ def thumbprint(
1516
for k in sorted_fields:
1617
data[k] = dict_value[k]
1718

18-
json_data = json_dumps(data)
19+
json_data = json.dumps(data, ensure_ascii=True, separators=(",", ":"))
1920
hash_value = hashlib.new(digest_method, to_bytes(json_data))
2021
digest_data = hash_value.digest()
2122
return urlsafe_b64encode(digest_data).decode("utf-8")

src/joserfc/util.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ def to_str(x: bytes | str, charset: str = "utf-8") -> str:
2222
return x
2323

2424

25-
def json_dumps(data: Any, ensure_ascii: bool = False) -> str:
26-
return json.dumps(data, ensure_ascii=ensure_ascii, separators=(",", ":"))
27-
28-
2925
def urlsafe_b64decode(s: bytes) -> bytes:
3026
if b"+" in s or b"/" in s:
3127
raise binascii.Error
@@ -51,11 +47,11 @@ def int_to_base64(num: int) -> str:
5147
return urlsafe_b64encode(s).decode("utf-8", "strict")
5248

5349

54-
def json_b64encode(text: Any, charset: str = "utf-8") -> bytes:
50+
def json_b64encode(text: Any) -> bytes:
5551
if isinstance(text, dict):
56-
text = json_dumps(text)
57-
return urlsafe_b64encode(to_bytes(text, charset))
52+
text = json.dumps(text, ensure_ascii=True, separators=(",", ":"))
53+
return urlsafe_b64encode(to_bytes(text, "ascii"))
5854

5955

60-
def json_b64decode(text: Any, charset: str = "utf-8") -> Any:
61-
return json.loads(urlsafe_b64decode(to_bytes(text, charset)))
56+
def json_b64decode(text: Any) -> Any:
57+
return json.loads(urlsafe_b64decode(to_bytes(text, "ascii")))

tests/jwt/test_claims.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import time
22
import datetime
3+
import json
4+
import uuid
35
from unittest import TestCase
46
from joserfc import jwt
57
from joserfc.jwk import OctKey
@@ -12,6 +14,13 @@
1214
)
1315

1416

17+
class UUIDEncoder(json.JSONEncoder):
18+
def default(self, o):
19+
if isinstance(o, uuid.UUID):
20+
return str(o)
21+
return super().default(o)
22+
23+
1524
class TestJWTClaims(TestCase):
1625
def test_check_sensitive_data(self):
1726
jwt.check_sensitive_data({})
@@ -135,3 +144,11 @@ def test_validate_nbf(self):
135144
claims_requests = jwt.JWTClaimsRegistry(now=now, leeway=500)
136145
claims_requests.validate({"nbf": now})
137146
self.assertRaises(InvalidTokenError, claims_requests.validate, {"nbf": now + 1000})
147+
148+
def test_claims_with_uuid_field(self):
149+
value = uuid.uuid4()
150+
claims = {"uuid": value}
151+
key = OctKey.import_key("secret")
152+
encoded_text = jwt.encode({"alg": "HS256"}, claims, key)
153+
decoded_data = jwt.decode(encoded_text, key)
154+
self.assertEqual(decoded_data, {"uuid": str(value)})

0 commit comments

Comments
 (0)