diff --git a/sdks/python/apache_beam/ml/ts/util.py b/sdks/python/apache_beam/ml/ts/util.py index 4005f57e0047..a50f2667621b 100644 --- a/sdks/python/apache_beam/ml/ts/util.py +++ b/sdks/python/apache_beam/ml/ts/util.py @@ -159,7 +159,19 @@ def __init__( self._data = data self._interval = interval self._repeat = repeat - self._duration = len(self._data) * interval + + # In `ImpulseSeqGenRestrictionProvider`, the total number of counts + # (i.e. total_outputs) is computed by ceil((end - start) / interval), + # where end is start + duration. + # Due to precision error of arithmetic operations, even if duration is set + # to len(self._data) * interval, (end - start) / interval could be a little + # bit smaller or bigger than len(self._data). + # In case of being bigger, total_outputs would be len(self._data) + 1, + # as the ceil() operation is used. + # Assuming that the precision error is no bigger than 1%, by subtracting + # a small amount, we ensure that the result after ceil is stable even if + # the precision error is present. + self._duration = len(self._data) * interval - 0.01 * interval self._max_duration = max_duration if max_duration is not None else float( "inf") diff --git a/sdks/python/apache_beam/ml/ts/util_test.py b/sdks/python/apache_beam/ml/ts/util_test.py index ac2bc6ea701f..5a2a8a79ce89 100644 --- a/sdks/python/apache_beam/ml/ts/util_test.py +++ b/sdks/python/apache_beam/ml/ts/util_test.py @@ -76,6 +76,17 @@ def test_timestamped_value(self): self.assertGreaterEqual(end - start, 3) self.assertLessEqual(end - start, 7) + def test_stable_output(self): + options = PipelineOptions() + data = [(Timestamp(1), 1), (Timestamp(2), 2), (Timestamp(3), 3), + (Timestamp(6), 6), (Timestamp(4), 4), (Timestamp(5), 5), + (Timestamp(7), 7), (Timestamp(8), 8), (Timestamp(9), 9), + (Timestamp(10), 10)] + expected = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + with beam.Pipeline(options=options) as p: + ret = (p | PeriodicStream(data, interval=0.0001)) + assert_that(ret, equal_to(expected)) + if __name__ == '__main__': logging.getLogger().setLevel(logging.WARNING)