diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index d7074a9d1972..2691857bf0a6 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -59,6 +59,7 @@ from apache_beam.coders import coder_impl from apache_beam.coders.avro_record import AvroRecord +from apache_beam.internal import cloudpickle_pickler from apache_beam.portability import common_urns from apache_beam.portability import python_urns from apache_beam.portability.api import beam_runner_api_pb2 @@ -93,6 +94,7 @@ 'AvroGenericCoder', 'BooleanCoder', 'BytesCoder', + 'CloudpickleCoder', 'DillCoder', 'FastPrimitivesCoder', 'FloatCoder', @@ -902,6 +904,13 @@ def _create_impl(self): return coder_impl.CallbackCoderImpl(maybe_dill_dumps, maybe_dill_loads) +class CloudpickleCoder(_PickleCoderBase): + """Coder using Apache Beam's vendored Cloudpickle pickler.""" + def _create_impl(self): + return coder_impl.CallbackCoderImpl( + cloudpickle_pickler.dumps, cloudpickle_pickler.loads) + + class DeterministicFastPrimitivesCoder(FastCoder): """Throws runtime errors when encoding non-deterministic values.""" def __init__(self, coder, step_label): diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 3ef95ede4e86..5ba16997d27b 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -215,6 +215,13 @@ def test_pickle_coder(self): coder = coders.PickleCoder() self.check_coder(coder, *self.test_values) + def test_cloudpickle_pickle_coder(self): + cell_value = (lambda x: lambda: x)(0).__closure__[0] + self.check_coder(coders.CloudpickleCoder(), 'a', 1, cell_value) + self.check_coder( + coders.TupleCoder((coders.VarIntCoder(), coders.CloudpickleCoder())), + (1, cell_value)) + def test_memoizing_pickle_coder(self): coder = coders._MemoizingPickleCoder() self.check_coder(coder, *self.test_values) diff --git a/sdks/python/apache_beam/ml/anomaly/transforms.py b/sdks/python/apache_beam/ml/anomaly/transforms.py index 5870878ec69c..ef5501b33786 100644 --- a/sdks/python/apache_beam/ml/anomaly/transforms.py +++ b/sdks/python/apache_beam/ml/anomaly/transforms.py @@ -25,7 +25,7 @@ from typing import Union import apache_beam as beam -from apache_beam.coders import DillCoder +from apache_beam.coders import CloudpickleCoder from apache_beam.ml.anomaly import aggregations from apache_beam.ml.anomaly.base import AggregationFn from apache_beam.ml.anomaly.base import AnomalyDetector @@ -57,7 +57,8 @@ class _ScoreAndLearnDoFn(beam.DoFn): then updates the model with the same data. It maintains the model state using Beam's state management. """ - MODEL_STATE_INDEX = ReadModifyWriteStateSpec('saved_model', DillCoder()) + MODEL_STATE_INDEX = ReadModifyWriteStateSpec( + 'saved_model', CloudpickleCoder()) def __init__(self, detector_spec: Spec): self._detector_spec = detector_spec @@ -227,7 +228,8 @@ class _StatefulThresholdDoFn(_BaseThresholdDoFn): AssertionError: If the provided `threshold_fn_spec` leads to the creation of a stateless `ThresholdFn`. """ - THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec('saved_tracker', DillCoder()) + THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec( + 'saved_tracker', CloudpickleCoder()) def __init__(self, threshold_fn_spec: Spec): assert isinstance(threshold_fn_spec.config, dict)