Skip to content

Commit 775d53b

Browse files
committed
feat: filter_algorithms supports KeySet objects
1 parent 70deabb commit 775d53b

File tree

3 files changed

+41
-3
lines changed

3 files changed

+41
-3
lines changed

docs/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Changelog
1919

2020
- ``filter_algorithms`` ``names`` defaults to all algorithms. :pr:`79`.
2121
- Replace ``JWSRegistry.guess_alg`` with ``JWSRegistry.guess_algorithm``.
22+
- ``filter_algorithms`` and ``guess_alg`` supports ``KeySet`` objects. :pr:`81`
2223

2324
1.5.0
2425
-----

src/joserfc/_rfc7515/registry.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
check_crit_header,
1818
check_supported_header,
1919
)
20+
from .._keys import KeySet
2021

2122
__all__ = [
2223
"JWSRegistry",
@@ -115,7 +116,7 @@ def validate_signature_size(self, signature: bytes) -> None:
115116
def guess_algorithm(cls, key: Any, strategy: Strategy) -> JWSAlgModel | None:
116117
"""Guess the JWS algorithm for a given key.
117118
118-
:param key: key instance
119+
:param key: key instance or a KeySet
119120
:param strategy: the strategy for guessing the JWS algorithm
120121
"""
121122
if strategy == cls.Strategy.RECOMMENDED:
@@ -145,12 +146,19 @@ def guess_alg(cls, key: Any, strategy: Strategy) -> str | None: # pragma: no co
145146
def filter_algorithms(cls, key: Any, names: list[str] | None = None) -> list[JWSAlgModel]:
146147
"""Filter JWS algorithms based on the given algorithm names.
147148
148-
:param key: key instance
149+
:param key: a key instance or a KeySet
149150
:param names: list of algorithm names
150151
"""
151152
if names is None:
152153
names = list(cls.algorithms.keys())
153154
rv: list[JWSAlgModel] = []
155+
if isinstance(key, KeySet):
156+
for k in key.keys:
157+
for alg in cls.filter_algorithms(k, names):
158+
if alg not in rv:
159+
rv.append(alg)
160+
return rv
161+
154162
for name in names:
155163
alg = cls.algorithms[name]
156164
try:

tests/jws/test_registry.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22

33
from joserfc.jws import JWSRegistry
4-
from joserfc.jwk import OctKey, RSAKey, ECKey, OKPKey
4+
from joserfc.jwk import OctKey, RSAKey, ECKey, OKPKey, KeySet
55

66

77
class JWSRegistryTest(unittest.TestCase):
@@ -63,3 +63,32 @@ def test_filter_algorithms_ed448(self):
6363
self.assertIn("EdDSA", names)
6464
self.assertIn("Ed448", names)
6565
self.assertNotIn("Ed25519", names)
66+
67+
def test_filter_algorithms_with_key_set(self):
68+
"""filter_algorithms should support KeySet and combine algorithms from all keys."""
69+
rsa_key1 = RSAKey.generate_key()
70+
rsa_key2 = RSAKey.generate_key()
71+
ec_key = ECKey.generate_key("P-256")
72+
key_set = KeySet([rsa_key1, rsa_key2, ec_key])
73+
74+
algs = JWSRegistry.filter_algorithms(key_set, JWSRegistry.algorithms.keys())
75+
names = [alg.name for alg in algs]
76+
77+
self.assertIn("RS256", names)
78+
self.assertIn("ES256", names)
79+
self.assertNotIn("ES384", names)
80+
self.assertEqual(names.count("RS256"), 1)
81+
82+
def test_guess_algorithm_with_key_set(self):
83+
"""guess_algorithm should find the best algorithm across all keys in the KeySet."""
84+
rsa_key = RSAKey.generate_key()
85+
ec_key = ECKey.generate_key("P-256")
86+
key_set = KeySet([rsa_key, ec_key])
87+
88+
# RS256 comes before ES256 in the recommended list
89+
alg = JWSRegistry.guess_algorithm(key_set, JWSRegistry.Strategy.RECOMMENDED)
90+
self.assertEqual(alg.name, "RS256")
91+
92+
# RS512 has the highest algorithm_security (512) among available algorithms
93+
alg = JWSRegistry.guess_algorithm(key_set, JWSRegistry.Strategy.SECURITY)
94+
self.assertEqual(alg.name, "RS512")

0 commit comments

Comments
 (0)