diff --git a/asherah/asherah.py b/asherah/asherah.py index b66d2bb..3f70e1c 100644 --- a/asherah/asherah.py +++ b/asherah/asherah.py @@ -1,4 +1,5 @@ """Main Asherah class, for encrypting and decrypting of data""" + # pylint: disable=line-too-long from __future__ import annotations @@ -6,17 +7,16 @@ import json import os from typing import ByteString, Union - from cobhan import Cobhan - from . import exceptions, types class Asherah: """The main class for providing encryption and decryption functionality""" - JSON_OVERHEAD = 256 - KEY_SIZE = 64 + ENCRYPTION_OVERHEAD = 48 + ENVELOPE_OVERHEAD = 185 + BASE64_OVERHEAD = 1.34 def __init__(self): self.__cobhan = Cobhan() @@ -35,6 +35,7 @@ def __init__(self): def setup(self, config: types.AsherahConfig) -> None: """Set up/initialize the underlying encryption library.""" + self.ik_overhead = len(config.service_name) + len(config.product_id) config_json = json.dumps(config.to_json()) config_buf = self.__cobhan.str_to_buf(config_json) result = self.__libasherah.SetupJson(config_buf) @@ -55,7 +56,13 @@ def encrypt(self, partition_id: str, data: Union[ByteString, str]): partition_id_buf = self.__cobhan.str_to_buf(partition_id) data_buf = self.__cobhan.bytearray_to_buf(data) # Outputs - json_buf = self.__cobhan.allocate_buf(len(data_buf) + self.JSON_OVERHEAD) + buffer_estimate = int( + self.ENVELOPE_OVERHEAD + + self.ik_overhead + + len(partition_id_buf) + + ((len(data_buf) + self.ENCRYPTION_OVERHEAD) * self.BASE64_OVERHEAD) + ) + json_buf = self.__cobhan.allocate_buf(buffer_estimate) result = self.__libasherah.EncryptToJson(partition_id_buf, data_buf, json_buf) if result < 0: diff --git a/tests/test_asherah.py b/tests/test_asherah.py index 78321c2..6aab69c 100644 --- a/tests/test_asherah.py +++ b/tests/test_asherah.py @@ -35,3 +35,9 @@ def test_decrypted_data_equals_original_data(self): encrypted = self.asherah.encrypt("partition", data) decrypted = self.asherah.decrypt("partition", encrypted) self.assertEqual(decrypted, data) + + def test_encrypt_decrypt_large_data(self): + data = b"a" * 1024 * 1024 + encrypted = self.asherah.encrypt("partition", data) + decrypted = self.asherah.decrypt("partition", encrypted) + self.assertEqual(decrypted, data)