1+ from __future__ import annotations
12import os
23import typing as t
34from abc import ABCMeta , abstractmethod
45from ..registry import Header , HeaderRegistryDict
56from ..errors import InvalidKeyTypeError , InvalidKeyLengthError
6- from .._keys import Key , ECKey
7+ from .._keys import Key , ECKey , OctKey
78
89KeyType = t .TypeVar ("KeyType" )
910
@@ -12,8 +13,8 @@ class Recipient(t.Generic[KeyType]):
1213 def __init__ (
1314 self ,
1415 parent : t .Union ["CompactEncryption" , "GeneralJSONEncryption" , "FlattenedJSONEncryption" ],
15- header : t . Optional [ Header ] = None ,
16- recipient_key : t . Optional [ KeyType ] = None ):
16+ header : Header | None = None ,
17+ recipient_key : KeyType | None = None ):
1718 self .__parent = parent
1819 self .header = header
1920 self .recipient_key = recipient_key
@@ -30,35 +31,35 @@ def headers(self) -> Header:
3031 rv .update (self .header )
3132 return rv
3233
33- def add_header (self , k : str , v : t .Any ):
34+ def add_header (self , k : str , v : t .Any ) -> None :
3435 if isinstance (self .__parent , CompactEncryption ):
3536 self .__parent .protected .update ({k : v })
3637 elif self .header :
3738 self .header .update ({k : v })
3839 else :
3940 self .header = {k : v }
4041
41- def set_kid (self , kid : str ):
42+ def set_kid (self , kid : str ) -> None :
4243 self .add_header ("kid" , kid )
4344
4445
4546class CompactEncryption :
4647 """An object to represent the JWE Compact Serialization. It is usually returned by
4748 ``decrypt_compact`` method.
4849 """
49- def __init__ (self , protected : Header , plaintext : t . Optional [ bytes ] = None ):
50+ def __init__ (self , protected : Header , plaintext : bytes | None = None ):
5051 #: protected header in dict
5152 self .protected = protected
5253 #: the plaintext in bytes
5354 self .plaintext = plaintext
54- self .recipient : t . Optional [ Recipient ] = None
55+ self .recipient : Recipient [ t . Any ] | None = None
5556 self .bytes_segments : t .Dict [str , bytes ] = {} # store the decoded segments
5657 self .base64_segments : t .Dict [str , bytes ] = {} # store the encoded segments
5758
5859 def headers (self ) -> Header :
5960 return self .protected
6061
61- def attach_recipient (self , key : Key , header : t . Optional [ Header ] = None ):
62+ def attach_recipient (self , key : Key , header : Header | None = None ) -> None :
6263 """Add a recipient to the JWE Compact Serialization. Please add a key that
6364 comply with the given "alg" value.
6465
@@ -71,7 +72,7 @@ def attach_recipient(self, key: Key, header: t.Optional[Header] = None):
7172 self .recipient = recipient
7273
7374 @property
74- def recipients (self ) -> t . List [Recipient ]:
75+ def recipients (self ) -> list [Recipient [ t . Any ] ]:
7576 if self .recipient is not None :
7677 return [self .recipient ]
7778 return []
@@ -89,14 +90,14 @@ class BaseJSONEncryption(metaclass=ABCMeta):
8990 #: an optional additional authenticated data
9091 aad : t .Optional [bytes ]
9192 #: a list of recipients
92- recipients : t .List [Recipient ]
93+ recipients : t .List [Recipient [ t . Any ] ]
9394
9495 def __init__ (
9596 self ,
9697 protected : Header ,
97- plaintext : t . Optional [ bytes ] = None ,
98- unprotected : t . Optional [ Header ] = None ,
99- aad : t . Optional [ bytes ] = None ):
98+ plaintext : bytes | None = None ,
99+ unprotected : Header | None = None ,
100+ aad : bytes | None = None ):
100101 self .protected = protected
101102 self .plaintext = plaintext
102103 self .unprotected = unprotected
@@ -106,7 +107,7 @@ def __init__(
106107 self .base64_segments : t .Dict [str , bytes ] = {} # store the encoded segments
107108
108109 @abstractmethod
109- def add_recipient (self , header : t . Optional [ Header ] = None , key : t . Optional [ Key ] = None ):
110+ def add_recipient (self , header : Header | None = None , key : Key | None = None ) -> None :
110111 """Add a recipient to the JWE JSON Serialization. Please add a key that
111112 comply with the "alg" to this recipient.
112113
@@ -131,7 +132,7 @@ class GeneralJSONEncryption(BaseJSONEncryption):
131132 """
132133 flattened = False
133134
134- def add_recipient (self , header : t . Optional [ Header ] = None , key : t . Optional [ Key ] = None ):
135+ def add_recipient (self , header : Header | None = None , key : Key | None = None ) -> None :
135136 recipient = Recipient (self , header , key )
136137 self .recipients .append (recipient )
137138
@@ -152,7 +153,7 @@ class FlattenedJSONEncryption(BaseJSONEncryption):
152153 """
153154 flattened = True
154155
155- def add_recipient (self , header : t . Optional [ Header ] = None , key : t . Optional [ Key ] = None ):
156+ def add_recipient (self , header : Header | None = None , key : Key | None = None ) -> None :
156157 self .recipients = [Recipient (self , header , key )]
157158
158159
@@ -178,7 +179,7 @@ def check_iv(self, iv: bytes) -> bytes:
178179 return iv
179180
180181 @abstractmethod
181- def encrypt (self , plaintext : bytes , cek : bytes , iv : bytes , aad : bytes ) -> t . Tuple [bytes , bytes ]:
182+ def encrypt (self , plaintext : bytes , cek : bytes , iv : bytes , aad : bytes ) -> tuple [bytes , bytes ]:
182183 pass
183184
184185 @abstractmethod
@@ -216,19 +217,19 @@ class KeyManagement:
216217 def direct_mode (self ) -> bool :
217218 return self .key_size is None
218219
219- def check_key_type (self , key : Key ):
220+ def check_key_type (self , key : Key ) -> None :
220221 if key .key_type not in self .key_types :
221222 raise InvalidKeyTypeError ()
222223
223- def prepare_recipient_header (self , recipient : Recipient ) :
224+ def prepare_recipient_header (self , recipient : Recipient [ t . Any ]) -> None :
224225 raise NotImplementedError ()
225226
226227
227228class JWEDirectEncryption (KeyManagement , metaclass = ABCMeta ):
228229 key_types = ["oct" ]
229230
230231 @abstractmethod
231- def compute_cek (self , size : int , recipient : Recipient ) -> bytes :
232+ def compute_cek (self , size : int , recipient : Recipient [ OctKey ] ) -> bytes :
232233 pass
233234
234235
@@ -238,11 +239,11 @@ def direct_mode(self) -> bool:
238239 return False
239240
240241 @abstractmethod
241- def encrypt_cek (self , cek : bytes , recipient : Recipient ) -> bytes :
242+ def encrypt_cek (self , cek : bytes , recipient : Recipient [ t . Any ] ) -> bytes :
242243 pass
243244
244245 @abstractmethod
245- def decrypt_cek (self , recipient : Recipient ) -> bytes :
246+ def decrypt_cek (self , recipient : Recipient [ t . Any ] ) -> bytes :
246247 pass
247248
248249
@@ -254,7 +255,7 @@ class JWEKeyWrapping(KeyManagement, metaclass=ABCMeta):
254255 def direct_mode (self ) -> bool :
255256 return False
256257
257- def check_op_key (self , op_key : bytes ):
258+ def check_op_key (self , op_key : bytes ) -> None :
258259 if len (op_key ) * 8 != self .key_size :
259260 raise InvalidKeyLengthError (f"A key of size { self .key_size } bits MUST be used" )
260261
@@ -267,11 +268,11 @@ def unwrap_cek(self, ek: bytes, key: bytes) -> bytes:
267268 pass
268269
269270 @abstractmethod
270- def encrypt_cek (self , cek : bytes , recipient : Recipient ) -> bytes :
271+ def encrypt_cek (self , cek : bytes , recipient : Recipient [ OctKey ] ) -> bytes :
271272 pass
272273
273274 @abstractmethod
274- def decrypt_cek (self , recipient : Recipient ) -> bytes :
275+ def decrypt_cek (self , recipient : Recipient [ OctKey ] ) -> bytes :
275276 pass
276277
277278
@@ -280,7 +281,7 @@ class JWEKeyAgreement(KeyManagement, metaclass=ABCMeta):
280281 tag_aware : bool = False
281282 key_wrapping : t .Optional [JWEKeyWrapping ]
282283
283- def prepare_ephemeral_key (self , recipient : Recipient [ECKey ]):
284+ def prepare_ephemeral_key (self , recipient : Recipient [ECKey ]) -> None :
284285 recipient_key = recipient .recipient_key
285286 assert recipient_key is not None
286287 self .check_key_type (recipient_key )
0 commit comments