Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdks/python/apache_beam/coders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import

from apache_beam.coders.coders import *
from apache_beam.coders.typecoders import registry
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/coders/coder_impl.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ cdef class DeterministicFastPrimitivesCoderImpl(CoderImpl):

cdef object NoneType
cdef char UNKNOWN_TYPE, NONE_TYPE, INT_TYPE, FLOAT_TYPE, BOOL_TYPE
cdef char STR_TYPE, UNICODE_TYPE, LIST_TYPE, TUPLE_TYPE, DICT_TYPE, SET_TYPE
cdef char BYTES_TYPE, UNICODE_TYPE, LIST_TYPE, TUPLE_TYPE, DICT_TYPE, SET_TYPE

cdef class FastPrimitivesCoderImpl(StreamCoderImpl):
cdef CoderImpl fallback_coder_impl
@cython.locals(unicode_value=unicode, dict_value=dict)
@cython.locals(dict_value=dict)
cpdef encode_to_stream(self, value, OutputStream stream, bint nested)


Expand Down
63 changes: 38 additions & 25 deletions sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@
This module may be optionally compiled with Cython, using the corresponding
coder_impl.pxd file for type hints.

Py2/3 porting: Native range is used on both python versions instead of
future.builtins.range to avoid performance regression in Cython compiled code.

