Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/base64io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

try: # Python 3.5.0 and 3.5.1 have incompatible typing modules
from types import TracebackType # noqa pylint: disable=unused-import
from typing import IO, Iterable, List, Type, Optional # noqa pylint: disable=unused-import
from typing import Union, IO, Iterable, List, Type, Optional, AnyStr # noqa pylint: disable=unused-import
except ImportError: # pragma: no cover
# We only actually need these imports when running the mypy checks
pass
Expand Down Expand Up @@ -209,7 +209,7 @@ def writelines(self, lines):
self.write(line)

def _read_additional_data_removing_whitespace(self, data, total_bytes_to_read):
# type: (bytes, int) -> bytes
# type: (AnyStr, int) -> AnyStr
"""Read additional data from wrapped stream until we reach the desired number of bytes.

.. note::
Expand All @@ -226,19 +226,20 @@ def _read_additional_data_removing_whitespace(self, data, total_bytes_to_read):
# case the base64 module happily removes any whitespace.
return data

_data_buffer = io.BytesIO()
_data_buffer.write(b"".join(data.split()))
_remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell()
_data_buffer = io.BytesIO() if isinstance(data, bytes) else io.StringIO()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Method return type annotation needs to change to AnyStr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. The mypy tests are failing but I'm not sure why - I haven't used mypy previously. It looks like a false positive - maybe mypy is not able to handle the branching logic that handles bytes and str differently?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The initial problem is just that you need to add an import for AnyStr in the types imports at the top. I'm not sure if the other issues it's running into are just a side effect of missing that or not.

join_char = b'' if isinstance(data, bytes) else u''
_data_buffer.write(join_char.join(data.split())) # type: ignore
_remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() # type: ignore

while _remaining_bytes_to_read > 0:
_raw_additional_data = self.__wrapped.read(_remaining_bytes_to_read)
if not _raw_additional_data:
# No more data to read from wrapped stream.
break

_data_buffer.write(b"".join(_raw_additional_data.split()))
_remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell()
return _data_buffer.getvalue()
_data_buffer.write(join_char.join(_raw_additional_data.split())) # type: ignore
_remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() # type: ignore
return _data_buffer.getvalue() # type: ignore

def read(self, b=-1):
# type: (int) -> bytes
Expand Down Expand Up @@ -273,7 +274,10 @@ def read(self, b=-1):
data = self.__wrapped.read(_bytes_to_read)
# Remove whitespace from read data and attempt to read more data to get the desired
# number of bytes.
if any([char.encode("utf-8") in data for char in string.whitespace]):
whitespace = string.whitespace.encode("utf-8") if isinstance(data, bytes) \
else string.whitespace # type: Union[bytes, str]

if any([char in data for char in whitespace]):
data = self._read_additional_data_removing_whitespace(data, _bytes_to_read)

results = io.BytesIO()
Expand Down
24 changes: 24 additions & 0 deletions test/unit/test_base64_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,22 @@ def test_base64io_decode(bytes_to_generate, bytes_per_round, number_of_rounds, t
assert test == plaintext_source[:total_bytes_to_expect]


@pytest.mark.parametrize(
"bytes_to_generate, bytes_per_round, number_of_rounds, total_bytes_to_expect", build_test_cases()
)
def test_base64io_decode_str(bytes_to_generate, bytes_per_round, number_of_rounds, total_bytes_to_expect):
plaintext_source = os.urandom(bytes_to_generate)
plaintext_b64 = io.StringIO(base64.b64encode(plaintext_source).decode('ascii'))
plaintext_wrapped = Base64IO(plaintext_b64)

test = b""
for _round in range(number_of_rounds):
test += plaintext_wrapped.read(bytes_per_round)

assert len(test) == total_bytes_to_expect
assert test == plaintext_source[:total_bytes_to_expect]


@pytest.mark.parametrize(
"bytes_to_generate, bytes_per_round, number_of_rounds, total_bytes_to_expect", build_test_cases()
)
Expand Down Expand Up @@ -297,6 +313,14 @@ def test_base64io_decode_with_whitespace(plaintext_source, b64_plaintext_with_wh
assert test == plaintext_source[:read_bytes]


@pytest.mark.parametrize("plaintext_source, b64_plaintext_with_whitespace, read_bytes", build_whitespace_testcases())
def test_base64io_decode_with_whitespace_str(plaintext_source, b64_plaintext_with_whitespace, read_bytes):
with Base64IO(io.StringIO(b64_plaintext_with_whitespace.decode('ascii'))) as decoder:
test = decoder.read(read_bytes)

assert test == plaintext_source[:read_bytes]


@pytest.mark.parametrize(
"plaintext_source, b64_plaintext_with_whitespace, read_bytes", ((b"\x00\x00\x00", b"AAAA", 3),)
)
Expand Down