Skip to content

Commit ab682f8

Browse files
committed
fix: improve type hints for jwk module
1 parent aef4796 commit ab682f8

2 files changed

Lines changed: 68 additions & 9 deletions

File tree

src/joserfc/jwk.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,25 @@ def guess_key(key: KeyFlexible, obj: GuestProtocol, use_random: bool = False) ->
8888
raise ValueError("Invalid key")
8989

9090

91-
def import_key(data: AnyKey, key_type: str | None = None, parameters: KeyParameters | None = None) -> Key:
91+
@t.overload
92+
def import_key(data: AnyKey, key_type: t.Literal["oct"], parameters: KeyParameters | None = None) -> OctKey: ...
93+
94+
95+
@t.overload
96+
def import_key(data: AnyKey, key_type: t.Literal["RSA"], parameters: KeyParameters | None = None) -> RSAKey: ...
97+
98+
99+
@t.overload
100+
def import_key(data: AnyKey, key_type: t.Literal["EC"], parameters: KeyParameters | None = None) -> ECKey: ...
101+
102+
103+
@t.overload
104+
def import_key(data: AnyKey, key_type: t.Literal["OKP"], parameters: KeyParameters | None = None) -> OKPKey: ...
105+
106+
107+
def import_key(
108+
data: AnyKey, key_type: t.Literal["oct", "RSA", "EC", "OKP"] | None = None, parameters: KeyParameters | None = None
109+
) -> Key:
92110
"""Importing a key from bytes, string, and dict. When ``value`` is a dict,
93111
this method can tell the key type automatically, otherwise, developers
94112
SHOULD pass the ``key_type`` themselves.
@@ -101,8 +119,48 @@ def import_key(data: AnyKey, key_type: str | None = None, parameters: KeyParamet
101119
return JWKRegistry.import_key(data, key_type, parameters)
102120

103121

122+
@t.overload
123+
def generate_key(
124+
key_type: t.Literal["oct"],
125+
crv_or_size: str | int | None = None,
126+
parameters: KeyParameters | None = None,
127+
private: bool = True,
128+
auto_kid: bool = False,
129+
) -> OctKey: ...
130+
131+
132+
@t.overload
133+
def generate_key(
134+
key_type: t.Literal["RSA"],
135+
crv_or_size: str | int | None = None,
136+
parameters: KeyParameters | None = None,
137+
private: bool = True,
138+
auto_kid: bool = False,
139+
) -> RSAKey: ...
140+
141+
142+
@t.overload
143+
def generate_key(
144+
key_type: t.Literal["EC"],
145+
crv_or_size: str | int | None = None,
146+
parameters: KeyParameters | None = None,
147+
private: bool = True,
148+
auto_kid: bool = False,
149+
) -> ECKey: ...
150+
151+
152+
@t.overload
153+
def generate_key(
154+
key_type: t.Literal["OKP"],
155+
crv_or_size: str | int | None = None,
156+
parameters: KeyParameters | None = None,
157+
private: bool = True,
158+
auto_kid: bool = False,
159+
) -> OKPKey: ...
160+
161+
104162
def generate_key(
105-
key_type: str,
163+
key_type: t.Literal["oct", "RSA", "EC", "OKP"],
106164
crv_or_size: str | int | None = None,
107165
parameters: KeyParameters | None = None,
108166
private: bool = True,

tests/jwt/test_jwt.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
DecodeError,
99
)
1010

11+
1112
def use_embedded_jwk(obj: GuestProtocol) -> Key:
1213
headers = obj.headers()
1314
return import_key(headers["jwk"])
@@ -82,13 +83,13 @@ def test_using_registry(self):
8283

8384
def test_with_embedded_jwk(self):
8485
value = (
85-
'eyJqd2siOnsiY3J2IjoiUC0yNTYiLCJ4IjoiVU05ZzVuS25aWFlvdldBbE'
86-
'03NmNMejl2VG96UmpfX0NIVV9kT2wtZ09vRSIsInkiOiJkczhhZVF3MWwy'
87-
'Y0RDQTdiQ2tPTnZ3REtwWEFidFhqdnFDbGVZSDhXc19VIiwia3R5IjoiRU'
88-
'MifSwiYWxnIjoiRVMyNTYifQ.eyJpc3MiOiJ1cm46ZXhhbXBsZTppc3N1Z'
89-
'XIiLCJhdWQiOiJ1cm46ZXhhbXBsZTphdWRpZW5jZSIsImlhdCI6MTYwNDU'
90-
'4MDc5NH0.60boak3_dErnW47ZPty1C0nrjeVq86EN_eK0GOq6K8w2OA0th'
91-
'KoBxFK4j-NuU9yZ_A9UKGxPT_G87DladBaV9g'
86+
"eyJqd2siOnsiY3J2IjoiUC0yNTYiLCJ4IjoiVU05ZzVuS25aWFlvdldBbE"
87+
"03NmNMejl2VG96UmpfX0NIVV9kT2wtZ09vRSIsInkiOiJkczhhZVF3MWwy"
88+
"Y0RDQTdiQ2tPTnZ3REtwWEFidFhqdnFDbGVZSDhXc19VIiwia3R5IjoiRU"
89+
"MifSwiYWxnIjoiRVMyNTYifQ.eyJpc3MiOiJ1cm46ZXhhbXBsZTppc3N1Z"
90+
"XIiLCJhdWQiOiJ1cm46ZXhhbXBsZTphdWRpZW5jZSIsImlhdCI6MTYwNDU"
91+
"4MDc5NH0.60boak3_dErnW47ZPty1C0nrjeVq86EN_eK0GOq6K8w2OA0th"
92+
"KoBxFK4j-NuU9yZ_A9UKGxPT_G87DladBaV9g"
9293
)
9394
token = jwt.decode(value, use_embedded_jwk)
9495
self.assertEqual(token.claims["iss"], "urn:example:issuer")

0 commit comments

Comments
 (0)