Skip to content

Commit 04f596f

Browse files
committed
feat(jwk): add a derive_key method for OKPKey
#82
1 parent e52094b commit 04f596f

File tree

2 files changed

+152
-7
lines changed

2 files changed

+152
-7
lines changed

src/joserfc/_rfc8037/okp_key.py

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import typing as t
44
from functools import cached_property
5+
from cryptography.hazmat.primitives import hashes
6+
from cryptography.hazmat.backends import default_backend
57
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey, Ed25519PrivateKey
68
from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PublicKey, Ed448PrivateKey
79
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PublicKey, X25519PrivateKey
@@ -12,6 +14,8 @@
1214
PrivateFormat,
1315
NoEncryption,
1416
)
17+
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
18+
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
1519
from .._rfc7517.models import CurveKey
1620
from .._rfc7517.types import KeyParameters, AnyKey
1721
from .._rfc7517.pem import CryptographyBinding
@@ -44,6 +48,12 @@
4448
"X25519": X25519PrivateKey,
4549
"X448": X448PrivateKey,
4650
}
51+
OKP_SEED_SIZES: dict[LiteralCurves, int] = {
52+
"Ed25519": 32,
53+
"Ed448": 57,
54+
"X25519": 32,
55+
"X448": 56,
56+
}
4757
PrivateKeyTypes = (Ed25519PrivateKey, Ed448PrivateKey, X25519PrivateKey, X448PrivateKey)
4858

4959

