diff --git a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json index 19ebbfb9ad92..613d20725f9b 100644 --- a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json @@ -1,3 +1,4 @@ { - "https://github.com/apache/beam/pull/35951": "triggering sideinput test" -} + "comment": "Trigger file for PostCommit Python ValidatesRunner Dataflow tests", + "modification": 1 +} \ No newline at end of file diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json index 2504db607e46..95fef3e26ca2 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 12 + "modification": 13 } diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 9e006dbeda93..a29cecbea7bd 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -32,6 +32,7 @@ # pytype: skip-file import re +import time from typing import Any from typing import List from typing import NamedTuple @@ -482,6 +483,69 @@ def parse_subscription(full_subscription): return project, subscription_name +class _DirectWriteToPubSubFn(DoFn): + """DirectRunner implementation for WriteToPubSub. + + This DoFn handles writing messages to PubSub in the batch + mode. It buffers messages and flushes them in batches for efficiency. + """ + BUFFER_SIZE_ELEMENTS = 100 + FLUSH_TIMEOUT_SECS = BUFFER_SIZE_ELEMENTS * 0.5 + + def __init__(self, transform): + self.project = transform.project + self.short_topic_name = transform.topic_name + self.id_label = transform.id_label + self.timestamp_attribute = transform.timestamp_attribute + self.with_attributes = transform.with_attributes + + # TODO(https://github.com/apache/beam/issues/18939): Add support for + # id_label and timestamp_attribute. + if transform.id_label: + raise NotImplementedError( + 'DirectRunner: id_label is not supported for ' + 'PubSub writes') + if transform.timestamp_attribute: + raise NotImplementedError( + 'DirectRunner: timestamp_attribute is not ' + 'supported for PubSub writes') + + def start_bundle(self): + self._buffer = [] + + def process(self, elem): + self._buffer.append(elem) + if len(self._buffer) >= self.BUFFER_SIZE_ELEMENTS: + self._flush() + + def finish_bundle(self): + self._flush() + + def _flush(self): + if not self._buffer: + return + + if pubsub is None: + raise ImportError('Google Cloud PubSub is not available') + + pub_client = pubsub.PublisherClient() + topic = pub_client.topic_path(self.project, self.short_topic_name) + + if self.with_attributes: + futures = [ + pub_client.publish(topic, elem.data, **elem.attributes) + for elem in self._buffer + ] + else: + futures = [pub_client.publish(topic, elem) for elem in self._buffer] + + timer_start = time.time() + for future in futures: + remaining = self.FLUSH_TIMEOUT_SECS - (time.time() - timer_start) + future.result(remaining) + self._buffer = [] + + # TODO(BEAM-27443): Remove (or repurpose as a proper PTransform). class _PubSubSource(iobase.SourceBase): """Source for a Cloud Pub/Sub topic or subscription. diff --git a/sdks/python/apache_beam/io/gcp/pubsub_batch_integration_test.py b/sdks/python/apache_beam/io/gcp/pubsub_batch_integration_test.py new file mode 100644 index 000000000000..00d2d7e2b3b0 --- /dev/null +++ b/sdks/python/apache_beam/io/gcp/pubsub_batch_integration_test.py @@ -0,0 +1,156 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Integration test for Google Cloud Pub/Sub WriteToPubSub in batch mode. +""" +# pytype: skip-file + +import logging +import time +import unittest +import uuid + +import pytest + +import apache_beam as beam +from apache_beam.io.gcp.pubsub import PubsubMessage +from apache_beam.io.gcp.pubsub import WriteToPubSub +from apache_beam.io.gcp.tests.pubsub_matcher import PubSubMessageMatcher +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.testing import test_utils +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.transforms.core import Create + +OUTPUT_TOPIC = 'psit_batch_topic_output' +OUTPUT_SUB = 'psit_batch_subscription_output' + +# How long TestDataflowRunner will wait for batch pipeline to complete +TEST_PIPELINE_DURATION_MS = 10 * 60 * 1000 +# How long PubSubMessageMatcher will wait for the correct set of messages +# to appear +MESSAGE_MATCHER_TIMEOUT_S = 5 * 60 + + +class PubSubBatchIntegrationTest(unittest.TestCase): + """Integration test for WriteToPubSub in batch mode with DataflowRunner.""" + + # Test data for batch processing + INPUT_MESSAGES = [ + b'batch_data001', + b'batch_data002', + b'batch_data003\xab\xac', + b'batch_data004\xab\xac' + ] + + EXPECTED_OUTPUT_MESSAGES = [ + PubsubMessage(b'batch_data001-processed', {'batch_job': 'true'}), + PubsubMessage(b'batch_data002-processed', {'batch_job': 'true'}), + PubsubMessage(b'batch_data003\xab\xac-processed', {'batch_job': 'true'}), + PubsubMessage(b'batch_data004\xab\xac-processed', {'batch_job': 'true'}) + ] + + def setUp(self): + self.test_pipeline = TestPipeline(is_integration_test=True) + self.runner_name = type(self.test_pipeline.runner).__name__ + self.project = self.test_pipeline.get_option('project') + self.uuid = str(uuid.uuid4()) + + # Set up PubSub environment. + from google.cloud import pubsub + self.pub_client = pubsub.PublisherClient() + self.output_topic = self.pub_client.create_topic( + name=self.pub_client.topic_path(self.project, OUTPUT_TOPIC + self.uuid)) + + self.sub_client = pubsub.SubscriberClient() + self.output_sub = self.sub_client.create_subscription( + name=self.sub_client.subscription_path( + self.project, OUTPUT_SUB + self.uuid), + topic=self.output_topic.name) + # Add a 30 second sleep after resource creation to ensure subscriptions + # will receive messages. + time.sleep(30) + + def tearDown(self): + test_utils.cleanup_subscriptions(self.sub_client, [self.output_sub]) + test_utils.cleanup_topics(self.pub_client, [self.output_topic]) + + def _test_batch_write(self, with_attributes): + """Runs batch IT pipeline with WriteToPubSub. + + Args: + with_attributes: False - Writes message data only. + True - Writes message data and attributes. + """ + # Set up pipeline options for batch mode + pipeline_options = PipelineOptions( + self.test_pipeline.get_full_options_as_args()) + pipeline_options.view_as(StandardOptions).streaming = False # Batch mode + + expected_messages = self.EXPECTED_OUTPUT_MESSAGES + if not with_attributes: + expected_messages = [pubsub_msg.data for pubsub_msg in expected_messages] + + pubsub_msg_verifier = PubSubMessageMatcher( + self.project, + self.output_sub.name, + expected_messages, + timeout=MESSAGE_MATCHER_TIMEOUT_S, + with_attributes=with_attributes) + + with beam.Pipeline(options=pipeline_options) as p: + # Create input data + input_data = p | 'CreateInput' >> Create(self.INPUT_MESSAGES) + + # Process data + if with_attributes: + + def add_batch_attributes(data): + return PubsubMessage(data + b'-processed', {'batch_job': 'true'}) + + processed_data = ( + input_data | 'AddAttributes' >> beam.Map(add_batch_attributes)) + else: + processed_data = ( + input_data | 'ProcessData' >> beam.Map(lambda x: x + b'-processed')) + + # Write to PubSub using WriteToPubSub in batch mode + _ = processed_data | 'WriteToPubSub' >> WriteToPubSub( + self.output_topic.name, with_attributes=with_attributes) + + # Verify the results + pubsub_msg_verifier.verify() + + @pytest.mark.it_postcommit + def test_batch_write_data_only(self): + """Test WriteToPubSub in batch mode with data only.""" + if self.runner_name != 'TestDataflowRunner': + self.skipTest('This test is specifically for DataflowRunner batch mode') + self._test_batch_write(with_attributes=False) + + @pytest.mark.it_postcommit + def test_batch_write_with_attributes(self): + """Test WriteToPubSub in batch mode with attributes.""" + if self.runner_name != 'TestDataflowRunner': + self.skipTest('This test is specifically for DataflowRunner batch mode') + self._test_batch_write(with_attributes=True) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.DEBUG) + unittest.main() diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index e3fb07a17625..582245bdcb52 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -1047,6 +1047,51 @@ def test_write_to_pubsub_with_attributes_no_overwrite(self, unused_mock): Lineage.query(p.result.metrics(), Lineage.SINK), set(["pubsub:topic:fakeprj.a_topic"])) + def test_write_messages_batch_mode_success(self, mock_pubsub): + """Test that WriteToPubSub works in batch mode with DirectRunner.""" + data = 'data' + payloads = [data] + + options = PipelineOptions([]) + # Explicitly set streaming=False to test batch mode + options.view_as(StandardOptions).streaming = False + with TestPipeline(options=options) as p: + _ = ( + p + | Create(payloads) + | WriteToPubSub( + 'projects/fakeprj/topics/a_topic', with_attributes=False)) + + # Apply the necessary PTransformOverrides for DirectRunner + overrides = _get_transform_overrides(options) + p.replace_all(overrides) + + mock_pubsub.return_value.publish.assert_has_calls( + [mock.call(mock.ANY, data)]) + + def test_write_messages_with_attributes_batch_mode_success(self, mock_pubsub): + """Test that WriteToPubSub with attributes works in batch mode.""" + data = b'data' + attributes = {'key': 'value'} + payloads = [PubsubMessage(data, attributes)] + + options = PipelineOptions([]) + # Explicitly set streaming=False to test batch mode + options.view_as(StandardOptions).streaming = False + with TestPipeline(options=options) as p: + _ = ( + p + | Create(payloads) + | WriteToPubSub( + 'projects/fakeprj/topics/a_topic', with_attributes=True)) + + # Apply the necessary PTransformOverrides for DirectRunner + overrides = _get_transform_overrides(options) + p.replace_all(overrides) + + mock_pubsub.return_value.publish.assert_has_calls( + [mock.call(mock.ANY, data, **attributes)]) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index 4893649b6137..b9963654ed8c 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -376,7 +376,14 @@ def run_pipeline(self, pipeline, options, pipeline_proto=None): # Performing configured PTransform overrides. Note that this is currently # done before Runner API serialization, since the new proto needs to # contain any added PTransforms. - pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES) + overrides = list(DataflowRunner._PTRANSFORM_OVERRIDES) + + # Add WriteToPubSub batch mode override if not in streaming mode + if not options.view_as(StandardOptions).streaming: + from apache_beam.runners.dataflow.ptransform_overrides import WriteToPubSubBatchOverride + overrides.append(WriteToPubSubBatchOverride(options)) + + pipeline.replace_all(overrides) if options.view_as(DebugOptions).lookup_experiment('use_legacy_bq_sink'): warnings.warn( diff --git a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py index 8004762f5eec..82e774b5801b 100644 --- a/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py +++ b/sdks/python/apache_beam/runners/dataflow/ptransform_overrides.py @@ -19,6 +19,7 @@ # pytype: skip-file +import apache_beam as beam from apache_beam.pipeline import PTransformOverride @@ -58,3 +59,31 @@ def expand(self, pbegin): # will choose the incorrect coder for this transform. return Read(ptransform.source).with_output_types( ptransform.source.coder.to_type_hint()) + + +class WriteToPubSubBatchOverride(PTransformOverride): + """A ``PTransformOverride`` for ``WriteToPubSub`` in batch mode on Dataflow. + + This override enables WriteToPubSub to work in batch mode on DataflowRunner + by using the DirectRunner implementation which supports both streaming and + batch modes. + """ + def __init__(self, pipeline_options): + self.pipeline_options = pipeline_options + + def matches(self, applied_ptransform): + # Imported here to avoid circular dependencies. + from apache_beam.io.gcp import pubsub as beam_pubsub + from apache_beam.options.pipeline_options import StandardOptions + + # Only override WriteToPubSub in batch mode (non-streaming) + return ( + isinstance(applied_ptransform.transform, beam_pubsub.WriteToPubSub) and + not self.pipeline_options.view_as(StandardOptions).streaming) + + def get_replacement_transform(self, ptransform): + # Imported here to avoid circular dependencies. + from apache_beam.io.gcp import pubsub as beam_pubsub + + # Use the DirectRunner implementation which supports batch mode + return beam.ParDo(beam_pubsub._DirectWriteToPubSubFn(ptransform)) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 0af0ca8d3175..fbbeb57958e4 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -25,7 +25,6 @@ import itertools import logging -import time import typing from google.protobuf import wrappers_pb2 @@ -202,9 +201,17 @@ def visit_transform(self, applied_ptransform): # Use BundleBasedDirectRunner if other runners are missing needed features. runner = BundleBasedDirectRunner() + # Check if transform overrides are needed - if so, + # use BundleBasedDirectRunner + # since Prism does not support transform overrides + transform_overrides = _get_transform_overrides(options) + if transform_overrides: + _LOGGER.info( + 'Transform overrides detected, falling back to DirectRunner.') + runner = BundleBasedDirectRunner() # Check whether all transforms used in the pipeline are supported by the # PrismRunner - if _PrismRunnerSupportVisitor().accept(pipeline, self._is_interactive): + elif _PrismRunnerSupportVisitor().accept(pipeline, self._is_interactive): _LOGGER.info('Running pipeline with PrismRunner.') from apache_beam.runners.portability import prism_runner runner = prism_runner.PrismRunner() @@ -519,59 +526,6 @@ def expand(self, pvalue): return PCollection(self.pipeline, is_bounded=self._source.is_bounded()) -class _DirectWriteToPubSubFn(DoFn): - BUFFER_SIZE_ELEMENTS = 100 - FLUSH_TIMEOUT_SECS = BUFFER_SIZE_ELEMENTS * 0.5 - - def __init__(self, transform): - self.project = transform.project - self.short_topic_name = transform.topic_name - self.id_label = transform.id_label - self.timestamp_attribute = transform.timestamp_attribute - self.with_attributes = transform.with_attributes - - # TODO(https://github.com/apache/beam/issues/18939): Add support for - # id_label and timestamp_attribute. - if transform.id_label: - raise NotImplementedError( - 'DirectRunner: id_label is not supported for ' - 'PubSub writes') - if transform.timestamp_attribute: - raise NotImplementedError( - 'DirectRunner: timestamp_attribute is not ' - 'supported for PubSub writes') - - def start_bundle(self): - self._buffer = [] - - def process(self, elem): - self._buffer.append(elem) - if len(self._buffer) >= self.BUFFER_SIZE_ELEMENTS: - self._flush() - - def finish_bundle(self): - self._flush() - - def _flush(self): - from google.cloud import pubsub - pub_client = pubsub.PublisherClient() - topic = pub_client.topic_path(self.project, self.short_topic_name) - - if self.with_attributes: - futures = [ - pub_client.publish(topic, elem.data, **elem.attributes) - for elem in self._buffer - ] - else: - futures = [pub_client.publish(topic, elem) for elem in self._buffer] - - timer_start = time.time() - for future in futures: - remaining = self.FLUSH_TIMEOUT_SECS - (time.time() - timer_start) - future.result(remaining) - self._buffer = [] - - def _get_pubsub_transform_overrides(pipeline_options): from apache_beam.io.gcp import pubsub as beam_pubsub from apache_beam.pipeline import PTransformOverride @@ -595,11 +549,8 @@ def matches(self, applied_ptransform): def get_replacement_transform_for_applied_ptransform( self, applied_ptransform): - if not pipeline_options.view_as(StandardOptions).streaming: - raise Exception( - 'PubSub I/O is only available in streaming mode ' - '(use the --streaming flag).') - return beam.ParDo(_DirectWriteToPubSubFn(applied_ptransform.transform)) + return beam.ParDo( + beam_pubsub._DirectWriteToPubSubFn(applied_ptransform.transform)) return [ReadFromPubSubOverride(), WriteToPubSubOverride()]