Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion sdks/python/apache_beam/ml/ts/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,19 @@ def __init__(
self._data = data
self._interval = interval
self._repeat = repeat
self._duration = len(self._data) * interval

# In `ImpulseSeqGenRestrictionProvider`, the total number of counts
# (i.e. total_outputs) is computed by ceil((end - start) / interval),
# where end is start + duration.
# Due to precision error of arithmetic operations, even if duration is set
# to len(self._data) * interval, (end - start) / interval could be a little
# bit smaller or bigger than len(self._data).
# In case of being bigger, total_outputs would be len(self._data) + 1,
# as the ceil() operation is used.
# Assuming that the precision error is no bigger than 1%, by subtracting
# a small amount, we ensure that the result after ceil is stable even if
# the precision error is present.
self._duration = len(self._data) * interval - 0.01 * interval
self._max_duration = max_duration if max_duration is not None else float(
"inf")

Expand Down
11 changes: 11 additions & 0 deletions sdks/python/apache_beam/ml/ts/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ def test_timestamped_value(self):
self.assertGreaterEqual(end - start, 3)
self.assertLessEqual(end - start, 7)

def test_stable_output(self):
options = PipelineOptions()
data = [(Timestamp(1), 1), (Timestamp(2), 2), (Timestamp(3), 3),
(Timestamp(6), 6), (Timestamp(4), 4), (Timestamp(5), 5),
(Timestamp(7), 7), (Timestamp(8), 8), (Timestamp(9), 9),
(Timestamp(10), 10)]
expected = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
with beam.Pipeline(options=options) as p:
ret = (p | PeriodicStream(data, interval=0.0001))
assert_that(ret, equal_to(expected))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.WARNING)
Expand Down
Loading