Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/base64io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ class Base64IO(io.IOBase):
avoid data loss. If used as a context manager, we take care of that for you.

:param wrapped: Stream to wrap
:param ignore_whitespace: If set, removes whitespace characters in the stream before decoding
"""

closed = False

def __init__(self, wrapped):
def __init__(self, wrapped, ignore_whitespace=True):
# type: (Base64IO, IO) -> None
"""Check for required methods on wrapped stream and set up read buffer.

Expand All @@ -85,6 +86,7 @@ def __init__(self, wrapped):
if not all(hasattr(wrapped, attr) for attr in required_attrs):
raise TypeError("Base64IO wrapped object must have attributes: %s" % (repr(sorted(required_attrs)),))
super(Base64IO, self).__init__()
self.ignore_whitespace = ignore_whitespace
self.__wrapped = wrapped
self.__read_buffer = b""
self.__write_buffer = b""
Expand Down Expand Up @@ -226,8 +228,8 @@ def _read_additional_data_removing_whitespace(self, data, total_bytes_to_read):
# case the base64 module happily removes any whitespace.
return data

_data_buffer = io.BytesIO()
_data_buffer.write(b"".join(data.split()))
_data_buffer = io.BytesIO() if isinstance(data, bytes) else io.StringIO()
_data_buffer.write(type(data)().join(data.split()))
_remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell()

while _remaining_bytes_to_read > 0:
Expand All @@ -236,7 +238,7 @@ def _read_additional_data_removing_whitespace(self, data, total_bytes_to_read):
# No more data to read from wrapped stream.
break

_data_buffer.write(b"".join(_raw_additional_data.split()))
_data_buffer.write(type(data)().join(_raw_additional_data.split()))
_remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell()
return _data_buffer.getvalue()

Expand Down Expand Up @@ -271,10 +273,15 @@ def read(self, b=-1):

# Read encoded bytes from wrapped stream.
data = self.__wrapped.read(_bytes_to_read)
# Remove whitespace from read data and attempt to read more data to get the desired
# number of bytes.
if any([char.encode("utf-8") in data for char in string.whitespace]):
data = self._read_additional_data_removing_whitespace(data, _bytes_to_read)
if self.ignore_whitespace:
# Remove whitespace from read data and attempt to read more data to get the desired
# number of bytes.
if isinstance(data, bytes):
if any([char.encode("utf-8") in data for char in string.whitespace]):
data = self._read_additional_data_removing_whitespace(data, _bytes_to_read)
else:
if any([char in data for char in string.whitespace]):
data = self._read_additional_data_removing_whitespace(data, _bytes_to_read)

results = io.BytesIO()
# First, load any stashed bytes
Expand Down
43 changes: 43 additions & 0 deletions test/unit/test_base64_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import division

import base64
import binascii
import functools
import io
import math
Expand Down Expand Up @@ -154,6 +155,22 @@ def test_base64io_decode(bytes_to_generate, bytes_per_round, number_of_rounds, t
assert test == plaintext_source[:total_bytes_to_expect]


@pytest.mark.parametrize(
"bytes_to_generate, bytes_per_round, number_of_rounds, total_bytes_to_expect", build_test_cases()
)
def test_base64io_decode_str(bytes_to_generate, bytes_per_round, number_of_rounds, total_bytes_to_expect):
plaintext_source = os.urandom(bytes_to_generate)
plaintext_b64 = io.StringIO(base64.b64encode(plaintext_source).decode('ascii'))
plaintext_wrapped = Base64IO(plaintext_b64)

test = b""
for _round in range(number_of_rounds):
test += plaintext_wrapped.read(bytes_per_round)

assert len(test) == total_bytes_to_expect
assert test == plaintext_source[:total_bytes_to_expect]


@pytest.mark.parametrize(
"bytes_to_generate, bytes_per_round, number_of_rounds, total_bytes_to_expect", build_test_cases()
)
Expand Down Expand Up @@ -297,6 +314,32 @@ def test_base64io_decode_with_whitespace(plaintext_source, b64_plaintext_with_wh
assert test == plaintext_source[:read_bytes]


@pytest.mark.parametrize("plaintext_source, b64_plaintext_with_whitespace, read_bytes", build_whitespace_testcases())
def test_base64io_decode_with_whitespace_ignored(plaintext_source, b64_plaintext_with_whitespace, read_bytes):
try:
with Base64IO(io.BytesIO(b64_plaintext_with_whitespace), ignore_whitespace=False) as decoder:
test = decoder.read(read_bytes)
except binascii.Error as e:
assert e.args[0] == 'Incorrect padding'


@pytest.mark.parametrize("plaintext_source, b64_plaintext_with_whitespace, read_bytes", build_whitespace_testcases())
def test_base64io_decode_with_whitespace_str(plaintext_source, b64_plaintext_with_whitespace, read_bytes):
with Base64IO(io.StringIO(b64_plaintext_with_whitespace.decode('ascii'))) as decoder:
test = decoder.read(read_bytes)

assert test == plaintext_source[:read_bytes]


@pytest.mark.parametrize("plaintext_source, b64_plaintext_with_whitespace, read_bytes", build_whitespace_testcases())
def test_base64io_decode_with_whitespace_ignored_str(plaintext_source, b64_plaintext_with_whitespace, read_bytes):
try:
with Base64IO(io.StringIO(b64_plaintext_with_whitespace.decode('ascii')), ignore_whitespace=False) as decoder:
test = decoder.read(read_bytes)
except binascii.Error as e:
assert e.args[0] == 'Incorrect padding'


@pytest.mark.parametrize(
"plaintext_source, b64_plaintext_with_whitespace, read_bytes", ((b"\x00\x00\x00", b"AAAA", 3),)
)
Expand Down