Skip to content

Commit 2ebe3cc

Browse files
committed
feat: add __eq__ for Key and KeySet
1 parent 94a651d commit 2ebe3cc

7 files changed

Lines changed: 61 additions & 1 deletion

File tree

src/joserfc/_keys.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ def __iter__(self) -> t.Iterator[Key]:
120120
def __bool__(self) -> bool:
121121
return bool(self.keys)
122122

123+
def __eq__(self, other: t.Any) -> bool:
124+
assert isinstance(other, KeySet)
125+
return self.keys == other.keys
126+
123127
def as_dict(self, private: bool | None = None, **params: t.Any) -> KeySetSerialization:
124128
keys: list[DictKey] = []
125129

src/joserfc/rfc7517/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def __init__(
9999
self.validate_dict_key(data)
100100
self._dict_value = data
101101

102+
def __eq__(self, other: t.Any) -> bool:
103+
if not isinstance(other, self.__class__):
104+
return False
105+
return self.dict_value == other.dict_value
106+
102107
def keys(self) -> KeysView[str]:
103108
return self.dict_value.keys()
104109

tests/jwk/test_ec_key.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from unittest import TestCase
2-
from joserfc.jwk import ECKey
2+
from joserfc.jwk import ECKey, OctKey
33
from joserfc.errors import InvalidExchangeKeyError
44
from tests.keys import read_key
55

@@ -78,3 +78,15 @@ def test_output_with_password(self):
7878
)
7979
key2 = ECKey.import_key(pem, password="secret")
8080
self.assertEqual(key.as_dict(), key2.as_dict())
81+
82+
def test_key_eq(self):
83+
key1 = ECKey.generate_key()
84+
key2 = ECKey.import_key(key1.as_dict())
85+
self.assertEqual(key1, key2)
86+
key3 = ECKey.generate_key()
87+
self.assertNotEqual(key1, key3)
88+
89+
def test_key_eq_with_different_types(self):
90+
key1 = ECKey.generate_key()
91+
key2 = OctKey.generate_key()
92+
self.assertNotEqual(key1, key2)

tests/jwk/test_jwk_set.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,18 @@ def test_key_set_iter(self):
7676
key_set = KeySet.generate_key_set('RSA', 2048)
7777
for k in key_set:
7878
self.assertEqual(k.key_type, "RSA")
79+
80+
def test_key_eq_with_same_keys(self):
81+
key_set1 = KeySet.generate_key_set('RSA', 2048)
82+
key_set2 = KeySet(key_set1.keys)
83+
self.assertIsNot(key_set1, key_set2)
84+
self.assertEqual(key_set1, key_set2)
85+
86+
def test_key_eq_with_new_keys(self):
87+
key_set1 = KeySet.generate_key_set('RSA', 2048)
88+
key_set2 = KeySet([
89+
RSAKey.import_key(k.as_dict())
90+
for k in key_set1
91+
])
92+
self.assertIsNot(key_set1, key_set2)
93+
self.assertEqual(key_set1, key_set2)

tests/jwk/test_oct_key.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,11 @@ def test_generate_key(self):
9292

9393
key = OctKey.generate_key(auto_kid=True)
9494
self.assertIsNotNone(key.kid)
95+
96+
def test_key_eq(self):
97+
key1 = OctKey.generate_key()
98+
key2 = OctKey.import_key(key1.as_dict())
99+
self.assertIsNot(key1, key2)
100+
self.assertEqual(key1, key2)
101+
key3 = OctKey.generate_key()
102+
self.assertNotEqual(key1, key3)

tests/jwk/test_okp_key.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,11 @@ def test_output_with_password(self):
100100
)
101101
key2 = OKPKey.import_key(pem, password="secret")
102102
self.assertEqual(key.as_pem(), key2.as_pem())
103+
104+
def test_key_eq(self):
105+
key1 = OKPKey.generate_key()
106+
key2 = OKPKey.import_key(key1.as_dict())
107+
self.assertIsNot(key1, key2)
108+
self.assertEqual(key1, key2)
109+
key3 = OKPKey.generate_key()
110+
self.assertNotEqual(key1, key3)

tests/jwk/test_rsa_key.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,11 @@ def test_output_with_password(self):
150150
)
151151
key2 = RSAKey.import_key(pem, password="secret")
152152
self.assertEqual(key.as_dict(), key2.as_dict())
153+
154+
def test_key_eq(self):
155+
key1 = RSAKey.generate_key()
156+
key2 = RSAKey.import_key(key1.as_dict())
157+
self.assertIsNot(key1, key2)
158+
self.assertEqual(key1, key2)
159+
key3 = RSAKey.generate_key()
160+
self.assertNotEqual(key1, key3)

0 commit comments

Comments
 (0)