From 293871bd24448d75e2aa8f8877afe5eca1d90042 Mon Sep 17 00:00:00 2001 From: lostluck <13907733+lostluck@users.noreply.github.com> Date: Wed, 1 Mar 2023 09:11:17 -0800 Subject: [PATCH 1/5] [go-timers] things I'm pretty sure work --- sdks/go/pkg/beam/core/funcx/fn.go | 45 ++++++- sdks/go/pkg/beam/core/graph/edge.go | 1 + sdks/go/pkg/beam/core/graph/fn.go | 80 ++++++++++- sdks/go/pkg/beam/core/runtime/exec/coder.go | 3 + .../pkg/beam/core/runtime/exec/coder_test.go | 26 +++- sdks/go/pkg/beam/core/runtime/exec/fn.go | 90 ++++++++++--- sdks/go/pkg/beam/core/runtime/exec/pardo.go | 15 ++- sdks/go/pkg/beam/core/runtime/exec/timers.go | 125 ++++++++++++++++++ .../pkg/beam/core/runtime/exec/translate.go | 22 +++ sdks/go/pkg/beam/core/runtime/graphx/coder.go | 2 +- .../pkg/beam/core/runtime/graphx/serialize.go | 5 + .../pkg/beam/core/runtime/graphx/translate.go | 15 +++ sdks/go/pkg/beam/core/timers/timers.go | 119 +++++++++++++++++ sdks/go/pkg/beam/core/typex/class.go | 1 + sdks/go/pkg/beam/core/typex/special.go | 3 +- sdks/go/pkg/beam/pardo.go | 12 ++ 16 files changed, 528 insertions(+), 36 deletions(-) create mode 100644 sdks/go/pkg/beam/core/runtime/exec/timers.go create mode 100644 sdks/go/pkg/beam/core/timers/timers.go diff --git a/sdks/go/pkg/beam/core/funcx/fn.go b/sdks/go/pkg/beam/core/funcx/fn.go index b579cb56d521..3aef5a1695ac 100644 --- a/sdks/go/pkg/beam/core/funcx/fn.go +++ b/sdks/go/pkg/beam/core/funcx/fn.go @@ -21,6 +21,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/state" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/timers" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" @@ -85,6 +86,8 @@ const ( FnWatermarkEstimator FnParamKind = 0x1000 // FnStateProvider indicates a function input parameter that implements state.Provider FnStateProvider FnParamKind = 0x2000 + // FnTimerProvider indicates a function input parameter that implements timer.Provider + FnTimerProvider FnParamKind = 0x4000 ) func (k FnParamKind) String() string { @@ -117,6 +120,8 @@ func (k FnParamKind) String() string { return "WatermarkEstimator" case FnStateProvider: return "StateProvider" + case FnTimerProvider: + return "TimerProvider" default: return fmt.Sprintf("%v", int(k)) } @@ -305,6 +310,15 @@ func (u *Fn) StateProvider() (pos int, exists bool) { return -1, false } +func (u *Fn) TimerProvider() (pos int, exists bool) { + for i, p := range u.Param { + if p.Kind == FnTimerProvider { + return i, true + } + } + return -1, false +} + // WatermarkEstimator returns (index, true) iff the function expects a // parameter that implements sdf.WatermarkEstimator. func (u *Fn) WatermarkEstimator() (pos int, exists bool) { @@ -392,6 +406,8 @@ func New(fn reflectx.Func) (*Fn, error) { kind = FnBundleFinalization case t == state.ProviderType: kind = FnStateProvider + case t == timers.ProviderType: + kind = FnTimerProvider case t == reflectx.Type: kind = FnType case t.Implements(reflect.TypeOf((*sdf.RTracker)(nil)).Elem()): @@ -482,7 +498,7 @@ func SubReturns(list []ReturnParam, indices ...int) []ReturnParam { } // The order of present parameters and return values must be as follows: -// func(FnContext?, FnPane?, FnWindow?, FnEventTime?, FnWatermarkEstimator?, FnType?, FnBundleFinalization?, FnRTracker?, FnStateProvider?, (FnValue, SideInput*)?, FnEmit*) (RetEventTime?, RetOutput?, RetError?) +// func(FnContext?, FnPane?, FnWindow?, FnEventTime?, FnWatermarkEstimator?, FnType?, FnBundleFinalization?, FnRTracker?, FnStateProvider?, FnTimerProvider?, (FnValue, SideInput*)?, FnEmit*) (RetEventTime?, RetOutput?, RetError?) // // where ? indicates 0 or 1, and * indicates any number. // and a SideInput is one of FnValue or FnIter or FnReIter @@ -517,6 +533,7 @@ var ( errRTrackerPrecedence = errors.New("may only have a single sdf.RTracker parameter and it must precede the main input parameter") errBundleFinalizationPrecedence = errors.New("may only have a single BundleFinalization parameter and it must precede the main input parameter") errStateProviderPrecedence = errors.New("may only have a single state.Provider parameter and it must precede the main input parameter") + errTimerProviderPrecedence = errors.New("may only have a single timer.Provider parameter and it must precede the main input parameter") errInputPrecedence = errors.New("inputs parameters must precede emit function parameters") ) @@ -535,6 +552,7 @@ const ( psRTracker psBundleFinalization psStateProvider + psTimerProvider ) func nextParamState(cur paramState, transition FnParamKind) (paramState, error) { @@ -559,6 +577,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error) return psRTracker, nil case FnStateProvider: return psStateProvider, nil + case FnTimerProvider: + return psTimerProvider, nil } case psContext: switch transition { @@ -578,6 +598,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error) return psRTracker, nil case FnStateProvider: return psStateProvider, nil + case FnTimerProvider: + return psTimerProvider, nil } case psPane: switch transition { @@ -595,6 +617,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error) return psRTracker, nil case FnStateProvider: return psStateProvider, nil + case FnTimerProvider: + return psTimerProvider, nil } case psWindow: switch transition { @@ -610,6 +634,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error) return psRTracker, nil case FnStateProvider: return psStateProvider, nil + case FnTimerProvider: + return psTimerProvider, nil } case psEventTime: switch transition { @@ -623,6 +649,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error) return psRTracker, nil case FnStateProvider: return psStateProvider, nil + case FnTimerProvider: + return psTimerProvider, nil } case psWatermarkEstimator: switch transition { @@ -634,6 +662,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error) return psRTracker, nil case FnStateProvider: return psStateProvider, nil + case FnTimerProvider: + return psTimerProvider, nil } case psType: switch transition { @@ -643,6 +673,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error) return psRTracker, nil case FnStateProvider: return psStateProvider, nil + case FnTimerProvider: + return psTimerProvider, nil } case psBundleFinalization: switch transition { @@ -650,13 +682,22 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error) return psRTracker, nil case FnStateProvider: return psStateProvider, nil + case FnTimerProvider: + return psTimerProvider, nil } case psRTracker: switch transition { case FnStateProvider: return psStateProvider, nil + case FnTimerProvider: + return psTimerProvider, nil } case psStateProvider: + switch transition { + case FnTimerProvider: + return psTimerProvider, nil + } + case psTimerProvider: // Completely handled by the default clause case psInput: switch transition { @@ -689,6 +730,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error) return -1, errRTrackerPrecedence case FnStateProvider: return -1, errStateProviderPrecedence + case FnTimerProvider: + return -1, errTimerProviderPrecedence case FnIter, FnReIter, FnValue, FnMultiMap: return psInput, nil case FnEmit: diff --git a/sdks/go/pkg/beam/core/graph/edge.go b/sdks/go/pkg/beam/core/graph/edge.go index a9f1c8a092b0..86891114dd0e 100644 --- a/sdks/go/pkg/beam/core/graph/edge.go +++ b/sdks/go/pkg/beam/core/graph/edge.go @@ -156,6 +156,7 @@ type MultiEdge struct { DoFn *DoFn // ParDo RestrictionCoder *coder.Coder // SplittableParDo StateCoders map[string]*coder.Coder // Stateful ParDo + TimerCoders map[string]*coder.Coder // Stateful ParDo CombineFn *CombineFn // Combine AccumCoder *coder.Coder // Combine Value []byte // Impulse diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go index 54cc02e07b3d..1bb6548e8506 100644 --- a/sdks/go/pkg/beam/core/graph/fn.go +++ b/sdks/go/pkg/beam/core/graph/fn.go @@ -22,6 +22,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/funcx" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/state" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/timers" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" @@ -167,13 +168,13 @@ const ( initialWatermarkEstimatorStateName = "InitialWatermarkEstimatorState" watermarkEstimatorStateName = "WatermarkEstimatorState" + onTimerName = "OnTimer" + createAccumulatorName = "CreateAccumulator" addInputName = "AddInput" mergeAccumulatorsName = "MergeAccumulators" extractOutputName = "ExtractOutput" compactName = "Compact" - - // TODO: ViewFn, etc. ) var doFnNames = []string{ @@ -304,6 +305,40 @@ func (f *DoFn) PipelineState() []state.PipelineState { return s } +type PipelineTimer interface { + TimerFamily() string + TimerDomain() timers.TimeDomainEnum +} + +var ( + _ PipelineTimer = timers.EventTime{} + _ PipelineTimer = timers.ProcessingTime{} +) + +func (f *DoFn) OnTimerFn() (*funcx.Fn, bool) { + m, ok := f.methods[onTimerName] + return m, ok +} + +func (f *DoFn) PipelineTimers() []PipelineTimer { + var t []PipelineTimer + if f.Recv == nil { + return t + } + + v := reflect.Indirect(reflect.ValueOf(f.Recv)) + + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + if f.CanInterface() { + if pt, ok := f.Interface().(PipelineTimer); ok { + t = append(t, pt) + } + } + } + return t +} + // SplittableDoFn represents a DoFn implementing SDF methods. type SplittableDoFn DoFn @@ -607,6 +642,11 @@ func AsDoFn(fn *Fn, numMainIn mainInputs) (*DoFn, error) { return nil, addContext(err, fn) } + err = validateTimer(doFn) + if err != nil { + return nil, addContext(err, fn) + } + return doFn, nil } @@ -1350,6 +1390,42 @@ func validateState(fn *DoFn, numIn mainInputs) error { return nil } +func validateTimer(fn *DoFn) error { + if fn.Fn == nil { + return nil + } + + pt := fn.PipelineTimers() + + if _, ok := fn.Fn.TimerProvider(); ok { + if len(pt) == 0 { + err := errors.Errorf("ProcessElement uses a TimerProvider, but no timer struct-tags are attached to the DoFn") + return errors.SetTopLevelMsgf(err, "ProcessElement uses a TimerProvider, but no timer struct-tags are attached to the DoFn"+ + ", Ensure that you are including the timer structs you're using to set/clear global state as uppercase member variables") + } + timerKeys := make(map[string]PipelineTimer) + for _, t := range pt { + k := t.TimerFamily() + if timer, ok := timerKeys[k]; ok { + err := errors.Errorf("Duplicate timer key %v", k) + return errors.SetTopLevelMsgf(err, "Duplicate timer key %v used by %v and %v. Ensure that keys are unique per DoFn", k, timer, t) + } else { + timerKeys[k] = t + } + } + } else { + if len(pt) > 0 { + err := errors.Errorf("ProcessElement doesn't use a TimerProvider, but Timer Struct is attached to the DoFn: %v", pt) + return errors.SetTopLevelMsgf(err, "ProcessElement doesn't use a TimerProvider, but Timer Struct is attached to the DoFn: %v"+ + ", Ensure that you are using the TimerProvider to set/clear the timers.", pt) + } + } + + // DO NOT SUBMIT: Require an OnTimer method existing + + return nil +} + // CombineFn represents a CombineFn. type CombineFn Fn diff --git a/sdks/go/pkg/beam/core/runtime/exec/coder.go b/sdks/go/pkg/beam/core/runtime/exec/coder.go index 39248f7f5ac3..4b750a9be98e 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/coder.go +++ b/sdks/go/pkg/beam/core/runtime/exec/coder.go @@ -1295,6 +1295,9 @@ func decodeTimer(dec ElementDecoder, win WindowDecoder, r io.Reader) (typex.Time if err != nil { return tm, errors.WithContext(err, "error decoding timer key") } + + // TODO Change to not type assert once general timers key + // fix is done. tm.Key = fv.Elm.(string) s, err := coder.DecodeStringUTF8(r) diff --git a/sdks/go/pkg/beam/core/runtime/exec/coder_test.go b/sdks/go/pkg/beam/core/runtime/exec/coder_test.go index 75d18e533cf1..533dbb458741 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/coder_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/coder_test.go @@ -21,16 +21,16 @@ import ( "reflect" "testing" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/coderx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" + "github.com/google/go-cmp/cmp" ) func TestCoders(t *testing.T) { - for _, test := range []struct { + tests := []struct { coder *coder.Coder val *FullValue }{ @@ -92,8 +92,22 @@ func TestCoders(t *testing.T) { }, { coder: coder.NewIntervalWindowCoder(), val: &FullValue{Elm: window.IntervalWindow{Start: 0, End: 100}}, + }, { + coder: coder.NewT(coder.NewString(), coder.NewGlobalWindow()), + val: &FullValue{ + Elm: typex.TimerMap{ + Key: "key", + Tag: "tag", + Windows: []typex.Window{window.GlobalWindow{}}, + Clear: false, + FireTimestamp: 1234, + HoldTimestamp: 5678, + Pane: typex.PaneInfo{IsFirst: true, IsLast: true, Timing: typex.PaneUnknown, Index: 0, NonSpeculativeIndex: 0}, + }, + }, }, - } { + } + for _, test := range tests { t.Run(fmt.Sprintf("%v", test.coder), func(t *testing.T) { var buf bytes.Buffer enc := MakeElementEncoder(test.coder) @@ -132,7 +146,7 @@ func compareFV(t *testing.T, got *FullValue, want *FullValue) { if gotFv, ok := got.Elm.(*FullValue); ok { compareFV(t, gotFv, wantFv) } - } else if got, want := got.Elm, want.Elm; got != want { + } else if got, want := got.Elm, want.Elm; !cmp.Equal(want, got) { t.Errorf("got %v [type: %s], want %v [type %s]", got, reflect.TypeOf(got), wantFv, reflect.TypeOf(wantFv)) } diff --git a/sdks/go/pkg/beam/core/runtime/exec/fn.go b/sdks/go/pkg/beam/core/runtime/exec/fn.go index eaf9df81e4aa..d0fdb8e36305 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/fn.go +++ b/sdks/go/pkg/beam/core/runtime/exec/fn.go @@ -28,6 +28,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" + "github.com/apache/beam/sdks/v2/go/pkg/beam/log" ) //go:generate specialize --input=fn_arity.tmpl @@ -69,16 +70,47 @@ func (bf *bundleFinalizer) RegisterCallback(t time.Duration, cb func() error) { } // Invoke invokes the fn with the given values. The extra values must match the non-main +// side input and emitters. It returns the direct output, if any.\ +// +// Deprecated: prefer InvokeWithOpts +func Invoke(ctx context.Context, pn typex.PaneInfo, ws []typex.Window, ts typex.EventTime, fn *funcx.Fn, opt *MainInput, bf *bundleFinalizer, we sdf.WatermarkEstimator, sa UserStateAdapter, sr StateReader, extra ...any) (*FullValue, error) { + if fn == nil { + return nil, nil // ok: nothing to Invoke + } + inv := newInvoker(fn) + return inv.invokeWithOpts(ctx, pn, ws, ts, InvokeOpts{opt: opt, bf: bf, we: we, sa: sa, sr: sr, extra: extra}) +} + +// InvokeOpts are optional parameters to invoke a Fn. +type InvokeOpts struct { + opt *MainInput + bf *bundleFinalizer + we sdf.WatermarkEstimator + sa UserStateAdapter + sr StateReader + ta UserTimerAdapter + tm DataManager + extra []any +} + +// InvokeWithOpts invokes the fn with the given values. The extra values must match the non-main // side input and emitters. It returns the direct output, if any. -func Invoke(ctx context.Context, pn typex.PaneInfo, ws []typex.Window, ts typex.EventTime, fn *funcx.Fn, opt *MainInput, bf *bundleFinalizer, we sdf.WatermarkEstimator, sa UserStateAdapter, reader StateReader, extra ...any) (*FullValue, error) { +func InvokeWithOpts(ctx context.Context, fn *funcx.Fn, pn typex.PaneInfo, ws []typex.Window, ts typex.EventTime, opts InvokeOpts) (*FullValue, error) { if fn == nil { return nil, nil // ok: nothing to Invoke } inv := newInvoker(fn) - return inv.Invoke(ctx, pn, ws, ts, opt, bf, we, sa, reader, extra...) + return inv.invokeWithOpts(ctx, pn, ws, ts, opts) +} + +// InvokeWithOptsWithoutEventTime runs the given function at time 0 in the global window. +func InvokeWithOptsWithoutEventTime(ctx context.Context, fn *funcx.Fn, opts InvokeOpts) (*FullValue, error) { + return InvokeWithOpts(ctx, fn, typex.NoFiringPane(), window.SingleGlobalWindow, mtime.ZeroTimestamp, opts) } // InvokeWithoutEventTime runs the given function at time 0 in the global window. +// +// Deprecated: prefer InvokeWithOptsWithoutEventTime func InvokeWithoutEventTime(ctx context.Context, fn *funcx.Fn, opt *MainInput, bf *bundleFinalizer, we sdf.WatermarkEstimator, sa UserStateAdapter, reader StateReader, extra ...any) (*FullValue, error) { if fn == nil { return nil, nil // ok: nothing to Invoke @@ -93,10 +125,12 @@ type invoker struct { fn *funcx.Fn args []any sp *stateProvider + tp *timerProvider + // TODO(lostluck): 2018/07/06 consider replacing with a slice of functions to run over the args slice, as an improvement. - ctxIdx, pnIdx, wndIdx, etIdx, bfIdx, weIdx, spIdx int // specialized input indexes - outEtIdx, outPcIdx, outErrIdx int // specialized output indexes - in, out []int // general indexes + ctxIdx, pnIdx, wndIdx, etIdx, bfIdx, weIdx, spIdx, tpIdx int // specialized input indexes + outEtIdx, outPcIdx, outErrIdx int // specialized output indexes + in, out []int // general indexes ret FullValue // ret is a cached allocation for passing to the next Unit. Units never modify the passed in FullValue. elmConvert, elm2Convert func(any) any // Cached conversion functions, which assums this invoker is always used with the same parameter types. @@ -129,6 +163,9 @@ func newInvoker(fn *funcx.Fn) *invoker { if n.spIdx, ok = fn.StateProvider(); !ok { n.spIdx = -1 } + if n.tpIdx, ok = fn.TimerProvider(); !ok { + n.tpIdx = -1 + } if n.outEtIdx, ok = fn.OutEventTime(); !ok { n.outEtIdx = -1 } @@ -163,7 +200,11 @@ func (n *invoker) InvokeWithoutEventTime(ctx context.Context, opt *MainInput, bf // Invoke invokes the fn with the given values. The extra values must match the non-main // side input and emitters. It returns the direct output, if any. -func (n *invoker) Invoke(ctx context.Context, pn typex.PaneInfo, ws []typex.Window, ts typex.EventTime, opt *MainInput, bf *bundleFinalizer, we sdf.WatermarkEstimator, sa UserStateAdapter, reader StateReader, extra ...any) (*FullValue, error) { +func (n *invoker) Invoke(ctx context.Context, pn typex.PaneInfo, ws []typex.Window, ts typex.EventTime, opt *MainInput, bf *bundleFinalizer, we sdf.WatermarkEstimator, sa UserStateAdapter, sr StateReader, extra ...any) (*FullValue, error) { + return n.invokeWithOpts(ctx, pn, ws, ts, InvokeOpts{opt: opt, bf: bf, we: we, sa: sa, sr: sr, extra: extra}) +} + +func (n *invoker) invokeWithOpts(ctx context.Context, pn typex.PaneInfo, ws []typex.Window, ts typex.EventTime, opts InvokeOpts) (*FullValue, error) { // (1) Populate contexts // extract these to make things easier to read. args := n.args @@ -178,7 +219,7 @@ func (n *invoker) Invoke(ctx context.Context, pn typex.PaneInfo, ws []typex.Wind } if n.wndIdx >= 0 { if len(ws) != 1 { - return nil, errors.Errorf("DoFns that observe windows must be invoked with single window: %v", opt.Key.Windows) + return nil, errors.Errorf("DoFns that observe windows must be invoked with single window: %v", opts.opt.Key.Windows) } args[n.wndIdx] = ws[0] } @@ -186,14 +227,14 @@ func (n *invoker) Invoke(ctx context.Context, pn typex.PaneInfo, ws []typex.Wind args[n.etIdx] = ts } if n.bfIdx >= 0 { - args[n.bfIdx] = bf + args[n.bfIdx] = opts.bf } if n.weIdx >= 0 { - args[n.weIdx] = we + args[n.weIdx] = opts.we } if n.spIdx >= 0 { - sp, err := sa.NewStateProvider(ctx, reader, ws[0], opt) + sp, err := opts.sa.NewStateProvider(ctx, opts.sr, ws[0], opts.opt) if err != nil { return nil, err } @@ -201,29 +242,38 @@ func (n *invoker) Invoke(ctx context.Context, pn typex.PaneInfo, ws []typex.Wind args[n.spIdx] = n.sp } + if n.tpIdx >= 0 { + log.Debugf(ctx, "timercall %+v", opts) + tp, err := opts.ta.NewTimerProvider(ctx, opts.tm, ws, opts.opt) + if err != nil { + return nil, err + } + n.tp = &tp + args[n.tpIdx] = n.tp + } // (2) Main input from value, if any. i := 0 - if opt != nil { - if opt.RTracker != nil { - args[in[i]] = opt.RTracker + if opts.opt != nil { + if opts.opt.RTracker != nil { + args[in[i]] = opts.opt.RTracker i++ } if n.elmConvert == nil { - from := reflect.TypeOf(opt.Key.Elm) + from := reflect.TypeOf(opts.opt.Key.Elm) n.elmConvert = ConvertFn(from, fn.Param[in[i]].T) } - args[in[i]] = n.elmConvert(opt.Key.Elm) + args[in[i]] = n.elmConvert(opts.opt.Key.Elm) i++ - if opt.Key.Elm2 != nil { + if opts.opt.Key.Elm2 != nil { if n.elm2Convert == nil { - from := reflect.TypeOf(opt.Key.Elm2) + from := reflect.TypeOf(opts.opt.Key.Elm2) n.elm2Convert = ConvertFn(from, fn.Param[in[i]].T) } - args[in[i]] = n.elm2Convert(opt.Key.Elm2) + args[in[i]] = n.elm2Convert(opts.opt.Key.Elm2) i++ } - for _, iter := range opt.Values { + for _, iter := range opts.opt.Values { param := fn.Param[in[i]] if param.Kind != funcx.FnIter { @@ -243,7 +293,7 @@ func (n *invoker) Invoke(ctx context.Context, pn typex.PaneInfo, ws []typex.Wind } // (3) Precomputed side input and emitters (or other output). - for _, arg := range extra { + for _, arg := range opts.extra { args[in[i]] = arg i++ } diff --git a/sdks/go/pkg/beam/core/runtime/exec/pardo.go b/sdks/go/pkg/beam/core/runtime/exec/pardo.go index 8cb5342ded87..be5c71f2c75a 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/pardo.go +++ b/sdks/go/pkg/beam/core/runtime/exec/pardo.go @@ -48,8 +48,10 @@ type ParDo struct { bf *bundleFinalizer we sdf.WatermarkEstimator - reader StateReader - cache *cacheElm + Timer UserTimerAdapter + timerManager DataManager + reader StateReader + cache *cacheElm status Status err errorx.GuardedError @@ -88,7 +90,7 @@ func (n *ParDo) Up(ctx context.Context) error { // Subsequent bundles might run this same node, and the context here would be // incorrectly refering to the older bundleId. setupCtx := metrics.SetPTransformID(ctx, n.PID) - if _, err := InvokeWithoutEventTime(setupCtx, n.Fn.SetupFn(), nil, nil, nil, nil, nil); err != nil { + if _, err := InvokeWithOptsWithoutEventTime(setupCtx, n.Fn.SetupFn(), InvokeOpts{}); err != nil { return n.fail(err) } @@ -111,6 +113,7 @@ func (n *ParDo) StartBundle(ctx context.Context, id string, data DataContext) er } n.status = Active n.reader = data.State + n.timerManager = data.Data // Allocating contexts all the time is expensive, but we seldom re-write them, // and never accept modified contexts from users, so we will cache them per-bundle // per-unit, to avoid the constant allocation overhead. @@ -236,6 +239,7 @@ func (n *ParDo) FinishBundle(_ context.Context) error { } n.reader = nil n.cache = nil + n.timerManager = nil if err := MultiFinishBundle(n.ctx, n.Out...); err != nil { return n.fail(err) @@ -251,8 +255,9 @@ func (n *ParDo) Down(ctx context.Context) error { n.status = Down n.reader = nil n.cache = nil + n.timerManager = nil - if _, err := InvokeWithoutEventTime(ctx, n.Fn.TeardownFn(), nil, nil, nil, nil, nil); err != nil { + if _, err := InvokeWithOptsWithoutEventTime(ctx, n.Fn.TeardownFn(), InvokeOpts{}); err != nil { n.err.TrySetError(err) } return n.err.Error() @@ -356,7 +361,7 @@ func (n *ParDo) invokeProcessFn(ctx context.Context, pn typex.PaneInfo, ws []typ if err := n.preInvoke(ctx, ws, ts); err != nil { return nil, err } - val, err = n.inv.Invoke(ctx, pn, ws, ts, opt, n.bf, n.we, n.UState, n.reader, n.cache.extra...) + val, err = n.inv.invokeWithOpts(ctx, pn, ws, ts, InvokeOpts{opt: opt, bf: n.bf, we: n.we, sa: n.UState, sr: n.reader, ta: n.Timer, tm: n.timerManager, extra: n.cache.extra}) if err != nil { return nil, err } diff --git a/sdks/go/pkg/beam/core/runtime/exec/timers.go b/sdks/go/pkg/beam/core/runtime/exec/timers.go new file mode 100644 index 000000000000..0ceed0d2ebd6 --- /dev/null +++ b/sdks/go/pkg/beam/core/runtime/exec/timers.go @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package exec + +import ( + "context" + "fmt" + "io" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/timers" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" +) + +type UserTimerAdapter interface { + NewTimerProvider(ctx context.Context, manager DataManager, w []typex.Window, element *MainInput) (timerProvider, error) +} + +type userTimerAdapter struct { + SID StreamID + wc WindowEncoder + kc ElementEncoder + timerIDToCoder map[string]*coder.Coder + C *coder.Coder +} + +func NewUserTimerAdapter(sID StreamID, c *coder.Coder, timerCoders map[string]*coder.Coder) UserTimerAdapter { + if !coder.IsW(c) { + panic(fmt.Sprintf("expected WV coder for user timer %v: %v", sID, c)) + } + + wc := MakeWindowEncoder(c.Window) + var kc ElementEncoder + if coder.IsKV(coder.SkipW(c)) { + kc = MakeElementEncoder(coder.SkipW(c).Components[0]) + } + + return &userTimerAdapter{SID: sID, wc: wc, kc: kc, C: c, timerIDToCoder: timerCoders} +} + +func (u *userTimerAdapter) NewTimerProvider(ctx context.Context, manager DataManager, w []typex.Window, element *MainInput) (timerProvider, error) { + if u.kc == nil { + return timerProvider{}, fmt.Errorf("cannot make a state provider for an unkeyed input %v", element) + } + elementKey, err := EncodeElement(u.kc, element.Key.Elm) + if err != nil { + return timerProvider{}, err + } + + // win, err := EncodeWindow(u.wc, w[0]) + // if err != nil { + // return timerProvider{}, err + // } + tp := timerProvider{ + ctx: ctx, + tm: manager, + elementKey: elementKey, + SID: u.SID, + window: w, + writersByFamily: make(map[string]io.Writer), + codersByFamily: u.timerIDToCoder, + } + + return tp, nil +} + +type timerProvider struct { + ctx context.Context + tm DataManager + SID StreamID + elementKey []byte + window []typex.Window + + pn typex.PaneInfo + + writersByFamily map[string]io.Writer + codersByFamily map[string]*coder.Coder +} + +func (p *timerProvider) getWriter(family string) (io.Writer, error) { + if w, ok := p.writersByFamily[family]; ok { + return w, nil + } else { + w, err := p.tm.OpenTimerWrite(p.ctx, p.SID, family) + if err != nil { + return nil, err + } + p.writersByFamily[family] = w + return p.writersByFamily[family], nil + } +} + +func (p *timerProvider) Set(t timers.TimerMap) { + w, err := p.getWriter(t.Family) + if err != nil { + panic(err) + } + tm := typex.TimerMap{ + Key: string(p.elementKey), + Tag: t.Tag, + Windows: p.window, + Clear: t.Clear, + FireTimestamp: t.FireTimestamp, + HoldTimestamp: t.HoldTimestamp, + Pane: p.pn, + } + fv := FullValue{Elm: tm} + enc := MakeElementEncoder(coder.SkipW(p.codersByFamily[t.Family])) + if err := enc.Encode(&fv, w); err != nil { + panic(err) + } +} diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go b/sdks/go/pkg/beam/core/runtime/exec/translate.go index 78cf0ef65cd6..0403f0ab0abb 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/translate.go +++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go @@ -16,6 +16,7 @@ package exec import ( + "context" "fmt" "math/rand" "strconv" @@ -30,6 +31,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/protox" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" + "github.com/apache/beam/sdks/v2/go/pkg/beam/log" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/golang/protobuf/proto" @@ -462,6 +464,7 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) { var data string var sides map[string]*pipepb.SideInput var userState map[string]*pipepb.StateSpec + var userTimers map[string]*pipepb.TimerFamilySpec switch urn { case graphx.URNParDo, urnPairWithRestriction, @@ -475,6 +478,7 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) { data = string(pardo.GetDoFn().GetPayload()) sides = pardo.GetSideInputs() userState = pardo.GetStateSpecs() + userTimers = pardo.GetTimerFamilySpecs() case urnPerKeyCombinePre, urnPerKeyCombineMerge, urnPerKeyCombineExtract, urnPerKeyCombineConvert: var cmb pipepb.CombinePayload if err := proto.Unmarshal(payload, &cmb); err != nil { @@ -587,6 +591,24 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) { n.UState = NewUserStateAdapter(sid, coder.NewW(ec, wc), stateIDToCoder, stateIDToKeyCoder, stateIDToCombineFn) } } + if len(userTimers) > 0 { + log.Debugf(context.TODO(), "userTimers %+v", userTimers) + timerIDToCoder := make(map[string]*coder.Coder) + for key, spec := range userTimers { + cID := spec.GetTimerFamilyCoderId() + c, err := b.coders.Coder(cID) + if err != nil { + return nil, err + } + timerIDToCoder[key] = c + sID := StreamID{Port: Port{URL: b.desc.GetTimerApiServiceDescriptor().GetUrl()}, PtransformID: id.to} + ec, wc, err := b.makeCoderForPCollection(input[0]) + if err != nil { + return nil, err + } + n.Timer = NewUserTimerAdapter(sID, coder.NewW(ec, wc), timerIDToCoder) + } + } for i := 1; i < len(input); i++ { // TODO(https://github.com/apache/beam/issues/18602) Handle ViewFns for side inputs diff --git a/sdks/go/pkg/beam/core/runtime/graphx/coder.go b/sdks/go/pkg/beam/core/runtime/graphx/coder.go index 498e145f5db4..34b44dd85920 100644 --- a/sdks/go/pkg/beam/core/runtime/graphx/coder.go +++ b/sdks/go/pkg/beam/core/runtime/graphx/coder.go @@ -73,7 +73,7 @@ func knownStandardCoders() []string { urnIntervalWindow, urnRowCoder, urnNullableCoder, - // TODO(https://github.com/apache/beam/issues/20510): Add urnTimerCoder once finalized. + urnTimerCoder, } } diff --git a/sdks/go/pkg/beam/core/runtime/graphx/serialize.go b/sdks/go/pkg/beam/core/runtime/graphx/serialize.go index 65ad1bdd0600..fb62e18e1306 100644 --- a/sdks/go/pkg/beam/core/runtime/graphx/serialize.go +++ b/sdks/go/pkg/beam/core/runtime/graphx/serialize.go @@ -28,6 +28,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime" v1pb "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx/v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/state" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/timers" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/jsonx" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" @@ -523,6 +524,8 @@ func tryEncodeSpecial(t reflect.Type) (v1pb.Type_Special, bool) { return v1pb.Type_BUNDLEFINALIZATION, true case state.ProviderType: return v1pb.Type_STATEPROVIDER, true + case timers.ProviderType: + return v1pb.Type_TIMERPROVIDER, true case typex.KVType: return v1pb.Type_KV, true case typex.CoGBKType: @@ -689,6 +692,8 @@ func decodeSpecial(s v1pb.Type_Special) (reflect.Type, error) { return typex.BundleFinalizationType, nil case v1pb.Type_STATEPROVIDER: return state.ProviderType, nil + case v1pb.Type_TIMERPROVIDER: + return timers.ProviderType, nil case v1pb.Type_KV: return typex.KVType, nil case v1pb.Type_COGBK: diff --git a/sdks/go/pkg/beam/core/runtime/graphx/translate.go b/sdks/go/pkg/beam/core/runtime/graphx/translate.go index 68074ac7eb3a..a427f22f8824 100644 --- a/sdks/go/pkg/beam/core/runtime/graphx/translate.go +++ b/sdks/go/pkg/beam/core/runtime/graphx/translate.go @@ -578,6 +578,21 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) { } payload.StateSpecs = stateSpecs } + if _, ok := edge.Edge.DoFn.ProcessElementFn().TimerProvider(); ok { + m.requirements[URNRequiresStatefulProcessing] = true + timerSpecs := make(map[string]*pipepb.TimerFamilySpec) + for _, pt := range edge.Edge.DoFn.PipelineTimers() { + coderID, err := m.coders.Add(edge.Edge.TimerCoders[pt.TimerFamily()]) + if err != nil { + return handleErr(err) + } + timerSpecs[pt.TimerFamily()] = &pipepb.TimerFamilySpec{ + TimeDomain: pipepb.TimeDomain_Enum(pt.TimerDomain()), + TimerFamilyCoderId: coderID, + } + } + payload.TimerFamilySpecs = timerSpecs + } spec = &pipepb.FunctionSpec{Urn: URNParDo, Payload: protox.MustEncode(payload)} annotations = edge.Edge.DoFn.Annotations() diff --git a/sdks/go/pkg/beam/core/timers/timers.go b/sdks/go/pkg/beam/core/timers/timers.go new file mode 100644 index 000000000000..130564790ca6 --- /dev/null +++ b/sdks/go/pkg/beam/core/timers/timers.go @@ -0,0 +1,119 @@ +// Licensed to the Apache SoFiringTimestampware Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, soFiringTimestampware +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package timer provides structs for reading and writing timers. +package timers + +import ( + "context" + "reflect" + "time" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/log" +) + +var ( + ProviderType = reflect.TypeOf((*Provider)(nil)).Elem() +) + +type TimeDomainEnum int32 + +const ( + TimeDomainUnspecified TimeDomainEnum = 0 + TimeDomainEventTime TimeDomainEnum = 1 + TimeDomainProcessingTime TimeDomainEnum = 2 +) + +type TimerMap struct { + Family string + Tag string + Clear bool + FireTimestamp, HoldTimestamp mtime.Time +} + +type Provider interface { + Set(t TimerMap) +} + +type EventTime struct { + // need to export them otherwise the key comes out empty in execution? + Family string +} + +func (t EventTime) Set(p Provider, firingTimestamp time.Time) { + fire := mtime.FromTime(firingTimestamp) + // Hold timestamp must match fireing timestamp if not otherwise set. + p.Set(TimerMap{Family: t.Family, FireTimestamp: fire, HoldTimestamp: fire}) +} + +type Opts struct { + Tag string + Hold time.Time +} + +func (t *EventTime) SetWithOpts(p Provider, firingTimestamp time.Time, opts Opts) { + fire := mtime.FromTime(firingTimestamp) + // Hold timestamp must match fireing timestamp if not otherwise set. + tm := TimerMap{Family: t.Family, Tag: opts.Tag, FireTimestamp: fire, HoldTimestamp: fire} + if !opts.Hold.IsZero() { + tm.HoldTimestamp = mtime.FromTime(opts.Hold) + } + p.Set(tm) +} + +func (e EventTime) TimerFamily() string { + return e.Family +} + +func (e EventTime) TimerDomain() TimeDomainEnum { + return TimeDomainEventTime +} + +type ProcessingTime struct { + Family string +} + +func (e ProcessingTime) TimerFamily() string { + return e.Family +} + +func (e ProcessingTime) TimerDomain() TimeDomainEnum { + return TimeDomainProcessingTime +} + +func (t ProcessingTime) Set(p Provider, firingTimestamp time.Time) { + log.Infof(context.Background(), "setting timer in core/timer: %+v", t) + fire := mtime.FromTime(firingTimestamp) + p.Set(TimerMap{Family: t.Family, FireTimestamp: fire, HoldTimestamp: fire}) +} + +func (t ProcessingTime) SetWithOpts(p Provider, firingTimestamp time.Time, opts Opts) { + fire := mtime.FromTime(firingTimestamp) + // Hold timestamp must match fireing timestamp if not otherwise set. + tm := TimerMap{Family: t.Family, Tag: opts.Tag, FireTimestamp: fire, HoldTimestamp: fire} + if !opts.Hold.IsZero() { + tm.HoldTimestamp = mtime.FromTime(opts.Hold) + } + p.Set(tm) +} + +func InEventTime(Key string) EventTime { + return EventTime{Family: Key} +} + +func InProcessingTime(Key string) ProcessingTime { + return ProcessingTime{Family: Key} +} diff --git a/sdks/go/pkg/beam/core/typex/class.go b/sdks/go/pkg/beam/core/typex/class.go index e112495ee986..63e4543a3e54 100644 --- a/sdks/go/pkg/beam/core/typex/class.go +++ b/sdks/go/pkg/beam/core/typex/class.go @@ -120,6 +120,7 @@ func isConcrete(t reflect.Type, visited map[uintptr]bool) error { t == EventTimeType || t.Implements(WindowType) || t == PaneInfoType || + t == TimersType || t == BundleFinalizationType || t == reflectx.Error || t == reflectx.Context || diff --git a/sdks/go/pkg/beam/core/typex/special.go b/sdks/go/pkg/beam/core/typex/special.go index 935371225848..4cce19dcfbd3 100644 --- a/sdks/go/pkg/beam/core/typex/special.go +++ b/sdks/go/pkg/beam/core/typex/special.go @@ -107,7 +107,8 @@ type Timers struct { // TimerMap is a placeholder for timer details used in encoding/decoding. type TimerMap struct { - Key, Tag string + Key string + Tag string Windows []Window // []typex.Window Clear bool FireTimestamp, HoldTimestamp mtime.Time diff --git a/sdks/go/pkg/beam/pardo.go b/sdks/go/pkg/beam/pardo.go index 1314836dfdc2..5de2854c0387 100644 --- a/sdks/go/pkg/beam/pardo.go +++ b/sdks/go/pkg/beam/pardo.go @@ -116,6 +116,18 @@ func TryParDo(s Scope, dofn any, col PCollection, opts ...Option) ([]PCollection } } + wc := inWfn.Coder() + pipelineTimers := fn.PipelineTimers() + if len(pipelineTimers) > 0 { + // TODO(riteshghorse): replace the coder with type of key + c := coder.NewString() + edge.TimerCoders = make(map[string]*coder.Coder) + for _, pt := range pipelineTimers { + tc := coder.NewT(c, wc) + edge.TimerCoders[pt.TimerFamily()] = tc + } + } + var ret []PCollection for _, out := range edge.Output { c := PCollection{out.To} From f2a1b8ae45a4e1d0e3a35bfbe99383cc2d3f6d40 Mon Sep 17 00:00:00 2001 From: lostluck <13907733+lostluck@users.noreply.github.com> Date: Wed, 1 Mar 2023 12:50:04 -0800 Subject: [PATCH 2/5] Timers received! --- sdks/go/examples/streaming_wordcap/wordcap.go | 211 +++++++++- sdks/go/pkg/beam/core/runtime/exec/data.go | 16 +- .../pkg/beam/core/runtime/exec/datasource.go | 161 +++++--- .../beam/core/runtime/exec/datasource_test.go | 9 +- sdks/go/pkg/beam/core/runtime/exec/fn.go | 6 +- sdks/go/pkg/beam/core/runtime/exec/timers.go | 18 +- .../pkg/beam/core/runtime/harness/datamgr.go | 374 +++++++++++------- .../beam/core/runtime/harness/datamgr_test.go | 141 ------- sdks/go/pkg/beam/core/timers/timers.go | 2 +- 9 files changed, 560 insertions(+), 378 deletions(-) diff --git a/sdks/go/examples/streaming_wordcap/wordcap.go b/sdks/go/examples/streaming_wordcap/wordcap.go index ddd9eab4e5f8..441d7f6d3244 100644 --- a/sdks/go/examples/streaming_wordcap/wordcap.go +++ b/sdks/go/examples/streaming_wordcap/wordcap.go @@ -26,16 +26,21 @@ package main import ( "context" "flag" + "fmt" "os" - "strings" + "time" "github.com/apache/beam/sdks/v2/go/pkg/beam" - "github.com/apache/beam/sdks/v2/go/pkg/beam/io/pubsubio" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/state" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/timers" + "github.com/apache/beam/sdks/v2/go/pkg/beam/io/rtrackers/offsetrange" "github.com/apache/beam/sdks/v2/go/pkg/beam/log" - "github.com/apache/beam/sdks/v2/go/pkg/beam/options/gcpopts" - "github.com/apache/beam/sdks/v2/go/pkg/beam/util/pubsubx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/x/beamx" "github.com/apache/beam/sdks/v2/go/pkg/beam/x/debug" + "golang.org/x/exp/slog" ) var ( @@ -50,32 +55,204 @@ var ( } ) +type Stateful struct { + ElementBag state.Bag[string] + TimerTime state.Value[int64] + MinTime state.Combining[int64, int64, int64] + + OutputState timers.ProcessingTime +} + +func NewStateful() *Stateful { + return &Stateful{ + ElementBag: state.MakeBagState[string]("elementBag"), + TimerTime: state.MakeValueState[int64]("timerTime"), + MinTime: state.MakeCombiningState[int64, int64, int64]("minTiInBag", func(a, b int64) int64 { + if a < b { + return a + } + return b + }), + + OutputState: timers.InProcessingTime("outputState"), + } +} + +func (s *Stateful) ProcessElement(ctx context.Context, ts beam.EventTime, sp state.Provider, tp timers.Provider, key, word string, emit func(string, string)) error { + log.Infof(ctx, "stateful dofn invoked key: %v word: %v", key, word) + + s.ElementBag.Add(sp, word) + s.MinTime.Add(sp, int64(ts)) + + toFire, ok, err := s.TimerTime.Read(sp) + if err != nil { + return err + } + if !ok { + toFire = int64(mtime.Now().Add(1 * time.Minute)) + } + minTime, _, err := s.MinTime.Read(sp) + if err != nil { + return err + } + + s.OutputState.SetWithOpts(tp, mtime.Time(toFire).ToTime(), timers.Opts{Hold: mtime.Time(minTime).ToTime()}) + s.TimerTime.Write(sp, toFire) + log.Infof(ctx, "stateful dofn key: %v word: %v, timer: %v, minTime: %v", key, word, toFire, minTime) + + // // Get the Value stored in our state + // val, ok, err := s.Val.Read(p) + // if err != nil { + // return err + // } + // log.Infof(ctx, "stateful dofn state read key: %v word: %v val: %v", key, word, val) + // if !ok { + // s.Val.Write(p, 1) + // } else { + // s.Val.Write(p, val+1) + // } + + // if val > 5 { + // log.Infof(ctx, "stateful dofn clearing key: %v word: %v val: %v", key, word, val) + // // Example of clearing and starting again with an empty bag + // s.Val.Clear(p) + // } + // fire := time.Now().Add(10 * time.Second) + + // log.Infof(ctx, "stateful dofn timer family: %v fire: %v now: %v key: %v word: %v", s.Fire.Family, fire, time.Now(), key, word) + // s.Fire.Set(tp, fire) + + // emit(key, word) + + return nil +} + +type eventtimeSDFStream struct { + RestSize, Mod, Fixed int64 + Sleep time.Duration +} + +func (fn *eventtimeSDFStream) Setup() error { + return nil +} + +func (fn *eventtimeSDFStream) CreateInitialRestriction(v beam.T) offsetrange.Restriction { + return offsetrange.Restriction{Start: 0, End: fn.RestSize} +} + +func (fn *eventtimeSDFStream) SplitRestriction(v beam.T, r offsetrange.Restriction) []offsetrange.Restriction { + // No split + return []offsetrange.Restriction{r} +} + +func (fn *eventtimeSDFStream) RestrictionSize(v beam.T, r offsetrange.Restriction) float64 { + return r.Size() +} + +func (fn *eventtimeSDFStream) CreateTracker(r offsetrange.Restriction) *sdf.LockRTracker { + return sdf.NewLockRTracker(offsetrange.NewTracker(r)) +} + +func (fn *eventtimeSDFStream) ProcessElement(ctx context.Context, _ *CWE, rt *sdf.LockRTracker, v beam.T, emit func(beam.EventTime, int64)) sdf.ProcessContinuation { + r := rt.GetRestriction().(offsetrange.Restriction) + i := r.Start + if r.Size() < 1 { + log.Debugf(ctx, "size 0 restriction, stoping to process sentinel", slog.Any("value", v)) + return sdf.StopProcessing() + } + slog.Debug("emitting element to restriction", slog.Any("value", v), slog.Group("restriction", + slog.Any("value", v), + slog.Float64("size", r.Size()), + slog.Int64("pos", i), + )) + if rt.TryClaim(i) { + v := (i % fn.Mod) + fn.Fixed + emit(mtime.Now(), v) + } + return sdf.ResumeProcessingIn(fn.Sleep) +} + +func (fn *eventtimeSDFStream) InitialWatermarkEstimatorState(_ beam.EventTime, _ offsetrange.Restriction, _ beam.T) int64 { + return int64(mtime.MinTimestamp) +} + +func (fn *eventtimeSDFStream) CreateWatermarkEstimator(initialState int64) *CWE { + return &CWE{Watermark: initialState} +} + +func (fn *eventtimeSDFStream) WatermarkEstimatorState(e *CWE) int64 { + return e.Watermark +} + +type CWE struct { + Watermark int64 // uses int64, since the SDK prevent mtime.Time from serialization. +} + +func (e *CWE) CurrentWatermark() time.Time { + return mtime.Time(e.Watermark).ToTime() +} + +func (e *CWE) ObserveTimestamp(ts time.Time) { + // We add 10 milliseconds to allow window boundaries to + // progress after emitting + e.Watermark = int64(mtime.FromTime(ts.Add(-90 * time.Millisecond))) +} + +func init() { + register.DoFn7x1[context.Context, beam.EventTime, state.Provider, timers.Provider, string, string, func(string, string), error](&Stateful{}) + register.Emitter2[string, string]() + register.DoFn5x1[context.Context, *CWE, *sdf.LockRTracker, beam.T, func(beam.EventTime, int64), sdf.ProcessContinuation]((*eventtimeSDFStream)(nil)) + register.Emitter2[beam.EventTime, int64]() +} + func main() { flag.Parse() beam.Init() ctx := context.Background() - project := gcpopts.GetProject(ctx) + //project := gcpopts.GetProject(ctx) log.Infof(ctx, "Publishing %v messages to: %v", len(data), *input) - defer pubsubx.CleanupTopic(ctx, project, *input) - sub, err := pubsubx.Publish(ctx, project, *input, data...) - if err != nil { - log.Fatal(ctx, err) - } + // defer pubsubx.CleanupTopic(ctx, project, *input) + // sub, err := pubsubx.Publish(ctx, project, *input, data...) + // if err != nil { + // log.Fatal(ctx, err) + // } - log.Infof(ctx, "Running streaming wordcap with subscription: %v", sub.ID()) + //log.Infof(ctx, "Running streaming wordcap with subscription: %v", sub.ID()) p := beam.NewPipeline() s := p.Root() - col := pubsubio.Read(s, project, *input, &pubsubio.ReadOptions{Subscription: sub.ID()}) - str := beam.ParDo(s, func(b []byte) string { - return (string)(b) - }, col) - cap := beam.ParDo(s, strings.ToUpper, str) - debug.Print(s, cap) + //col := pubsubio.Read(s, project, *input, &pubsubio.ReadOptions{Subscription: sub.ID()}) + // col = beam.WindowInto(s, window.NewFixedWindows(60*time.Second), col) + + // str := beam.ParDo(s, func(b []byte) string { + // return (string)(b) + // }, col) + + imp := beam.Impulse(s) + elms := 100 + out := beam.ParDo(s, &eventtimeSDFStream{ + Sleep: time.Second, + RestSize: int64(elms), + Mod: int64(elms), + Fixed: 1, + }, imp) + // out = beam.WindowInto(s, window.NewFixedWindows(10*time.Second), out) + str := beam.ParDo(s, func(b int64) string { + return fmt.Sprintf("element%03d", b) + }, out) + + keyed := beam.ParDo(s, func(ctx context.Context, ts beam.EventTime, s string) (string, string) { + log.Infof(ctx, "adding key ts: %v now: %v word: %v", ts.ToTime(), time.Now(), s) + return "test", s + }, str) + debug.Printf(s, "pre stateful: %v", keyed) + + timed := beam.ParDo(s, NewStateful(), keyed) + debug.Printf(s, "post stateful: %v", timed) if err := beamx.Run(context.Background(), p); err != nil { log.Exitf(ctx, "Failed to execute job: %v", err) diff --git a/sdks/go/pkg/beam/core/runtime/exec/data.go b/sdks/go/pkg/beam/core/runtime/exec/data.go index fdc1e368a52b..9380bb8902bd 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/data.go +++ b/sdks/go/pkg/beam/core/runtime/exec/data.go @@ -57,10 +57,12 @@ type SideCache interface { // DataManager manages external data byte streams. Each data stream can be // opened by one consumer only. type DataManager interface { - // OpenRead opens a closable byte stream for reading. - OpenRead(ctx context.Context, id StreamID) (io.ReadCloser, error) - // OpenWrite opens a closable byte stream for writing. + // OpenElementChan opens a channel for data and timers. + OpenElementChan(ctx context.Context, id StreamID) (<-chan Elements, error) + // OpenWrite opens a closable byte stream for data writing. OpenWrite(ctx context.Context, id StreamID) (io.WriteCloser, error) + // OpenTimerWrite opens a byte stream for writing timers + OpenTimerWrite(ctx context.Context, id StreamID, family string) (io.WriteCloser, error) } // StateReader is the interface for reading side input data. @@ -91,4 +93,10 @@ type StateReader interface { GetSideInputCache() SideCache } -// TODO(herohde) 7/20/2018: user state management +// Elements holds data or timers sent across the data channel. +// If TimerFamilyID is populated, it's a timer, otherwise it's +// data elements. +type Elements struct { + Data, Timers []byte + TimerFamilyID string +} diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource.go b/sdks/go/pkg/beam/core/runtime/exec/datasource.go index a6347fc8d0e1..12a5faeaf30d 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource.go @@ -97,17 +97,60 @@ func (n *DataSource) StartBundle(ctx context.Context, id string, data DataContex n.source = data.Data n.state = data.State n.start = time.Now() - n.index = -1 + n.index = 0 n.splitIdx = math.MaxInt64 n.mu.Unlock() return n.Out.StartBundle(ctx, id, data) } +// process handles converting elements from the data source to timers. +func (n *DataSource) process(ctx context.Context, data func(bcr *byteCountReader) error, timer func(bcr *byteCountReader, timerFamilyID string) error) error { + elms, err := n.source.OpenElementChan(ctx, n.SID) + if err != nil { + return err + } + + n.PCol.resetSize() // initialize the size distribution for this bundle. + var r bytes.Reader + + var byteCount int + bcr := byteCountReader{reader: &r, count: &byteCount} + for { + var err error + select { + case e, ok := <-elms: + // Channel closed, so time to exit + if !ok { + return nil + } + if len(e.Data) > 0 { + r.Reset(e.Data) + log.Debugf(ctx, "%v: received %v", n, e.Data) + err = data(&bcr) + } + if len(e.Timers) > 0 { + r.Reset(e.Timers) + err = timer(&bcr, e.TimerFamilyID) + } + case <-ctx.Done(): + return nil + } + + if err != nil { + if err != io.EOF { + return errors.Wrap(err, "source failed") + } + // io.EOF means the reader successfully drained + // We're ready for a new buffer. + } + } +} + // ByteCountReader is a passthrough reader that counts all the bytes read through it. // It trusts the nested reader to return accurate byte information. type byteCountReader struct { count *int - reader io.ReadCloser + reader io.Reader } func (r *byteCountReader) Read(p []byte) (int, error) { @@ -117,7 +160,10 @@ func (r *byteCountReader) Read(p []byte) (int, error) { } func (r *byteCountReader) Close() error { - return r.reader.Close() + if c, ok := r.reader.(io.Closer); ok { + c.Close() + } + return nil } func (r *byteCountReader) reset() int { @@ -128,15 +174,6 @@ func (r *byteCountReader) reset() int { // Process opens the data source, reads and decodes data, kicking off element processing. func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { - r, err := n.source.OpenRead(ctx, n.SID) - if err != nil { - return nil, err - } - defer r.Close() - n.PCol.resetSize() // initialize the size distribution for this bundle. - var byteCount int - bcr := byteCountReader{reader: r, count: &byteCount} - c := coder.SkipW(n.Coder) wc := MakeWindowDecoder(n.Coder.Window) @@ -155,58 +192,68 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { } var checkpoints []*Checkpoint - for { - if n.incrementIndexAndCheckSplit() { - break - } - // TODO(lostluck) 2020/02/22: Should we include window headers or just count the element sizes? - ws, t, pn, err := DecodeWindowedValueHeader(wc, r) - if err != nil { - if err == io.EOF { - break + err := n.process(ctx, func(bcr *byteCountReader) error { + for { + // TODO(lostluck) 2020/02/22: Should we include window headers or just count the element sizes? + ws, t, pn, err := DecodeWindowedValueHeader(wc, bcr.reader) + if err != nil { + return err } - return nil, errors.Wrap(err, "source failed") - } - - // Decode key or parallel element. - pe, err := cp.Decode(&bcr) - if err != nil { - return nil, errors.Wrap(err, "source decode failed") - } - pe.Timestamp = t - pe.Windows = ws - pe.Pane = pn - var valReStreams []ReStream - for _, cv := range cvs { - values, err := n.makeReStream(ctx, cv, &bcr, len(cvs) == 1 && n.singleIterate) + // Decode key or parallel element. + pe, err := cp.Decode(bcr) if err != nil { - return nil, err + return errors.Wrap(err, "source decode failed") } - valReStreams = append(valReStreams, values) - } + pe.Timestamp = t + pe.Windows = ws + pe.Pane = pn - if err := n.Out.ProcessElement(ctx, pe, valReStreams...); err != nil { - return nil, err - } - // Collect the actual size of the element, and reset the bytecounter reader. - n.PCol.addSize(int64(bcr.reset())) - bcr.reader = r - - // Check if there's a continuation and return residuals - // Needs to be done immeadiately after processing to not lose the element. - if c := n.getProcessContinuation(); c != nil { - cp, err := n.checkpointThis(ctx, c) - if err != nil { - // Errors during checkpointing should fail a bundle. - return nil, err + log.Debugf(ctx, "%v: processing %v,%v", n, pe.Elm, pe.Elm2) + + var valReStreams []ReStream + for _, cv := range cvs { + values, err := n.makeReStream(ctx, cv, bcr, len(cvs) == 1 && n.singleIterate) + if err != nil { + return err + } + valReStreams = append(valReStreams, values) + } + + if err := n.Out.ProcessElement(ctx, pe, valReStreams...); err != nil { + return err + } + // Collect the actual size of the element, and reset the bytecounter reader. + n.PCol.addSize(int64(bcr.reset())) + + // Check if there's a continuation and return residuals + // Needs to be done immeadiately after processing to not lose the element. + if c := n.getProcessContinuation(); c != nil { + cp, err := n.checkpointThis(ctx, c) + if err != nil { + // Errors during checkpointing should fail a bundle. + return err + } + if cp != nil { + checkpoints = append(checkpoints, cp) + } } - if cp != nil { - checkpoints = append(checkpoints, cp) + // We've finished processing an element, check if we have finished a split. + if n.incrementIndexAndCheckSplit() { + break } } - } - return checkpoints, nil + // Signal data loop exit. + log.Debugf(ctx, "%v: exiting data loop", n) + return nil + }, + func(bcr *byteCountReader, timerFamilyID string) error { + tmap, err := decodeTimer(cp, wc, bcr) + log.Errorf(ctx, "timer received: %v - %+v err: %v", timerFamilyID, tmap, err) + return nil + }) + + return checkpoints, err } func (n *DataSource) makeReStream(ctx context.Context, cv ElementDecoder, bcr *byteCountReader, onlyStream bool) (ReStream, error) { @@ -313,7 +360,7 @@ func (n *DataSource) makeReStream(ctx context.Context, cv ElementDecoder, bcr *b } } -func readStreamToBuffer(cv ElementDecoder, r io.ReadCloser, size int64, buf []FullValue) ([]FullValue, error) { +func readStreamToBuffer(cv ElementDecoder, r io.Reader, size int64, buf []FullValue) ([]FullValue, error) { for i := int64(0); i < size; i++ { value, err := cv.Decode(r) if err != nil { diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go index 2da3284f016a..14e954d26afe 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go @@ -1018,16 +1018,21 @@ func runOnRoots(ctx context.Context, t *testing.T, p *Plan, name string, mthd fu type TestDataManager struct { R io.ReadCloser + C chan Elements } -func (dm *TestDataManager) OpenRead(ctx context.Context, id StreamID) (io.ReadCloser, error) { - return dm.R, nil +func (dm *TestDataManager) OpenElementChan(ctx context.Context, id StreamID) (<-chan Elements, error) { + return dm.C, nil } func (dm *TestDataManager) OpenWrite(ctx context.Context, id StreamID) (io.WriteCloser, error) { return nil, nil } +func (dm *TestDataManager) OpenTimerWrite(ctx context.Context, id StreamID, key string) (io.WriteCloser, error) { + return nil, nil +} + // TestSideInputReader simulates state reads using channels. type TestStateReader struct { StateReader diff --git a/sdks/go/pkg/beam/core/runtime/exec/fn.go b/sdks/go/pkg/beam/core/runtime/exec/fn.go index d0fdb8e36305..c108627a52c2 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/fn.go +++ b/sdks/go/pkg/beam/core/runtime/exec/fn.go @@ -28,7 +28,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" - "github.com/apache/beam/sdks/v2/go/pkg/beam/log" ) //go:generate specialize --input=fn_arity.tmpl @@ -243,11 +242,10 @@ func (n *invoker) invokeWithOpts(ctx context.Context, pn typex.PaneInfo, ws []ty } if n.tpIdx >= 0 { - log.Debugf(ctx, "timercall %+v", opts) - tp, err := opts.ta.NewTimerProvider(ctx, opts.tm, ws, opts.opt) + tp, err := opts.ta.NewTimerProvider(ctx, opts.tm, ts, ws, opts.opt) if err != nil { return nil, err - } + } /* */ n.tp = &tp args[n.tpIdx] = n.tp } diff --git a/sdks/go/pkg/beam/core/runtime/exec/timers.go b/sdks/go/pkg/beam/core/runtime/exec/timers.go index 0ceed0d2ebd6..9e1d52d31a1f 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/timers.go +++ b/sdks/go/pkg/beam/core/runtime/exec/timers.go @@ -23,10 +23,11 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/timers" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "github.com/apache/beam/sdks/v2/go/pkg/beam/log" ) type UserTimerAdapter interface { - NewTimerProvider(ctx context.Context, manager DataManager, w []typex.Window, element *MainInput) (timerProvider, error) + NewTimerProvider(ctx context.Context, manager DataManager, inputTimestamp typex.EventTime, windows []typex.Window, element *MainInput) (timerProvider, error) } type userTimerAdapter struct { @@ -51,7 +52,7 @@ func NewUserTimerAdapter(sID StreamID, c *coder.Coder, timerCoders map[string]*c return &userTimerAdapter{SID: sID, wc: wc, kc: kc, C: c, timerIDToCoder: timerCoders} } -func (u *userTimerAdapter) NewTimerProvider(ctx context.Context, manager DataManager, w []typex.Window, element *MainInput) (timerProvider, error) { +func (u *userTimerAdapter) NewTimerProvider(ctx context.Context, manager DataManager, inputTs typex.EventTime, w []typex.Window, element *MainInput) (timerProvider, error) { if u.kc == nil { return timerProvider{}, fmt.Errorf("cannot make a state provider for an unkeyed input %v", element) } @@ -68,6 +69,7 @@ func (u *userTimerAdapter) NewTimerProvider(ctx context.Context, manager DataMan ctx: ctx, tm: manager, elementKey: elementKey, + inputTimestamp: inputTs, SID: u.SID, window: w, writersByFamily: make(map[string]io.Writer), @@ -78,11 +80,12 @@ func (u *userTimerAdapter) NewTimerProvider(ctx context.Context, manager DataMan } type timerProvider struct { - ctx context.Context - tm DataManager - SID StreamID - elementKey []byte - window []typex.Window + ctx context.Context + tm DataManager + SID StreamID + inputTimestamp typex.EventTime + elementKey []byte + window []typex.Window pn typex.PaneInfo @@ -117,6 +120,7 @@ func (p *timerProvider) Set(t timers.TimerMap) { HoldTimestamp: t.HoldTimestamp, Pane: p.pn, } + log.Debugf(p.ctx, "timer set: %+v", tm) fv := FullValue{Elm: tm} enc := MakeElementEncoder(coder.SkipW(p.codersByFamily[t.Family])) if err := enc.Encode(&fv, w); err != nil { diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go index 5a9b536b2889..f15e71401ebc 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go @@ -17,6 +17,7 @@ package harness import ( "context" + "fmt" "io" "sync" "time" @@ -47,22 +48,31 @@ func NewScopedDataManager(mgr *DataChannelManager, instID instructionID) *Scoped return &ScopedDataManager{mgr: mgr, instID: instID} } -// OpenRead opens an io.ReadCloser on the given stream. -func (s *ScopedDataManager) OpenRead(ctx context.Context, id exec.StreamID) (io.ReadCloser, error) { +// OpenWrite opens an io.WriteCloser on the given stream. +func (s *ScopedDataManager) OpenWrite(ctx context.Context, id exec.StreamID) (io.WriteCloser, error) { ch, err := s.open(ctx, id.Port) if err != nil { return nil, err } - return ch.OpenRead(ctx, id.PtransformID, s.instID), nil + return ch.OpenWrite(ctx, id.PtransformID, s.instID), nil } -// OpenWrite opens an io.WriteCloser on the given stream. -func (s *ScopedDataManager) OpenWrite(ctx context.Context, id exec.StreamID) (io.WriteCloser, error) { +// OpenElementChan returns a channel of exec.Elements on the given stream. +func (s *ScopedDataManager) OpenElementChan(ctx context.Context, id exec.StreamID) (<-chan exec.Elements, error) { ch, err := s.open(ctx, id.Port) if err != nil { return nil, err } - return ch.OpenWrite(ctx, id.PtransformID, s.instID), nil + return ch.OpenElementChan(ctx, id.PtransformID, s.instID), nil +} + +// OpenTimerWrite opens an io.WriteCloser on the given stream to write timers +func (s *ScopedDataManager) OpenTimerWrite(ctx context.Context, id exec.StreamID, family string) (io.WriteCloser, error) { + ch, err := s.open(ctx, id.Port) + if err != nil { + return nil, err + } + return ch.OpenTimerWrite(ctx, id.PtransformID, s.instID, family), nil } func (s *ScopedDataManager) open(ctx context.Context, port exec.Port) (*DataChannel, error) { @@ -134,8 +144,9 @@ func (m *DataChannelManager) closeInstruction(instID instructionID) { // clientID identifies a client of a connected channel. type clientID struct { - ptransformID string - instID instructionID + ptransformID string + instID instructionID + timerFamilyID string } // This is a reduced version of the full gRPC interface to help with testing. @@ -155,8 +166,9 @@ type DataChannel struct { id string client dataClient - writers map[instructionID]map[string]*dataWriter - readers map[instructionID]map[string]*dataReader + writers map[instructionID]map[string]*dataWriter + timerWriters map[instructionID]map[string]*timerWriter + channels map[instructionID]map[string]*elementsChan // recently terminated instructions endedInstructions map[instructionID]struct{} @@ -172,6 +184,19 @@ type DataChannel struct { mu sync.Mutex // guards mutable internal data, notably the maps and readErr. } +type elementsChan struct { + ch chan exec.Elements + complete bool +} + +func (ec *elementsChan) Close() error { + if !ec.complete { + ec.complete = true + close(ec.ch) + } + return nil +} + func newDataChannel(ctx context.Context, port exec.Port) (*DataChannel, error) { ctx, cancelFn := context.WithCancel(ctx) cc, err := dial(ctx, port.URL, "data", 15*time.Second) @@ -196,7 +221,8 @@ func makeDataChannel(ctx context.Context, id string, client dataClient, cancelFn id: id, client: client, writers: make(map[instructionID]map[string]*dataWriter), - readers: make(map[instructionID]map[string]*dataReader), + timerWriters: make(map[instructionID]map[string]*timerWriter), + channels: make(map[instructionID]map[string]*elementsChan), endedInstructions: make(map[instructionID]struct{}), cancelFn: cancelFn, } @@ -214,25 +240,56 @@ func (c *DataChannel) terminateStreamOnError(err error) { } } -// OpenRead returns an io.ReadCloser of the data elements for the given instruction and ptransform. -func (c *DataChannel) OpenRead(ctx context.Context, ptransformID string, instID instructionID) io.ReadCloser { +// OpenWrite returns an io.WriteCloser of the data elements for the given instruction and ptransform. +func (c *DataChannel) OpenWrite(ctx context.Context, ptransformID string, instID instructionID) io.WriteCloser { + return c.makeWriter(ctx, clientID{ptransformID: ptransformID, instID: instID}) +} + +// OpenElementChan returns a channel of typex.Elements for the given instruction and ptransform. +func (c *DataChannel) OpenElementChan(ctx context.Context, ptransformID string, instID instructionID) <-chan exec.Elements { c.mu.Lock() defer c.mu.Unlock() cid := clientID{ptransformID: ptransformID, instID: instID} if c.readErr != nil { - log.Errorf(ctx, "opening a reader %v on a closed channel", cid) - return &errReader{c.readErr} + panic(fmt.Errorf("opening a reader %v on a closed channel", cid)) } - return c.makeReader(ctx, cid) + return c.makeChannel(ctx, cid).ch } -// OpenWrite returns an io.WriteCloser of the data elements for the given instruction and ptransform. -func (c *DataChannel) OpenWrite(ctx context.Context, ptransformID string, instID instructionID) io.WriteCloser { - return c.makeWriter(ctx, clientID{ptransformID: ptransformID, instID: instID}) +// makeChannel creates a channel of exec.Elements. It expects to be called while c.mu is held. +func (c *DataChannel) makeChannel(ctx context.Context, id clientID) *elementsChan { + var m map[string]*elementsChan + var ok bool + if m, ok = c.channels[id.instID]; !ok { + m = make(map[string]*elementsChan) + c.channels[id.instID] = m + } + + if r, ok := m[id.ptransformID]; ok { + return r + } + + r := &elementsChan{ch: make(chan exec.Elements, 20)} + // Just in case initial data for an instruction arrives *after* an instructon has ended. + // eg. it was blocked by another reader being slow, or the other instruction failed. + // So we provide a pre-completed reader, and do not cache it, as there's no further cleanup for it. + if _, ok := c.endedInstructions[id.instID]; ok { + close(r.ch) + r.complete = true + return r + } + + m[id.ptransformID] = r + return r +} + +// OpenTimerWrite returns io.WriteCloser for the given timerFamilyID, instruction and ptransform. +func (c *DataChannel) OpenTimerWrite(ctx context.Context, ptransformID string, instID instructionID, family string) io.WriteCloser { + return c.makeTimerWriter(ctx, clientID{timerFamilyID: family, ptransformID: ptransformID, instID: instID}) } func (c *DataChannel) read(ctx context.Context) { - cache := make(map[clientID]*dataReader) + cache := make(map[clientID]*elementsChan) for { msg, err := c.client.Recv() if err != nil { @@ -243,15 +300,11 @@ func (c *DataChannel) read(ctx context.Context) { // close the r.buf channels twice, or send on a closed channel. // Any other approach is racy, and may cause one of the above // panics. - for _, m := range c.readers { - for _, r := range m { - log.Errorf(ctx, "DataChannel.read %v reader %v closing due to error on channel", c.id, r.id) - if !r.completed { - r.completed = true - r.err = err - close(r.buf) - } - delete(cache, r.id) + for instID, m := range c.channels { + for tid, r := range m { + log.Errorf(ctx, "DataChannel.read %v channel inst: %v tid %v closing due to error on channel", c.id, instID, tid) + r.Close() + delete(cache, clientID{ptransformID: tid, instID: instID}) } } c.terminateStreamOnError(err) @@ -274,31 +327,28 @@ func (c *DataChannel) read(ctx context.Context) { for _, elm := range msg.GetData() { id := clientID{ptransformID: elm.TransformId, instID: instructionID(elm.GetInstructionId())} - var r *dataReader + var r *elementsChan if local, ok := cache[id]; ok { r = local } else { c.mu.Lock() - r = c.makeReader(ctx, id) + r = c.makeChannel(ctx, id) c.mu.Unlock() cache[id] = r } + // This send is deliberately blocking, if we exceed the buffering for + // a reader. We can't buffer the entire main input, if some user code + // is slow (or gets stuck). If the local side closes, the reader + // will be marked as completed and further remote data will be ignored. + select { + case r.ch <- exec.Elements{Data: elm.GetData()}: + case <-ctx.Done(): + // Technically, we need to close all the things here... to start. + r.Close() + } if elm.GetIsLast() { - // If this reader hasn't closed yet, do so now. - if !r.completed { - // Use the last segment if any. - if len(elm.GetData()) != 0 { - // In case of local side closing, send with select. - select { - case r.buf <- elm.GetData(): - case <-r.done: - } - } - // Close buffer to signal EOF. - r.completed = true - close(r.buf) - } + r.Close() // Clean up local bookkeeping. We'll never see another message // for it again. We have to be careful not to remove the real @@ -307,12 +357,32 @@ func (c *DataChannel) read(ctx context.Context) { delete(cache, id) continue } + } + for _, tim := range msg.GetTimers() { + id := clientID{ + ptransformID: tim.TransformId, + instID: instructionID(tim.GetInstructionId()), + // timerFamilyID: tim.GetTimerFamilyId(), + } + log.Infof(ctx, "timer received for %v, %v: %v", id, tim.GetTimerFamilyId(), tim.GetTimers()) + var r *elementsChan + if local, ok := cache[id]; ok { + r = local + } else { + c.mu.Lock() + r = c.makeChannel(ctx, id) + c.mu.Unlock() + cache[id] = r + } + if tim.GetIsLast() { + // If this reader hasn't closed yet, do so now. + r.Close() - if r.completed { - // The local reader has closed but the remote is still sending data. - // Just ignore it. We keep the reader config in the cache so we don't - // treat it as a new reader. Eventually the stream will finish and go - // through normal teardown. + // Clean up local bookkeeping. We'll never see another message + // for it again. We have to be careful not to remove the real + // one, because readers may be initialized after we've seen + // the full stream. + delete(cache, id) continue } @@ -321,64 +391,15 @@ func (c *DataChannel) read(ctx context.Context) { // is slow (or gets stuck). If the local side closes, the reader // will be marked as completed and further remote data will be ignored. select { - case r.buf <- elm.GetData(): - case <-r.done: - r.completed = true - close(r.buf) + case r.ch <- exec.Elements{Timers: tim.GetTimers(), TimerFamilyID: tim.GetTimerFamilyId()}: + case <-ctx.Done(): + // Technically, we need to close all the things here... to start. + r.Close() } } } } -type errReader struct { - err error -} - -func (r *errReader) Read(_ []byte) (int, error) { - return 0, r.err -} - -func (r *errReader) Close() error { - return r.err -} - -// makeReader creates a dataReader. It expects to be called while c.mu is held. -func (c *DataChannel) makeReader(ctx context.Context, id clientID) *dataReader { - var m map[string]*dataReader - var ok bool - if m, ok = c.readers[id.instID]; !ok { - m = make(map[string]*dataReader) - c.readers[id.instID] = m - } - - if r, ok := m[id.ptransformID]; ok { - return r - } - - r := &dataReader{id: id, buf: make(chan []byte, bufElements), done: make(chan bool, 1), channel: c} - - // Just in case initial data for an instruction arrives *after* an instructon has ended. - // eg. it was blocked by another reader being slow, or the other instruction failed. - // So we provide a pre-completed reader, and do not cache it, as there's no further cleanup for it. - if _, ok := c.endedInstructions[id.instID]; ok { - r.completed = true - close(r.buf) - r.err = io.EOF // In case of any actual data readers, so they terminate without error. - return r - } - - m[id.ptransformID] = r - return r -} - -func (c *DataChannel) removeReader(id clientID) { - c.mu.Lock() - if m, ok := c.readers[id.instID]; ok { - delete(m, id.ptransformID) - } - c.mu.Unlock() -} - const endedInstructionCap = 32 // removeInstruction closes all readers and writers registered for the instruction @@ -395,21 +416,25 @@ func (c *DataChannel) removeInstruction(instID instructionID) { c.endedInstructions[instID] = struct{}{} c.rmQueue = append(c.rmQueue, instID) - rs := c.readers[instID] ws := c.writers[instID] + tws := c.timerWriters[instID] + ecs := c.channels[instID] // Prevent other users while we iterate. - delete(c.readers, instID) delete(c.writers, instID) + delete(c.timerWriters, instID) + delete(c.channels, instID) c.mu.Unlock() - // Close grabs the channel lock, so this must be outside the critical section. - for _, r := range rs { - r.Close() - } for _, w := range ws { w.Close() } + for _, tw := range tws { + tw.Close() + } + for _, ec := range ecs { + ec.Close() + } } func (c *DataChannel) makeWriter(ctx context.Context, id clientID) *dataWriter { @@ -423,7 +448,7 @@ func (c *DataChannel) makeWriter(ctx context.Context, id clientID) *dataWriter { c.writers[id.instID] = m } - if w, ok := m[id.ptransformID]; ok { + if w, ok := m[makeID(id)]; ok { return w } @@ -432,50 +457,40 @@ func (c *DataChannel) makeWriter(ctx context.Context, id clientID) *dataWriter { // runner or user directed. w := &dataWriter{ch: c, id: id} - m[id.ptransformID] = w + m[makeID(id)] = w return w } -type dataReader struct { - id clientID - buf chan []byte - done chan bool - cur []byte - channel *DataChannel - completed bool - err error -} +func (c *DataChannel) makeTimerWriter(ctx context.Context, id clientID) *timerWriter { + c.mu.Lock() + defer c.mu.Unlock() -func (r *dataReader) Close() error { - r.done <- true - r.channel.removeReader(r.id) - return nil -} + var m map[string]*timerWriter + var ok bool + if m, ok = c.timerWriters[id.instID]; !ok { + m = make(map[string]*timerWriter) + c.timerWriters[id.instID] = m + } -func (r *dataReader) Read(buf []byte) (int, error) { - if r.cur == nil { - b, ok := <-r.buf - if !ok { - if r.err == nil { - return 0, io.EOF - } - return 0, r.err - } - r.cur = b + if w, ok := m[makeID(id)]; ok { + return w } - // We don't need to check for a 0 length copy from r.cur here, since that's - // checked before buffers are handed to the r.buf channel. - n := copy(buf, r.cur) + // We don't check for ended instructions for writers, as writers + // can only be created if an instruction is in scope, and aren't + // runner or user directed. - switch { - case len(r.cur) == n: - r.cur = nil - default: - r.cur = r.cur[n:] - } + w := &timerWriter{ch: c, id: id} + m[makeID(id)] = w + return w +} - return n, nil +func makeID(id clientID) string { + newID := id.ptransformID + if id.timerFamilyID != "" { + newID += ":" + id.timerFamilyID + } + return newID } type dataWriter struct { @@ -574,3 +589,72 @@ func (w *dataWriter) Write(p []byte) (n int, err error) { w.buf = append(w.buf, p...) return len(p), nil } + +type timerWriter struct { + id clientID + ch *DataChannel +} + +// send requires the ch.mu lock to be held. +func (w *timerWriter) send(msg *fnpb.Elements) error { + recordStreamSend(msg) + if err := w.ch.client.Send(msg); err != nil { + if err == io.EOF { + log.Warnf(context.TODO(), "dataWriter[%v;%v] EOF on send; fetching real error", w.id, w.ch.id) + err = nil + for err == nil { + // Per GRPC stream documentation, if there's an EOF, we must call Recv + // until a non-nil error is returned, to ensure resources are cleaned up. + // https://pkg.go.dev/google.golang.org/grpc#ClientConn.NewStream + _, err = w.ch.client.Recv() + } + } + log.Warnf(context.TODO(), "dataWriter[%v;%v] error on send: %v", w.id, w.ch.id, err) + w.ch.terminateStreamOnError(err) + return err + } + return nil +} + +func (w *timerWriter) Close() error { + w.ch.mu.Lock() + defer w.ch.mu.Unlock() + delete(w.ch.timerWriters[w.id.instID], makeID(w.id)) + var msg *fnpb.Elements + msg = &fnpb.Elements{ + Timers: []*fnpb.Elements_Timers{ + { + InstructionId: string(w.id.instID), + TransformId: w.id.ptransformID, + TimerFamilyId: w.id.timerFamilyID, + IsLast: true, + }, + }, + } + return w.send(msg) +} + +func (w *timerWriter) writeTimers(p []byte) error { + w.ch.mu.Lock() + defer w.ch.mu.Unlock() + + msg := &fnpb.Elements{ + Timers: []*fnpb.Elements_Timers{ + { + InstructionId: string(w.id.instID), + TransformId: w.id.ptransformID, + TimerFamilyId: w.id.timerFamilyID, + Timers: p, + }, + }, + } + return w.send(msg) +} + +func (w *timerWriter) Write(p []byte) (n int, err error) { + // write timers directly without buffering. + if err := w.writeTimers(p); err != nil { + return 0, err + } + return len(p), nil +} diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go index f69d9abde49b..3f77d69f1737 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go @@ -101,147 +101,6 @@ func (f *fakeDataClient) Send(*fnpb.Elements) error { return nil } -func TestDataChannelTerminate_dataReader(t *testing.T) { - // The logging of channels closed is quite noisy for this test - log.SetOutput(io.Discard) - - expectedError := fmt.Errorf("EXPECTED ERROR") - - tests := []struct { - name string - expectedError error - caseFn func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) - }{ - { - name: "onClose", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // We don't read up all the buffered data, but immediately close the reader. - // Previously, since nothing was consuming the incoming gRPC data, the whole - // data channel would get stuck, and the client.Recv() call was eventually - // no longer called. - r.Close() - - // If done is signaled, that means client.Recv() has been called to flush the - // channel, meaning consumer code isn't stuck. - <-client.done - }, - }, { - name: "onSentinel", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // fakeDataClient eventually returns a sentinel element. - }, - }, { - name: "onIsLast_withData", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // Set the last call with data to use is_last. - client.isLastCall = 2 - }, - }, { - name: "onIsLast_withoutData", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // Set the call without data to use is_last. - client.isLastCall = 3 - }, - }, { - name: "onRecvError", - expectedError: expectedError, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // The SDK starts reading in a goroutine immeadiately after open. - // Set the 2nd Recv call to have an error. - client.err = expectedError - }, - }, { - name: "onInstructionEnd", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - c.removeInstruction("inst_ref") - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - done := make(chan bool, 1) - client := &fakeDataClient{t: t, done: done} - ctx, cancelFn := context.WithCancel(context.Background()) - c := makeDataChannel(ctx, "id", client, cancelFn) - - r := c.OpenRead(ctx, "ptr", "inst_ref") - - n, err := r.Read(make([]byte, 4)) - if err != nil { - t.Errorf("Unexpected error from read: %v, read %d bytes.", err, n) - } - test.caseFn(t, r, client, c) - // Drain the reader. - i := 1 // For the earlier Read. - for err == nil { - read := make([]byte, 4) - _, err = r.Read(read) - i++ - } - - if got, want := err, test.expectedError; got != want { - t.Errorf("Unexpected error from read %d: got %v, want %v", i, got, want) - } - // Verify that new readers return the same error on their reads after client.Recv is done. - if n, err := c.OpenRead(ctx, "ptr", "inst_ref").Read(make([]byte, 4)); err != test.expectedError { - t.Errorf("Unexpected error from read: got %v, want, %v read %d bytes.", err, test.expectedError, n) - } - - select { - case <-ctx.Done(): // Assert that the context must have been cancelled on read failures. - return - case <-time.After(time.Second * 5): - t.Fatal("context wasn't cancelled") - } - }) - } -} - -func TestDataChannelRemoveInstruction_dataAfterClose(t *testing.T) { - done := make(chan bool, 1) - client := &fakeDataClient{t: t, done: done} - client.blocked.Lock() - - ctx, cancelFn := context.WithCancel(context.Background()) - c := makeDataChannel(ctx, "id", client, cancelFn) - c.removeInstruction("inst_ref") - - client.blocked.Unlock() - - r := c.OpenRead(ctx, "ptr", "inst_ref") - - dr := r.(*dataReader) - if !dr.completed || dr.err != io.EOF { - t.Errorf("Expected a closed reader, but was still open: completed: %v, err: %v", dr.completed, dr.err) - } - - n, err := r.Read(make([]byte, 4)) - if err != io.EOF { - t.Errorf("Unexpected error from read: %v, read %d bytes.", err, n) - } -} - -func TestDataChannelRemoveInstruction_limitInstructionCap(t *testing.T) { - done := make(chan bool, 1) - client := &fakeDataClient{t: t, done: done} - ctx, cancelFn := context.WithCancel(context.Background()) - c := makeDataChannel(ctx, "id", client, cancelFn) - - for i := 0; i < endedInstructionCap+10; i++ { - instID := instructionID(fmt.Sprintf("inst_ref%d", i)) - c.OpenRead(ctx, "ptr", instID) - c.removeInstruction(instID) - } - if got, want := len(c.endedInstructions), endedInstructionCap; got != want { - t.Errorf("unexpected len(endedInstructions) got %v, want %v,", got, want) - } -} - func TestDataChannelTerminate_Writes(t *testing.T) { // The logging of channels closed is quite noisy for this test log.SetOutput(io.Discard) diff --git a/sdks/go/pkg/beam/core/timers/timers.go b/sdks/go/pkg/beam/core/timers/timers.go index 130564790ca6..afb5ddd98b81 100644 --- a/sdks/go/pkg/beam/core/timers/timers.go +++ b/sdks/go/pkg/beam/core/timers/timers.go @@ -102,7 +102,7 @@ func (t ProcessingTime) Set(p Provider, firingTimestamp time.Time) { func (t ProcessingTime) SetWithOpts(p Provider, firingTimestamp time.Time, opts Opts) { fire := mtime.FromTime(firingTimestamp) - // Hold timestamp must match fireing timestamp if not otherwise set. + // Hold timestamp must match input element timestamp if not otherwise set. tm := TimerMap{Family: t.Family, Tag: opts.Tag, FireTimestamp: fire, HoldTimestamp: fire} if !opts.Hold.IsZero() { tm.HoldTimestamp = mtime.FromTime(opts.Hold) From c4062732f3c8913f78439dc02fd6c579f83dbede Mon Sep 17 00:00:00 2001 From: lostluck <13907733+lostluck@users.noreply.github.com> Date: Wed, 1 Mar 2023 14:59:13 -0800 Subject: [PATCH 3/5] [timers] adjust debugging --- sdks/go/examples/streaming_wordcap/wordcap.go | 7 +++---- .../pkg/beam/core/runtime/exec/datasource.go | 1 - .../pkg/beam/core/runtime/exec/translate.go | 2 ++ .../pkg/beam/core/runtime/harness/datamgr.go | 8 ++++++- .../pkg/beam/core/runtime/harness/statemgr.go | 21 +++++++++---------- 5 files changed, 22 insertions(+), 17 deletions(-) diff --git a/sdks/go/examples/streaming_wordcap/wordcap.go b/sdks/go/examples/streaming_wordcap/wordcap.go index 441d7f6d3244..78a4b73764af 100644 --- a/sdks/go/examples/streaming_wordcap/wordcap.go +++ b/sdks/go/examples/streaming_wordcap/wordcap.go @@ -79,7 +79,7 @@ func NewStateful() *Stateful { } func (s *Stateful) ProcessElement(ctx context.Context, ts beam.EventTime, sp state.Provider, tp timers.Provider, key, word string, emit func(string, string)) error { - log.Infof(ctx, "stateful dofn invoked key: %v word: %v", key, word) + // log.Infof(ctx, "stateful dofn invoked key: %v word: %v", key, word) s.ElementBag.Add(sp, word) s.MinTime.Add(sp, int64(ts)) @@ -98,7 +98,7 @@ func (s *Stateful) ProcessElement(ctx context.Context, ts beam.EventTime, sp sta s.OutputState.SetWithOpts(tp, mtime.Time(toFire).ToTime(), timers.Opts{Hold: mtime.Time(minTime).ToTime()}) s.TimerTime.Write(sp, toFire) - log.Infof(ctx, "stateful dofn key: %v word: %v, timer: %v, minTime: %v", key, word, toFire, minTime) + //log.Infof(ctx, "stateful dofn key: %v word: %v, timer: %v, minTime: %v", key, word, toFire, minTime) // // Get the Value stored in our state // val, ok, err := s.Val.Read(p) @@ -233,7 +233,7 @@ func main() { // }, col) imp := beam.Impulse(s) - elms := 100 + elms := 3 out := beam.ParDo(s, &eventtimeSDFStream{ Sleep: time.Second, RestSize: int64(elms), @@ -249,7 +249,6 @@ func main() { log.Infof(ctx, "adding key ts: %v now: %v word: %v", ts.ToTime(), time.Now(), s) return "test", s }, str) - debug.Printf(s, "pre stateful: %v", keyed) timed := beam.ParDo(s, NewStateful(), keyed) debug.Printf(s, "post stateful: %v", timed) diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource.go b/sdks/go/pkg/beam/core/runtime/exec/datasource.go index 12a5faeaf30d..8741a99ea110 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource.go @@ -125,7 +125,6 @@ func (n *DataSource) process(ctx context.Context, data func(bcr *byteCountReader } if len(e.Data) > 0 { r.Reset(e.Data) - log.Debugf(ctx, "%v: received %v", n, e.Data) err = data(&bcr) } if len(e.Timers) > 0 { diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go b/sdks/go/pkg/beam/core/runtime/exec/translate.go index 0403f0ab0abb..10291dce7636 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/translate.go +++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go @@ -35,6 +35,7 @@ import ( fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/encoding/prototext" ) // TODO(lostluck): 2018/05/28 Extract these from the canonical enums in beam_runner_api.proto @@ -53,6 +54,7 @@ const ( // UnmarshalPlan converts a model bundle descriptor into an execution Plan. func UnmarshalPlan(desc *fnpb.ProcessBundleDescriptor) (*Plan, error) { + log.Debug(context.TODO(), prototext.Format(desc)) b, err := newBuilder(desc) if err != nil { return nil, err diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go index f15e71401ebc..f9a52cc60443 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go @@ -269,11 +269,14 @@ func (c *DataChannel) makeChannel(ctx context.Context, id clientID) *elementsCha return r } + log.Infof(ctx, "make data read channel %v", id) + r := &elementsChan{ch: make(chan exec.Elements, 20)} // Just in case initial data for an instruction arrives *after* an instructon has ended. // eg. it was blocked by another reader being slow, or the other instruction failed. // So we provide a pre-completed reader, and do not cache it, as there's no further cleanup for it. if _, ok := c.endedInstructions[id.instID]; ok { + log.Infof(ctx, "data read channel %v already ended", id) close(r.ch) r.complete = true return r @@ -364,7 +367,7 @@ func (c *DataChannel) read(ctx context.Context) { instID: instructionID(tim.GetInstructionId()), // timerFamilyID: tim.GetTimerFamilyId(), } - log.Infof(ctx, "timer received for %v, %v: %v", id, tim.GetTimerFamilyId(), tim.GetTimers()) + var r *elementsChan if local, ok := cache[id]; ok { r = local @@ -385,6 +388,7 @@ func (c *DataChannel) read(ctx context.Context) { delete(cache, id) continue } + log.Infof(ctx, "timer received for %v, %v: %v", id, tim.GetTimerFamilyId(), tim.GetTimers()) // This send is deliberately blocking, if we exceed the buffering for // a reader. We can't buffer the entire main input, if some user code @@ -638,6 +642,8 @@ func (w *timerWriter) writeTimers(p []byte) error { w.ch.mu.Lock() defer w.ch.mu.Unlock() + log.Infof(context.TODO(), "timer write for %+v: %v", w.id, p) + msg := &fnpb.Elements{ Timers: []*fnpb.Elements_Timers{ { diff --git a/sdks/go/pkg/beam/core/runtime/harness/statemgr.go b/sdks/go/pkg/beam/core/runtime/harness/statemgr.go index f10f0d92e84e..feeb51a1951f 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/statemgr.go +++ b/sdks/go/pkg/beam/core/runtime/harness/statemgr.go @@ -463,18 +463,17 @@ func (r *stateKeyReader) Close() error { return nil } -func (r *stateKeyWriter) Write(buf []byte) (int, error) { - r.mu.Lock() - localChannel := r.ch - r.mu.Unlock() - +func (w *stateKeyWriter) Write(buf []byte) (int, error) { + w.mu.Lock() + localChannel := w.ch + w.mu.Unlock() var req *fnpb.StateRequest - switch r.writeType { + switch w.writeType { case writeTypeAppend: req = &fnpb.StateRequest{ // Id: set by StateChannel - InstructionId: string(r.instID), - StateKey: r.key, + InstructionId: string(w.instID), + StateKey: w.key, Request: &fnpb.StateRequest_Append{ Append: &fnpb.StateAppendRequest{ Data: buf, @@ -484,14 +483,14 @@ func (r *stateKeyWriter) Write(buf []byte) (int, error) { case writeTypeClear: req = &fnpb.StateRequest{ // ID: set by StateChannel - InstructionId: string(r.instID), - StateKey: r.key, + InstructionId: string(w.instID), + StateKey: w.key, Request: &fnpb.StateRequest_Clear{ Clear: &fnpb.StateClearRequest{}, }, } default: - return 0, errors.Errorf("Unknown write type %v", r.writeType) + return 0, errors.Errorf("Unknown write type %v", w.writeType) } _, err := localChannel.Send(req) From aeaa0f3aadda542d857a91d8b2b6c849681e325a Mon Sep 17 00:00:00 2001 From: lostluck <13907733+lostluck@users.noreply.github.com> Date: Tue, 21 Mar 2023 16:07:08 -0700 Subject: [PATCH 4/5] refactor to handle timers --- sdks/go/examples/streaming_wordcap/wordcap.go | 11 +- sdks/go/pkg/beam/core/runtime/exec/data.go | 6 +- .../pkg/beam/core/runtime/exec/datasource.go | 24 +- sdks/go/pkg/beam/core/runtime/exec/pardo.go | 5 + sdks/go/pkg/beam/core/runtime/exec/plan.go | 10 + .../pkg/beam/core/runtime/exec/translate.go | 1 + .../pkg/beam/core/runtime/harness/datamgr.go | 293 ++++++++++-------- 7 files changed, 203 insertions(+), 147 deletions(-) diff --git a/sdks/go/examples/streaming_wordcap/wordcap.go b/sdks/go/examples/streaming_wordcap/wordcap.go index 78a4b73764af..704997412524 100644 --- a/sdks/go/examples/streaming_wordcap/wordcap.go +++ b/sdks/go/examples/streaming_wordcap/wordcap.go @@ -79,7 +79,7 @@ func NewStateful() *Stateful { } func (s *Stateful) ProcessElement(ctx context.Context, ts beam.EventTime, sp state.Provider, tp timers.Provider, key, word string, emit func(string, string)) error { - // log.Infof(ctx, "stateful dofn invoked key: %v word: %v", key, word) + log.Infof(ctx, "stateful dofn invoked key: %q word: %q", key, word) s.ElementBag.Add(sp, word) s.MinTime.Add(sp, int64(ts)) @@ -96,9 +96,12 @@ func (s *Stateful) ProcessElement(ctx context.Context, ts beam.EventTime, sp sta return err } - s.OutputState.SetWithOpts(tp, mtime.Time(toFire).ToTime(), timers.Opts{Hold: mtime.Time(minTime).ToTime()}) + s.OutputState.SetWithOpts(tp, mtime.Time(toFire).ToTime(), timers.Opts{ + Hold: mtime.Time(minTime).ToTime(), + Tag: word, + }) s.TimerTime.Write(sp, toFire) - //log.Infof(ctx, "stateful dofn key: %v word: %v, timer: %v, minTime: %v", key, word, toFire, minTime) + log.Infof(ctx, "stateful dofn key: %v word: %v, timer: %v, minTime: %v", key, word, toFire, minTime) // // Get the Value stored in our state // val, ok, err := s.Val.Read(p) @@ -157,7 +160,7 @@ func (fn *eventtimeSDFStream) ProcessElement(ctx context.Context, _ *CWE, rt *sd r := rt.GetRestriction().(offsetrange.Restriction) i := r.Start if r.Size() < 1 { - log.Debugf(ctx, "size 0 restriction, stoping to process sentinel", slog.Any("value", v)) + log.Debugf(ctx, "size 0 restriction, stoping to process sentinel %v", slog.Any("value", v)) return sdf.StopProcessing() } slog.Debug("emitting element to restriction", slog.Any("value", v), slog.Group("restriction", diff --git a/sdks/go/pkg/beam/core/runtime/exec/data.go b/sdks/go/pkg/beam/core/runtime/exec/data.go index 9380bb8902bd..71954819a748 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/data.go +++ b/sdks/go/pkg/beam/core/runtime/exec/data.go @@ -58,7 +58,7 @@ type SideCache interface { // opened by one consumer only. type DataManager interface { // OpenElementChan opens a channel for data and timers. - OpenElementChan(ctx context.Context, id StreamID) (<-chan Elements, error) + OpenElementChan(ctx context.Context, id StreamID, expectedTimerTransforms []string) (<-chan Elements, error) // OpenWrite opens a closable byte stream for data writing. OpenWrite(ctx context.Context, id StreamID) (io.WriteCloser, error) // OpenTimerWrite opens a byte stream for writing timers @@ -97,6 +97,6 @@ type StateReader interface { // If TimerFamilyID is populated, it's a timer, otherwise it's // data elements. type Elements struct { - Data, Timers []byte - TimerFamilyID string + Data, Timers []byte + TimerFamilyID, PtransformID string } diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource.go b/sdks/go/pkg/beam/core/runtime/exec/datasource.go index 8741a99ea110..5fa1a1f94374 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource.go @@ -30,6 +30,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/ioutilx" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" "github.com/apache/beam/sdks/v2/go/pkg/beam/log" + "golang.org/x/exp/maps" ) // DataSource is a Root execution unit. @@ -40,6 +41,8 @@ type DataSource struct { Coder *coder.Coder Out Node PCol PCollection // Handles size metrics. Value instead of pointer so it's initialized by default in tests. + // OnTimerTransforms maps PtransformIDs to their execution nodes that handle OnTimer callbacks. + OnTimerTransforms map[string]*ParDo source DataManager state StateReader @@ -104,8 +107,10 @@ func (n *DataSource) StartBundle(ctx context.Context, id string, data DataContex } // process handles converting elements from the data source to timers. -func (n *DataSource) process(ctx context.Context, data func(bcr *byteCountReader) error, timer func(bcr *byteCountReader, timerFamilyID string) error) error { - elms, err := n.source.OpenElementChan(ctx, n.SID) +func (n *DataSource) process(ctx context.Context, data func(bcr *byteCountReader, ptransformID string) error, timer func(bcr *byteCountReader, ptransformID, timerFamilyID string) error) error { + // TODO(riteshghorse): Pass in the PTransformIDs expecting OnTimer calls. + // The SID contains this instruction's expected data processing transform (this one). + elms, err := n.source.OpenElementChan(ctx, n.SID, maps.Keys(n.OnTimerTransforms)) if err != nil { return err } @@ -121,15 +126,18 @@ func (n *DataSource) process(ctx context.Context, data func(bcr *byteCountReader case e, ok := <-elms: // Channel closed, so time to exit if !ok { + log.Infof(ctx, "%v: Data Channel closed", n) return nil } if len(e.Data) > 0 { r.Reset(e.Data) - err = data(&bcr) + err = data(&bcr, e.PtransformID) } if len(e.Timers) > 0 { + // TODO remove this debug log. + log.Infof(ctx, "timer received for %v; %v : %v", e.PtransformID, e.TimerFamilyID, e.Timers) r.Reset(e.Timers) - err = timer(&bcr, e.TimerFamilyID) + err = timer(&bcr, e.PtransformID, e.TimerFamilyID) } case <-ctx.Done(): return nil @@ -191,7 +199,7 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { } var checkpoints []*Checkpoint - err := n.process(ctx, func(bcr *byteCountReader) error { + err := n.process(ctx, func(bcr *byteCountReader, ptransformID string) error { for { // TODO(lostluck) 2020/02/22: Should we include window headers or just count the element sizes? ws, t, pn, err := DecodeWindowedValueHeader(wc, bcr.reader) @@ -208,7 +216,7 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { pe.Windows = ws pe.Pane = pn - log.Debugf(ctx, "%v: processing %v,%v", n, pe.Elm, pe.Elm2) + log.Infof(ctx, "%v[%v]: processing %+v,%v", n, ptransformID, pe.Elm, pe.Elm2) var valReStreams []ReStream for _, cv := range cvs { @@ -246,9 +254,9 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { log.Debugf(ctx, "%v: exiting data loop", n) return nil }, - func(bcr *byteCountReader, timerFamilyID string) error { + func(bcr *byteCountReader, ptransformID, timerFamilyID string) error { tmap, err := decodeTimer(cp, wc, bcr) - log.Errorf(ctx, "timer received: %v - %+v err: %v", timerFamilyID, tmap, err) + log.Infof(ctx, "timer received for: %v and %v - %+v err: %v", ptransformID, timerFamilyID, tmap, err) return nil }) diff --git a/sdks/go/pkg/beam/core/runtime/exec/pardo.go b/sdks/go/pkg/beam/core/runtime/exec/pardo.go index be5c71f2c75a..3233cdfd30ef 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/pardo.go +++ b/sdks/go/pkg/beam/core/runtime/exec/pardo.go @@ -76,6 +76,11 @@ func (n *ParDo) ID() UnitID { return n.UID } +// HasOnTimer returns if this ParDo wraps a DoFn that has an OnTimer method. +func (n *ParDo) HasOnTimer() bool { + return n.Timer != nil +} + // Up initializes this ParDo and does one-time DoFn setup. func (n *ParDo) Up(ctx context.Context) error { if n.status != Initializing { diff --git a/sdks/go/pkg/beam/core/runtime/exec/plan.go b/sdks/go/pkg/beam/core/runtime/exec/plan.go index 7958cf382383..fffe8619a573 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/plan.go +++ b/sdks/go/pkg/beam/core/runtime/exec/plan.go @@ -53,6 +53,7 @@ func NewPlan(id string, units []Unit) (*Plan, error) { callbacks: []bundleFinalizationCallback{}, lastValidCallback: time.Now(), } + var onTimers map[string]*ParDo for _, u := range units { if u == nil { @@ -67,6 +68,12 @@ func NewPlan(id string, units []Unit) (*Plan, error) { if p, ok := u.(*PCollection); ok { pcols = append(pcols, p) } + if pd, ok := u.(*ParDo); ok && pd.HasOnTimer() { + if onTimers == nil { + onTimers = map[string]*ParDo{} + } + onTimers[pd.PID] = pd + } if p, ok := u.(needsBundleFinalization); ok { p.AttachFinalizer(&bf) } @@ -74,6 +81,9 @@ func NewPlan(id string, units []Unit) (*Plan, error) { if len(roots) == 0 { return nil, errors.Errorf("no root units") } + if len(onTimers) > 0 { + source.OnTimerTransforms = onTimers + } return &Plan{ id: id, diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go b/sdks/go/pkg/beam/core/runtime/exec/translate.go index 10291dce7636..d510a20d24ec 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/translate.go +++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go @@ -608,6 +608,7 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) { if err != nil { return nil, err } + // TODO: ensure this only gets set once or they're always the same. n.Timer = NewUserTimerAdapter(sID, coder.NewW(ec, wc), timerIDToCoder) } } diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go index f9a52cc60443..ab938750c301 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go @@ -58,12 +58,12 @@ func (s *ScopedDataManager) OpenWrite(ctx context.Context, id exec.StreamID) (io } // OpenElementChan returns a channel of exec.Elements on the given stream. -func (s *ScopedDataManager) OpenElementChan(ctx context.Context, id exec.StreamID) (<-chan exec.Elements, error) { +func (s *ScopedDataManager) OpenElementChan(ctx context.Context, id exec.StreamID, expectedTimerTransforms []string) (<-chan exec.Elements, error) { ch, err := s.open(ctx, id.Port) if err != nil { return nil, err } - return ch.OpenElementChan(ctx, id.PtransformID, s.instID), nil + return ch.OpenElementChan(ctx, id.PtransformID, s.instID, expectedTimerTransforms), nil } // OpenTimerWrite opens an io.WriteCloser on the given stream to write timers @@ -144,9 +144,8 @@ func (m *DataChannelManager) closeInstruction(instID instructionID) { // clientID identifies a client of a connected channel. type clientID struct { - ptransformID string - instID instructionID - timerFamilyID string + instID instructionID + ptransformID string } // This is a reduced version of the full gRPC interface to help with testing. @@ -166,9 +165,9 @@ type DataChannel struct { id string client dataClient - writers map[instructionID]map[string]*dataWriter - timerWriters map[instructionID]map[string]*timerWriter - channels map[instructionID]map[string]*elementsChan + writers map[instructionID]map[string]*dataWriter // PTransformID + timerWriters map[instructionID]map[timerKey]*timerWriter + channels map[instructionID]*elementsChan // recently terminated instructions endedInstructions map[instructionID]struct{} @@ -184,14 +183,38 @@ type DataChannel struct { mu sync.Mutex // guards mutable internal data, notably the maps and readErr. } +type timerKey struct { + ptransformID, family string +} + type elementsChan struct { - ch chan exec.Elements - complete bool + ch chan exec.Elements + readingTransforms map[string]bool + complete int // count of "done" streams +} + +// Closed indicates if all expected streams are complete +func (ec *elementsChan) Closed() bool { + return ec.complete == len(ec.readingTransforms) +} + +// Done signals that this PTransform has no more data coming to it. +// Timer consuming PTransforms and DataSources use distinct transformIDs, +// So the channel should only close when all data is completed. +func (ec *elementsChan) Done(ptransformID string) { + if !ec.Closed() { + if ec.readingTransforms[ptransformID] { + ec.readingTransforms[ptransformID] = false + ec.complete++ + } + if ec.Closed() { + close(ec.ch) + } + } } func (ec *elementsChan) Close() error { - if !ec.complete { - ec.complete = true + if !ec.Closed() { close(ec.ch) } return nil @@ -221,8 +244,8 @@ func makeDataChannel(ctx context.Context, id string, client dataClient, cancelFn id: id, client: client, writers: make(map[instructionID]map[string]*dataWriter), - timerWriters: make(map[instructionID]map[string]*timerWriter), - channels: make(map[instructionID]map[string]*elementsChan), + timerWriters: make(map[instructionID]map[timerKey]*timerWriter), + channels: make(map[instructionID]*elementsChan), endedInstructions: make(map[instructionID]struct{}), cancelFn: cancelFn, } @@ -246,53 +269,68 @@ func (c *DataChannel) OpenWrite(ctx context.Context, ptransformID string, instID } // OpenElementChan returns a channel of typex.Elements for the given instruction and ptransform. -func (c *DataChannel) OpenElementChan(ctx context.Context, ptransformID string, instID instructionID) <-chan exec.Elements { +func (c *DataChannel) OpenElementChan(ctx context.Context, ptransformID string, instID instructionID, expectedTimerTransforms []string) <-chan exec.Elements { c.mu.Lock() defer c.mu.Unlock() cid := clientID{ptransformID: ptransformID, instID: instID} if c.readErr != nil { panic(fmt.Errorf("opening a reader %v on a closed channel", cid)) } - return c.makeChannel(ctx, cid).ch + log.Infof(ctx, "OpenElementChan %v %v", cid, expectedTimerTransforms) + return c.makeChannel(ctx, cid, expectedTimerTransforms...).ch } // makeChannel creates a channel of exec.Elements. It expects to be called while c.mu is held. -func (c *DataChannel) makeChannel(ctx context.Context, id clientID) *elementsChan { - var m map[string]*elementsChan - var ok bool - if m, ok = c.channels[id.instID]; !ok { - m = make(map[string]*elementsChan) - c.channels[id.instID] = m - } - - if r, ok := m[id.ptransformID]; ok { +func (c *DataChannel) makeChannel(ctx context.Context, id clientID, additionalTransforms ...string) *elementsChan { + if r, ok := c.channels[id.instID]; ok { + if !r.Closed() { + // Ensure new readers are accounted for, and current data stream state is respected. + // That is, we only add an entry if it doesn't exist already. + if _, ok := r.readingTransforms[id.ptransformID]; !ok { + r.readingTransforms[id.ptransformID] = true + } + for _, pid := range additionalTransforms { + if _, ok := r.readingTransforms[pid]; !ok { + r.readingTransforms[pid] = true + } + } + } + log.Infof(ctx, "looked up data read channel %v %v", id, additionalTransforms) return r } - log.Infof(ctx, "make data read channel %v", id) + log.Infof(ctx, "make data read channel %v %v", id, additionalTransforms) - r := &elementsChan{ch: make(chan exec.Elements, 20)} + r := &elementsChan{ + ch: make(chan exec.Elements, 20), + readingTransforms: map[string]bool{id.ptransformID: true}, + } // Just in case initial data for an instruction arrives *after* an instructon has ended. // eg. it was blocked by another reader being slow, or the other instruction failed. // So we provide a pre-completed reader, and do not cache it, as there's no further cleanup for it. if _, ok := c.endedInstructions[id.instID]; ok { log.Infof(ctx, "data read channel %v already ended", id) - close(r.ch) - r.complete = true + r.Done(id.ptransformID) return r } - m[id.ptransformID] = r + // Set after checking for endedInstructions to set them only once. + for _, pid := range additionalTransforms { + r.readingTransforms[pid] = true + } + + c.channels[id.instID] = r return r } // OpenTimerWrite returns io.WriteCloser for the given timerFamilyID, instruction and ptransform. func (c *DataChannel) OpenTimerWrite(ctx context.Context, ptransformID string, instID instructionID, family string) io.WriteCloser { - return c.makeTimerWriter(ctx, clientID{timerFamilyID: family, ptransformID: ptransformID, instID: instID}) + return c.makeTimerWriter(ctx, clientID{ptransformID: ptransformID, instID: instID}, family) } func (c *DataChannel) read(ctx context.Context) { - cache := make(map[clientID]*elementsChan) + cache := make(map[instructionID]*elementsChan) + seenLast := make([]clientID, 5) for { msg, err := c.client.Recv() if err != nil { @@ -303,12 +341,10 @@ func (c *DataChannel) read(ctx context.Context) { // close the r.buf channels twice, or send on a closed channel. // Any other approach is racy, and may cause one of the above // panics. - for instID, m := range c.channels { - for tid, r := range m { - log.Errorf(ctx, "DataChannel.read %v channel inst: %v tid %v closing due to error on channel", c.id, instID, tid) - r.Close() - delete(cache, clientID{ptransformID: tid, instID: instID}) - } + for instID, r := range c.channels { + log.Errorf(ctx, "DataChannel.read %v channel inst: %v closing due to error on channel", c.id, instID) + r.Close() + delete(cache, instID) } c.terminateStreamOnError(err) c.mu.Unlock() @@ -326,80 +362,80 @@ func (c *DataChannel) read(ctx context.Context) { // Each message may contain segments for multiple streams, so we // must treat each segment in isolation. We maintain a local cache // to reduce lock contention. - - for _, elm := range msg.GetData() { - id := clientID{ptransformID: elm.TransformId, instID: instructionID(elm.GetInstructionId())} - - var r *elementsChan - if local, ok := cache[id]; ok { - r = local - } else { - c.mu.Lock() - r = c.makeChannel(ctx, id) - c.mu.Unlock() - cache[id] = r + iterateElements(ctx, "timer", c, cache, &seenLast, msg.GetTimers(), + func(elm *fnpb.Elements_Timers) exec.Elements { + return exec.Elements{Timers: elm.GetTimers(), PtransformID: elm.GetTransformId(), TimerFamilyID: elm.GetTimerFamilyId()} + }) + + iterateElements(ctx, "data", c, cache, &seenLast, msg.GetData(), + func(elm *fnpb.Elements_Data) exec.Elements { + return exec.Elements{Data: elm.GetData(), PtransformID: elm.GetTransformId()} + }) + + // Mark all readers that we've seen the last of as done, after queuing their elements. + if len(seenLast) > 0 { + c.mu.Lock() + for _, id := range seenLast { + r, ok := cache[id.instID] + if !ok { + continue // we've already closed this cached reader, skip + } + r.Done(id.ptransformID) + if r.Closed() { + // Clean up local bookkeeping. We'll never see another message + // for it again. We have to be careful not to remove the real + // one, because readers may be initialized after we've seen + // the full stream. + delete(cache, id.instID) + } } + seenLast = seenLast[:0] // reset for re-use + c.mu.Unlock() + } + } +} - // This send is deliberately blocking, if we exceed the buffering for - // a reader. We can't buffer the entire main input, if some user code - // is slow (or gets stuck). If the local side closes, the reader - // will be marked as completed and further remote data will be ignored. - select { - case r.ch <- exec.Elements{Data: elm.GetData()}: - case <-ctx.Done(): - // Technically, we need to close all the things here... to start. - r.Close() - } - if elm.GetIsLast() { - r.Close() +// dataEle is a light interface against the proto Data and Timer Elements. +type dataEle interface { + GetTransformId() string + GetInstructionId() string + GetIsLast() bool +} - // Clean up local bookkeeping. We'll never see another message - // for it again. We have to be careful not to remove the real - // one, because readers may be initialized after we've seen - // the full stream. - delete(cache, id) - continue - } +func iterateElements[E dataEle](ctx context.Context, kind string, c *DataChannel, cache map[instructionID]*elementsChan, seenLast *[]clientID, elms []E, wrap func(E) exec.Elements) { + for _, elm := range elms { + id := clientID{ptransformID: elm.GetTransformId(), instID: instructionID(elm.GetInstructionId())} + + var r *elementsChan + if local, ok := cache[id.instID]; ok { + r = local + } else { + c.mu.Lock() + r = c.makeChannel(ctx, id) + c.mu.Unlock() + cache[id.instID] = r } - for _, tim := range msg.GetTimers() { - id := clientID{ - ptransformID: tim.TransformId, - instID: instructionID(tim.GetInstructionId()), - // timerFamilyID: tim.GetTimerFamilyId(), - } - var r *elementsChan - if local, ok := cache[id]; ok { - r = local - } else { - c.mu.Lock() - r = c.makeChannel(ctx, id) - c.mu.Unlock() - cache[id] = r - } - if tim.GetIsLast() { - // If this reader hasn't closed yet, do so now. - r.Close() + if r.Closed() { + log.Infof(ctx, "%s for closed channel %v", kind, id) + continue + } - // Clean up local bookkeeping. We'll never see another message - // for it again. We have to be careful not to remove the real - // one, because readers may be initialized after we've seen - // the full stream. - delete(cache, id) - continue - } - log.Infof(ctx, "timer received for %v, %v: %v", id, tim.GetTimerFamilyId(), tim.GetTimers()) - - // This send is deliberately blocking, if we exceed the buffering for - // a reader. We can't buffer the entire main input, if some user code - // is slow (or gets stuck). If the local side closes, the reader - // will be marked as completed and further remote data will be ignored. - select { - case r.ch <- exec.Elements{Timers: tim.GetTimers(), TimerFamilyID: tim.GetTimerFamilyId()}: - case <-ctx.Done(): - // Technically, we need to close all the things here... to start. - r.Close() - } + // This send is deliberately blocking, if we exceed the buffering for + // a reader. We can't buffer the entire main input, if some user code + // is slow (or gets stuck). If the local side closes, the reader + // will be marked as completed and further remote data will be ignored. + select { + case r.ch <- wrap(elm): + case <-ctx.Done(): + // Technically, we need to close all the things here... to start. + c.mu.Lock() + r.Close() + c.mu.Unlock() + } + if elm.GetIsLast() { + log.Infof(ctx, "done with %s for %v", kind, id) + *seenLast = append(*seenLast, id) } } } @@ -422,7 +458,9 @@ func (c *DataChannel) removeInstruction(instID instructionID) { ws := c.writers[instID] tws := c.timerWriters[instID] - ecs := c.channels[instID] + + // Element channels are per instruction + ec := c.channels[instID] // Prevent other users while we iterate. delete(c.writers, instID) @@ -436,9 +474,7 @@ func (c *DataChannel) removeInstruction(instID instructionID) { for _, tw := range tws { tw.Close() } - for _, ec := range ecs { - ec.Close() - } + ec.Close() } func (c *DataChannel) makeWriter(ctx context.Context, id clientID) *dataWriter { @@ -452,7 +488,7 @@ func (c *DataChannel) makeWriter(ctx context.Context, id clientID) *dataWriter { c.writers[id.instID] = m } - if w, ok := m[makeID(id)]; ok { + if w, ok := m[id.ptransformID]; ok { return w } @@ -461,22 +497,22 @@ func (c *DataChannel) makeWriter(ctx context.Context, id clientID) *dataWriter { // runner or user directed. w := &dataWriter{ch: c, id: id} - m[makeID(id)] = w + m[id.ptransformID] = w return w } -func (c *DataChannel) makeTimerWriter(ctx context.Context, id clientID) *timerWriter { +func (c *DataChannel) makeTimerWriter(ctx context.Context, id clientID, family string) *timerWriter { c.mu.Lock() defer c.mu.Unlock() - var m map[string]*timerWriter + var m map[timerKey]*timerWriter var ok bool if m, ok = c.timerWriters[id.instID]; !ok { - m = make(map[string]*timerWriter) + m = make(map[timerKey]*timerWriter) c.timerWriters[id.instID] = m } - - if w, ok := m[makeID(id)]; ok { + tk := timerKey{ptransformID: id.ptransformID, family: family} + if w, ok := m[tk]; ok { return w } @@ -484,19 +520,11 @@ func (c *DataChannel) makeTimerWriter(ctx context.Context, id clientID) *timerWr // can only be created if an instruction is in scope, and aren't // runner or user directed. - w := &timerWriter{ch: c, id: id} - m[makeID(id)] = w + w := &timerWriter{ch: c, id: id, timerFamilyID: family} + m[tk] = w return w } -func makeID(id clientID) string { - newID := id.ptransformID - if id.timerFamilyID != "" { - newID += ":" + id.timerFamilyID - } - return newID -} - type dataWriter struct { buf []byte @@ -595,8 +623,9 @@ func (w *dataWriter) Write(p []byte) (n int, err error) { } type timerWriter struct { - id clientID - ch *DataChannel + id clientID + timerFamilyID string + ch *DataChannel } // send requires the ch.mu lock to be held. @@ -623,14 +652,14 @@ func (w *timerWriter) send(msg *fnpb.Elements) error { func (w *timerWriter) Close() error { w.ch.mu.Lock() defer w.ch.mu.Unlock() - delete(w.ch.timerWriters[w.id.instID], makeID(w.id)) + delete(w.ch.timerWriters[w.id.instID], timerKey{w.id.ptransformID, w.timerFamilyID}) var msg *fnpb.Elements msg = &fnpb.Elements{ Timers: []*fnpb.Elements_Timers{ { InstructionId: string(w.id.instID), TransformId: w.id.ptransformID, - TimerFamilyId: w.id.timerFamilyID, + TimerFamilyId: w.timerFamilyID, IsLast: true, }, }, @@ -649,7 +678,7 @@ func (w *timerWriter) writeTimers(p []byte) error { { InstructionId: string(w.id.instID), TransformId: w.id.ptransformID, - TimerFamilyId: w.id.timerFamilyID, + TimerFamilyId: w.timerFamilyID, Timers: p, }, }, From 5c051522d8eeb510a16fb747356a96f5d2fceb8e Mon Sep 17 00:00:00 2001 From: lostluck <13907733+lostluck@users.noreply.github.com> Date: Fri, 24 Mar 2023 12:42:46 -0700 Subject: [PATCH 5/5] It's all working, and correct, and safe! Exporting to a sepearate PR. --- sdks/go/examples/streaming_wordcap/wordcap.go | 2 +- .../pkg/beam/core/runtime/exec/datasource.go | 58 ++- .../beam/core/runtime/exec/datasource_test.go | 114 +++-- .../beam/core/runtime/exec/dynsplit_test.go | 14 +- .../pkg/beam/core/runtime/harness/datamgr.go | 124 ++++-- .../beam/core/runtime/harness/datamgr_test.go | 415 ++++++++++++++++++ .../pkg/beam/core/runtime/harness/harness.go | 8 +- 7 files changed, 621 insertions(+), 114 deletions(-) diff --git a/sdks/go/examples/streaming_wordcap/wordcap.go b/sdks/go/examples/streaming_wordcap/wordcap.go index 704997412524..d86780e0d7d8 100644 --- a/sdks/go/examples/streaming_wordcap/wordcap.go +++ b/sdks/go/examples/streaming_wordcap/wordcap.go @@ -245,7 +245,7 @@ func main() { }, imp) // out = beam.WindowInto(s, window.NewFixedWindows(10*time.Second), out) str := beam.ParDo(s, func(b int64) string { - return fmt.Sprintf("element%03d", b) + return fmt.Sprintf("%03d", b) }, out) keyed := beam.ParDo(s, func(ctx context.Context, ts beam.EventTime, s string) (string, string) { diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource.go b/sdks/go/pkg/beam/core/runtime/exec/datasource.go index 5fa1a1f94374..54d86a6fa52e 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource.go @@ -44,8 +44,9 @@ type DataSource struct { // OnTimerTransforms maps PtransformIDs to their execution nodes that handle OnTimer callbacks. OnTimerTransforms map[string]*ParDo - source DataManager - state StateReader + source DataManager + state StateReader + curInst string index int64 splitIdx int64 @@ -97,6 +98,7 @@ func (n *DataSource) Up(ctx context.Context) error { // StartBundle initializes this datasource for the bundle. func (n *DataSource) StartBundle(ctx context.Context, id string, data DataContext) error { n.mu.Lock() + n.curInst = id n.source = data.Data n.state = data.State n.start = time.Now() @@ -106,9 +108,19 @@ func (n *DataSource) StartBundle(ctx context.Context, id string, data DataContex return n.Out.StartBundle(ctx, id, data) } +// splitSuccess is a marker error to indicate we've reached the split index. +// Akin to io.EOF. +var splitSuccess = errors.New("split index reached") + // process handles converting elements from the data source to timers. +// +// The data and timer callback functions must return an io.EOF if the reader terminates to signal that an additional +// buffer is desired. On successful splits, [splitSuccess] must be returned to indicate that the +// PTransform is done processing data for this instruction. func (n *DataSource) process(ctx context.Context, data func(bcr *byteCountReader, ptransformID string) error, timer func(bcr *byteCountReader, ptransformID, timerFamilyID string) error) error { - // TODO(riteshghorse): Pass in the PTransformIDs expecting OnTimer calls. + defer func() { + log.Infof(ctx, "%v DataSource.process returning", n.curInst) + }() // The SID contains this instruction's expected data processing transform (this one). elms, err := n.source.OpenElementChan(ctx, n.SID, maps.Keys(n.OnTimerTransforms)) if err != nil { @@ -120,35 +132,44 @@ func (n *DataSource) process(ctx context.Context, data func(bcr *byteCountReader var byteCount int bcr := byteCountReader{reader: &r, count: &byteCount} + + splitPrimaryComplete := map[string]bool{} for { var err error select { case e, ok := <-elms: // Channel closed, so time to exit if !ok { - log.Infof(ctx, "%v: Data Channel closed", n) + log.Infof(ctx, "%v Data Channel closed", n.curInst) return nil } + if splitPrimaryComplete[e.PtransformID] { + log.Infof(ctx, "%v skipping elements for %v, previous split", n.curInst, e.PtransformID) + continue + } if len(e.Data) > 0 { r.Reset(e.Data) err = data(&bcr, e.PtransformID) } if len(e.Timers) > 0 { - // TODO remove this debug log. - log.Infof(ctx, "timer received for %v; %v : %v", e.PtransformID, e.TimerFamilyID, e.Timers) + log.Infof(ctx, "%v timer received for %v; %v : %v", n.curInst, e.PtransformID, e.TimerFamilyID, e.Timers) r.Reset(e.Timers) err = timer(&bcr, e.PtransformID, e.TimerFamilyID) } - case <-ctx.Done(): - return nil - } - if err != nil { - if err != io.EOF { + if err == splitSuccess { + log.Infof(ctx, "%v split success received for %v", n.curInst, e.PtransformID) + // Returning splitSuccess means we've split, and aren't consuming the remaining buffer. + // We mark the PTransform done to ignore further data. + splitPrimaryComplete[e.PtransformID] = true + } else if err != nil && err != io.EOF { return errors.Wrap(err, "source failed") } - // io.EOF means the reader successfully drained + // io.EOF means the reader successfully drained. // We're ready for a new buffer. + case <-ctx.Done(): + log.Infof(ctx, "%v context canceled for: %v", n.curInst, ctx.Err()) + return nil } } } @@ -204,6 +225,7 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { // TODO(lostluck) 2020/02/22: Should we include window headers or just count the element sizes? ws, t, pn, err := DecodeWindowedValueHeader(wc, bcr.reader) if err != nil { + log.Infof(ctx, "%v decode window error: %v", n.curInst, err) return err } @@ -216,7 +238,7 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { pe.Windows = ws pe.Pane = pn - log.Infof(ctx, "%v[%v]: processing %+v,%v", n, ptransformID, pe.Elm, pe.Elm2) + log.Infof(ctx, "%v[%v]: processing %+v,%v", n.curInst, ptransformID, pe.Elm, pe.Elm2) var valReStreams []ReStream for _, cv := range cvs { @@ -234,7 +256,7 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { n.PCol.addSize(int64(bcr.reset())) // Check if there's a continuation and return residuals - // Needs to be done immeadiately after processing to not lose the element. + // Needs to be done immediately after processing to not lose the element. if c := n.getProcessContinuation(); c != nil { cp, err := n.checkpointThis(ctx, c) if err != nil { @@ -247,12 +269,10 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { } // We've finished processing an element, check if we have finished a split. if n.incrementIndexAndCheckSplit() { - break + log.Infof(ctx, "%v split index reached", n.curInst) + return splitSuccess } } - // Signal data loop exit. - log.Debugf(ctx, "%v: exiting data loop", n) - return nil }, func(bcr *byteCountReader, ptransformID, timerFamilyID string) error { tmap, err := decodeTimer(cp, wc, bcr) @@ -526,7 +546,7 @@ func (n *DataSource) checkpointThis(ctx context.Context, pc sdf.ProcessContinuat // The bufSize param specifies the estimated number of elements that will be // sent to this DataSource, and is used to be able to perform accurate splits // even if the DataSource has not yet received all its elements. A bufSize of -// 0 or less indicates that its unknown, and so uses the current known size. +// 0 or less indicates that it's unknown, and so uses the current known size. func (n *DataSource) Split(ctx context.Context, splits []int64, frac float64, bufSize int64) (SplitResult, error) { if n == nil { return SplitResult{}, fmt.Errorf("failed to split at requested splits: {%v}, DataSource not initialized", splits) diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go index 14e954d26afe..ebede8538083 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go @@ -16,7 +16,6 @@ package exec import ( - "bytes" "context" "fmt" "io" @@ -43,20 +42,20 @@ func TestDataSource_PerElement(t *testing.T) { name string expected []any Coder *coder.Coder - driver func(*coder.Coder, io.WriteCloser, []any) + driver func(*coder.Coder, *chanWriter, []any) }{ { name: "perElement", expected: []any{int64(1), int64(2), int64(3), int64(4), int64(5)}, Coder: coder.NewW(coder.NewVarInt(), coder.NewGlobalWindow()), - driver: func(c *coder.Coder, pw io.WriteCloser, expected []any) { + driver: func(c *coder.Coder, cw *chanWriter, expected []any) { wc := MakeWindowEncoder(c.Window) ec := MakeElementEncoder(coder.SkipW(c)) for _, v := range expected { - EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, typex.NoFiringPane(), pw) - ec.Encode(&FullValue{Elm: v}, pw) + EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, typex.NoFiringPane(), cw) + ec.Encode(&FullValue{Elm: v}, cw) } - pw.Close() + cw.Close() }, }, } @@ -70,11 +69,11 @@ func TestDataSource_PerElement(t *testing.T) { Coder: test.Coder, Out: out, } - pr, pw := io.Pipe() - go test.driver(source.Coder, pw, test.expected) + cw := makeChanWriter() + go test.driver(source.Coder, cw, test.expected) constructAndExecutePlanWithContext(t, []Unit{out, source}, DataContext{ - Data: &TestDataManager{R: pr}, + Data: &TestDataManager{Ch: cw.Ch}, }) validateSource(t, out, source, makeValues(test.expected...)) @@ -98,14 +97,14 @@ func TestDataSource_Iterators(t *testing.T) { name string keys, vals []any Coder *coder.Coder - driver func(c *coder.Coder, dmw io.WriteCloser, siwFn func() io.WriteCloser, ks, vs []any) + driver func(c *coder.Coder, dmw *chanWriter, siwFn func() io.WriteCloser, ks, vs []any) }{ { name: "beam:coder:iterable:v1-singleChunk", keys: []any{int64(42), int64(53)}, vals: []any{int64(1), int64(2), int64(3), int64(4), int64(5)}, Coder: coder.NewW(coder.NewCoGBK([]*coder.Coder{coder.NewVarInt(), coder.NewVarInt()}), coder.NewGlobalWindow()), - driver: func(c *coder.Coder, dmw io.WriteCloser, _ func() io.WriteCloser, ks, vs []any) { + driver: func(c *coder.Coder, dmw *chanWriter, _ func() io.WriteCloser, ks, vs []any) { wc, kc, vc := extractCoders(c) for _, k := range ks { EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, typex.NoFiringPane(), dmw) @@ -123,7 +122,7 @@ func TestDataSource_Iterators(t *testing.T) { keys: []any{int64(42), int64(53)}, vals: []any{int64(1), int64(2), int64(3), int64(4), int64(5)}, Coder: coder.NewW(coder.NewCoGBK([]*coder.Coder{coder.NewVarInt(), coder.NewVarInt()}), coder.NewGlobalWindow()), - driver: func(c *coder.Coder, dmw io.WriteCloser, _ func() io.WriteCloser, ks, vs []any) { + driver: func(c *coder.Coder, dmw *chanWriter, _ func() io.WriteCloser, ks, vs []any) { wc, kc, vc := extractCoders(c) for _, k := range ks { EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, typex.NoFiringPane(), dmw) @@ -144,7 +143,7 @@ func TestDataSource_Iterators(t *testing.T) { keys: []any{int64(42), int64(53)}, vals: []any{int64(1), int64(2), int64(3), int64(4), int64(5)}, Coder: coder.NewW(coder.NewCoGBK([]*coder.Coder{coder.NewVarInt(), coder.NewVarInt()}), coder.NewGlobalWindow()), - driver: func(c *coder.Coder, dmw io.WriteCloser, swFn func() io.WriteCloser, ks, vs []any) { + driver: func(c *coder.Coder, dmw *chanWriter, swFn func() io.WriteCloser, ks, vs []any) { wc, kc, vc := extractCoders(c) for _, k := range ks { EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, typex.NoFiringPane(), dmw) @@ -155,6 +154,8 @@ func TestDataSource_Iterators(t *testing.T) { token := []byte(tokenString) coder.EncodeVarInt(int64(len(token)), dmw) // token. dmw.Write(token) + dmw.Flush() // Flush here to allow state IO from this goroutine. + // Each state stream needs to be a different writer, so get a new writer. sw := swFn() for _, v := range vs { @@ -170,6 +171,7 @@ func TestDataSource_Iterators(t *testing.T) { for _, singleIterate := range []bool{true, false} { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + fmt.Println(test.name) capture := &IteratorCaptureNode{CaptureNode: CaptureNode{UID: 1}} out := Node(capture) units := []Unit{out} @@ -187,8 +189,7 @@ func TestDataSource_Iterators(t *testing.T) { Out: out, } units = append(units, source) - dmr, dmw := io.Pipe() - + cw := makeChanWriter() // Simulate individual state channels with pipes and a channel. sRc := make(chan io.ReadCloser) swFn := func() io.WriteCloser { @@ -196,10 +197,10 @@ func TestDataSource_Iterators(t *testing.T) { sRc <- sr return sw } - go test.driver(source.Coder, dmw, swFn, test.keys, test.vals) + go test.driver(source.Coder, cw, swFn, test.keys, test.vals) constructAndExecutePlanWithContext(t, units, DataContext{ - Data: &TestDataManager{R: dmr}, + Data: &TestDataManager{Ch: cw.Ch}, State: &TestStateReader{Rc: sRc}, }) if len(capture.CapturedInputs) == 0 { @@ -240,7 +241,7 @@ func TestDataSource_Iterators(t *testing.T) { func TestDataSource_Split(t *testing.T) { elements := []any{int64(1), int64(2), int64(3), int64(4), int64(5)} - initSourceTest := func(name string) (*DataSource, *CaptureNode, io.ReadCloser) { + initSourceTest := func(name string) (*DataSource, *CaptureNode, chan Elements) { out := &CaptureNode{UID: 1} c := coder.NewW(coder.NewVarInt(), coder.NewGlobalWindow()) source := &DataSource{ @@ -250,7 +251,7 @@ func TestDataSource_Split(t *testing.T) { Coder: c, Out: out, } - pr, pw := io.Pipe() + cw := makeChanWriter() go func(c *coder.Coder, pw io.WriteCloser, elements []any) { wc := MakeWindowEncoder(c.Window) @@ -260,8 +261,8 @@ func TestDataSource_Split(t *testing.T) { ec.Encode(&FullValue{Elm: v}, pw) } pw.Close() - }(c, pw, elements) - return source, out, pr + }(c, cw, elements) + return source, out, cw.Ch } tests := []struct { @@ -289,12 +290,12 @@ func TestDataSource_Split(t *testing.T) { test.expected = elements[:test.splitIdx] } t.Run(test.name, func(t *testing.T) { - source, out, pr := initSourceTest(test.name) + source, out, ch := initSourceTest(test.name) p, err := NewPlan("a", []Unit{out, source}) if err != nil { t.Fatalf("failed to construct plan: %v", err) } - dc := DataContext{Data: &TestDataManager{R: pr}} + dc := DataContext{Data: &TestDataManager{Ch: ch}} ctx := context.Background() // StartBundle resets the source, so no splits can be actuated before then, @@ -358,7 +359,7 @@ func TestDataSource_Split(t *testing.T) { test.expected = elements[:test.splitIdx] } t.Run(test.name, func(t *testing.T) { - source, out, pr := initSourceTest(test.name) + source, out, ch := initSourceTest(test.name) unblockCh, blockedCh := make(chan struct{}), make(chan struct{}, 1) // Block on the one less than the desired split, // so the desired split is the first valid split. @@ -401,7 +402,7 @@ func TestDataSource_Split(t *testing.T) { }() constructAndExecutePlanWithContext(t, []Unit{out, blocker, source}, DataContext{ - Data: &TestDataManager{R: pr}, + Data: &TestDataManager{Ch: ch}, }) validateSource(t, out, source, makeValues(test.expected...)) @@ -427,12 +428,12 @@ func TestDataSource_Split(t *testing.T) { expected: elements[:3], } - source, out, pr := initSourceTest("bufSize") + source, out, ch := initSourceTest("bufSize") p, err := NewPlan("a", []Unit{out, source}) if err != nil { t.Fatalf("failed to construct plan: %v", err) } - dc := DataContext{Data: &TestDataManager{R: pr}} + dc := DataContext{Data: &TestDataManager{Ch: ch}} ctx := context.Background() // StartBundle resets the source, so no splits can be actuated before then, @@ -490,7 +491,7 @@ func TestDataSource_Split(t *testing.T) { test := test name := fmt.Sprintf("withFraction_%v", test.fraction) t.Run(name, func(t *testing.T) { - source, out, pr := initSourceTest(name) + source, out, ch := initSourceTest(name) unblockCh, blockedCh := make(chan struct{}), make(chan struct{}, 1) // Block on the one less than the desired split, // so the desired split is the first valid split. @@ -527,10 +528,10 @@ func TestDataSource_Split(t *testing.T) { t.Errorf("error in Split: got sub-element split = %t, want %t", isSubElm, test.isSubElm) } if isSubElm { - if got, want := splitRes.TId, testTransformId; got != want { + if got, want := splitRes.TId, testTransformID; got != want { t.Errorf("error in Split: got incorrect Transform Id = %v, want %v", got, want) } - if got, want := splitRes.InId, testInputId; got != want { + if got, want := splitRes.InId, testInputID; got != want { t.Errorf("error in Split: got incorrect Input Id = %v, want %v", got, want) } if _, ok := splitRes.OW["output1"]; !ok { @@ -558,7 +559,7 @@ func TestDataSource_Split(t *testing.T) { }() constructAndExecutePlanWithContext(t, []Unit{out, blocker, source}, DataContext{ - Data: &TestDataManager{R: pr}, + Data: &TestDataManager{Ch: ch}, }) validateSource(t, out, source, makeValues(elements[:test.splitIdx]...)) @@ -571,12 +572,12 @@ func TestDataSource_Split(t *testing.T) { // Test expects splitting errors, but for processing to be successful. t.Run("errors", func(t *testing.T) { - source, out, pr := initSourceTest("noSplitsUntilStarted") + source, out, ch := initSourceTest("noSplitsUntilStarted") p, err := NewPlan("a", []Unit{out, source}) if err != nil { t.Fatalf("failed to construct plan: %v", err) } - dc := DataContext{Data: &TestDataManager{R: pr}} + dc := DataContext{Data: &TestDataManager{Ch: ch}} ctx := context.Background() if sr, err := p.Split(ctx, SplitPoints{Splits: []int64{0, 3}, Frac: -1}); err != nil || !sr.Unsuccessful { @@ -620,8 +621,8 @@ func TestDataSource_Split(t *testing.T) { }) } -const testTransformId = "transform_id" -const testInputId = "input_id" +const testTransformID = "transform_id" +const testInputID = "input_id" // TestSplittableUnit is an implementation of the SplittableUnit interface // for DataSource tests. @@ -651,12 +652,12 @@ func (n *TestSplittableUnit) GetProgress() float64 { // GetTransformId returns a constant transform ID that can be tested for. func (n *TestSplittableUnit) GetTransformId() string { - return testTransformId + return testTransformID } // GetInputId returns a constant input ID that can be tested for. func (n *TestSplittableUnit) GetInputId() string { - return testInputId + return testInputID } // GetOutputWatermark gets the current output watermark of the splittable unit @@ -966,20 +967,21 @@ func TestCheckpointing(t *testing.T) { } enc := MakeElementEncoder(wvERSCoder) - var buf bytes.Buffer + cw := makeChanWriter() // We encode the element several times to ensure we don't // drop any residuals, the root of issue #24931. wantCount := 3 for i := 0; i < wantCount; i++ { - if err := enc.Encode(value, &buf); err != nil { + if err := enc.Encode(value, cw); err != nil { t.Fatalf("couldn't encode value: %v", err) } } + cw.Close() if err := root.StartBundle(ctx, "testBund", DataContext{ Data: &TestDataManager{ - R: io.NopCloser(&buf), + Ch: cw.Ch, }, }, ); err != nil { @@ -1017,22 +1019,44 @@ func runOnRoots(ctx context.Context, t *testing.T, p *Plan, name string, mthd fu } type TestDataManager struct { - R io.ReadCloser - C chan Elements + Ch chan Elements } -func (dm *TestDataManager) OpenElementChan(ctx context.Context, id StreamID) (<-chan Elements, error) { - return dm.C, nil +func (dm *TestDataManager) OpenElementChan(ctx context.Context, id StreamID, expectedTimerTransforms []string) (<-chan Elements, error) { + return dm.Ch, nil } func (dm *TestDataManager) OpenWrite(ctx context.Context, id StreamID) (io.WriteCloser, error) { return nil, nil } -func (dm *TestDataManager) OpenTimerWrite(ctx context.Context, id StreamID, key string) (io.WriteCloser, error) { +func (dm *TestDataManager) OpenTimerWrite(ctx context.Context, id StreamID, family string) (io.WriteCloser, error) { return nil, nil } +type chanWriter struct { + Ch chan Elements + Buf []byte +} + +func (cw *chanWriter) Write(p []byte) (int, error) { + cw.Buf = append(cw.Buf, p...) + return len(p), nil +} + +func (cw *chanWriter) Close() error { + cw.Flush() + close(cw.Ch) + return nil +} + +func (cw *chanWriter) Flush() { + cw.Ch <- Elements{Data: cw.Buf, PtransformID: "myPTransform"} + cw.Buf = nil +} + +func makeChanWriter() *chanWriter { return &chanWriter{Ch: make(chan Elements, 20)} } + // TestSideInputReader simulates state reads using channels. type TestStateReader struct { StateReader diff --git a/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go b/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go index 1fa3ae94c866..84c84a8d3164 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/dynsplit_test.go @@ -75,10 +75,10 @@ func TestDynamicSplit(t *testing.T) { plan, out := createSdfPlan(t, t.Name(), dfn, cdr) // Create thread to send element to pipeline. - pr, pw := io.Pipe() + cw := makeChanWriter() elm := createElm() - go writeElm(elm, cdr, pw) - dc := DataContext{Data: &TestDataManager{R: pr}} + go writeElm(elm, cdr, cw) + dc := DataContext{Data: &TestDataManager{Ch: cw.Ch}} // Call driver to coordinate processing & splitting threads. splitRes, procRes := test.driver(context.Background(), plan, dc, sdf) @@ -92,7 +92,7 @@ func TestDynamicSplit(t *testing.T) { RI: 1, PS: nil, RS: nil, - TId: testTransformId, + TId: testTransformID, InId: indexToInputId(0), } if diff := cmp.Diff(splitRes.split, wantSplit, cmpopts.IgnoreFields(SplitResult{}, "PS", "RS")); diff != "" { @@ -263,7 +263,7 @@ func createSplitTestInCoder() *coder.Coder { func createSdfPlan(t *testing.T, name string, fn *graph.DoFn, cdr *coder.Coder) (*Plan, *CaptureNode) { out := &CaptureNode{UID: 0} n := &ParDo{UID: 1, Fn: fn, Out: []Node{out}} - sdf := &ProcessSizedElementsAndRestrictions{PDo: n, TfId: testTransformId} + sdf := &ProcessSizedElementsAndRestrictions{PDo: n, TfId: testTransformID} ds := &DataSource{ UID: 2, SID: StreamID{PtransformID: "DataSource"}, @@ -281,8 +281,8 @@ func createSdfPlan(t *testing.T, name string, fn *graph.DoFn, cdr *coder.Coder) } // writeElm is meant to be the goroutine for feeding an element to the -// DataSourc of the test pipeline. -func writeElm(elm *FullValue, cdr *coder.Coder, pw *io.PipeWriter) { +// DataSource of the test pipeline. +func writeElm(elm *FullValue, cdr *coder.Coder, pw io.WriteCloser) { wc := MakeWindowEncoder(cdr.Window) ec := MakeElementEncoder(coder.SkipW(cdr)) if err := EncodeWindowedValueHeader(wc, window.SingleGlobalWindow, mtime.ZeroTimestamp, typex.NoFiringPane(), pw); err != nil { diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go index ab938750c301..d9fd2e251844 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go @@ -36,8 +36,9 @@ const ( // ScopedDataManager scopes the global gRPC data manager to a single instruction. // The indirection makes it easier to control access. type ScopedDataManager struct { - mgr *DataChannelManager - instID instructionID + mgr *DataChannelManager + instID instructionID + openPorts []exec.Port closed bool mu sync.Mutex @@ -63,7 +64,7 @@ func (s *ScopedDataManager) OpenElementChan(ctx context.Context, id exec.StreamI if err != nil { return nil, err } - return ch.OpenElementChan(ctx, id.PtransformID, s.instID, expectedTimerTransforms), nil + return ch.OpenElementChan(ctx, id.PtransformID, s.instID, expectedTimerTransforms) } // OpenTimerWrite opens an io.WriteCloser on the given stream to write timers @@ -81,6 +82,7 @@ func (s *ScopedDataManager) open(ctx context.Context, port exec.Port) (*DataChan s.mu.Unlock() return nil, errors.Errorf("instruction %v no longer processing", s.instID) } + s.openPorts = append(s.openPorts, port) local := s.mgr s.mu.Unlock() @@ -92,9 +94,10 @@ func (s *ScopedDataManager) Close() error { s.mu.Lock() defer s.mu.Unlock() s.closed = true - s.mgr.closeInstruction(s.instID) + log.Infof(context.TODO(), "ScopedDataManager Closing, %v", string(s.instID)) + err := s.mgr.closeInstruction(s.instID, s.openPorts) s.mgr = nil - return nil + return err } // DataChannelManager manages data channels over the Data API. A fixed number of channels @@ -134,12 +137,21 @@ func (m *DataChannelManager) Open(ctx context.Context, port exec.Port) (*DataCha return ch, nil } -func (m *DataChannelManager) closeInstruction(instID instructionID) { +func (m *DataChannelManager) closeInstruction(instID instructionID, ports []exec.Port) error { m.mu.Lock() defer m.mu.Unlock() - for _, ch := range m.ports { - ch.removeInstruction(instID) + var firstNonNilError error + for _, port := range ports { + ch, ok := m.ports[port.URL] + if !ok { + continue + } + err := ch.removeInstruction(instID) + if err != nil && firstNonNilError == nil { + firstNonNilError = err + } } + return firstNonNilError } // clientID identifies a client of a connected channel. @@ -149,13 +161,13 @@ type clientID struct { } // This is a reduced version of the full gRPC interface to help with testing. -// TODO(wcn): need a compile-time assertion to make sure this stays synced with what's -// in fnpb.BeamFnData_DataClient type dataClient interface { Send(*fnpb.Elements) error Recv() (*fnpb.Elements, error) } +var _ dataClient = (fnpb.BeamFnData_DataClient)(nil) // Assert our interfaces are compatible. + // DataChannel manages a single gRPC stream over the Data API. Data from // multiple bundles can be multiplexed over this stream. Data is pushed // over the channel, so data for a reader may arrive before the reader @@ -188,9 +200,17 @@ type timerKey struct { } type elementsChan struct { + instID instructionID ch chan exec.Elements readingTransforms map[string]bool complete int // count of "done" streams + done chan bool +} + +// InstructionEnded signals the read loop to close the channel. +func (ec *elementsChan) InstructionEnded() { + log.Infof(context.TODO(), "EC InstructionEnded: %v %v", ec.readingTransforms, ec.instID) + ec.done <- true } // Closed indicates if all expected streams are complete @@ -202,20 +222,28 @@ func (ec *elementsChan) Closed() bool { // Timer consuming PTransforms and DataSources use distinct transformIDs, // So the channel should only close when all data is completed. func (ec *elementsChan) Done(ptransformID string) { + log.Infof(context.TODO(), "EC Done, (%v) %v", ptransformID, ec.instID) if !ec.Closed() { if ec.readingTransforms[ptransformID] { ec.readingTransforms[ptransformID] = false ec.complete++ } if ec.Closed() { - close(ec.ch) + log.Infof(context.TODO(), "EC last ptransform done, closing channel (%v) %v", ptransformID, ec.instID) + ec.close() } } } +func (ec *elementsChan) close() { + log.Infof(context.TODO(), "closing elementsChan for %v", ec.instID) + ec.complete = len(ec.readingTransforms) + close(ec.ch) +} + func (ec *elementsChan) Close() error { if !ec.Closed() { - close(ec.ch) + ec.close() } return nil } @@ -269,32 +297,32 @@ func (c *DataChannel) OpenWrite(ctx context.Context, ptransformID string, instID } // OpenElementChan returns a channel of typex.Elements for the given instruction and ptransform. -func (c *DataChannel) OpenElementChan(ctx context.Context, ptransformID string, instID instructionID, expectedTimerTransforms []string) <-chan exec.Elements { +func (c *DataChannel) OpenElementChan(ctx context.Context, ptransformID string, instID instructionID, expectedTimerTransforms []string) (<-chan exec.Elements, error) { c.mu.Lock() defer c.mu.Unlock() cid := clientID{ptransformID: ptransformID, instID: instID} if c.readErr != nil { - panic(fmt.Errorf("opening a reader %v on a closed channel", cid)) + return nil, fmt.Errorf("opening a reader %v on a closed channel. Original error: %w", cid, c.readErr) } log.Infof(ctx, "OpenElementChan %v %v", cid, expectedTimerTransforms) - return c.makeChannel(ctx, cid, expectedTimerTransforms...).ch + return c.makeChannel(ctx, true, cid, expectedTimerTransforms...).ch, nil } // makeChannel creates a channel of exec.Elements. It expects to be called while c.mu is held. -func (c *DataChannel) makeChannel(ctx context.Context, id clientID, additionalTransforms ...string) *elementsChan { +func (c *DataChannel) makeChannel(ctx context.Context, fromSource bool, id clientID, additionalTransforms ...string) *elementsChan { if r, ok := c.channels[id.instID]; ok { - if !r.Closed() { - // Ensure new readers are accounted for, and current data stream state is respected. - // That is, we only add an entry if it doesn't exist already. - if _, ok := r.readingTransforms[id.ptransformID]; !ok { - r.readingTransforms[id.ptransformID] = true - } - for _, pid := range additionalTransforms { - if _, ok := r.readingTransforms[pid]; !ok { - r.readingTransforms[pid] = true - } - } - } + // if !r.Closed() { + // // Ensure new readers are accounted for, and current data stream state is respected. + // // That is, we only add an entry if it doesn't exist already. + // if _, ok := r.readingTransforms[id.ptransformID]; !ok { + // r.readingTransforms[id.ptransformID] = true + // } + // for _, pid := range additionalTransforms { + // if _, ok := r.readingTransforms[pid]; !ok { + // r.readingTransforms[pid] = true + // } + // } + // } log.Infof(ctx, "looked up data read channel %v %v", id, additionalTransforms) return r } @@ -302,8 +330,10 @@ func (c *DataChannel) makeChannel(ctx context.Context, id clientID, additionalTr log.Infof(ctx, "make data read channel %v %v", id, additionalTransforms) r := &elementsChan{ + instID: id.instID, ch: make(chan exec.Elements, 20), readingTransforms: map[string]bool{id.ptransformID: true}, + done: make(chan bool, 1), } // Just in case initial data for an instruction arrives *after* an instructon has ended. // eg. it was blocked by another reader being slow, or the other instruction failed. @@ -330,7 +360,7 @@ func (c *DataChannel) OpenTimerWrite(ctx context.Context, ptransformID string, i func (c *DataChannel) read(ctx context.Context) { cache := make(map[instructionID]*elementsChan) - seenLast := make([]clientID, 5) + seenLast := make([]clientID, 0, 5) for { msg, err := c.client.Recv() if err != nil { @@ -338,9 +368,8 @@ func (c *DataChannel) read(ctx context.Context) { c.mu.Lock() c.readErr = err // prevent not yet opened readers from hanging. // Readers must be closed from this goroutine, since we can't - // close the r.buf channels twice, or send on a closed channel. - // Any other approach is racy, and may cause one of the above - // panics. + // close the elementsChan channel twice, or send on those closed channels. + // Any other approach is racy, and may cause one of the above panics. for instID, r := range c.channels { log.Errorf(ctx, "DataChannel.read %v channel inst: %v closing due to error on channel", c.id, instID) r.Close() @@ -369,6 +398,9 @@ func (c *DataChannel) read(ctx context.Context) { iterateElements(ctx, "data", c, cache, &seenLast, msg.GetData(), func(elm *fnpb.Elements_Data) exec.Elements { + if len(elm.GetData()) != 0 { + log.Infof(ctx, "sent data to %v for %v", elm.GetInstructionId(), elm.GetTransformId()) + } return exec.Elements{Data: elm.GetData(), PtransformID: elm.GetTransformId()} }) @@ -376,12 +408,15 @@ func (c *DataChannel) read(ctx context.Context) { if len(seenLast) > 0 { c.mu.Lock() for _, id := range seenLast { + log.Infof(ctx, "is last seen for EC %v", id) r, ok := cache[id.instID] if !ok { + log.Infof(ctx, "cached EC already closed: %v", id) continue // we've already closed this cached reader, skip } r.Done(id.ptransformID) if r.Closed() { + log.Infof(ctx, "removing EC %v from cache", id) // Clean up local bookkeeping. We'll never see another message // for it again. We have to be careful not to remove the real // one, because readers may be initialized after we've seen @@ -411,7 +446,7 @@ func iterateElements[E dataEle](ctx context.Context, kind string, c *DataChannel r = local } else { c.mu.Lock() - r = c.makeChannel(ctx, id) + r = c.makeChannel(ctx, false, id) c.mu.Unlock() cache[id.instID] = r } @@ -421,14 +456,19 @@ func iterateElements[E dataEle](ctx context.Context, kind string, c *DataChannel continue } - // This send is deliberately blocking, if we exceed the buffering for + // This send deliberately blocks if we exceed the buffering for // a reader. We can't buffer the entire main input, if some user code // is slow (or gets stuck). If the local side closes, the reader // will be marked as completed and further remote data will be ignored. select { case r.ch <- wrap(elm): + case <-r.done: // In case of out of band cancels. + log.Infof(ctx, "out of band cancel %s for %v", kind, id) + c.mu.Lock() + r.Close() + c.mu.Unlock() case <-ctx.Done(): - // Technically, we need to close all the things here... to start. + log.Infof(ctx, "context cancel %s for %v: %v", kind, id, ctx.Err()) c.mu.Lock() r.Close() c.mu.Unlock() @@ -444,7 +484,7 @@ const endedInstructionCap = 32 // removeInstruction closes all readers and writers registered for the instruction // and deletes this instruction from the channel's reader and writer maps. -func (c *DataChannel) removeInstruction(instID instructionID) { +func (c *DataChannel) removeInstruction(instID instructionID) error { c.mu.Lock() // We don't want to leak memory, so cap the endedInstructions list. @@ -458,14 +498,15 @@ func (c *DataChannel) removeInstruction(instID instructionID) { ws := c.writers[instID] tws := c.timerWriters[instID] - - // Element channels are per instruction ec := c.channels[instID] // Prevent other users while we iterate. delete(c.writers, instID) delete(c.timerWriters, instID) delete(c.channels, instID) + + // Return readErr to defend against data loss via short reads. + err := c.readErr c.mu.Unlock() for _, w := range ws { @@ -474,7 +515,10 @@ func (c *DataChannel) removeInstruction(instID instructionID) { for _, tw := range tws { tw.Close() } - ec.Close() + if ec != nil { + ec.InstructionEnded() + } + return err } func (c *DataChannel) makeWriter(ctx context.Context, id clientID) *dataWriter { @@ -516,7 +560,7 @@ func (c *DataChannel) makeTimerWriter(ctx context.Context, id clientID, family s return w } - // We don't check for ended instructions for writers, as writers + // We don't check for finished instructions for writers, as writers // can only be created if an instruction is in scope, and aren't // runner or user directed. diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go index 3f77d69f1737..dd48f33dd260 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go @@ -18,6 +18,7 @@ package harness import ( "bytes" "context" + "errors" "fmt" "io" "log" @@ -26,6 +27,7 @@ import ( "testing" "time" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" ) @@ -101,6 +103,419 @@ func (f *fakeDataClient) Send(*fnpb.Elements) error { return nil } +type fakeChanClient struct { + ch chan *fnpb.Elements +} + +func (f *fakeChanClient) Recv() (*fnpb.Elements, error) { + e, ok := <-f.ch + if !ok { + return nil, io.EOF + } + return e, nil +} + +func (f *fakeChanClient) Send(e *fnpb.Elements) error { + f.ch <- e + return nil +} + +func (f *fakeChanClient) Close() error { + close(f.ch) + return nil +} + +func TestElementChan(t *testing.T) { + const instID = "inst_ref" + dataID := "dataTransform" + timerID := "timerTransform" + timerFamily := "timerFamily" + setupClient := func(t *testing.T) (context.Context, *fakeChanClient, *DataChannel) { + fmt.Println("TTTTT ", t.Name()) + t.Helper() + client := &fakeChanClient{ch: make(chan *fnpb.Elements, 20)} + ctx, cancelFn := context.WithCancel(context.Background()) + t.Cleanup(cancelFn) + t.Cleanup(func() { client.Close() }) + + c := makeDataChannel(ctx, "id", client, cancelFn) + return ctx, client, c + } + drainAndSum := func(t *testing.T, elms <-chan exec.Elements) (sum, count int) { + t.Helper() + for e := range elms { // only exits if data channel is closed. + count++ + if len(e.Data) != 0 { + sum += int(e.Data[0]) + } + if len(e.Timers) != 0 { + if e.TimerFamilyID != timerFamily { + t.Errorf("timer received without family set: %v, state= sum %v, count %v", e, sum, count) + } + sum += int(e.Timers[0]) + } + } + return sum, count + } + + timerElm := func(val byte, isLast bool) *fnpb.Elements_Timers { + return &fnpb.Elements_Timers{InstructionId: instID, TransformId: timerID, Timers: []byte{val}, IsLast: isLast, TimerFamilyId: timerFamily} + } + dataElm := func(val byte, isLast bool) *fnpb.Elements_Data { + return &fnpb.Elements_Data{InstructionId: instID, TransformId: dataID, Data: []byte{val}, IsLast: isLast} + } + noTimerElm := func() *fnpb.Elements_Timers { + return &fnpb.Elements_Timers{InstructionId: instID, TransformId: timerID, Timers: []byte{}, IsLast: true} + } + noDataElm := func() *fnpb.Elements_Data { + return &fnpb.Elements_Data{InstructionId: instID, TransformId: dataID, Data: []byte{}, IsLast: true} + } + + // Simple batch HappyPath. + t.Run("readerThenData_singleRecv", func(t *testing.T) { + ctx, client, c := setupClient(t) + + elms, err := c.OpenElementChan(ctx, dataID, instID, nil) + if err != nil { + t.Errorf("Unexpected error from OpenElementChan(%v, %v, nil): %v", dataID, instID, err) + } + + client.Send(&fnpb.Elements{ + Data: []*fnpb.Elements_Data{ + dataElm(1, false), + dataElm(2, false), + dataElm(3, true), + }, + }) + + sum, count := drainAndSum(t, elms) + if wantSum, wantCount := 6, 3; sum != wantSum && count != wantCount { + t.Errorf("got sum %v, count %v, want sum %v, count %v", sum, count, wantSum, wantSum) + } + }) + + t.Run("readerThenData_multipleRecv", func(t *testing.T) { + ctx, client, c := setupClient(t) + + elms, err := c.OpenElementChan(ctx, dataID, instID, nil) + if err != nil { + t.Errorf("Unexpected error from OpenElementChan(%v, %v, nil): %v", dataID, instID, err) + } + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(1, false)}}) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(2, false)}}) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(3, true)}}) + + sum, count := drainAndSum(t, elms) + if wantSum, wantCount := 6, 3; sum != wantSum && count != wantCount { + t.Errorf("got sum %v, count %v, want sum %v, count %v", sum, count, wantSum, wantSum) + } + }) + t.Run("readerThenDataAndTimers", func(t *testing.T) { + ctx, client, c := setupClient(t) + + elms, err := c.OpenElementChan(ctx, dataID, instID, []string{timerID}) + if err != nil { + t.Errorf("Unexpected error from OpenElementChan(%v, %v, nil): %v", dataID, instID, err) + } + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(1, false)}}) + + client.Send(&fnpb.Elements{Timers: []*fnpb.Elements_Timers{timerElm(2, true)}}) + + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(3, true)}}) + + sum, count := drainAndSum(t, elms) + if wantSum, wantCount := 6, 3; sum != wantSum && count != wantCount { + t.Errorf("got sum %v, count %v, want sum %v, count %v", sum, count, wantSum, wantSum) + } + }) + + t.Run("DataThenReaderThenLast", func(t *testing.T) { + ctx, client, c := setupClient(t) + client.Send(&fnpb.Elements{ + Data: []*fnpb.Elements_Data{ + dataElm(1, false), + dataElm(2, false), + }, + }) + elms, err := c.OpenElementChan(ctx, dataID, instID, nil) + if err != nil { + t.Errorf("Unexpected error from OpenElementChan(%v, %v, nil): %v", dataID, instID, err) + } + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(3, true)}}) + + sum, count := drainAndSum(t, elms) + if wantSum, wantCount := 6, 3; sum != wantSum && count != wantCount { + t.Errorf("got sum %v, count %v, want sum %v, count %v", sum, count, wantSum, wantSum) + } + }) + + t.Run("AllDataThenReader", func(t *testing.T) { + ctx, client, c := setupClient(t) + + client.Send(&fnpb.Elements{ + Data: []*fnpb.Elements_Data{ + dataElm(1, false), + dataElm(2, false), + dataElm(3, true), + }, + }) + + elms, err := c.OpenElementChan(ctx, dataID, instID, nil) + if err != nil { + t.Errorf("Unexpected error from OpenElementChan(%v, %v, nil): %v", dataID, instID, err) + } + + sum, count := drainAndSum(t, elms) + if wantSum, wantCount := 6, 3; sum != wantSum && count != wantCount { + t.Errorf("got sum %v, count %v, want sum %v, count %v", sum, count, wantSum, wantSum) + } + }) + + t.Run("TimerThenReaderThenDataCloseThenLastTimer", func(t *testing.T) { + ctx, client, c := setupClient(t) + client.Send(&fnpb.Elements{ + Timers: []*fnpb.Elements_Timers{ + timerElm(1, false), + timerElm(2, false), + }, + }) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{noDataElm()}}) + + elms, err := c.OpenElementChan(ctx, dataID, instID, []string{timerID}) + if err != nil { + t.Errorf("Unexpected error from OpenElementChan(%v, %v, nil): %v", dataID, instID, err) + } + var wg sync.WaitGroup + wg.Add(1) + var sum, count int + go func() { + defer wg.Done() + sum, count = drainAndSum(t, elms) + }() + + client.Send(&fnpb.Elements{Timers: []*fnpb.Elements_Timers{timerElm(3, true)}}) + + wg.Wait() + if wantSum, wantCount := 6, 3; sum != wantSum && count != wantCount { + t.Errorf("got sum %v, count %v, want sum %v, count %v", sum, count, wantSum, wantSum) + } + }) + + t.Run("AllTimerThenReaderThenDataClose", func(t *testing.T) { + ctx, client, c := setupClient(t) + client.Send(&fnpb.Elements{ + Timers: []*fnpb.Elements_Timers{ + timerElm(1, false), + timerElm(2, false), + timerElm(3, true), + }, + }) + + elms, err := c.OpenElementChan(ctx, dataID, instID, []string{timerID}) + if err != nil { + t.Errorf("Unexpected error from OpenElementChan(%v, %v, nil): %v", dataID, instID, err) + } + var wg sync.WaitGroup + wg.Add(1) + var sum, count int + go func() { + defer wg.Done() + sum, count = drainAndSum(t, elms) + }() + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{noDataElm()}}) + + wg.Wait() + if wantSum, wantCount := 6, 3; sum != wantSum && count != wantCount { + t.Errorf("got sum %v, count %v, want sum %v, count %v", sum, count, wantSum, wantSum) + } + }) + t.Run("NoTimersThenReaderThenNoData", func(t *testing.T) { + ctx, client, c := setupClient(t) + client.Send(&fnpb.Elements{Timers: []*fnpb.Elements_Timers{noTimerElm()}}) + + elms, err := c.OpenElementChan(ctx, dataID, instID, []string{timerID}) + if err != nil { + t.Errorf("Unexpected error from OpenElementChan(%v, %v, nil): %v", dataID, instID, err) + } + var wg sync.WaitGroup + wg.Add(1) + var sum, count int + go func() { + defer wg.Done() + sum, count = drainAndSum(t, elms) + }() + + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{noDataElm()}}) + + wg.Wait() + if wantSum, wantCount := 0, 0; sum != wantSum && count != wantCount { + t.Errorf("got sum %v, count %v, want sum %v, count %v", sum, count, wantSum, wantSum) + } + }) + t.Run("SomeTimersThenReaderThenAData", func(t *testing.T) { + ctx, client, c := setupClient(t) + client.Send(&fnpb.Elements{Timers: []*fnpb.Elements_Timers{timerElm(1, false), timerElm(2, true)}}) + + elms, err := c.OpenElementChan(ctx, dataID, instID, []string{timerID}) + if err != nil { + t.Errorf("Unexpected error from OpenElementChan(%v, %v, nil): %v", dataID, instID, err) + } + var wg sync.WaitGroup + wg.Add(1) + var sum, count int + go func() { + defer wg.Done() + sum, count = drainAndSum(t, elms) + }() + + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(3, true)}}) + + wg.Wait() + if wantSum, wantCount := 6, 3; sum != wantSum && count != wantCount { + t.Errorf("got sum %v, count %v, want sum %v, count %v", sum, count, wantSum, wantSum) + } + }) + + t.Run("SomeTimersThenADataThenReader", func(t *testing.T) { + ctx, client, c := setupClient(t) + client.Send(&fnpb.Elements{Timers: []*fnpb.Elements_Timers{timerElm(1, false), timerElm(2, true)}}) + client.Send(&fnpb.Elements{Data: []*fnpb.Elements_Data{dataElm(3, true)}}) + + elms, err := c.OpenElementChan(ctx, dataID, instID, []string{timerID}) + if err != nil { + t.Errorf("Unexpected error from OpenElementChan(%v, %v, nil): %v", dataID, instID, err) + } + var wg sync.WaitGroup + wg.Add(1) + var sum, count int + go func() { + defer wg.Done() + sum, count = drainAndSum(t, elms) + }() + + wg.Wait() + if wantSum, wantCount := 6, 3; sum != wantSum && count != wantCount { + t.Errorf("got sum %v, count %v, want sum %v, count %v", sum, count, wantSum, wantSum) + } + }) +} + +func TestDataChannelTerminate_dataReader(t *testing.T) { + // The logging of channels closed is quite noisy for this test + log.SetOutput(io.Discard) + + expectedError := fmt.Errorf("EXPECTED ERROR") + + tests := []struct { + name string + expectedError error + caseFn func(t *testing.T, ch *elementsChan, client *fakeDataClient, c *DataChannel) + }{ + { + name: "onInstructionEnded", + expectedError: io.EOF, + caseFn: func(t *testing.T, ch *elementsChan, client *fakeDataClient, c *DataChannel) { + ch.InstructionEnded() + }, + }, { + name: "onSentinel", + expectedError: io.EOF, + caseFn: func(t *testing.T, ch *elementsChan, client *fakeDataClient, c *DataChannel) { + // fakeDataClient eventually returns a sentinel element. + }, + }, { + name: "onIsLast_withData", + expectedError: io.EOF, + caseFn: func(t *testing.T, ch *elementsChan, client *fakeDataClient, c *DataChannel) { + // Set the last call with data to use is_last. + client.isLastCall = 2 + }, + }, { + name: "onIsLast_withoutData", + expectedError: io.EOF, + caseFn: func(t *testing.T, ch *elementsChan, client *fakeDataClient, c *DataChannel) { + // Set the call without data to use is_last. + client.isLastCall = 3 + }, + }, { + name: "onRecvError", + expectedError: expectedError, + caseFn: func(t *testing.T, ch *elementsChan, client *fakeDataClient, c *DataChannel) { + // The SDK starts reading in a goroutine immeadiately after open. + // Set the 2nd Recv call to have an error. + client.err = expectedError + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + fmt.Println("TTTTT ", test.name) + done := make(chan bool, 1) + client := &fakeDataClient{t: t, done: done} + client.blocked.Lock() + ctx, cancelFn := context.WithCancel(context.Background()) + c := makeDataChannel(ctx, "id", client, cancelFn) + + elms, err := c.OpenElementChan(ctx, "ptr", "inst_ref", nil) + if err != nil { + t.Errorf("Unexpected error from OpenElementChan: %v", err) + } + ech := c.channels["inst_ref"] + + test.caseFn(t, ech, client, c) + client.blocked.Unlock() + // Drain channel + for range elms { + } + + // Verify that new readers return the same error on their reads after client.Recv is done. + if _, err = c.OpenElementChan(ctx, "ptr", "inst_ref", nil); !errors.Is(err, test.expectedError) { + t.Errorf("Unexpected error from read: got %v, want %v", err, test.expectedError) + } + + select { + case <-ctx.Done(): // Assert that the context must have been cancelled on read failures. + return + case <-time.After(time.Second * 5): + t.Fatal("context wasn't cancelled") + } + }) + } +} + +func TestDataChannelRemoveInstruction_dataAfterClose(t *testing.T) { + done := make(chan bool, 1) + client := &fakeDataClient{t: t, done: done} + client.blocked.Lock() + + ctx, cancelFn := context.WithCancel(context.Background()) + c := makeDataChannel(ctx, "id", client, cancelFn) + c.removeInstruction("inst_ref") + + client.blocked.Unlock() + + _, err := c.OpenElementChan(ctx, "ptr", "inst_ref", nil) + if err != nil { + t.Errorf("Unexpected error from read: %v,", err) + } +} + +func TestDataChannelRemoveInstruction_limitInstructionCap(t *testing.T) { + done := make(chan bool, 1) + client := &fakeDataClient{t: t, done: done} + ctx, cancelFn := context.WithCancel(context.Background()) + c := makeDataChannel(ctx, "id", client, cancelFn) + + for i := 0; i < endedInstructionCap+10; i++ { + instID := instructionID(fmt.Sprintf("inst_ref%d", i)) + c.OpenElementChan(ctx, "ptr", instID, nil) + c.removeInstruction(instID) + } + if got, want := len(c.endedInstructions), endedInstructionCap; got != want { + t.Errorf("unexpected len(endedInstructions) got %v, want %v,", got, want) + } +} + func TestDataChannelTerminate_Writes(t *testing.T) { // The logging of channels closed is quite noisy for this test log.SetOutput(io.Discard) diff --git a/sdks/go/pkg/beam/core/runtime/harness/harness.go b/sdks/go/pkg/beam/core/runtime/harness/harness.go index c260a46c80ee..d0025d3d9195 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/harness.go +++ b/sdks/go/pkg/beam/core/runtime/harness/harness.go @@ -408,7 +408,7 @@ func (c *control) handleInstruction(ctx context.Context, req *fnpb.InstructionRe sampler.stop() - data.Close() + dataError := data.Close() state.Close() c.cache.CompleteBundle(tokens...) @@ -422,6 +422,10 @@ func (c *control) handleInstruction(ctx context.Context, req *fnpb.InstructionRe // Mark the instruction as failed. if err != nil { c.failed[instID] = err + } else if dataError != io.EOF && dataError != nil { + // If there was an error on the data channel reads, fail this bundle + // since we may have had a short read. + c.failed[instID] = dataError } else { // Non failure plans should either be moved to the finalized state // or to plans so they can be re-used. @@ -578,7 +582,7 @@ func (c *control) handleInstruction(ctx context.Context, req *fnpb.InstructionRe } // Unsuccessful splits without errors indicate we should return an empty response, - // as processing can confinue. + // as processing can continue. if sr.Unsuccessful { return &fnpb.InstructionResponse{ InstructionId: string(instID),