diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py index 0aac3dbf7e80..0aabda3e1e0a 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py +++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py @@ -21,7 +21,6 @@ from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload from apache_beam.utils import timestamp -from apache_beam.utils.timestamp import Timestamp class StreamingCache(object): @@ -53,12 +52,9 @@ def __init__(self, readers): self._headers = {r.header().tag: r.header() for r in readers} self._readers = {r.header().tag: r.read() for r in readers} - # The watermarks per tag. Useful for introspection in the stream. - self._watermarks = {tag: timestamp.MIN_TIMESTAMP for tag in self._headers} - # The most recently read timestamp per tag. self._stream_times = { - tag: timestamp.MIN_TIMESTAMP + tag: timestamp.Timestamp(seconds=0) for tag in self._headers } @@ -79,23 +75,21 @@ def _test_stream_events_before_target(self, target_timestamp): if self._stream_times[tag] >= target_timestamp: continue try: - record = next(r) - records.append((tag, record)) - self._stream_times[tag] = Timestamp.from_proto(record.processing_time) + record = next(r).recorded_event + if record.HasField('processing_time_event'): + self._stream_times[tag] += timestamp.Duration( + micros=record.processing_time_event.advance_duration) + records.append((tag, record, self._stream_times[tag])) except StopIteration: pass return records def _merge_sort(self, previous_events, new_events): return sorted( - previous_events + new_events, - key=lambda x: Timestamp.from_proto(x[1].processing_time), - reverse=True) + previous_events + new_events, key=lambda x: x[2], reverse=True) def _min_timestamp_of(self, events): - return ( - Timestamp.from_proto(events[-1][1].processing_time) - if events else timestamp.MAX_TIMESTAMP) + return events[-1][2] if events else timestamp.MAX_TIMESTAMP def _event_stream_caught_up_to_target(self, events, target_timestamp): empty_events = not events @@ -107,7 +101,7 @@ def read(self): """ # The largest timestamp read from the different streams. - target_timestamp = timestamp.Timestamp.of(0) + target_timestamp = timestamp.MAX_TIMESTAMP # The events from last iteration that are past the target timestamp. unsent_events = [] @@ -130,19 +124,20 @@ def read(self): # Loop through the elements with the correct timestamp. while not self._event_stream_caught_up_to_target(events_to_send, target_timestamp): - tag, r = events_to_send.pop() # First advance the clock to match the time of the stream. This has # a side-effect of also advancing this cache's clock. - curr_timestamp = Timestamp.from_proto(r.processing_time) + tag, r, curr_timestamp = events_to_send.pop() if curr_timestamp > self._monotonic_clock: yield self._advance_processing_time(curr_timestamp) # Then, send either a new element or watermark. - if r.HasField('element'): - yield self._add_element(r.element, tag) - elif r.HasField('watermark'): - yield self._advance_watermark(r.watermark, tag) + if r.HasField('element_event'): + r.element_event.tag = tag + yield r + elif r.HasField('watermark_event'): + r.watermark_event.tag = tag + yield r unsent_events = events_to_send target_timestamp = self._min_timestamp_of(unsent_events) @@ -163,14 +158,5 @@ def _advance_processing_time(self, new_timestamp): self._monotonic_clock = new_timestamp return e - def _advance_watermark(self, watermark, tag): - """Advances the watermark for tag and returns AdvanceWatermark event. - """ - self._watermarks[tag] = Timestamp.from_proto(watermark) - e = TestStreamPayload.Event( - watermark_event=TestStreamPayload.Event.AdvanceWatermark( - new_watermark=self._watermarks[tag].micros, tag=tag)) - return e - def reader(self): return StreamingCache.Reader(self._readers) diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py index f6dec487ad03..32ac868ec682 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py +++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py @@ -26,6 +26,7 @@ from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache +from apache_beam.utils.timestamp import Duration from apache_beam.utils.timestamp import Timestamp # Nose automatically detects tests if they match a regex. Here, it mistakens @@ -42,19 +43,28 @@ def __init__(self, tag=None): self._records = [] self._coder = coders.FastPrimitivesCoder() - def add_element(self, element, event_time, processing_time): + def add_element(self, element, event_time): element_payload = TestStreamPayload.TimestampedElement( encoded_element=self._coder.encode(element), timestamp=Timestamp.of(event_time).micros) record = TestStreamFileRecord( - element=element_payload, - processing_time=Timestamp.of(processing_time).to_proto()) + recorded_event=TestStreamPayload.Event( + element_event=TestStreamPayload.Event.AddElements( + elements=[element_payload]))) + self._records.append(record) + + def advance_watermark(self, watermark): + record = TestStreamFileRecord( + recorded_event=TestStreamPayload.Event( + watermark_event=TestStreamPayload.Event.AdvanceWatermark( + new_watermark=Timestamp.of(watermark).micros))) self._records.append(record) - def advance_watermark(self, watermark, processing_time): + def advance_processing_time(self, processing_time_delta): record = TestStreamFileRecord( - watermark=Timestamp.of(watermark).to_proto(), - processing_time=Timestamp.of(processing_time).to_proto()) + recorded_event=TestStreamPayload.Event( + processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime( + advance_duration=Duration.of(processing_time_delta).micros))) self._records.append(record) def header(self): @@ -80,9 +90,11 @@ def test_single_reader(self): """Tests that we expect to see all the correctly emitted TestStreamPayloads. """ in_memory_reader = InMemoryReader() - in_memory_reader.add_element(element=0, event_time=0, processing_time=0) - in_memory_reader.add_element(element=1, event_time=1, processing_time=1) - in_memory_reader.add_element(element=2, event_time=2, processing_time=2) + in_memory_reader.add_element(element=0, event_time=0) + in_memory_reader.advance_processing_time(1) + in_memory_reader.add_element(element=1, event_time=1) + in_memory_reader.advance_processing_time(1) + in_memory_reader.add_element(element=2, event_time=2) cache = StreamingCache([in_memory_reader]) reader = cache.reader() coder = coders.FastPrimitivesCoder() @@ -120,18 +132,24 @@ def test_multiple_readers(self): """Tests that the service advances the clock with multiple outputs.""" letters = InMemoryReader('letters') - letters.advance_watermark(0, 1) - letters.add_element(element='a', event_time=0, processing_time=1) - letters.advance_watermark(10, 11) - letters.add_element(element='b', event_time=10, processing_time=11) + letters.advance_processing_time(1) + letters.advance_watermark(0) + letters.add_element(element='a', event_time=0) + letters.advance_processing_time(10) + letters.advance_watermark(10) + letters.add_element(element='b', event_time=10) numbers = InMemoryReader('numbers') - numbers.add_element(element=1, event_time=0, processing_time=2) - numbers.add_element(element=2, event_time=0, processing_time=3) - numbers.add_element(element=2, event_time=0, processing_time=4) + numbers.advance_processing_time(2) + numbers.add_element(element=1, event_time=0) + numbers.advance_processing_time(1) + numbers.add_element(element=2, event_time=0) + numbers.advance_processing_time(1) + numbers.add_element(element=2, event_time=0) late = InMemoryReader('late') - late.add_element(element='late', event_time=0, processing_time=101) + late.advance_processing_time(101) + late.add_element(element='late', event_time=0) cache = StreamingCache([letters, numbers, late]) reader = cache.reader()