Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
from apache_beam.utils import timestamp
from apache_beam.utils.timestamp import Timestamp


class StreamingCache(object):
Expand Down Expand Up @@ -53,12 +52,9 @@ def __init__(self, readers):
self._headers = {r.header().tag: r.header() for r in readers}
self._readers = {r.header().tag: r.read() for r in readers}

# The watermarks per tag. Useful for introspection in the stream.
self._watermarks = {tag: timestamp.MIN_TIMESTAMP for tag in self._headers}

# The most recently read timestamp per tag.
self._stream_times = {
tag: timestamp.MIN_TIMESTAMP
tag: timestamp.Timestamp(seconds=0)
for tag in self._headers
}

Expand All @@ -79,23 +75,21 @@ def _test_stream_events_before_target(self, target_timestamp):
if self._stream_times[tag] >= target_timestamp:
continue
try:
record = next(r)
records.append((tag, record))
self._stream_times[tag] = Timestamp.from_proto(record.processing_time)
record = next(r).recorded_event
if record.HasField('processing_time_event'):
self._stream_times[tag] += timestamp.Duration(
micros=record.processing_time_event.advance_duration)
records.append((tag, record, self._stream_times[tag]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be much more readable if we used typing.NamedTuple

except StopIteration:
pass
return records

def _merge_sort(self, previous_events, new_events):
return sorted(
previous_events + new_events,
key=lambda x: Timestamp.from_proto(x[1].processing_time),
reverse=True)
previous_events + new_events, key=lambda x: x[2], reverse=True)

def _min_timestamp_of(self, events):
return (
Timestamp.from_proto(events[-1][1].processing_time)
if events else timestamp.MAX_TIMESTAMP)
return events[-1][2] if events else timestamp.MAX_TIMESTAMP

def _event_stream_caught_up_to_target(self, events, target_timestamp):
empty_events = not events
Expand All @@ -107,7 +101,7 @@ def read(self):
"""

# The largest timestamp read from the different streams.
target_timestamp = timestamp.Timestamp.of(0)
target_timestamp = timestamp.MAX_TIMESTAMP

# The events from last iteration that are past the target timestamp.
unsent_events = []
Expand All @@ -130,19 +124,20 @@ def read(self):
# Loop through the elements with the correct timestamp.
while not self._event_stream_caught_up_to_target(events_to_send,
target_timestamp):
tag, r = events_to_send.pop()

# First advance the clock to match the time of the stream. This has
# a side-effect of also advancing this cache's clock.
curr_timestamp = Timestamp.from_proto(r.processing_time)
tag, r, curr_timestamp = events_to_send.pop()
if curr_timestamp > self._monotonic_clock:
yield self._advance_processing_time(curr_timestamp)

# Then, send either a new element or watermark.
if r.HasField('element'):
yield self._add_element(r.element, tag)
elif r.HasField('watermark'):
yield self._advance_watermark(r.watermark, tag)
if r.HasField('element_event'):
r.element_event.tag = tag
yield r
elif r.HasField('watermark_event'):
r.watermark_event.tag = tag
yield r
unsent_events = events_to_send
target_timestamp = self._min_timestamp_of(unsent_events)

Expand All @@ -163,14 +158,5 @@ def _advance_processing_time(self, new_timestamp):
self._monotonic_clock = new_timestamp
return e

def _advance_watermark(self, watermark, tag):
"""Advances the watermark for tag and returns AdvanceWatermark event.
"""
self._watermarks[tag] = Timestamp.from_proto(watermark)
e = TestStreamPayload.Event(
watermark_event=TestStreamPayload.Event.AdvanceWatermark(
new_watermark=self._watermarks[tag].micros, tag=tag))
return e

def reader(self):
return StreamingCache.Reader(self._readers)
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
from apache_beam.utils.timestamp import Duration
from apache_beam.utils.timestamp import Timestamp

# Nose automatically detects tests if they match a regex. Here, it mistakens
Expand All @@ -42,19 +43,28 @@ def __init__(self, tag=None):
self._records = []
self._coder = coders.FastPrimitivesCoder()

def add_element(self, element, event_time, processing_time):
def add_element(self, element, event_time):
element_payload = TestStreamPayload.TimestampedElement(
encoded_element=self._coder.encode(element),
timestamp=Timestamp.of(event_time).micros)
record = TestStreamFileRecord(
element=element_payload,
processing_time=Timestamp.of(processing_time).to_proto())
recorded_event=TestStreamPayload.Event(
element_event=TestStreamPayload.Event.AddElements(
elements=[element_payload])))
self._records.append(record)

def advance_watermark(self, watermark):
record = TestStreamFileRecord(
recorded_event=TestStreamPayload.Event(
watermark_event=TestStreamPayload.Event.AdvanceWatermark(
new_watermark=Timestamp.of(watermark).micros)))
self._records.append(record)

def advance_watermark(self, watermark, processing_time):
def advance_processing_time(self, processing_time_delta):
record = TestStreamFileRecord(
watermark=Timestamp.of(watermark).to_proto(),
processing_time=Timestamp.of(processing_time).to_proto())
recorded_event=TestStreamPayload.Event(
processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
advance_duration=Duration.of(processing_time_delta).micros)))
self._records.append(record)

def header(self):
Expand All @@ -80,9 +90,11 @@ def test_single_reader(self):
"""Tests that we expect to see all the correctly emitted TestStreamPayloads.
"""
in_memory_reader = InMemoryReader()
in_memory_reader.add_element(element=0, event_time=0, processing_time=0)
in_memory_reader.add_element(element=1, event_time=1, processing_time=1)
in_memory_reader.add_element(element=2, event_time=2, processing_time=2)
in_memory_reader.add_element(element=0, event_time=0)
in_memory_reader.advance_processing_time(1)
in_memory_reader.add_element(element=1, event_time=1)
in_memory_reader.advance_processing_time(1)
in_memory_reader.add_element(element=2, event_time=2)
cache = StreamingCache([in_memory_reader])
reader = cache.reader()
coder = coders.FastPrimitivesCoder()
Expand Down Expand Up @@ -120,18 +132,24 @@ def test_multiple_readers(self):
"""Tests that the service advances the clock with multiple outputs."""

letters = InMemoryReader('letters')
letters.advance_watermark(0, 1)
letters.add_element(element='a', event_time=0, processing_time=1)
letters.advance_watermark(10, 11)
letters.add_element(element='b', event_time=10, processing_time=11)
letters.advance_processing_time(1)
letters.advance_watermark(0)
letters.add_element(element='a', event_time=0)
letters.advance_processing_time(10)
letters.advance_watermark(10)
letters.add_element(element='b', event_time=10)

numbers = InMemoryReader('numbers')
numbers.add_element(element=1, event_time=0, processing_time=2)
numbers.add_element(element=2, event_time=0, processing_time=3)
numbers.add_element(element=2, event_time=0, processing_time=4)
numbers.advance_processing_time(2)
numbers.add_element(element=1, event_time=0)
numbers.advance_processing_time(1)
numbers.add_element(element=2, event_time=0)
numbers.advance_processing_time(1)
numbers.add_element(element=2, event_time=0)

late = InMemoryReader('late')
late.add_element(element='late', event_time=0, processing_time=101)
late.advance_processing_time(101)
late.add_element(element='late', event_time=0)

cache = StreamingCache([letters, numbers, late])
reader = cache.reader()
Expand Down