diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go index bc8449c72b39..c73db507c792 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -869,14 +869,20 @@ func (em *ElementManager) triageTimers(d TentativeData, inputInfo PColInfo, stag for tentativeKey, timers := range d.timers { keyToTimers := map[timerKey]element{} for _, t := range timers { - key, tag, elms := decodeTimer(inputInfo.KeyDec, true, t) - for _, e := range elms { - keyToTimers[timerKey{key: string(key), tag: tag, win: e.window}] = e - } - if len(elms) == 0 { - // TODO(lostluck): Determine best way to mark a timer cleared. - continue - } + // TODO: Call in a for:range loop when Beam's minimum Go version hits 1.23.0 + iter := decodeTimerIter(inputInfo.KeyDec, true, t) + iter(func(ret timerRet) bool { + for _, e := range ret.elms { + keyToTimers[timerKey{key: string(ret.keyBytes), tag: ret.tag, win: e.window}] = e + } + if len(ret.elms) == 0 { + for _, w := range ret.windows { + delete(keyToTimers, timerKey{key: string(ret.keyBytes), tag: ret.tag, win: w}) + } + } + // Indicate we'd like to continue iterating. + return true + }) } for _, elm := range keyToTimers { diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/timers.go b/sdks/go/pkg/beam/runners/prism/internal/engine/timers.go index 787d27858a0e..9a3bd6f9682b 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/timers.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/timers.go @@ -31,53 +31,77 @@ import ( "google.golang.org/protobuf/encoding/protowire" ) -// DecodeTimer extracts timers to elements for insertion into their keyed queues. -// Returns the key bytes, tag, window exploded elements, and the hold timestamp. +type timerRet struct { + keyBytes []byte + tag string + elms []element + windows []typex.Window +} + +// decodeTimerIter extracts timers to elements for insertion into their keyed queues, +// through a go iterator function, to be called by the caller with their processing function. +// +// For each timer, a key, tag, windowed elements, and the window set are returned. +// // If the timer has been cleared, no elements will be returned. Any existing timers -// for the tag *must* be cleared from the pending queue. -func decodeTimer(keyDec func(io.Reader) []byte, usesGlobalWindow bool, raw []byte) ([]byte, string, []element) { - keyBytes := keyDec(bytes.NewBuffer(raw)) - - d := decoder{raw: raw, cursor: len(keyBytes)} - tag := string(d.Bytes()) - - var ws []typex.Window - numWin := d.Fixed32() - if usesGlobalWindow { - for i := 0; i < int(numWin); i++ { - ws = append(ws, window.GlobalWindow{}) - } - } else { - // Assume interval windows here, since we don't understand custom windows yet. - for i := 0; i < int(numWin); i++ { - ws = append(ws, d.IntervalWindow()) - } - } +// for the tag *must* be cleared from the pending queue. The windows associated with +// the clear are provided to be able to delete pending timers. +func decodeTimerIter(keyDec func(io.Reader) []byte, usesGlobalWindow bool, raw []byte) func(func(timerRet) bool) { + return func(yield func(timerRet) bool) { + for len(raw) > 0 { + keyBytes := keyDec(bytes.NewBuffer(raw)) + d := decoder{raw: raw, cursor: len(keyBytes)} + tag := string(d.Bytes()) + + var ws []typex.Window + numWin := d.Fixed32() + if usesGlobalWindow { + for i := 0; i < int(numWin); i++ { + ws = append(ws, window.GlobalWindow{}) + } + } else { + // Assume interval windows here, since we don't understand custom windows yet. + for i := 0; i < int(numWin); i++ { + ws = append(ws, d.IntervalWindow()) + } + } - clear := d.Bool() - hold := mtime.MaxTimestamp - if clear { - return keyBytes, tag, nil - } + clear := d.Bool() + hold := mtime.MaxTimestamp + if clear { + if !yield(timerRet{keyBytes, tag, nil, ws}) { + return // Halt iteration if yeild returns false. + } + // Otherwise continue handling the remaining bytes. + raw = d.UnusedBytes() + continue + } - firing := d.Timestamp() - hold = d.Timestamp() - pane := d.Pane() + firing := d.Timestamp() + hold = d.Timestamp() + pane := d.Pane() + + var elms []element + for _, w := range ws { + elms = append(elms, element{ + tag: tag, + elmBytes: nil, // indicates this is a timer. + keyBytes: keyBytes, + window: w, + timestamp: firing, + holdTimestamp: hold, + pane: pane, + sequence: len(elms), + }) + } - var ret []element - for _, w := range ws { - ret = append(ret, element{ - tag: tag, - elmBytes: nil, // indicates this is a timer. - keyBytes: keyBytes, - window: w, - timestamp: firing, - holdTimestamp: hold, - pane: pane, - sequence: len(ret), - }) + if !yield(timerRet{keyBytes, tag, elms, ws}) { + return // Halt iteration if yeild returns false. + } + // Otherwise continue handling the remaining bytes. + raw = d.UnusedBytes() + } } - return keyBytes, tag, ret } type decoder struct { @@ -140,6 +164,13 @@ func (d *decoder) Bytes() []byte { return b } +// UnusedBytes returns the remainder of bytes in the buffer that weren't yet used. +// Multiple timers can be provided in a single timers buffer, since multiple dynamic +// timer tags may be set. +func (d *decoder) UnusedBytes() []byte { + return d.raw[d.cursor:] +} + func (d *decoder) Bool() bool { if b := d.Byte(); b == 0 { return false diff --git a/sdks/python/apache_beam/runners/portability/prism_runner_test.py b/sdks/python/apache_beam/runners/portability/prism_runner_test.py index 324fe5a17b54..b179156877e4 100644 --- a/sdks/python/apache_beam/runners/portability/prism_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/prism_runner_test.py @@ -40,6 +40,7 @@ from apache_beam.runners.portability import portable_runner_test from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.transforms import userstate from apache_beam.transforms import window from apache_beam.transforms.sql import SqlTransform from apache_beam.utils import timestamp @@ -200,6 +201,37 @@ def test_windowing(self): assert_that( res, equal_to([('k', [1, 2]), ('k', [100, 101, 102]), ('k', [123])])) + # The fn_runner_test.py version of this test doesn't execute the process + # method for some reason. Overridden here to validate that the cleared + # timer won't re-fire. + def test_pardo_timers_clear(self): + timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) + + class TimerDoFn(beam.DoFn): + def process(self, element, timer=beam.DoFn.TimerParam(timer_spec)): + unused_key, ts = element + timer.set(ts) + timer.set(2 * ts) + + @userstate.on_timer(timer_spec) + def process_timer( + self, + ts=beam.DoFn.TimestampParam, + timer=beam.DoFn.TimerParam(timer_spec)): + timer.set(timestamp.Timestamp(micros=2 * ts.micros)) + timer.clear() # Shouldn't fire again + yield 'fired' + + with self.create_pipeline() as p: + actual = ( + p + | beam.Create([('k1', 10), ('k2', 100)]) + | beam.ParDo(TimerDoFn()) + | beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts))) + + expected = [('fired', ts) for ts in (20, 200)] + assert_that(actual, equal_to(expected)) + # Can't read host files from within docker, read a "local" file there. def test_read(self): print('name:', __name__)