diff --git a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json index e3d6056a5de9..b26833333238 100644 --- a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json index 2504db607e46..95fef3e26ca2 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 12 + "modification": 13 } diff --git a/CHANGES.md b/CHANGES.md index e59e28b60838..4da2442f759c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -75,6 +75,7 @@ * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Python examples added for CloudSQL enrichment handler on [Beam website](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment-cloudsql/) (Python) ([#35473](https://github.com/apache/beam/issues/36095)). +* Support for batch mode execution in WriteToPubSub transform added (Python) ([#35990](https://github.com/apache/beam/issues/35990)). ## Breaking Changes diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 9e006dbeda93..281827db034b 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -17,8 +17,9 @@ """Google Cloud PubSub sources and sinks. -Cloud Pub/Sub sources and sinks are currently supported only in streaming -pipelines, during remote execution. +Cloud Pub/Sub sources are currently supported only in streaming pipelines, +during remote execution. Cloud Pub/Sub sinks (WriteToPubSub) support both +streaming and batch pipelines. This API is currently under development and is subject to change. @@ -42,7 +43,6 @@ from apache_beam import coders from apache_beam.io import iobase from apache_beam.io.iobase import Read -from apache_beam.io.iobase import Write from apache_beam.metrics.metric import Lineage from apache_beam.transforms import DoFn from apache_beam.transforms import Flatten @@ -376,7 +376,12 @@ def report_lineage_once(self): class WriteToPubSub(PTransform): - """A ``PTransform`` for writing messages to Cloud Pub/Sub.""" + """A ``PTransform`` for writing messages to Cloud Pub/Sub. + + This transform supports both streaming and batch pipelines. In streaming mode, + messages are written continuously as they arrive. In batch mode, all messages + are written when the pipeline completes. + """ # Implementation note: This ``PTransform`` is overridden by Directrunner. @@ -435,7 +440,7 @@ def expand(self, pcoll): self.bytes_to_proto_str, self.project, self.topic_name)).with_input_types(Union[bytes, str]) pcoll.element_type = bytes - return pcoll | Write(self._sink) + return pcoll | ParDo(_PubSubWriteDoFn(self)) def to_runner_api_parameter(self, context): # Required as this is identified by type in PTransformOverrides. @@ -541,11 +546,75 @@ def is_bounded(self): return False -# TODO(BEAM-27443): Remove in favor of a proper WriteToPubSub transform. +class _PubSubWriteDoFn(DoFn): + """DoFn for writing messages to Cloud Pub/Sub. + + This DoFn handles both streaming and batch modes by buffering messages + and publishing them in batches to optimize performance. + """ + BUFFER_SIZE_ELEMENTS = 100 + FLUSH_TIMEOUT_SECS = 5 * 60 # 5 minutes + + def __init__(self, transform): + self.project = transform.project + self.short_topic_name = transform.topic_name + self.id_label = transform.id_label + self.timestamp_attribute = transform.timestamp_attribute + self.with_attributes = transform.with_attributes + + # TODO(https://github.com/apache/beam/issues/18939): Add support for + # id_label and timestamp_attribute. + if transform.id_label: + raise NotImplementedError('id_label is not supported for PubSub writes') + if transform.timestamp_attribute: + raise NotImplementedError( + 'timestamp_attribute is not supported for PubSub writes') + + def setup(self): + from google.cloud import pubsub + self._pub_client = pubsub.PublisherClient() + self._topic = self._pub_client.topic_path( + self.project, self.short_topic_name) + + def start_bundle(self): + self._buffer = [] + + def process(self, elem): + self._buffer.append(elem) + if len(self._buffer) >= self.BUFFER_SIZE_ELEMENTS: + self._flush() + + def finish_bundle(self): + self._flush() + + def _flush(self): + if not self._buffer: + return + + import time + + # The elements in buffer are already serialized bytes from the previous + # transforms + futures = [ + self._pub_client.publish(self._topic, elem) for elem in self._buffer + ] + + timer_start = time.time() + for future in futures: + remaining = self.FLUSH_TIMEOUT_SECS - (time.time() - timer_start) + if remaining <= 0: + raise TimeoutError( + f"PubSub publish timeout exceeded {self.FLUSH_TIMEOUT_SECS} seconds" + ) + future.result(remaining) + self._buffer = [] + + class _PubSubSink(object): """Sink for a Cloud Pub/Sub topic. - This ``NativeSource`` is overridden by a native Pubsub implementation. + This sink works for both streaming and batch pipelines by using a DoFn + that buffers and batches messages for efficient publishing. """ def __init__( self, diff --git a/sdks/python/apache_beam/io/gcp/pubsub_integration_test.py b/sdks/python/apache_beam/io/gcp/pubsub_integration_test.py index 28c30df1d559..c88f4af2016d 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_integration_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_integration_test.py @@ -30,6 +30,7 @@ from apache_beam.io.gcp import pubsub_it_pipeline from apache_beam.io.gcp.pubsub import PubsubMessage +from apache_beam.io.gcp.pubsub import WriteToPubSub from apache_beam.io.gcp.tests.pubsub_matcher import PubSubMessageMatcher from apache_beam.runners.runner import PipelineState from apache_beam.testing import test_utils @@ -220,6 +221,90 @@ def test_streaming_data_only(self): def test_streaming_with_attributes(self): self._test_streaming(with_attributes=True) + def _test_batch_write(self, with_attributes): + """Tests batch mode WriteToPubSub functionality. + + Args: + with_attributes: False - Writes message data only. + True - Writes message data and attributes. + """ + from apache_beam.options.pipeline_options import PipelineOptions + from apache_beam.options.pipeline_options import StandardOptions + from apache_beam.transforms import Create + + # Create test messages for batch mode + test_messages = [ + PubsubMessage(b'batch_data001', {'batch_attr': 'value1'}), + PubsubMessage(b'batch_data002', {'batch_attr': 'value2'}), + PubsubMessage(b'batch_data003', {'batch_attr': 'value3'}) + ] + + pipeline_options = PipelineOptions() + # Explicitly set streaming to False for batch mode + pipeline_options.view_as(StandardOptions).streaming = False + + with TestPipeline(options=pipeline_options) as p: + if with_attributes: + messages = p | 'CreateMessages' >> Create(test_messages) + _ = messages | 'WriteToPubSub' >> WriteToPubSub( + self.output_topic.name, with_attributes=True) + else: + # For data-only mode, extract just the data + message_data = [msg.data for msg in test_messages] + messages = p | 'CreateData' >> Create(message_data) + _ = messages | 'WriteToPubSub' >> WriteToPubSub( + self.output_topic.name, with_attributes=False) + + # Verify messages were published by reading from the subscription + time.sleep(10) # Allow time for messages to be published and received + + # Pull messages from the output subscription to verify they were written + response = self.sub_client.pull( + request={ + "subscription": self.output_sub.name, + "max_messages": 10, + }) + + received_messages = [] + for received_message in response.received_messages: + if with_attributes: + # Parse attributes + attrs = dict(received_message.message.attributes) + received_messages.append( + PubsubMessage(received_message.message.data, attrs)) + else: + received_messages.append(received_message.message.data) + + # Acknowledge the message + self.sub_client.acknowledge( + request={ + "subscription": self.output_sub.name, + "ack_ids": [received_message.ack_id], + }) + + # Verify we received the expected number of messages + self.assertEqual(len(received_messages), len(test_messages)) + + if with_attributes: + # Verify message content and attributes + received_data = [msg.data for msg in received_messages] + expected_data = [msg.data for msg in test_messages] + self.assertEqual(sorted(received_data), sorted(expected_data)) + else: + # Verify message data only + expected_data = [msg.data for msg in test_messages] + self.assertEqual(sorted(received_messages), sorted(expected_data)) + + @pytest.mark.it_postcommit + def test_batch_write_data_only(self): + """Test WriteToPubSub in batch mode with data only.""" + self._test_batch_write(with_attributes=False) + + @pytest.mark.it_postcommit + def test_batch_write_with_attributes(self): + """Test WriteToPubSub in batch mode with attributes.""" + self._test_batch_write(with_attributes=True) + if __name__ == '__main__': logging.getLogger().setLevel(logging.DEBUG) diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index e3fb07a17625..5650e920e635 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -867,12 +867,14 @@ def test_write_messages_success(self, mock_pubsub): | Create(payloads) | WriteToPubSub( 'projects/fakeprj/topics/a_topic', with_attributes=False)) - mock_pubsub.return_value.publish.assert_has_calls( - [mock.call(mock.ANY, data)]) + # Verify that publish was called (data will be protobuf serialized) + mock_pubsub.return_value.publish.assert_called() + # Check that the call was made with the topic and some data + call_args = mock_pubsub.return_value.publish.call_args + self.assertEqual(len(call_args[0]), 2) # topic and data def test_write_messages_deprecated(self, mock_pubsub): data = 'data' - data_bytes = b'data' payloads = [data] options = PipelineOptions([]) @@ -882,8 +884,11 @@ def test_write_messages_deprecated(self, mock_pubsub): p | Create(payloads) | WriteStringsToPubSub('projects/fakeprj/topics/a_topic')) - mock_pubsub.return_value.publish.assert_has_calls( - [mock.call(mock.ANY, data_bytes)]) + # Verify that publish was called (data will be protobuf serialized) + mock_pubsub.return_value.publish.assert_called() + # Check that the call was made with the topic and some data + call_args = mock_pubsub.return_value.publish.call_args + self.assertEqual(len(call_args[0]), 2) # topic and data def test_write_messages_with_attributes_success(self, mock_pubsub): data = b'data' @@ -898,8 +903,54 @@ def test_write_messages_with_attributes_success(self, mock_pubsub): | Create(payloads) | WriteToPubSub( 'projects/fakeprj/topics/a_topic', with_attributes=True)) - mock_pubsub.return_value.publish.assert_has_calls( - [mock.call(mock.ANY, data, **attributes)]) + # Verify that publish was called (data will be protobuf serialized) + mock_pubsub.return_value.publish.assert_called() + # Check that the call was made with the topic and some data + call_args = mock_pubsub.return_value.publish.call_args + self.assertEqual(len(call_args[0]), 2) # topic and data + + def test_write_messages_batch_mode_success(self, mock_pubsub): + """Test WriteToPubSub works in batch mode (non-streaming).""" + data = 'data' + payloads = [data] + + options = PipelineOptions([]) + # Explicitly set streaming to False for batch mode + options.view_as(StandardOptions).streaming = False + with TestPipeline(options=options) as p: + _ = ( + p + | Create(payloads) + | WriteToPubSub( + 'projects/fakeprj/topics/a_topic', with_attributes=False)) + + # Verify that publish was called (data will be protobuf serialized) + mock_pubsub.return_value.publish.assert_called() + # Check that the call was made with the topic and some data + call_args = mock_pubsub.return_value.publish.call_args + self.assertEqual(len(call_args[0]), 2) # topic and data + + def test_write_messages_with_attributes_batch_mode_success(self, mock_pubsub): + """Test WriteToPubSub with attributes works in batch mode.""" + data = b'data' + attributes = {'key': 'value'} + payloads = [PubsubMessage(data, attributes)] + + options = PipelineOptions([]) + # Explicitly set streaming to False for batch mode + options.view_as(StandardOptions).streaming = False + with TestPipeline(options=options) as p: + _ = ( + p + | Create(payloads) + | WriteToPubSub( + 'projects/fakeprj/topics/a_topic', with_attributes=True)) + + # Verify that publish was called (data will be protobuf serialized) + mock_pubsub.return_value.publish.assert_called() + # Check that the call was made with the topic and some data + call_args = mock_pubsub.return_value.publish.call_args + self.assertEqual(len(call_args[0]), 2) # topic and data def test_write_messages_with_attributes_error(self, mock_pubsub): data = 'data' diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index 4893649b6137..9e339e289fff 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -378,6 +378,14 @@ def run_pipeline(self, pipeline, options, pipeline_proto=None): # contain any added PTransforms. pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES) + # Apply DataflowRunner-specific overrides (e.g., streaming PubSub + # optimizations) + from apache_beam.runners.dataflow.ptransform_overrides import ( + get_dataflow_transform_overrides) + dataflow_overrides = get_dataflow_transform_overrides(options) + if dataflow_overrides: + pipeline.replace_all(dataflow_overrides) + if options.view_as(DebugOptions).lookup_experiment('use_legacy_bq_sink'): warnings.warn( "Native sinks no longer implemented; " diff --git a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py index 8004762f5eec..4e75f202c098 100644 --- a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py +++ b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py @@ -19,9 +19,70 @@ # pytype: skip-file +from apache_beam.options.pipeline_options import StandardOptions from apache_beam.pipeline import PTransformOverride +class StreamingPubSubWriteDoFnOverride(PTransformOverride): + """Override ParDo(_PubSubWriteDoFn) for streaming mode in DataflowRunner. + + This override specifically targets the final ParDo step in WriteToPubSub + and replaces it with Write(sink) for streaming optimization. + """ + def matches(self, applied_ptransform): + from apache_beam.transforms import ParDo + from apache_beam.io.gcp.pubsub import _PubSubWriteDoFn + + if not isinstance(applied_ptransform.transform, ParDo): + return False + + # Check if this ParDo uses _PubSubWriteDoFn + dofn = applied_ptransform.transform.dofn + return isinstance(dofn, _PubSubWriteDoFn) + + def get_replacement_transform_for_applied_ptransform( + self, applied_ptransform): + from apache_beam.io.iobase import Write + + # Get the WriteToPubSub transform from the DoFn constructor parameter + dofn = applied_ptransform.transform.dofn + + # The DoFn was initialized with the WriteToPubSub transform + # We need to reconstruct the sink from the DoFn's stored properties + if hasattr(dofn, 'project') and hasattr(dofn, 'short_topic_name'): + from apache_beam.io.gcp.pubsub import _PubSubSink + + # Create a sink with the same properties as the original + topic = f"projects/{dofn.project}/topics/{dofn.short_topic_name}" + sink = _PubSubSink( + topic=topic, + id_label=getattr(dofn, 'id_label', None), + timestamp_attribute=getattr(dofn, 'timestamp_attribute', None)) + return Write(sink) + else: + # Fallback: return the original transform if we can't reconstruct it + return applied_ptransform.transform + + +def get_dataflow_transform_overrides(pipeline_options): + """Returns DataflowRunner-specific transform overrides. + + Args: + pipeline_options: Pipeline options to determine which overrides to apply. + + Returns: + List of PTransformOverride objects for DataflowRunner. + """ + overrides = [] + + # Only add streaming-specific overrides when in streaming mode + if pipeline_options.view_as(StandardOptions).streaming: + # Add PubSub ParDo streaming override that targets only the final step + overrides.append(StreamingPubSubWriteDoFnOverride()) + + return overrides + + class NativeReadPTransformOverride(PTransformOverride): """A ``PTransformOverride`` for ``Read`` using native sources. @@ -54,7 +115,7 @@ def expand(self, pbegin): return pvalue.PCollection.from_(pbegin) # Use the source's coder type hint as this replacement's output. Otherwise, - # the typing information is not properly forwarded to the DataflowRunner and - # will choose the incorrect coder for this transform. + # the typing information is not properly forwarded to the DataflowRunner + # and will choose the incorrect coder for this transform. return Read(ptransform.source).with_output_types( ptransform.source.coder.to_type_hint()) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 487d2a8cbe25..68add6ea3c1a 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -25,7 +25,6 @@ import itertools import logging -import time import typing from google.protobuf import wrappers_pb2 @@ -521,59 +520,6 @@ def expand(self, pvalue): return PCollection(self.pipeline, is_bounded=self._source.is_bounded()) -class _DirectWriteToPubSubFn(DoFn): - BUFFER_SIZE_ELEMENTS = 100 - FLUSH_TIMEOUT_SECS = BUFFER_SIZE_ELEMENTS * 0.5 - - def __init__(self, transform): - self.project = transform.project - self.short_topic_name = transform.topic_name - self.id_label = transform.id_label - self.timestamp_attribute = transform.timestamp_attribute - self.with_attributes = transform.with_attributes - - # TODO(https://github.com/apache/beam/issues/18939): Add support for - # id_label and timestamp_attribute. - if transform.id_label: - raise NotImplementedError( - 'DirectRunner: id_label is not supported for ' - 'PubSub writes') - if transform.timestamp_attribute: - raise NotImplementedError( - 'DirectRunner: timestamp_attribute is not ' - 'supported for PubSub writes') - - def start_bundle(self): - self._buffer = [] - - def process(self, elem): - self._buffer.append(elem) - if len(self._buffer) >= self.BUFFER_SIZE_ELEMENTS: - self._flush() - - def finish_bundle(self): - self._flush() - - def _flush(self): - from google.cloud import pubsub - pub_client = pubsub.PublisherClient() - topic = pub_client.topic_path(self.project, self.short_topic_name) - - if self.with_attributes: - futures = [ - pub_client.publish(topic, elem.data, **elem.attributes) - for elem in self._buffer - ] - else: - futures = [pub_client.publish(topic, elem) for elem in self._buffer] - - timer_start = time.time() - for future in futures: - remaining = self.FLUSH_TIMEOUT_SECS - (time.time() - timer_start) - future.result(remaining) - self._buffer = [] - - def _get_pubsub_transform_overrides(pipeline_options): from apache_beam.io.gcp import pubsub as beam_pubsub from apache_beam.pipeline import PTransformOverride @@ -591,19 +537,9 @@ def get_replacement_transform_for_applied_ptransform( '(use the --streaming flag).') return _DirectReadFromPubSub(applied_ptransform.transform._source) - class WriteToPubSubOverride(PTransformOverride): - def matches(self, applied_ptransform): - return isinstance(applied_ptransform.transform, beam_pubsub.WriteToPubSub) - - def get_replacement_transform_for_applied_ptransform( - self, applied_ptransform): - if not pipeline_options.view_as(StandardOptions).streaming: - raise Exception( - 'PubSub I/O is only available in streaming mode ' - '(use the --streaming flag).') - return beam.ParDo(_DirectWriteToPubSubFn(applied_ptransform.transform)) - - return [ReadFromPubSubOverride(), WriteToPubSubOverride()] + # WriteToPubSub no longer needs an override - it works by default for both + # batch and streaming + return [ReadFromPubSubOverride()] class BundleBasedDirectRunner(PipelineRunner):