Skip to content

Commit 212ea38

Browse files
committed
fix: guess correct key with "use" and "alg"
ref: authlib/authlib#771
1 parent b48ea52 commit 212ea38

6 files changed

Lines changed: 109 additions & 25 deletions

File tree

src/joserfc/_keys.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,21 +140,32 @@ def as_dict(self, private: bool | None = None, **params: t.Any) -> KeySetSeriali
140140
keys.append(key.as_dict(private=private, **params))
141141
return {"keys": keys}
142142

143-
def get_by_kid(self, kid: str | None = None) -> Key:
143+
def get_by_kid(self, kid: str | None = None, parameters: KeyParameters | None = None) -> Key:
144144
if kid is None and len(self.keys) == 1:
145145
return self.keys[0]
146146

147-
for key in self.keys:
148-
if key.kid == kid:
149-
return key
150-
raise InvalidKeyIdError(f"No key for kid: '{kid}'")
147+
keys = [key for key in self.keys if key.kid == kid]
148+
if parameters:
149+
keys = list(_filter_keys_by_parameters(keys, parameters))
150+
151+
if len(keys) == 1:
152+
return keys[0]
153+
154+
elif len(keys) == 0:
155+
raise InvalidKeyIdError(f"No key for kid: '{kid}'")
156+
else:
157+
raise InvalidKeyIdError(f"Multiple keys for kid: '{kid}'")
151158

