diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index cb23e3967e33..e7bf9e02b02a 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -64,6 +64,7 @@ from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.typehints import typehints from apache_beam.utils import proto_utils +from apache_beam.utils import windowed_value if TYPE_CHECKING: from apache_beam.coders.typecoders import CoderRegistry @@ -113,7 +114,8 @@ 'WindowedValueCoder', 'ParamWindowedValueCoder', 'BigIntegerCoder', - 'DecimalCoder' + 'DecimalCoder', + 'PaneInfoCoder' ] T = TypeVar('T') @@ -1753,6 +1755,24 @@ def __hash__(self): return hash(type(self)) +class PaneInfoCoder(FastCoder): + def _create_impl(self): + return coder_impl.PaneInfoCoderImpl() + + def is_deterministic(self): + # type: () -> bool + return True + + def to_type_hint(self): + return windowed_value.PaneInfo + + def __eq__(self, other): + return type(self) == type(other) + + def __hash__(self): + return hash(type(self)) + + class DecimalCoder(FastCoder): def _create_impl(self): return coder_impl.DecimalCoderImpl() diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index bed93cbc5545..21de0e70d800 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -362,6 +362,18 @@ def test_interval_window_coder(self): coders.TupleCoder((coders.IntervalWindowCoder(), )), (window.IntervalWindow(0, 10), )) + def test_paneinfo_window_coder(self): + self.check_coder( + coders.PaneInfoCoder(), + *[ + windowed_value.PaneInfo( + is_first=y == 0, + is_last=y == 9, + timing=windowed_value.PaneInfoTiming.EARLY, + index=y, + nonspeculative_index=-1) for y in range(0, 10) + ]) + def test_timestamp_coder(self): self.check_coder( coders.TimestampCoder(), @@ -539,6 +551,7 @@ def test_windowed_value_coder(self): def test_param_windowed_value_coder(self): from apache_beam.transforms.window import IntervalWindow from apache_beam.utils.windowed_value import PaneInfo + # pylint: disable=too-many-function-args wv = windowed_value.create( b'', # Milliseconds to microseconds diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index 892f508d0136..19300c675596 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -73,6 +73,7 @@ def MakeXyzs(v): from apache_beam.coders import coders from apache_beam.typehints import typehints +from apache_beam.utils import windowed_value __all__ = ['registry'] @@ -92,6 +93,7 @@ def register_standard_coders(self, fallback_coder): self._register_coder_internal(bytes, coders.BytesCoder) self._register_coder_internal(bool, coders.BooleanCoder) self._register_coder_internal(str, coders.StrUtf8Coder) + self._register_coder_internal(windowed_value.PaneInfo, coders.PaneInfoCoder) self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder) self._register_coder_internal(typehints.DictConstraint, coders.MapCoder) self._register_coder_internal( diff --git a/sdks/python/apache_beam/coders/typecoders_test.py b/sdks/python/apache_beam/coders/typecoders_test.py index 3adc8255409d..3c59cff68651 100644 --- a/sdks/python/apache_beam/coders/typecoders_test.py +++ b/sdks/python/apache_beam/coders/typecoders_test.py @@ -24,6 +24,7 @@ from apache_beam.coders import typecoders from apache_beam.internal import pickler from apache_beam.typehints import typehints +from apache_beam.utils import windowed_value class CustomClass(object): @@ -141,6 +142,24 @@ def test_nullable_coder(self): self.assertEqual(expected_coder.encode(None), real_coder.encode(None)) self.assertEqual(expected_coder.encode(b'abc'), real_coder.encode(b'abc')) + def test_paneinfo_coder(self): + expected_coder = coders.PaneInfoCoder() + real_coder = typecoders.registry.get_coder(windowed_value.PaneInfo) + self.assertEqual(expected_coder, real_coder) + for i in range(10): + pane_info = windowed_value.PaneInfo( + is_first=i==0, + is_last=i==9, + timing=windowed_value.PaneInfoTiming.EARLY, # 0 + index=i, + nonspeculative_index=-1 + ) + + encoded = real_coder.encode(pane_info) + + self.assertEqual(expected_coder.encode(pane_info), encoded) + self.assertEqual(pane_info, real_coder.decode(encoded)) + if __name__ == '__main__': unittest.main()