For internal use only; no backwards-compatibility guarantees.
"""
from __future__ import absolute_import
from __future__ import division

from types import NoneType

import six
import sys
from builtins import chr
from builtins import object

from apache_beam.coders import observable
from apache_beam.utils import windowed_value
Expand All @@ -54,10 +58,12 @@
from .slow_stream import get_varint_size
# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports

try:
long # Python 2
except NameError:
long = int # Python 3
try: # Python 2
long # pylint: disable=long-builtin
unicode # pylint: disable=unicode-builtin
except NameError: # Python 3
long = int
unicode = str


class CoderImpl(object):
Expand Down Expand Up @@ -199,7 +205,7 @@ def __init__(self, coder, step_label):
self._step_label = step_label

def _check_safe(self, value):
if isinstance(value, (str, six.text_type, long, int, float)):
if isinstance(value, (bytes, unicode, long, int, float)):
pass
elif value is None:
pass
Expand Down Expand Up @@ -253,7 +259,7 @@ def decode(self, encoded):
NONE_TYPE = 0
INT_TYPE = 1
FLOAT_TYPE = 2
STR_TYPE = 3
BYTES_TYPE = 3
UNICODE_TYPE = 4
BOOL_TYPE = 9
LIST_TYPE = 5
Expand All @@ -279,21 +285,21 @@ def get_estimated_size_and_observables(self, value, nested=False):

def encode_to_stream(self, value, stream, nested):
t = type(value)
if t is NoneType:
if value is None:
stream.write_byte(NONE_TYPE)
elif t is int:
stream.write_byte(INT_TYPE)
stream.write_var_int64(value)
elif t is float:
stream.write_byte(FLOAT_TYPE)
stream.write_bigendian_double(value)
elif t is str:
stream.write_byte(STR_TYPE)
elif t is bytes:
stream.write_byte(BYTES_TYPE)
stream.write(value, nested)
elif t is six.text_type:
unicode_value = value # for typing
elif t is unicode:
text_value = value # for typing
stream.write_byte(UNICODE_TYPE)
stream.write(unicode_value.encode('utf-8'), nested)
stream.write(text_value.encode('utf-8'), nested)
elif t is list or t is tuple or t is set:
stream.write_byte(
LIST_TYPE if t is list else TUPLE_TYPE if t is tuple else SET_TYPE)
Expand All @@ -304,7 +310,13 @@ def encode_to_stream(self, value, stream, nested):
dict_value = value # for typing
stream.write_byte(DICT_TYPE)
stream.write_var_int64(len(dict_value))
for k, v in dict_value.iteritems():
# Use iteritems() on Python 2 instead of future.builtins.iteritems to
# avoid performance regression in Cython compiled code.
if sys.version_info[0] == 2:
items = dict_value.iteritems() # pylint: disable=dict-iter-method
else:
items = dict_value.items()
for k, v in items:
self.encode_to_stream(k, stream, True)
self.encode_to_stream(v, stream, True)
elif t is bool:
Expand All @@ -322,7 +334,7 @@ def decode_from_stream(self, stream, nested):
return stream.read_var_int64()
elif t == FLOAT_TYPE:
return stream.read_bigendian_double()
elif t == STR_TYPE:
elif t == BYTES_TYPE:
return stream.read_all(nested)
elif t == UNICODE_TYPE:
return stream.read_all(nested).decode('utf-8')
Expand Down Expand Up @@ -394,8 +406,9 @@ def _from_normal_time(self, value):

def encode_to_stream(self, value, out, nested):
span_micros = value.end.micros - value.start.micros
out.write_bigendian_uint64(self._from_normal_time(value.end.micros / 1000))
out.write_var_int64(span_micros / 1000)
out.write_bigendian_uint64(
self._from_normal_time(value.end.micros // 1000))
out.write_var_int64(span_micros // 1000)

def decode_from_stream(self, in_, nested):
end_millis = self._to_normal_time(in_.read_bigendian_uint64())
Expand All @@ -409,7 +422,7 @@ def estimate_size(self, value, nested=False):
# An IntervalWindow is context-insensitive, with a timestamp (8 bytes)
# and a varint timespam.
span = value.end.micros - value.start.micros
return 8 + get_varint_size(span / 1000)
return 8 + get_varint_size(span // 1000)


class TimestampCoderImpl(StreamCoderImpl):
Expand All @@ -427,7 +440,7 @@ def estimate_size(self, unused_value, nested=False):
return 8


small_ints = [chr(_) for _ in range(128)]
small_ints = [chr(_).encode('latin-1') for _ in range(128)]


class VarIntCoderImpl(StreamCoderImpl):
Expand Down Expand Up @@ -474,7 +487,7 @@ def decode_from_stream(self, stream, nested):
return self._value

def encode(self, value):
b = '' # avoid byte vs str vs unicode error
b = b'' # avoid byte vs str vs unicode error
return b

def decode(self, encoded):
Expand Down Expand Up @@ -783,7 +796,7 @@ def encode_to_stream(self, value, out, nested):
# TODO(BEAM-1524): Clean this up once we have a BEAM wide consensus on
# precision of timestamps.
self._from_normal_time(
restore_sign * (abs(wv.timestamp_micros) / 1000)))
restore_sign * (abs(wv.timestamp_micros) // 1000)))
self._windows_coder.encode_to_stream(wv.windows, out, True)
# Default PaneInfo encoded byte representing NO_FIRING.
self._pane_info_coder.encode_to_stream(wv.pane_info, out, True)
Expand All @@ -797,9 +810,9 @@ def decode_from_stream(self, in_stream, nested):
# were indeed MIN/MAX timestamps.
# TODO(BEAM-1524): Clean this up once we have a BEAM wide consensus on
# precision of timestamps.
if timestamp == -(abs(MIN_TIMESTAMP.micros) / 1000):
if timestamp == -(abs(MIN_TIMESTAMP.micros) // 1000):
timestamp = MIN_TIMESTAMP.micros
elif timestamp == (MAX_TIMESTAMP.micros / 1000):
elif timestamp == (MAX_TIMESTAMP.micros // 1000):
timestamp = MAX_TIMESTAMP.micros
else:
timestamp *= 1000
Expand Down
18 changes: 14 additions & 4 deletions sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from __future__ import absolute_import

import base64
import cPickle as pickle
from builtins import object

import google.protobuf
from google.protobuf import wrappers_pb2
Expand All @@ -33,6 +33,12 @@
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.utils import proto_utils

# This is for py2/3 compatibility. cPickle was renamed pickle in python 3.
try:
import cPickle as pickle # Python 2
except ImportError:
import pickle # Python 3

# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
try:
from .stream import get_varint_size
Expand Down Expand Up @@ -210,11 +216,15 @@ def as_cloud_object(self):
def __repr__(self):
return self.__class__.__name__

# pylint: disable=protected-access
def __eq__(self, other):
# pylint: disable=protected-access
return (self.__class__ == other.__class__
and self._dict_without_impl() == other._dict_without_impl())
# pylint: enable=protected-access

def __hash__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change required by Python3 migration or we are just fixing an omission that hash was not previously defined, while eq was?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Python 2, __hash__ defaults to give a different value for each instance. On Python 3, __hash__ defaults to None if __eq__ is implemented. By implementing __hash__, we get consistent behavior on both versions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revisiting this in light of other PRs. I think, it would be safer to guarantee the contract that hash does not change for the same object if we compute it here based on object type, sent #5390.
Another possibility to guarantee consistent behavior between Python 2 and 3 would be to set __hash__ = None if we can infer that a class is obviously non-hashable.

Copy link
Contributor Author

@RobbeSneyders RobbeSneyders May 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also use the id which is guaranteed to stay the same for an object:
hash(id(self))
The default Python 2 hash also relies on id.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. Although that wouldn't honor the contract between eq and hash.

return hash((self.__class__,) +
tuple(sorted(self._dict_without_impl().items())))
# pylint: enable=protected-access

_known_urns = {}

Expand Down Expand Up @@ -312,7 +322,7 @@ class ToStringCoder(Coder):

def encode(self, value):
try: # Python 2
if isinstance(value, unicode):
if isinstance(value, unicode): # pylint: disable=unicode-builtin
return value.encode('utf-8')
except NameError: # Python 3
pass
Expand Down
16 changes: 7 additions & 9 deletions sdks/python/apache_beam/coders/coders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import absolute_import

import base64
import logging
import unittest
from builtins import object

from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
from apache_beam.coders import coders
Expand Down Expand Up @@ -47,17 +48,11 @@ def test_equality(self):
class CodersTest(unittest.TestCase):

def test_str_utf8_coder(self):
real_coder = coders_registry.get_coder(str)
expected_coder = coders.BytesCoder()
self.assertEqual(
real_coder.encode('abc'), expected_coder.encode('abc'))
self.assertEqual('abc', real_coder.decode(real_coder.encode('abc')))

real_coder = coders_registry.get_coder(bytes)
expected_coder = coders.BytesCoder()
self.assertEqual(
real_coder.encode('abc'), expected_coder.encode('abc'))
self.assertEqual('abc', real_coder.decode(real_coder.encode('abc')))
real_coder.encode(b'abc'), expected_coder.encode(b'abc'))
self.assertEqual(b'abc', real_coder.decode(real_coder.encode(b'abc')))


# The test proto message file was generated by running the following:
Expand Down Expand Up @@ -99,6 +94,9 @@ def __eq__(self, other):
return True
return False

def __hash__(self):
return hash(type(self))


class FallbackCoderTest(unittest.TestCase):

Expand Down
39 changes: 20 additions & 19 deletions sdks/python/apache_beam/coders/coders_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
import math
import unittest
from builtins import range

import dill

Expand Down Expand Up @@ -103,7 +104,7 @@ def test_custom_coder(self):

self.check_coder(CustomCoder(), 1, -10, 5)
self.check_coder(coders.TupleCoder((CustomCoder(), coders.BytesCoder())),
(1, 'a'), (-10, 'b'), (5, 'c'))
(1, b'a'), (-10, b'b'), (5, b'c'))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not critical, but looks like 'a' is not replaced with b'a' here - are these changes done by some tool or manually?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is aimed at the 'a' in line 109?
The marking of the strings as bytes literals is done manually. I've only marked strings as bytes literals when it's clear that they're meant to represent bytes (when testing BytesCoder, when the content of the string are clearly bytes, ...). When a string is not marked, it represents str on both versions, which seems ok for the 'a' at line 109 for example.

def test_pickle_coder(self):
self.check_coder(coders.PickleCoder(), 'a', 1, 1.5, (1, 2, 3))
Expand All @@ -129,7 +130,7 @@ def test_dill_coder(self):

def test_fast_primitives_coder(self):
coder = coders.FastPrimitivesCoder(coders.SingletonCoder(len))
self.check_coder(coder, None, 1, -1, 1.5, 'str\0str', u'unicode\0\u0101')
self.check_coder(coder, None, 1, -1, 1.5, b'str\0str', u'unicode\0\u0101')
self.check_coder(coder, (), (1, 2, 3))
self.check_coder(coder, [], [1, 2, 3])
self.check_coder(coder, dict(), {'a': 'b'}, {0: dict(), 1: len})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here 'a' and 'b' are not bytestrings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above.
The unmarked 'a' and 'b' here represent str on both versions, which seems ok for this test.

Expand All @@ -139,7 +140,7 @@ def test_fast_primitives_coder(self):
self.check_coder(coders.TupleCoder((coder,)), ('a',), (1,))

def test_bytes_coder(self):
self.check_coder(coders.BytesCoder(), 'a', '\0', 'z' * 1000)
self.check_coder(coders.BytesCoder(), b'a', b'\0', b'z' * 1000)

def test_varint_coder(self):
# Small ints.
Expand Down Expand Up @@ -190,7 +191,7 @@ def test_timestamp_coder(self):
timestamp.Timestamp(micros=1234567890123456789))
self.check_coder(
coders.TupleCoder((coders.TimestampCoder(), coders.BytesCoder())),
(timestamp.Timestamp.of(27), 'abc'))
(timestamp.Timestamp.of(27), b'abc'))

def test_tuple_coder(self):
kv_coder = coders.TupleCoder((coders.VarIntCoder(), coders.BytesCoder()))
Expand All @@ -206,14 +207,14 @@ def test_tuple_coder(self):
kv_coder.as_cloud_object())
# Test binary representation
self.assertEqual(
'\x04abc',
kv_coder.encode((4, 'abc')))
b'\x04abc',
kv_coder.encode((4, b'abc')))
# Test unnested
self.check_coder(
kv_coder,
(1, 'a'),
(-2, 'a' * 100),
(300, 'abc\0' * 5))
(1, b'a'),
(-2, b'a' * 100),
(300, b'abc\0' * 5))
# Test nested
self.check_coder(
coders.TupleCoder(
Expand Down Expand Up @@ -322,12 +323,12 @@ def test_windowed_value_coder(self):
},
coder.as_cloud_object())
# Test binary representation
self.assertEqual('\x7f\xdf;dZ\x1c\xac\t\x00\x00\x00\x01\x0f\x01',
self.assertEqual(b'\x7f\xdf;dZ\x1c\xac\t\x00\x00\x00\x01\x0f\x01',
coder.encode(window.GlobalWindows.windowed_value(1)))

# Test decoding large timestamp
self.assertEqual(
coder.decode('\x7f\xdf;dZ\x1c\xac\x08\x00\x00\x00\x01\x0f\x00'),
coder.decode(b'\x7f\xdf;dZ\x1c\xac\x08\x00\x00\x00\x01\x0f\x00'),
windowed_value.create(0, MIN_TIMESTAMP.micros, (GlobalWindow(),)))

# Test unnested
Expand Down Expand Up @@ -364,7 +365,7 @@ def test_proto_coder(self):
proto_coder = coders.ProtoCoder(ma.__class__)
self.check_coder(proto_coder, ma)
self.check_coder(coders.TupleCoder((proto_coder, coders.BytesCoder())),
(ma, 'a'), (mb, 'b'))
(ma, b'a'), (mb, b'b'))

def test_global_window_coder(self):
coder = coders.GlobalWindowCoder()
Expand All @@ -391,16 +392,16 @@ def test_length_prefix_coder(self):
},
coder.as_cloud_object())
# Test binary representation
self.assertEqual('\x00', coder.encode(''))
self.assertEqual('\x01a', coder.encode('a'))
self.assertEqual('\x02bc', coder.encode('bc'))
self.assertEqual('\xff\x7f' + 'z' * 16383, coder.encode('z' * 16383))
self.assertEqual(b'\x00', coder.encode(b''))
self.assertEqual(b'\x01a', coder.encode(b'a'))
self.assertEqual(b'\x02bc', coder.encode(b'bc'))
self.assertEqual(b'\xff\x7f' + b'z' * 16383, coder.encode(b'z' * 16383))
# Test unnested
self.check_coder(coder, '', 'a', 'bc', 'def')
self.check_coder(coder, b'', b'a', b'bc', b'def')
# Test nested
self.check_coder(coders.TupleCoder((coder, coder)),
('', 'a'),
('bc', 'def'))
(b'', b'a'),
(b'bc', b'def'))

def test_nested_observables(self):
class FakeObservableIterator(observable.ObservableMixin):
Expand Down
Loading