|
1 | 1 | from __future__ import annotations |
2 | 2 | import warnings |
3 | | -from typing import TypedDict, Any |
| 3 | +import typing as t |
4 | 4 | from functools import cached_property |
5 | 5 | from cryptography.hazmat.primitives.asymmetric.rsa import ( |
6 | 6 | generate_private_key, |
|
22 | 22 | from ..util import int_to_base64, base64_to_int |
23 | 23 |
|
24 | 24 |
|
25 | | -RSADictKey = TypedDict( |
| 25 | +RSADictKey = t.TypedDict( |
26 | 26 | "RSADictKey", |
27 | 27 | { |
28 | 28 | "n": str, |
@@ -136,20 +136,24 @@ def private_key(self) -> RSAPrivateKey | None: |
136 | 136 |
|
137 | 137 | @classmethod |
138 | 138 | def import_key( |
139 | | - cls: Any, |
140 | | - value: AnyKey, |
| 139 | + cls: t.Any, |
| 140 | + value: AnyKey | RSAPrivateKey | RSAPublicKey, |
141 | 141 | parameters: KeyParameters | None = None, |
142 | | - password: Any = None, |
| 142 | + password: t.Any = None, |
143 | 143 | ) -> "RSAKey": |
144 | | - key: RSAKey = super(RSAKey, cls).import_key(value, parameters, password) |
| 144 | + key: RSAKey |
| 145 | + if isinstance(value, (RSAPrivateKey, RSAPublicKey)): |
| 146 | + key = cls(value, value, parameters) |
| 147 | + else: |
| 148 | + key = super(RSAKey, cls).import_key(value, parameters, password) |
145 | 149 | if key.raw_value.key_size < 2048: |
146 | 150 | # https://csrc.nist.gov/publications/detail/sp/800-131a/rev-2/final |
147 | 151 | warnings.warn("Key size should be >= 2048 bits", SecurityWarning) |
148 | 152 | return key |
149 | 153 |
|
150 | 154 | @classmethod |
151 | 155 | def generate_key( |
152 | | - cls, |
| 156 | + cls: t.Type["RSAKey"], |
153 | 157 | key_size: int | None = 2048, |
154 | 158 | parameters: KeyParameters | None = None, |
155 | 159 | private: bool = True, |
|
0 commit comments