Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@

from apache_beam.coders import coder_impl
from apache_beam.coders.avro_record import AvroRecord
from apache_beam.internal import cloudpickle_pickler
from apache_beam.portability import common_urns
from apache_beam.portability import python_urns
from apache_beam.portability.api import beam_runner_api_pb2
Expand Down Expand Up @@ -93,6 +94,7 @@
'AvroGenericCoder',
'BooleanCoder',
'BytesCoder',
'CloudpickleCoder',
'DillCoder',
'FastPrimitivesCoder',
'FloatCoder',
Expand Down Expand Up @@ -902,6 +904,13 @@ def _create_impl(self):
return coder_impl.CallbackCoderImpl(maybe_dill_dumps, maybe_dill_loads)


class CloudpickleCoder(_PickleCoderBase):
"""Coder using Apache Beam's vendored Cloudpickle pickler."""
def _create_impl(self):
return coder_impl.CallbackCoderImpl(
cloudpickle_pickler.dumps, cloudpickle_pickler.loads)


class DeterministicFastPrimitivesCoder(FastCoder):
"""Throws runtime errors when encoding non-deterministic values."""
def __init__(self, coder, step_label):
Expand Down
7 changes: 7 additions & 0 deletions sdks/python/apache_beam/coders/coders_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ def test_pickle_coder(self):
coder = coders.PickleCoder()
self.check_coder(coder, *self.test_values)

def test_cloudpickle_pickle_coder(self):
cell_value = (lambda x: lambda: x)(0).__closure__[0]
self.check_coder(coders.CloudpickleCoder(), 'a', 1, cell_value)
self.check_coder(
coders.TupleCoder((coders.VarIntCoder(), coders.CloudpickleCoder())),
(1, cell_value))

def test_memoizing_pickle_coder(self):
coder = coders._MemoizingPickleCoder()
self.check_coder(coder, *self.test_values)
Expand Down
8 changes: 5 additions & 3 deletions sdks/python/apache_beam/ml/anomaly/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import Union

import apache_beam as beam
from apache_beam.coders import DillCoder
from apache_beam.coders import CloudpickleCoder
from apache_beam.ml.anomaly import aggregations
from apache_beam.ml.anomaly.base import AggregationFn
from apache_beam.ml.anomaly.base import AnomalyDetector
Expand Down Expand Up @@ -57,7 +57,8 @@ class _ScoreAndLearnDoFn(beam.DoFn):
then updates the model with the same data. It maintains the model state
using Beam's state management.
"""
MODEL_STATE_INDEX = ReadModifyWriteStateSpec('saved_model', DillCoder())
MODEL_STATE_INDEX = ReadModifyWriteStateSpec(
'saved_model', CloudpickleCoder())

def __init__(self, detector_spec: Spec):
self._detector_spec = detector_spec
Expand Down Expand Up @@ -227,7 +228,8 @@ class _StatefulThresholdDoFn(_BaseThresholdDoFn):
AssertionError: If the provided `threshold_fn_spec` leads to the
creation of a stateless `ThresholdFn`.
"""
THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec('saved_tracker', DillCoder())
THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec(
'saved_tracker', CloudpickleCoder())

def __init__(self, threshold_fn_spec: Spec):
assert isinstance(threshold_fn_spec.config, dict)
Expand Down
Loading