diff --git a/src/joserfc/util.py b/src/joserfc/util.py index 4221031e..0394e0a0 100644 --- a/src/joserfc/util.py +++ b/src/joserfc/util.py @@ -4,7 +4,7 @@ import struct import binascii import json - +from .errors import DecodeError def to_bytes(x: Any, charset: str = "utf-8", errors: str = "strict") -> bytes: if isinstance(x, bytes): @@ -21,10 +21,23 @@ def to_str(x: bytes | str, charset: str = "utf-8") -> str: return x.decode(charset) return x +def __is_urlsafe_b64_encoding_non_canonical(s: bytes) -> bool: + # https://github.com/FrancoisCapon/Base64SteganographyTools/blob/main/tools/b64_print_regular_characters.sh + p = len(s) % 4 # padding? + if p == 0: + return False + p = 4 - p # number of padding characters + if p == 2 and s[-1] in b"AQgw": + return False + if p == 1 and s[-1] in b"AEIMQUYcgkosw048": + return False + return True def urlsafe_b64decode(s: bytes) -> bytes: if b"+" in s or b"/" in s: raise binascii.Error + if __is_urlsafe_b64_encoding_non_canonical(s): + raise DecodeError s += b"=" * (-len(s) % 4) return base64.b64decode(s, b"-_", validate=True) diff --git a/tests/jws/test_compact.py b/tests/jws/test_compact.py index 189fb8ea..7875d9a8 100644 --- a/tests/jws/test_compact.py +++ b/tests/jws/test_compact.py @@ -73,3 +73,12 @@ def test_strict_check_header(self): registry = JWSRegistry(strict_check_header=False) serialize_compact(header, b"hi", key, registry=registry) + + def test_non_canonical_signature_encoding(self): + text = "eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyIjoiYWRtaW4ifQ.VI29GgHzuh2xfF0bkRYvZIsSuQnbTXSIvuRyt7RDrwo"[:-1] + "p" + self.assertRaises( + DecodeError, + deserialize_compact, + text, + OctKey.import_key("secret") + ) diff --git a/tests/test_util.py b/tests/test_util.py index 42eff1a9..c834ff88 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,7 +1,7 @@ import binascii from unittest import TestCase from joserfc import util - +from joserfc import errors class TestUtil(TestCase): def test_to_bytes(self): @@ -23,3 +23,7 @@ def test_json_b64encode(self): def test_urlsafe_b64decode(self): self.assertEqual(util.urlsafe_b64decode(b"_foo123-"), b"\xfd\xfa(\xd7m\xfe") self.assertRaises(binascii.Error, util.urlsafe_b64decode, b"+foo123/") + for c in "RSTUVWXYZabdef": # A -> QQ== + self.assertRaises(errors.DecodeError, util.urlsafe_b64decode, b"Q" + c.encode()) + for c in "FGH": # AAAAAAAAAAAAAA -> QUFBQUFBQUFBQUFBQUE= + self.assertRaises(errors.DecodeError, util.urlsafe_b64decode, b"QUFBQUFBQUFBQUFBQU" + c.encode())