152-
def pick_random_key(self, algorithm: str) -> t.Optional[Key]:
159+
def pick_random_key(self, algorithm: str, parameters: KeyParameters | None = None) -> t.Optional[Key]:
153160
key_types = self.algorithm_keys.get(algorithm)
154161
if key_types:
155162
keys = [k for k in self.keys if k.key_type in key_types]
156163
else:
157164
keys = self.keys
165+
166+
if parameters:
167+
keys = list(_filter_keys_by_parameters(keys, parameters))
168+
158169
if keys:
159170
return random.choice(keys)
160171
return None
@@ -186,3 +197,19 @@ def generate_key_set(
186197
keys.append(key)
187198

188199
return cls(keys)
200+
201+
202+
def _filter_keys_by_parameters(keys: list[Key], parameters: KeyParameters) -> t.Generator[Key]:
203+
_use = parameters.get("use")
204+
_alg = parameters.get("alg")
205+
206+
for key in keys:
207+
designed_use = key.get("use")
208+
if designed_use and _use and designed_use != _use:
209+
continue
210+
211+
designed_alg = key.get("alg")
212+
if designed_alg and _alg and designed_alg != _alg:
213+
continue
214+
215+
yield key

src/joserfc/_rfc7517/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,15 @@ def as_dict(self, private: t.Optional[bool] = None, **params: t.Any) -> DictKey:
205205
def check_use(self, use: str) -> None:
206206
"""Check if this key supports the given "use".
207207
208+
Values defined by this specification are:
209+
210+
- "sig" (signature)
211+
- "enc" (encryption)
212+
213+
Other values MAY be used. The "use" value is a case-sensitive
214+
string. Use of the "use" member is OPTIONAL, unless the application
215+
requires its presence.
216+
208217
:param use: this key is used for, e.g. "sig", "enc"
209218
:raise: UnsupportedKeyUseError
210219
"""

src/joserfc/jwe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def encrypt_compact(
109109

110110
obj = CompactEncryption(protected, to_bytes(plaintext))
111111
recipient: Recipient[Key] = Recipient(obj)
112-
key = guess_key(public_key, recipient, True)
112+
key = guess_key(public_key, recipient, True, use="enc")
113113
key.check_use("enc")
114114
recipient.recipient_key = key
115115
if sender_key:
@@ -155,7 +155,7 @@ def decrypt_compact(
155155

156156
recipient = obj.recipient
157157
assert recipient is not None
158-
key = guess_key(private_key, recipient)
158+
key = guess_key(private_key, recipient, use="enc")
159159
key.check_use("enc")
160160
recipient.recipient_key = key
161161
if sender_key:
@@ -228,7 +228,7 @@ def encrypt_json(
228228
recipient.sender_key = _guess_sender_key(recipient, sender_key, True)
229229
if not recipient.recipient_key:
230230
assert public_key is not None
231-
key = guess_key(public_key, recipient, True)
231+
key = guess_key(public_key, recipient, True, use="enc")
232232
key.check_use("enc")
233233
recipient.recipient_key = key
234234

@@ -276,7 +276,7 @@ def _attach_recipient_keys(
276276
recipients: list[Recipient[Key]], private_key: KeyFlexible, sender_key: ECKey | OKPKey | KeySet | None = None
277277
) -> None:
278278
for recipient in recipients:
279-
key = guess_key(private_key, recipient)
279+
key = guess_key(private_key, recipient, use="enc")
280280
key.check_use("enc")
281281
recipient.recipient_key = key
282282
if sender_key:

src/joserfc/jwk.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,18 @@ def set_kid(self, kid: str) -> None: ...
5656
KeyFlexible = t.Union[KeyBase, KeyCallable]
5757

5858

59-
def guess_key(key: KeyFlexible, obj: GuestProtocol, use_random: bool = False) -> Key:
59+
def guess_key(
60+
key: KeyFlexible,
61+
obj: GuestProtocol,
62+
random: bool = False,
63+
use: t.Literal["sig", "enc"] | None = None,
64+
) -> Key:
6065
"""Guess key from a various sources.
6166
6267
:param key: a very flexible key
6368
:param obj: a protocol that has ``headers`` and ``set_kid`` methods
64-
:param use_random: pick a random key from key set
69+
:param random: pick a random key from key set
70+
:param use: optional "use" value
6571
"""
6672
resolved_key: KeyBase
6773
if callable(key):
@@ -74,15 +80,20 @@ def guess_key(key: KeyFlexible, obj: GuestProtocol, use_random: bool = False) ->
7480
elif isinstance(resolved_key, KeySet):
7581
headers = obj.headers()
7682
kid: str | None = headers.get("kid")
77-
if not kid and use_random:
83+
84+
parameters: KeyParameters = {"alg": headers["alg"]}
85+
if use:
86+
parameters["use"] = use
87+
88+
if not kid and random:
7889
# choose one key by random
79-
return_key = resolved_key.pick_random_key(headers["alg"])
90+
return_key = resolved_key.pick_random_key(headers["alg"], parameters)
8091
if return_key is None:
8192
raise ValueError("Invalid key")
8293
return_key.ensure_kid()
8394
obj.set_kid(t.cast(str, return_key.kid))
8495
else:
85-
return_key = resolved_key.get_by_kid(kid)
96+
return_key = resolved_key.get_by_kid(kid, parameters)
8697
return return_key
8798
else:
8899
raise ValueError("Invalid key")

src/joserfc/jws.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ def serialize_compact(
130130
if private_key is None:
131131
raise MissingKeyError()
132132

133-
key = guess_key(private_key, obj, True)
133+
key = guess_key(private_key, obj, True, use="sig")
134134
key.check_use("sig")
135135
alg.check_key_type(key)
136-
key.check_alg(protected["alg"])
136+
key.check_alg(alg.name)
137137

138138
if is_rfc7797:
139139
out = sign_rfc7515_compact(obj, alg, key)
@@ -170,7 +170,7 @@ def validate_compact(
170170
if public_key is None:
171171
raise MissingKeyError()
172172

173-
key: Key = guess_key(public_key, obj)
173+
key: Key = guess_key(public_key, obj, use="sig")
174174
key.check_use("sig")
175175
alg.check_key_type(key)
176176
return verify_compact(obj, alg, key)
@@ -265,8 +265,8 @@ def serialize_json(
265265
if registry is None:
266266
registry = construct_registry(algorithms)
267267

268-
def find_key(obj: Any) -> Key:
269-
return guess_key(private_key, obj, True)
268+
def find_key(obj: HeaderMember) -> Key:
269+
return guess_key(private_key, obj, True, use="sig")
270270

271271
_payload = to_bytes(payload)
272272
if isinstance(members, list):
@@ -315,8 +315,8 @@ def deserialize_json(
315315
if registry is None:
316316
registry = construct_registry(algorithms)
317317

318-
def find_key(obj: Any) -> Key:
319-
return guess_key(public_key, obj)
318+
def find_key(obj: HeaderMember) -> Key:
319+
return guess_key(public_key, obj, use="sig")
320320

321321
if "signatures" in value:
322322
general_obj = extract_general_json(value)

tests/jwk/test_key_methods.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from unittest import TestCase
22

3-
# trigger register_key_set
4-
import joserfc.jws # noqa: F401
3+
from joserfc import jws
54
from joserfc.jwk import guess_key, import_key, generate_key, thumbprint_uri
65
from joserfc.jwk import KeySet, OctKey, RSAKey, ECKey, OKPKey
76
from joserfc.errors import (
@@ -16,7 +15,7 @@
1615

1716
class Guest:
1817
def __init__(self):
19-
self._headers = {}
18+
self._headers = {"alg": "HS256"}
2019

2120
def headers(self):
2221
return self._headers
@@ -138,3 +137,41 @@ def test_thumbprint_uri(self):
138137
)
139138
expected = "urn:ietf:params:oauth:jwk-thumbprint:sha-256:w9eYdC6_s_tLQ8lH6PUpc0mddazaqtPgeC2IgWDiqY8"
140139
self.assertEqual(value, expected)
140+
141+
def test_find_correct_key_with_use(self):
142+
key = OctKey.generate_key()
143+
dict_key = key.as_dict()
144+
145+
key1: OctKey = OctKey.import_key(dict_key, {"use": "enc"})
146+
key2: OctKey = OctKey.import_key(dict_key, {"use": "sig"})
147+
self.assertEqual(key1.kid, key2.kid)
148+
149+
key_set = KeySet([key1, key2])
150+
# pick randomly
151+
jws.serialize_compact({"alg": "HS256"}, "foo", key_set)
152+
# get by kid
153+
jws.serialize_compact({"alg": "HS256", "kid": key2.kid}, "foo", key_set)
154+
155+
key_set = KeySet([key1, key2, key2])
156+
self.assertRaises(
157+
InvalidKeyIdError,
158+
jws.serialize_compact,
159+
{"alg": "HS256", "kid": key2.kid},
160+
"foo",
161+
key_set,
162+
)
163+
164+
def test_find_correct_key_with_alg(self):
165+
key = OctKey.generate_key()
166+
dict_key = key.as_dict()
167+
168+
key1: OctKey = OctKey.import_key(dict_key, {"alg": "HS256"})
169+
key2: OctKey = OctKey.import_key(dict_key, {"alg": "dir"})
170+
171+
self.assertEqual(key1.kid, key2.kid)
172+
173+
key_set = KeySet([key1, key2])
174+
# pick randomly
175+
jws.serialize_compact({"alg": "HS256"}, "foo", key_set)
176+
# get by kid
177+
jws.serialize_compact({"alg": "HS256", "kid": key2.kid}, "foo", key_set)

0 commit comments

Comments
 (0)