diff --git a/src/base64io/__init__.py b/src/base64io/__init__.py index a952212..bd7acec 100644 --- a/src/base64io/__init__.py +++ b/src/base64io/__init__.py @@ -71,11 +71,12 @@ class Base64IO(io.IOBase): avoid data loss. If used as a context manager, we take care of that for you. :param wrapped: Stream to wrap + :param ignore_whitespace: If set, removes whitespace characters in the stream before decoding """ closed = False - def __init__(self, wrapped): + def __init__(self, wrapped, ignore_whitespace=True): # type: (Base64IO, IO) -> None """Check for required methods on wrapped stream and set up read buffer. @@ -85,6 +86,7 @@ def __init__(self, wrapped): if not all(hasattr(wrapped, attr) for attr in required_attrs): raise TypeError("Base64IO wrapped object must have attributes: %s" % (repr(sorted(required_attrs)),)) super(Base64IO, self).__init__() + self.ignore_whitespace = ignore_whitespace self.__wrapped = wrapped self.__read_buffer = b"" self.__write_buffer = b"" @@ -226,8 +228,8 @@ def _read_additional_data_removing_whitespace(self, data, total_bytes_to_read): # case the base64 module happily removes any whitespace. return data - _data_buffer = io.BytesIO() - _data_buffer.write(b"".join(data.split())) + _data_buffer = io.BytesIO() if isinstance(data, bytes) else io.StringIO() + _data_buffer.write(type(data)().join(data.split())) _remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() while _remaining_bytes_to_read > 0: @@ -236,7 +238,7 @@ def _read_additional_data_removing_whitespace(self, data, total_bytes_to_read): # No more data to read from wrapped stream. break - _data_buffer.write(b"".join(_raw_additional_data.split())) + _data_buffer.write(type(data)().join(_raw_additional_data.split())) _remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() return _data_buffer.getvalue() @@ -271,10 +273,15 @@ def read(self, b=-1): # Read encoded bytes from wrapped stream. data = self.__wrapped.read(_bytes_to_read) - # Remove whitespace from read data and attempt to read more data to get the desired - # number of bytes. - if any([char.encode("utf-8") in data for char in string.whitespace]): - data = self._read_additional_data_removing_whitespace(data, _bytes_to_read) + if self.ignore_whitespace: + # Remove whitespace from read data and attempt to read more data to get the desired + # number of bytes. + if isinstance(data, bytes): + if any([char.encode("utf-8") in data for char in string.whitespace]): + data = self._read_additional_data_removing_whitespace(data, _bytes_to_read) + else: + if any([char in data for char in string.whitespace]): + data = self._read_additional_data_removing_whitespace(data, _bytes_to_read) results = io.BytesIO() # First, load any stashed bytes diff --git a/test/unit/test_base64_stream.py b/test/unit/test_base64_stream.py index de8410c..9f59152 100644 --- a/test/unit/test_base64_stream.py +++ b/test/unit/test_base64_stream.py @@ -14,6 +14,7 @@ from __future__ import division import base64 +import binascii import functools import io import math @@ -154,6 +155,22 @@ def test_base64io_decode(bytes_to_generate, bytes_per_round, number_of_rounds, t assert test == plaintext_source[:total_bytes_to_expect] +@pytest.mark.parametrize( + "bytes_to_generate, bytes_per_round, number_of_rounds, total_bytes_to_expect", build_test_cases() +) +def test_base64io_decode_str(bytes_to_generate, bytes_per_round, number_of_rounds, total_bytes_to_expect): + plaintext_source = os.urandom(bytes_to_generate) + plaintext_b64 = io.StringIO(base64.b64encode(plaintext_source).decode('ascii')) + plaintext_wrapped = Base64IO(plaintext_b64) + + test = b"" + for _round in range(number_of_rounds): + test += plaintext_wrapped.read(bytes_per_round) + + assert len(test) == total_bytes_to_expect + assert test == plaintext_source[:total_bytes_to_expect] + + @pytest.mark.parametrize( "bytes_to_generate, bytes_per_round, number_of_rounds, total_bytes_to_expect", build_test_cases() ) @@ -297,6 +314,32 @@ def test_base64io_decode_with_whitespace(plaintext_source, b64_plaintext_with_wh assert test == plaintext_source[:read_bytes] +@pytest.mark.parametrize("plaintext_source, b64_plaintext_with_whitespace, read_bytes", build_whitespace_testcases()) +def test_base64io_decode_with_whitespace_ignored(plaintext_source, b64_plaintext_with_whitespace, read_bytes): + try: + with Base64IO(io.BytesIO(b64_plaintext_with_whitespace), ignore_whitespace=False) as decoder: + test = decoder.read(read_bytes) + except binascii.Error as e: + assert e.args[0] == 'Incorrect padding' + + +@pytest.mark.parametrize("plaintext_source, b64_plaintext_with_whitespace, read_bytes", build_whitespace_testcases()) +def test_base64io_decode_with_whitespace_str(plaintext_source, b64_plaintext_with_whitespace, read_bytes): + with Base64IO(io.StringIO(b64_plaintext_with_whitespace.decode('ascii'))) as decoder: + test = decoder.read(read_bytes) + + assert test == plaintext_source[:read_bytes] + + +@pytest.mark.parametrize("plaintext_source, b64_plaintext_with_whitespace, read_bytes", build_whitespace_testcases()) +def test_base64io_decode_with_whitespace_ignored_str(plaintext_source, b64_plaintext_with_whitespace, read_bytes): + try: + with Base64IO(io.StringIO(b64_plaintext_with_whitespace.decode('ascii')), ignore_whitespace=False) as decoder: + test = decoder.read(read_bytes) + except binascii.Error as e: + assert e.args[0] == 'Incorrect padding' + + @pytest.mark.parametrize( "plaintext_source, b64_plaintext_with_whitespace, read_bytes", ((b"\x00\x00\x00", b"AAAA", 3),) )