@@ -193,14 +203,103 @@ def generate_key(
193203
raw_key = cls.binding.generate_private_key("Ed25519")
194204
else:
195205
raw_key = cls.binding.generate_private_key(crv)
196-
if private:
197-
key = cls(raw_key, raw_key, parameters)
206+
return _wrap_key(cls, raw_key, private, auto_kid, parameters)
207+
208+
@classmethod
209+
def derive_key(
210+
cls: t.Type["OKPKey"],
211+
secret: bytes | str,
212+
crv: LiteralCurves = "Ed25519",
213+
parameters: KeyParameters | None = None,
214+
private: bool = True,
215+
auto_kid: bool = False,
216+
kdf_name: t.Literal["HKDF", "PBKDF2"] = "HKDF",
217+
kdf_options: dict[str, t.Any] | None = None,
218+
) -> "OKPKey":
219+
"""
220+
Derives a key from a given input secret using a specified key derivation function
221+
(KDF) and elliptic curve algorithm.
222+
223+
To derive a key using **HKDF**, the ``kdf_options`` may contain the ``algorithm``,
224+
``salt`` and ``info`` values:
225+
226+
.. code-block:: python
227+
228+
from cryptography.hazmat.primitives import hashes
229+
from joserfc.jwk import OKPKey
230+
231+
# default kdf_name is HKDF, algorithm is SHA256
232+
OKPKey.derive_key("secret")
233+
# equivalent to
234+
OKPKey.derive_key(
235+
"secret", "Ed25519",
236+
kdf_name="HKDF",
237+
kdf_options={
238+
"algorithm": hashes.SHA256(),
239+
"salt": b"joserfc:OKP:HKDF:Ed25519",
240+
"info": b"",
241+
}
242+
)
243+
244+
To derive a key using **PBKDF2**, the ``kdf_options`` may contain the ``algorithm``,
245+
``salt`` and ``iterations`` values:
246+
247+
.. code-block:: python
248+
249+
from cryptography.hazmat.primitives import hashes
250+
from joserfc.jwk import OKPKey
251+
252+
OKPKey.derive_key("secret", kdf_name="PBKDF2")
253+
# equivalent to
254+
OKPKey.derive_key(
255+
"secret", "Ed25519",
256+
kdf_name="PBKDF2",
257+
kdf_options={
258+
"algorithm": hashes.SHA256(),
259+
"salt": b"joserfc:OKP:PBKDF2:Ed25519",
260+
"iterations": 100000,
261+
}
262+
)
263+
264+
:param secret: The input secret used for key derivation.
265+
:param crv: OKPKey curve name
266+
:param parameters: extra parameter in JWK
267+
:param private: generate a private key or public key
268+
:param auto_kid: add ``kid`` automatically
269+
:param kdf_name: Key derivation function name
270+
:param kdf_options: Additional options for the KDF
271+
"""
272+
if kdf_options is None:
273+
kdf_options = {}
274+
275+
algorithm = kdf_options.pop("algorithm", None)
276+
if algorithm is None:
277+
algorithm = hashes.SHA256()
278+
279+
kdf_options.setdefault("salt", to_bytes(f"joserfc:OKP:{kdf_name}:{crv}"))
280+
if kdf_name == "HKDF":
281+
kdf_options.setdefault("info", b"")
282+
hkdf = HKDF(
283+
algorithm=algorithm,
284+
length=OKP_SEED_SIZES[crv],
285+
backend=default_backend(),
286+
**kdf_options,
287+
)
288+
seed = hkdf.derive(to_bytes(secret))
289+
elif kdf_name == "PBKDF2":
290+
kdf_options.setdefault("iterations", 100000)
291+
pbkdf2 = PBKDF2HMAC(
292+
algorithm=algorithm,
293+
length=OKP_SEED_SIZES[crv],
294+
backend=default_backend(),
295+
**kdf_options,
296+
)
297+
seed = pbkdf2.derive(to_bytes(secret))
198298
else:
199-
pub_key = raw_key.public_key()
200-
key = cls(pub_key, pub_key, parameters)
201-
if auto_kid:
202-
key.ensure_kid()
203-
return key
299+
raise ValueError(f"Invalid kdf value: '{kdf_name}'")
300+
301+
raw_key = cls.binding.from_private_bytes(crv, seed)
302+
return _wrap_key(cls, raw_key, private, auto_kid, parameters)
204303

205304

206305
def get_key_curve(key: t.Union[PublicOKPKey, PrivateOKPKey]) -> LiteralCurves:
@@ -213,3 +312,20 @@ def get_key_curve(key: t.Union[PublicOKPKey, PrivateOKPKey]) -> LiteralCurves:
213312
elif isinstance(key, (X448PublicKey, X448PrivateKey)):
214313
return "X448"
215314
raise ValueError("Invalid key") # pragma: no cover
315+
316+
317+
def _wrap_key(
318+
cls: t.Type["OKPKey"],
319+
raw_key: PrivateOKPKey,
320+
private: bool,
321+
auto_kid: bool,
322+
parameters: KeyParameters | None = None,
323+
) -> OKPKey:
324+
if private:
325+
key = cls(raw_key, raw_key, parameters)
326+
else:
327+
pub_key = raw_key.public_key()
328+
key = cls(pub_key, pub_key, parameters)
329+
if auto_kid:
330+
key.ensure_kid()
331+
return key

tests/jwk/test_okp_key.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from unittest import TestCase
2+
from cryptography.hazmat.primitives import hashes
23
from joserfc.jwk import OKPKey
34
from joserfc.errors import (
45
InvalidExchangeKeyError,
@@ -118,3 +119,31 @@ def test_key_eq(self):
118119
self.assertEqual(key1, key2)
119120
key3 = OKPKey.generate_key()
120121
self.assertNotEqual(key1, key3)
122+
123+
def test_derive_key_errors(self):
124+
self.assertRaises(KeyError, OKPKey.derive_key, "secret", "invalid")
125+
self.assertRaises(ValueError, OKPKey.derive_key, "secret", "Ed25519", kdf_name="invalid")
126+
127+
def test_derive_key_with_default_kwargs(self):
128+
curves = ["Ed25519", "Ed448", "X25519", "X448"]
129+
for crv in curves:
130+
key1 = OKPKey.derive_key("okp-secret-key", crv)
131+
key2 = OKPKey.derive_key("okp-secret-key", crv)
132+
self.assertEqual(key1, key2)
133+
134+
for crv in curves:
135+
key1 = OKPKey.derive_key("okp-secret-key", crv, kdf_name="PBKDF2")
136+
key2 = OKPKey.derive_key("okp-secret-key", crv, kdf_name="PBKDF2")
137+
self.assertEqual(key1, key2)
138+
139+
def test_derive_key_with_new_salt(self):
140+
curves = ["Ed25519", "Ed448", "X25519", "X448"]
141+
for crv in curves:
142+
key1 = OKPKey.derive_key("okp-secret-key", crv, kdf_options={"salt": b"salt"})
143+
key2 = OKPKey.derive_key("okp-secret-key", crv, kdf_options={"salt": b"salt"})
144+
self.assertEqual(key1, key2)
145+
146+
def test_derive_key_with_different_hash(self):
147+
key1 = OKPKey.derive_key("okp-secret-key", "Ed25519", kdf_options={"algorithm": hashes.SHA256()})
148+
key2 = OKPKey.derive_key("okp-secret-key", "Ed25519", kdf_options={"algorithm": hashes.SHA512()})
149+
self.assertNotEqual(key1, key2)

0 commit comments

Comments
 (0)