|
1 | 1 | from unittest import TestCase |
2 | 2 |
|
3 | | -# trigger register_key_set |
4 | | -import joserfc.jws # noqa: F401 |
| 3 | +from joserfc import jws |
5 | 4 | from joserfc.jwk import guess_key, import_key, generate_key, thumbprint_uri |
6 | 5 | from joserfc.jwk import KeySet, OctKey, RSAKey, ECKey, OKPKey |
7 | 6 | from joserfc.errors import ( |
|
16 | 15 |
|
17 | 16 | class Guest: |
18 | 17 | def __init__(self): |
19 | | - self._headers = {} |
| 18 | + self._headers = {"alg": "HS256"} |
20 | 19 |
|
21 | 20 | def headers(self): |
22 | 21 | return self._headers |
@@ -138,3 +137,41 @@ def test_thumbprint_uri(self): |
138 | 137 | ) |
139 | 138 | expected = "urn:ietf:params:oauth:jwk-thumbprint:sha-256:w9eYdC6_s_tLQ8lH6PUpc0mddazaqtPgeC2IgWDiqY8" |
140 | 139 | self.assertEqual(value, expected) |
| 140 | + |
| 141 | + def test_find_correct_key_with_use(self): |
| 142 | + key = OctKey.generate_key() |
| 143 | + dict_key = key.as_dict() |
| 144 | + |
| 145 | + key1: OctKey = OctKey.import_key(dict_key, {"use": "enc"}) |
| 146 | + key2: OctKey = OctKey.import_key(dict_key, {"use": "sig"}) |
| 147 | + self.assertEqual(key1.kid, key2.kid) |
| 148 | + |
| 149 | + key_set = KeySet([key1, key2]) |
| 150 | + # pick randomly |
| 151 | + jws.serialize_compact({"alg": "HS256"}, "foo", key_set) |
| 152 | + # get by kid |
| 153 | + jws.serialize_compact({"alg": "HS256", "kid": key2.kid}, "foo", key_set) |
| 154 | + |
| 155 | + key_set = KeySet([key1, key2, key2]) |
| 156 | + self.assertRaises( |
| 157 | + InvalidKeyIdError, |
| 158 | + jws.serialize_compact, |
| 159 | + {"alg": "HS256", "kid": key2.kid}, |
| 160 | + "foo", |
| 161 | + key_set, |
| 162 | + ) |
| 163 | + |
| 164 | + def test_find_correct_key_with_alg(self): |
| 165 | + key = OctKey.generate_key() |
| 166 | + dict_key = key.as_dict() |
| 167 | + |
| 168 | + key1: OctKey = OctKey.import_key(dict_key, {"alg": "HS256"}) |
| 169 | + key2: OctKey = OctKey.import_key(dict_key, {"alg": "dir"}) |
| 170 | + |
| 171 | + self.assertEqual(key1.kid, key2.kid) |
| 172 | + |
| 173 | + key_set = KeySet([key1, key2]) |
| 174 | + # pick randomly |
| 175 | + jws.serialize_compact({"alg": "HS256"}, "foo", key_set) |
| 176 | + # get by kid |
| 177 | + jws.serialize_compact({"alg": "HS256", "kid": key2.kid}, "foo", key_set) |
0 commit comments