11from __future__ import annotations
22import json
3+ from json import JSONEncoder , JSONDecoder
34from .rfc7519 .claims import convert_claims
45from .rfc7519 .claims import Claims as Claims
56from .rfc7519 .claims import check_sensitive_data as check_sensitive_data
@@ -49,18 +50,20 @@ def encode(
4950 claims : Claims ,
5051 key : KeyFlexible ,
5152 algorithms : list [str ] | None = None ,
52- registry : JWSRegistry | JWERegistry | None = None ) -> str :
53+ registry : JWSRegistry | JWERegistry | None = None ,
54+ encoder_cls : JSONEncoder | None = None ) -> str :
5355 """Encode a JSON Web Token with the given header, and claims.
5456
5557 :param header: A dict of the JWT header
5658 :param claims: A dict of the JWT claims to be encoded
5759 :param key: key used to sign the signature
5860 :param algorithms: a list of allowed algorithms
5961 :param registry: a ``JWSRegistry`` or ``JWERegistry`` to use
62+ :param encoder_cls: A JSONEncoder subclass to use
6063 """
6164 # add ``typ`` in header
6265 _header = {"typ" : "JWT" , ** header }
63- payload = convert_claims (claims )
66+ payload = convert_claims (claims , encoder_cls )
6467 if isinstance (registry , JWERegistry ):
6568 return encrypt_compact (_header , payload , key , algorithms , registry )
6669 else :
@@ -71,14 +74,16 @@ def decode(
7174 value : bytes | str ,
7275 key : KeyFlexible ,
7376 algorithms : list [str ] | None = None ,
74- registry : JWSRegistry | JWERegistry | None = None ) -> Token :
77+ registry : JWSRegistry | JWERegistry | None = None ,
78+ decoder_cls : JSONDecoder | None = None ) -> Token :
7579 """Decode the JSON Web Token string with the given key, and validate
7680 it with the claims requests.
7781
7882 :param value: text of the JWT
7983 :param key: key used to verify the signature
8084 :param algorithms: a list of allowed algorithms
8185 :param registry: a ``JWSRegistry`` or ``JWERegistry`` to use
86+ :param decoder_cls: A JSONDecoder subclass to use
8287 :raise: BadSignatureError
8388 """
8489 _value = to_bytes (value )
@@ -90,7 +95,7 @@ def decode(
9095 header , payload = _decode_jws (_value , key , algorithms , registry )
9196
9297 try :
93- claims : Claims = json .loads (payload )
98+ claims : Claims = json .loads (payload , cls = decoder_cls )
9499 except (TypeError , ValueError ):
95100 raise InvalidPayloadError ()
96101
0 commit comments