diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 8675e9535061..1fa29a890c2f 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 28 + "modification": 29 } diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 281827db034b..59eadee5538e 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -414,6 +414,7 @@ def __init__( self.project, self.topic_name = parse_topic(topic) self.full_topic = topic self._sink = _PubSubSink(topic, id_label, timestamp_attribute) + self.pipeline_options = None # Will be set during expand() @staticmethod def message_to_proto_str(element: PubsubMessage) -> bytes: @@ -429,6 +430,9 @@ def bytes_to_proto_str(element: Union[bytes, str]) -> bytes: return msg._to_proto_str(for_publish=True) def expand(self, pcoll): + # Store pipeline options for use in DoFn + self.pipeline_options = pcoll.pipeline.options if pcoll.pipeline else None + if self.with_attributes: pcoll = pcoll | 'ToProtobufX' >> ParDo( _AddMetricsAndMap( @@ -564,11 +568,65 @@ def __init__(self, transform): # TODO(https://github.com/apache/beam/issues/18939): Add support for # id_label and timestamp_attribute. - if transform.id_label: - raise NotImplementedError('id_label is not supported for PubSub writes') - if transform.timestamp_attribute: - raise NotImplementedError( - 'timestamp_attribute is not supported for PubSub writes') + # Only raise errors for DirectRunner or batch pipelines + pipeline_options = transform.pipeline_options + output_labels_supported = True + + if pipeline_options: + from apache_beam.options.pipeline_options import StandardOptions + + # Check if using DirectRunner + try: + # Get runner from pipeline options + all_options = pipeline_options.get_all_options() + runner_name = all_options.get('runner', StandardOptions.DEFAULT_RUNNER) + + # Check if it's a DirectRunner variant + if (runner_name is None or + (runner_name in StandardOptions.LOCAL_RUNNERS or 'DirectRunner' + in str(runner_name) or 'TestDirectRunner' in str(runner_name))): + output_labels_supported = False + except Exception: + # If we can't determine runner, assume DirectRunner for safety + output_labels_supported = False + + # Check if in batch mode (not streaming) + standard_options = pipeline_options.view_as(StandardOptions) + if not standard_options.streaming: + output_labels_supported = False + else: + # If no pipeline options available, fall back to original behavior + output_labels_supported = False + + # Log debug information for troubleshooting + import logging + runner_info = getattr( + pipeline_options, 'runner', + 'None') if pipeline_options else 'No options' + streaming_info = 'Unknown' + if pipeline_options: + try: + standard_options = pipeline_options.view_as(StandardOptions) + streaming_info = 'streaming=%s' % standard_options.streaming + except Exception: + streaming_info = 'streaming=unknown' + + logging.debug( + 'PubSub unsupported feature check: runner=%s, %s', + runner_info, + streaming_info) + + if not output_labels_supported: + + if transform.id_label: + raise NotImplementedError( + f'id_label is not supported for PubSub writes with DirectRunner ' + f'or in batch mode (runner={runner_info}, {streaming_info})') + if transform.timestamp_attribute: + raise NotImplementedError( + f'timestamp_attribute is not supported for PubSub writes with ' + f'DirectRunner or in batch mode ' + f'(runner={runner_info}, {streaming_info})') def setup(self): from google.cloud import pubsub @@ -593,11 +651,21 @@ def _flush(self): import time - # The elements in buffer are already serialized bytes from the previous - # transforms - futures = [ - self._pub_client.publish(self._topic, elem) for elem in self._buffer - ] + # The elements in buffer are serialized protobuf bytes from the previous + # transforms. We need to deserialize them to extract data and attributes. + futures = [] + for elem in self._buffer: + # Deserialize the protobuf to get the original PubsubMessage + pubsub_msg = PubsubMessage._from_proto_str(elem) + + # Publish with the correct data and attributes + if self.with_attributes and pubsub_msg.attributes: + future = self._pub_client.publish( + self._topic, pubsub_msg.data, **pubsub_msg.attributes) + else: + future = self._pub_client.publish(self._topic, pubsub_msg.data) + + futures.append(future) timer_start = time.time() for future in futures: diff --git a/sdks/python/apache_beam/io/gcp/pubsub_integration_test.py b/sdks/python/apache_beam/io/gcp/pubsub_integration_test.py index c88f4af2016d..8387fe734fc1 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_integration_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_integration_test.py @@ -44,10 +44,10 @@ # How long TestXXXRunner will wait for pubsub_it_pipeline to run before # cancelling it. -TEST_PIPELINE_DURATION_MS = 8 * 60 * 1000 +TEST_PIPELINE_DURATION_MS = 10 * 60 * 1000 # How long PubSubMessageMatcher will wait for the correct set of messages to # appear. -MESSAGE_MATCHER_TIMEOUT_S = 5 * 60 +MESSAGE_MATCHER_TIMEOUT_S = 10 * 60 class PubSubIntegrationTest(unittest.TestCase):