Skip to content

Commit 6db656f

Browse files
committed
fix(jwk): allow import key from cryptography native key types
1 parent 7b8e373 commit 6db656f

File tree

5 files changed

+34
-14
lines changed

5 files changed

+34
-14
lines changed

src/joserfc/_rfc7518/ec_key.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,20 @@ def curve_key_size(self) -> int:
147147
@classmethod
148148
def import_key(
149149
cls: t.Any,
150-
value: AnyKey,
150+
value: AnyKey | EllipticCurvePrivateKey | EllipticCurvePublicKey,
151151
parameters: KeyParameters | None = None,
152152
password: t.Any = None,
153153
) -> "ECKey":
154-
return super(ECKey, cls).import_key(value, parameters, password)
154+
key: ECKey
155+
if isinstance(value, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
156+
key = cls(value, value, parameters)
157+
else:
158+
key = super(ECKey, cls).import_key(value, parameters, password)
159+
return key
155160

156161
@classmethod
157162
def generate_key(
158-
cls,
163+
cls: t.Type["ECKey"],
159164
crv: str | None = "P-256",
160165
parameters: KeyParameters | None = None,
161166
private: bool = True,

src/joserfc/_rfc7518/rsa_key.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22
import warnings
3-
from typing import TypedDict, Any
3+
import typing as t
44
from functools import cached_property
55
from cryptography.hazmat.primitives.asymmetric.rsa import (
66
generate_private_key,
@@ -22,7 +22,7 @@
2222
from ..util import int_to_base64, base64_to_int
2323

2424

25-
RSADictKey = TypedDict(
25+
RSADictKey = t.TypedDict(
2626
"RSADictKey",
2727
{
2828
"n": str,
@@ -136,20 +136,24 @@ def private_key(self) -> RSAPrivateKey | None:
136136

137137
@classmethod
138138
def import_key(
139-
cls: Any,
140-
value: AnyKey,
139+
cls: t.Any,
140+
value: AnyKey | RSAPrivateKey | RSAPublicKey,
141141
parameters: KeyParameters | None = None,
142-
password: Any = None,
142+
password: t.Any = None,
143143
) -> "RSAKey":
144-
key: RSAKey = super(RSAKey, cls).import_key(value, parameters, password)
144+
key: RSAKey
145+
if isinstance(value, (RSAPrivateKey, RSAPublicKey)):
146+
key = cls(value, value, parameters)
147+
else:
148+
key = super(RSAKey, cls).import_key(value, parameters, password)
145149
if key.raw_value.key_size < 2048:
146150
# https://csrc.nist.gov/publications/detail/sp/800-131a/rev-2/final
147151
warnings.warn("Key size should be >= 2048 bits", SecurityWarning)
148152
return key
149153

150154
@classmethod
151155
def generate_key(
152-
cls,
156+
cls: t.Type["RSAKey"],
153157
key_size: int | None = 2048,
154158
parameters: KeyParameters | None = None,
155159
private: bool = True,

src/joserfc/_rfc8037/okp_key.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,12 @@ def curve_name(self) -> LiteralCurves:
150150

151151
@classmethod
152152
def import_key(
153-
cls: t.Type[OKPKey],
153+
cls: t.Any,
154154
value: AnyKey | PrivateOKPKey | PublicOKPKey,
155155
parameters: KeyParameters | None = None,
156156
password: t.Any = None,
157157
) -> "OKPKey":
158+
key: OKPKey
158159
if isinstance(
159160
value,
160161
(
@@ -168,12 +169,14 @@ def import_key(
168169
X448PublicKey,
169170
),
170171
):
171-
return cls(value, value, parameters)
172-
return super(OKPKey, cls).import_key(value, parameters, password)
172+
key = cls(value, value, parameters)
173+
else:
174+
key = super(OKPKey, cls).import_key(value, parameters, password)
175+
return key
173176

174177
@classmethod
175178
def generate_key(
176-
cls: t.Type[OKPKey],
179+
cls: t.Type["OKPKey"],
177180
crv: LiteralCurves | None = "Ed25519",
178181
parameters: KeyParameters | None = None,
179182
private: bool = True,

tests/jwk/test_ec_key.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def test_import_p512_key(self):
4545
def test_import_secp256k1_key(self):
4646
self.run_import_key("secp256k1")
4747

48+
def test_import_from_native_keys(self):
49+
key = ECKey.generate_key()
50+
self.assertEqual(key, ECKey.import_key(key.private_key))
51+
4852
def test_generate_key(self):
4953
self.assertRaises(ValueError, ECKey.generate_key, "Invalid")
5054

tests/jwk/test_rsa_key.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def test_import_key_from_openssl(self):
9292
key: RSAKey = RSAKey.import_key(private_pem)
9393
self.assertTrue(key.is_private)
9494

95+
def test_import_from_native_keys(self):
96+
key = RSAKey.generate_key()
97+
self.assertEqual(key, RSAKey.import_key(key.private_key))
98+
9599
def test_output_as_methods(self):
96100
private_pem = read_key("rsa-openssl-private.pem")
97101
key: RSAKey = RSAKey.import_key(private_pem)

0 commit comments

Comments
 (0)