Skip to content

Commit b8434cb

Browse files
committed
fix(jwa): improve RFC9864 check key logic
1 parent 59918a1 commit b8434cb

File tree

3 files changed

+30
-16
lines changed

3 files changed

+30
-16
lines changed

src/joserfc/_rfc9864/jws_eddsa.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,30 @@
22
from cryptography.exceptions import InvalidSignature
33
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey, Ed25519PrivateKey
44
from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PublicKey, Ed448PrivateKey
5-
from ..errors import InvalidKeyTypeError
5+
from ..errors import InvalidKeyCurveError
66
from .._rfc7515.model import JWSAlgModel
77
from .._rfc8037.okp_key import OKPKey
88

99

10-
_private_key_mapping = {"Ed25519": Ed25519PrivateKey, "Ed448": Ed448PrivateKey}
11-
_public_key_mapping = {"Ed25519": Ed25519PublicKey, "Ed448": Ed448PublicKey}
12-
13-
1410
class EdDSAAlgorithm(JWSAlgModel):
1511
key_type = "OKP"
1612

1713
def __init__(self, curve: t.Literal["Ed25519", "Ed448"]):
1814
self.name = curve
15+
self.curve = curve
1916
self.description = f"EdDSA using the {curve} parameter set"
2017

18+
def check_key(self, key: OKPKey) -> None:
19+
super().check_key(key)
20+
if key.curve_name != self.curve:
21+
raise InvalidKeyCurveError(f"Key for '{self.name}' not supported, only '{self.curve}' allowed")
22+
2123
def sign(self, msg: bytes, key: OKPKey) -> bytes:
2224
op_key = t.cast(t.Union[Ed25519PrivateKey, Ed448PrivateKey], key.get_op_key("sign"))
23-
private_key_cls = _private_key_mapping[self.name]
24-
if not isinstance(op_key, private_key_cls):
25-
raise InvalidKeyTypeError(f"Algorithm '{self.name}' requires '{self.name}' OKP key")
2625
return op_key.sign(msg)
2726

2827
def verify(self, msg: bytes, sig: bytes, key: OKPKey) -> bool:
2928
op_key = t.cast(t.Union[Ed25519PublicKey, Ed448PublicKey], key.get_op_key("verify"))
30-
public_key_cls = _public_key_mapping[self.name]
31-
if not isinstance(op_key, public_key_cls):
32-
raise InvalidKeyTypeError(f"Algorithm '{self.name}' requires '{self.name}' OKP key")
3329
try:
3430
op_key.verify(sig, msg)
3531
return True

tests/jws/test_eddsa.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from unittest import TestCase
33
from joserfc import jwt
44
from joserfc.jwk import OKPKey
5-
from joserfc.errors import InvalidKeyTypeError, BadSignatureError
5+
from joserfc.errors import InvalidKeyTypeError, InvalidKeyCurveError, BadSignatureError
66
from tests.base import load_key
77

88

@@ -23,9 +23,9 @@ def test_Ed25519(self):
2323
encoded_jwt = jwt.encode({"alg": "Ed25519"}, {}, self.ed25519_key, algorithms=algorithms)
2424
jwt.decode(encoded_jwt, self.ed25519_key, algorithms=algorithms)
2525
self.assertRaises(
26-
InvalidKeyTypeError, jwt.encode, {"alg": "Ed25519"}, {}, self.ed448_key, algorithms=algorithms
26+
InvalidKeyCurveError, jwt.encode, {"alg": "Ed25519"}, {}, self.ed448_key, algorithms=algorithms
2727
)
28-
self.assertRaises(InvalidKeyTypeError, jwt.decode, encoded_jwt, self.ed448_key, algorithms=algorithms)
28+
self.assertRaises(InvalidKeyCurveError, jwt.decode, encoded_jwt, self.ed448_key, algorithms=algorithms)
2929
wrong_key = OKPKey.generate_key("Ed25519", private=False)
3030
self.assertRaises(BadSignatureError, jwt.decode, encoded_jwt, wrong_key, algorithms=algorithms)
3131

@@ -34,8 +34,8 @@ def test_Ed448(self):
3434
encoded_jwt = jwt.encode({"alg": "Ed448"}, {}, self.ed448_key, algorithms=algorithms)
3535
jwt.decode(encoded_jwt, self.ed448_key, algorithms=algorithms)
3636
self.assertRaises(
37-
InvalidKeyTypeError, jwt.encode, {"alg": "Ed448"}, {}, self.ed25519_key, algorithms=algorithms
37+
InvalidKeyCurveError, jwt.encode, {"alg": "Ed448"}, {}, self.ed25519_key, algorithms=algorithms
3838
)
39-
self.assertRaises(InvalidKeyTypeError, jwt.decode, encoded_jwt, self.ed25519_key, algorithms=algorithms)
39+
self.assertRaises(InvalidKeyCurveError, jwt.decode, encoded_jwt, self.ed25519_key, algorithms=algorithms)
4040
wrong_key = OKPKey.generate_key("Ed448", private=False)
4141
self.assertRaises(BadSignatureError, jwt.decode, encoded_jwt, wrong_key, algorithms=algorithms)

tests/jws/test_registry.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,21 @@ def test_filter_algorithms_default_names(self):
4545
explicit = JWSRegistry.filter_algorithms(self.rsa_key, all_names)
4646
default = JWSRegistry.filter_algorithms(self.rsa_key)
4747
self.assertEqual(explicit, default)
48+
49+
def test_filter_algorithms_ed25519(self):
50+
"""Ed25519 keys should only be compatible with EdDSA and Ed25519, not Ed448."""
51+
ed25519_key = OKPKey.generate_key("Ed25519")
52+
algs = JWSRegistry.filter_algorithms(ed25519_key)
53+
names = [alg.name for alg in algs]
54+
self.assertIn("EdDSA", names)
55+
self.assertIn("Ed25519", names)
56+
self.assertNotIn("Ed448", names)
57+
58+
def test_filter_algorithms_ed448(self):
59+
"""Ed448 keys should only be compatible with EdDSA and Ed448, not Ed25519."""
60+
ed448_key = OKPKey.generate_key("Ed448")
61+
algs = JWSRegistry.filter_algorithms(ed448_key)
62+
names = [alg.name for alg in algs]
63+
self.assertIn("EdDSA", names)
64+
self.assertIn("Ed448", names)
65+
self.assertNotIn("Ed25519", names)

0 commit comments

Comments
 (0)