From d4c28f0969554b39c34d21ae2ff9b0721be62999 Mon Sep 17 00:00:00 2001 From: lostluck <13907733+lostluck@users.noreply.github.com> Date: Mon, 13 Nov 2023 15:51:26 -0800 Subject: [PATCH] Garbage collect side inputs. --- .../pkg/beam/core/runtime/exec/translate.go | 2 +- sdks/go/pkg/beam/core/runtime/graphx/coder.go | 14 +- .../prism/internal/engine/elementmanager.go | 82 ++++++++-- .../internal/engine/elementmanager_test.go | 24 +-- .../beam/runners/prism/internal/execute.go | 16 +- .../runners/prism/internal/execute_test.go | 1 + .../beam/runners/prism/internal/preprocess.go | 5 +- .../pkg/beam/runners/prism/internal/stage.go | 153 +++++++----------- .../runners/prism/internal/worker/bundle.go | 10 +- .../runners/prism/internal/worker/worker.go | 33 +--- .../prism/internal/worker/worker_test.go | 10 +- 11 files changed, 174 insertions(+), 176 deletions(-) diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go b/sdks/go/pkg/beam/core/runtime/exec/translate.go index 4f078092a310..6b3e3e457229 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/translate.go +++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go @@ -365,7 +365,7 @@ func (b *builder) makeCoderForPCollection(id string) (*coder.Coder, *coder.Windo } wc, err := b.coders.WindowCoder(ws.GetWindowCoderId()) if err != nil { - return nil, nil, err + return nil, nil, errors.Errorf("could not unmarshal window coder for pcollection %v: %w", id, err) } return c, wc, nil } diff --git a/sdks/go/pkg/beam/core/runtime/graphx/coder.go b/sdks/go/pkg/beam/core/runtime/graphx/coder.go index 34b44dd85920..87b3771e5756 100644 --- a/sdks/go/pkg/beam/core/runtime/graphx/coder.go +++ b/sdks/go/pkg/beam/core/runtime/graphx/coder.go @@ -162,7 +162,7 @@ func (b *CoderUnmarshaller) WindowCoder(id string) (*coder.WindowCoder, error) { c, err := b.peek(id) if err != nil { - return nil, err + return nil, errors.Errorf("could not unmarshal window coder: %w", err) } w, err := urnToWindowCoder(c.GetSpec().GetUrn()) @@ -218,7 +218,7 @@ func (b *CoderUnmarshaller) makeCoder(id string, c *pipepb.Coder) (*coder.Coder, id := components[1] elm, err := b.peek(id) if err != nil { - return nil, err + return nil, errors.Errorf("could not unmarshal kv coder value component: %w", err) } switch elm.GetSpec().GetUrn() { @@ -261,7 +261,7 @@ func (b *CoderUnmarshaller) makeCoder(id string, c *pipepb.Coder) (*coder.Coder, sub, err := b.peek(components[0]) if err != nil { - return nil, err + return nil, errors.Errorf("could not unmarshal length prefix coder component: %w", err) } // No payload means this coder was length prefixed by the runner @@ -307,7 +307,7 @@ func (b *CoderUnmarshaller) makeCoder(id string, c *pipepb.Coder) (*coder.Coder, } w, err := b.WindowCoder(components[1]) if err != nil { - return nil, err + return nil, errors.Errorf("could not unmarshal window coder: %w", err) } t := typex.New(typex.WindowedValueType, elm.T) wvc := &coder.Coder{Kind: coder.WindowedValue, T: t, Components: []*coder.Coder{elm}, Window: w} @@ -356,7 +356,7 @@ func (b *CoderUnmarshaller) makeCoder(id string, c *pipepb.Coder) (*coder.Coder, } w, err := b.WindowCoder(components[1]) if err != nil { - return nil, err + return nil, errors.Errorf("could not unmarshal window coder for timer: %w", err) } return coder.NewT(elm, w), nil case urnRowCoder: @@ -389,7 +389,7 @@ func (b *CoderUnmarshaller) makeCoder(id string, c *pipepb.Coder) (*coder.Coder, case urnGlobalWindow: w, err := b.WindowCoder(id) if err != nil { - return nil, err + return nil, errors.Errorf("could not unmarshal global window coder: %w", err) } return &coder.Coder{Kind: coder.Window, T: typex.New(reflect.TypeOf((*struct{})(nil)).Elem()), Window: w}, nil default: @@ -400,7 +400,7 @@ func (b *CoderUnmarshaller) makeCoder(id string, c *pipepb.Coder) (*coder.Coder, func (b *CoderUnmarshaller) peek(id string) (*pipepb.Coder, error) { c, ok := b.models[id] if !ok { - return nil, errors.Errorf("coder with id %v not found", id) + return nil, errors.Errorf("(peek) coder with id %v not found", id) } return c, nil } diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go index df53bce8ac57..b6b8bae6031f 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -120,7 +120,7 @@ type ElementManager struct { stages map[string]*stageState // The state for each stage. consumers map[string][]string // Map from pcollectionID to stageIDs that consumes them as primary input. - sideConsumers map[string][]string // Map from pcollectionID to stageIDs that consumes them as side input. + sideConsumers map[string][]LinkID // Map from pcollectionID to the stage+transform+input that consumes them as side input. pcolParents map[string]string // Map from pcollectionID to stageIDs that produce the pcollection. @@ -131,12 +131,17 @@ type ElementManager struct { pendingElements sync.WaitGroup // pendingElements counts all unprocessed elements in a job. Jobs with no pending elements terminate successfully. } +// LinkID represents a fully qualified input or output. +type LinkID struct { + Transform, Local, Global string +} + func NewElementManager(config Config) *ElementManager { return &ElementManager{ config: config, stages: map[string]*stageState{}, consumers: map[string][]string{}, - sideConsumers: map[string][]string{}, + sideConsumers: map[string][]LinkID{}, pcolParents: map[string]string{}, watermarkRefreshes: set[string]{}, inprogressBundles: set[string]{}, @@ -146,9 +151,9 @@ func NewElementManager(config Config) *ElementManager { // AddStage adds a stage to this element manager, connecting it's PCollections and // nodes to the watermark propagation graph. -func (em *ElementManager) AddStage(ID string, inputIDs, sides, outputIDs []string) { +func (em *ElementManager) AddStage(ID string, inputIDs, outputIDs []string, sides []LinkID) { slog.Debug("AddStage", slog.String("ID", ID), slog.Any("inputs", inputIDs), slog.Any("sides", sides), slog.Any("outputs", outputIDs)) - ss := makeStageState(ID, inputIDs, sides, outputIDs) + ss := makeStageState(ID, inputIDs, outputIDs, sides) em.stages[ss.ID] = ss for _, outputIDs := range ss.outputIDs { @@ -158,7 +163,9 @@ func (em *ElementManager) AddStage(ID string, inputIDs, sides, outputIDs []strin em.consumers[input] = append(em.consumers[input], ss.ID) } for _, side := range ss.sides { - em.sideConsumers[side] = append(em.sideConsumers[side], ss.ID) + // Note that we use the StageID as the global ID in the value since we need + // to be able to look up the consuming stage, from the global PCollectionID. + em.sideConsumers[side.Global] = append(em.sideConsumers[side.Global], LinkID{Global: ss.ID, Local: side.Local, Transform: side.Transform}) } } @@ -363,6 +370,11 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol consumer := em.stages[sID] consumer.AddPending(newPending) } + sideConsumers := em.sideConsumers[output] + for _, link := range sideConsumers { + consumer := em.stages[link.Global] + consumer.AddPendingSide(newPending, link.Transform, link.Local) + } } // Return unprocessed to this stage's pending @@ -489,7 +501,7 @@ type stageState struct { ID string inputID string // PCollection ID of the parallel input outputIDs []string // PCollection IDs of outputs to update consumers. - sides []string // PCollection IDs of side inputs that can block execution. + sides []LinkID // PCollection IDs of side inputs that can block execution. // Special handling bits aggregate bool // whether this state needs to block for aggregation. @@ -501,12 +513,13 @@ type stageState struct { output mtime.Time // Output watermark for the whole stage estimatedOutput mtime.Time // Estimated watermark output from DoFns - pending elementHeap // pending input elements for this stage that are to be processesd - inprogress map[string]elements // inprogress elements by active bundles, keyed by bundle + pending elementHeap // pending input elements for this stage that are to be processesd + inprogress map[string]elements // inprogress elements by active bundles, keyed by bundle + sideInputs map[LinkID]map[typex.Window][][]byte // side input data for this stage, from {tid, inputID} -> window } // makeStageState produces an initialized stageState. -func makeStageState(ID string, inputIDs, sides, outputIDs []string) *stageState { +func makeStageState(ID string, inputIDs, outputIDs []string, sides []LinkID) *stageState { ss := &stageState{ ID: ID, outputIDs: outputIDs, @@ -536,6 +549,42 @@ func (ss *stageState) AddPending(newPending []element) { heap.Init(&ss.pending) } +// AddPendingSide adds elements to be consumed as side inputs. +func (ss *stageState) AddPendingSide(newPending []element, tID, inputID string) { + ss.mu.Lock() + defer ss.mu.Unlock() + if ss.sideInputs == nil { + ss.sideInputs = map[LinkID]map[typex.Window][][]byte{} + } + key := LinkID{Transform: tID, Local: inputID} + in, ok := ss.sideInputs[key] + if !ok { + in = map[typex.Window][][]byte{} + ss.sideInputs[key] = in + } + for _, e := range newPending { + in[e.window] = append(in[e.window], e.elmBytes) + } +} + +func (ss *stageState) GetSideData(tID, inputID string, watermark mtime.Time) map[typex.Window][][]byte { + ss.mu.Lock() + defer ss.mu.Unlock() + + d := ss.sideInputs[LinkID{Transform: tID, Local: inputID}] + ret := map[typex.Window][][]byte{} + for win, ds := range d { + if win.MaxTimestamp() <= watermark { + ret[win] = ds + } + } + return ret +} + +func (em *ElementManager) GetSideData(sID, tID, inputID string, watermark mtime.Time) map[typex.Window][][]byte { + return em.stages[sID].GetSideData(tID, inputID, watermark) +} + // updateUpstreamWatermark is for the parent of the input pcollection // to call, to update downstream stages with it's current watermark. // This avoids downstream stages inverting lock orderings from @@ -699,7 +748,18 @@ func (ss *stageState) updateWatermarks(minPending, minStateHold mtime.Time, em * } // Inform side input consumers, but don't update the upstream watermark. for _, sID := range em.sideConsumers[outputCol] { - refreshes.insert(sID) + refreshes.insert(sID.Global) + } + } + // Garbage collect state, timers and side inputs, for all windows + // that are before the new output watermark. + // They'll never be read in again. + for _, wins := range ss.sideInputs { + for win := range wins { + // Clear out anything we've already used. + if win.MaxTimestamp() < newOut { + delete(wins, win) + } } } } @@ -725,7 +785,7 @@ func (ss *stageState) bundleReady(em *ElementManager) (mtime.Time, bool) { } ready := true for _, side := range ss.sides { - pID, ok := em.pcolParents[side] + pID, ok := em.pcolParents[side.Global] if !ok { panic(fmt.Sprintf("stage[%v] no parent ID for side input %v", ss.ID, side)) } diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go index ddfdd5b8816c..0005ca8ed881 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go @@ -61,7 +61,7 @@ func TestElementHeap(t *testing.T) { func TestStageState_minPendingTimestamp(t *testing.T) { newState := func() *stageState { - return makeStageState("test", []string{"testInput"}, nil, []string{"testOutput"}) + return makeStageState("test", []string{"testInput"}, []string{"testOutput"}, nil) } t.Run("noElements", func(t *testing.T) { ss := newState() @@ -188,13 +188,13 @@ func TestStageState_minPendingTimestamp(t *testing.T) { } func TestStageState_UpstreamWatermark(t *testing.T) { - impulse := makeStageState("impulse", nil, nil, []string{"output"}) + impulse := makeStageState("impulse", nil, []string{"output"}, nil) _, up := impulse.UpstreamWatermark() if got, want := up, mtime.MaxTimestamp; got != want { t.Errorf("impulse.UpstreamWatermark() = %v, want %v", got, want) } - dofn := makeStageState("dofn", []string{"input"}, nil, []string{"output"}) + dofn := makeStageState("dofn", []string{"input"}, []string{"output"}, nil) dofn.updateUpstreamWatermark("input", 42) _, up = dofn.UpstreamWatermark() @@ -202,7 +202,7 @@ func TestStageState_UpstreamWatermark(t *testing.T) { t.Errorf("dofn.UpstreamWatermark() = %v, want %v", got, want) } - flatten := makeStageState("flatten", []string{"a", "b", "c"}, nil, []string{"output"}) + flatten := makeStageState("flatten", []string{"a", "b", "c"}, []string{"output"}, nil) flatten.updateUpstreamWatermark("a", 50) flatten.updateUpstreamWatermark("b", 42) flatten.updateUpstreamWatermark("c", 101) @@ -216,7 +216,7 @@ func TestStageState_updateWatermarks(t *testing.T) { inputCol := "testInput" outputCol := "testOutput" newState := func() (*stageState, *stageState, *ElementManager) { - underTest := makeStageState("underTest", []string{inputCol}, nil, []string{outputCol}) + underTest := makeStageState("underTest", []string{inputCol}, []string{outputCol}, nil) outStage := makeStageState("outStage", []string{outputCol}, nil, nil) em := &ElementManager{ consumers: map[string][]string{ @@ -315,7 +315,7 @@ func TestStageState_updateWatermarks(t *testing.T) { func TestElementManager(t *testing.T) { t.Run("impulse", func(t *testing.T) { em := NewElementManager(Config{}) - em.AddStage("impulse", nil, nil, []string{"output"}) + em.AddStage("impulse", nil, []string{"output"}, nil) em.AddStage("dofn", []string{"output"}, nil, nil) em.Impulse("impulse") @@ -370,8 +370,8 @@ func TestElementManager(t *testing.T) { t.Run("dofn", func(t *testing.T) { em := NewElementManager(Config{}) - em.AddStage("impulse", nil, nil, []string{"input"}) - em.AddStage("dofn1", []string{"input"}, nil, []string{"output"}) + em.AddStage("impulse", nil, []string{"input"}, nil) + em.AddStage("dofn1", []string{"input"}, []string{"output"}, nil) em.AddStage("dofn2", []string{"output"}, nil, nil) em.Impulse("impulse") @@ -421,9 +421,9 @@ func TestElementManager(t *testing.T) { t.Run("side", func(t *testing.T) { em := NewElementManager(Config{}) - em.AddStage("impulse", nil, nil, []string{"input"}) - em.AddStage("dofn1", []string{"input"}, nil, []string{"output"}) - em.AddStage("dofn2", []string{"input"}, []string{"output"}, nil) + em.AddStage("impulse", nil, []string{"input"}, nil) + em.AddStage("dofn1", []string{"input"}, []string{"output"}, nil) + em.AddStage("dofn2", []string{"input"}, nil, []LinkID{{Transform: "dofn2", Global: "output", Local: "local"}}) em.Impulse("impulse") var i int @@ -472,7 +472,7 @@ func TestElementManager(t *testing.T) { }) t.Run("residual", func(t *testing.T) { em := NewElementManager(Config{}) - em.AddStage("impulse", nil, nil, []string{"input"}) + em.AddStage("impulse", nil, []string{"input"}, nil) em.AddStage("dofn", []string{"input"}, nil, nil) em.Impulse("impulse") diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index 5e07e161dd5c..6b5ec872de17 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -155,10 +155,6 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic stages := map[string]*stage{} var impulses []string - // Inialize the "dataservice cache" to support side inputs. - // TODO(https://github.com/apache/beam/issues/28543), remove this concept. - ds := &worker.DataService{} - for i, stage := range topo { tid := stage.transforms[0] t := ts[tid] @@ -206,7 +202,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic switch urn { case urns.TransformGBK: - em.AddStage(stage.ID, []string{getOnlyValue(t.GetInputs())}, nil, []string{getOnlyValue(t.GetOutputs())}) + em.AddStage(stage.ID, []string{getOnlyValue(t.GetInputs())}, []string{getOnlyValue(t.GetOutputs())}, nil) for _, global := range t.GetInputs() { col := comps.GetPcollections()[global] ed := collectionPullDecoder(col.GetCoderId(), coders, comps) @@ -221,22 +217,22 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic em.StageAggregates(stage.ID) case urns.TransformImpulse: impulses = append(impulses, stage.ID) - em.AddStage(stage.ID, nil, nil, []string{getOnlyValue(t.GetOutputs())}) + em.AddStage(stage.ID, nil, []string{getOnlyValue(t.GetOutputs())}, nil) case urns.TransformFlatten: inputs := maps.Values(t.GetInputs()) sort.Strings(inputs) - em.AddStage(stage.ID, inputs, nil, []string{getOnlyValue(t.GetOutputs())}) + em.AddStage(stage.ID, inputs, []string{getOnlyValue(t.GetOutputs())}, nil) } stages[stage.ID] = stage case wk.Env: - if err := buildDescriptor(stage, comps, wk, ds); err != nil { + if err := buildDescriptor(stage, comps, wk, em); err != nil { return fmt.Errorf("prism error building stage %v: \n%w", stage.ID, err) } stages[stage.ID] = stage slog.Debug("pipelineBuild", slog.Group("stage", slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName()))) outputs := maps.Keys(stage.OutputsToCoders) sort.Strings(outputs) - em.AddStage(stage.ID, []string{stage.primaryInput}, stage.sides, outputs) + em.AddStage(stage.ID, []string{stage.primaryInput}, outputs, stage.sideInputs) default: err := fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId()) slog.Error("Execute", err) @@ -273,7 +269,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic defer func() { <-maxParallelism }() s := stages[rb.StageID] wk := wks[s.envID] - if err := s.Execute(ctx, j, wk, ds, comps, em, rb); err != nil { + if err := s.Execute(ctx, j, wk, comps, em, rb); err != nil { // Ensure we clean up on bundle failure em.FailBundle(rb) bundleFailed <- err diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go index ce821bef8985..fe3da83c67e2 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go @@ -65,6 +65,7 @@ func execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error) } func executeWithT(ctx context.Context, t testing.TB, p *beam.Pipeline) (beam.PipelineResult, error) { + t.Helper() t.Log("startingTest - ", t.Name()) s1 := rand.NewSource(time.Now().UnixNano()) r1 := rand.New(s1) diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go index fb244cb4fbbb..7d7fcf93a9d0 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go +++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go @@ -21,6 +21,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -392,7 +393,7 @@ func prepareStage(stg *stage, comps *pipepb.Components, pipelineFacts fusionFact // Now we can see which consumers (inputs) aren't covered by the producers (outputs). mainInputs := map[string]string{} - var sideInputs []link + var sideInputs []engine.LinkID inputs := map[string]bool{} for pid, plinks := range stageFacts.pcolConsumers { // Check if this PCollection is generated in this bundle. @@ -406,7 +407,7 @@ func prepareStage(stg *stage, comps *pipepb.Components, pipelineFacts fusionFact t := comps.GetTransforms()[link.transform] sis, _ := getSideInputs(t) if _, ok := sis[link.local]; ok { - sideInputs = append(sideInputs, link) + sideInputs = append(sideInputs, engine.LinkID{Transform: link.transform, Global: link.global, Local: link.local}) } else { mainInputs[link.global] = link.global } diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index 4925405bb4ef..c8439ad3bdad 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -19,11 +19,9 @@ import ( "bytes" "context" "fmt" - "io" "time" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" - "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" @@ -58,24 +56,23 @@ type link struct { type stage struct { ID string transforms []string - primaryInput string // PCollection used as the parallel input. - outputs []link // PCollections that must escape this stage. - sideInputs []link // Non-parallel input PCollections and their consumers - internalCols []string // PCollections that escape. Used for precise coder sending. + primaryInput string // PCollection used as the parallel input. + outputs []link // PCollections that must escape this stage. + sideInputs []engine.LinkID // Non-parallel input PCollections and their consumers + internalCols []string // PCollections that escape. Used for precise coder sending. envID string exe transformExecuter inputTransformID string inputInfo engine.PColInfo desc *fnpb.ProcessBundleDescriptor - sides []string - prepareSides func(b *worker.B, tid string, watermark mtime.Time) + prepareSides func(b *worker.B, watermark mtime.Time) SinkToPCollection map[string]string OutputsToCoders map[string]engine.PColInfo } -func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, ds *worker.DataService, comps *pipepb.Components, em *engine.ElementManager, rb engine.RunBundle) error { +func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, comps *pipepb.Components, em *engine.ElementManager, rb engine.RunBundle) error { slog.Debug("Execute: starting bundle", "bundle", rb) var b *worker.B @@ -113,7 +110,7 @@ func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, d } b.Init() - s.prepareSides(b, s.transforms[0], rb.Watermark) + s.prepareSides(b, rb.Watermark) slog.Debug("Execute: processing", "bundle", rb) defer b.Cleanup(wk) @@ -215,8 +212,6 @@ progress: md := wk.MonitoringMetadata(ctx, unknownIDs) j.AddMetricShortIDs(md) } - // TODO(https://github.com/apache/beam/issues/28543) handle side input data properly. - ds.Commit(b.OutputData) var residualData [][]byte var minOutputWatermark map[string]mtime.Time for _, rr := range resp.GetResidualRoots() { @@ -281,7 +276,7 @@ func portFor(wInCid string, wk *worker.W) []byte { // It assumes that the side inputs are not sourced from PCollections generated by any transform in this stage. // // Because we need the local ids for routing the sources/sinks information. -func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, ds *worker.DataService) error { +func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *engine.ElementManager) error { // Assume stage has an indicated primary input coders := map[string]*pipepb.Coder{} @@ -316,21 +311,17 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, ds *wor transforms[sinkID] = sinkTransform(sinkID, portFor(wOutCid, wk), o.global) } - // Then lets do Side Inputs, since they are also uniform. - var sides []string var prepareSides []func(b *worker.B, watermark mtime.Time) for _, si := range stg.sideInputs { - col := comps.GetPcollections()[si.global] + col := comps.GetPcollections()[si.Global] oCID := col.GetCoderId() nCID, err := lpUnknownCoders(oCID, coders, comps.GetCoders()) if err != nil { - return fmt.Errorf("buildDescriptor: failed to handle coder on stage %v for side input %+v, pcol %q %v:\n%w", stg.ID, si, si.global, prototext.Format(col), err) + return fmt.Errorf("buildDescriptor: failed to handle coder on stage %v for side input %+v, pcol %q %v:\n%w", stg.ID, si, si.Global, prototext.Format(col), err) } - - sides = append(sides, si.global) if oCID != nCID { // Add a synthetic PCollection set with the new coder. - newGlobal := si.global + "_prismside" + newGlobal := si.Global + "_prismside" comps.GetPcollections()[newGlobal] = &pipepb.PCollection{ DisplayData: col.GetDisplayData(), UniqueName: col.GetUniqueName(), @@ -339,11 +330,11 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, ds *wor WindowingStrategyId: col.WindowingStrategyId, } // Update side inputs to point to new PCollection with any replaced coders. - transforms[si.transform].GetInputs()[si.local] = newGlobal + transforms[si.Transform].GetInputs()[si.Local] = newGlobal } - prepSide, err := handleSideInput(si.transform, si.local, si.global, comps, coders, ds) + prepSide, err := handleSideInput(si, comps, coders, em) if err != nil { - slog.Error("buildDescriptor: handleSideInputs", err, slog.String("transformID", si.transform)) + slog.Error("buildDescriptor: handleSideInputs", err, slog.String("transformID", si.Transform)) return err } prepareSides = append(prepareSides, prepSide) @@ -391,12 +382,11 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, ds *wor } stg.desc = desc - stg.prepareSides = func(b *worker.B, _ string, watermark mtime.Time) { + stg.prepareSides = func(b *worker.B, watermark mtime.Time) { for _, prep := range prepareSides { prep(b, watermark) } } - stg.sides = sides // List of the global pcollection IDs this stage needs to wait on for side inputs. stg.SinkToPCollection = sink2Col stg.OutputsToCoders = col2Coders stg.inputInfo = inputInfo @@ -406,48 +396,45 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, ds *wor } // handleSideInput returns a closure that will look up the data for a side input appropriate for the given watermark. -func handleSideInput(tid, local, global string, comps *pipepb.Components, coders map[string]*pipepb.Coder, ds *worker.DataService) (func(b *worker.B, watermark mtime.Time), error) { - t := comps.GetTransforms()[tid] +func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[string]*pipepb.Coder, em *engine.ElementManager) (func(b *worker.B, watermark mtime.Time), error) { + t := comps.GetTransforms()[link.Transform] sis, err := getSideInputs(t) if err != nil { return nil, err } - switch si := sis[local]; si.GetAccessPattern().GetUrn() { + switch si := sis[link.Local]; si.GetAccessPattern().GetUrn() { case urns.SideInputIterable: slog.Debug("urnSideInputIterable", slog.String("sourceTransform", t.GetUniqueName()), - slog.String("local", local), - slog.String("global", global)) - col := comps.GetPcollections()[global] - ed := collectionPullDecoder(col.GetCoderId(), coders, comps) - wDec, wEnc := getWindowValueCoders(comps, col, coders) - // May be of zero length, but that's OK. Side inputs can be empty. + slog.String("local", link.Local), + slog.String("global", link.Global)) - global, local := global, local - return func(b *worker.B, watermark mtime.Time) { - data := ds.GetAllData(global) + col := comps.GetPcollections()[link.Global] + // The returned coders are unused here, but they add the side input coders + // to the stage components for use SDK side. + collectionPullDecoder(col.GetCoderId(), coders, comps) + getWindowValueCoders(comps, col, coders) + // May be of zero length, but that's OK. Side inputs can be emp + return func(b *worker.B, watermark mtime.Time) { + // May be of zero length, but that's OK. Side inputs can be empty. + data := em.GetSideData(b.PBDID, link.Transform, link.Local, watermark) if b.IterableSideInputData == nil { - b.IterableSideInputData = map[string]map[string]map[typex.Window][][]byte{} + b.IterableSideInputData = map[worker.SideInputKey]map[typex.Window][][]byte{} } - if _, ok := b.IterableSideInputData[tid]; !ok { - b.IterableSideInputData[tid] = map[string]map[typex.Window][][]byte{} - } - b.IterableSideInputData[tid][local] = collateByWindows(data, watermark, wDec, wEnc, - func(r io.Reader) [][]byte { - return [][]byte{ed(r)} - }, func(a, b [][]byte) [][]byte { - return append(a, b...) - }) + b.IterableSideInputData[worker.SideInputKey{ + TransformID: link.Transform, + Local: link.Local, + }] = data }, nil case urns.SideInputMultiMap: slog.Debug("urnSideInputMultiMap", slog.String("sourceTransform", t.GetUniqueName()), - slog.String("local", local), - slog.String("global", global)) - col := comps.GetPcollections()[global] + slog.String("local", link.Local), + slog.String("global", link.Global)) + col := comps.GetPcollections()[link.Global] kvc := comps.GetCoders()[col.GetCoderId()] if kvc.GetSpec().GetUrn() != urns.CoderKV { @@ -456,36 +443,37 @@ func handleSideInput(tid, local, global string, comps *pipepb.Components, coders kd := collectionPullDecoder(kvc.GetComponentCoderIds()[0], coders, comps) vd := collectionPullDecoder(kvc.GetComponentCoderIds()[1], coders, comps) - wDec, wEnc := getWindowValueCoders(comps, col, coders) - global, local := global, local + // The returned coders are unused here, but they add the side input coders + // to the stage components for use SDK side. + getWindowValueCoders(comps, col, coders) return func(b *worker.B, watermark mtime.Time) { // May be of zero length, but that's OK. Side inputs can be empty. - data := ds.GetAllData(global) + data := em.GetSideData(b.PBDID, link.Transform, link.Local, watermark) if b.MultiMapSideInputData == nil { - b.MultiMapSideInputData = map[string]map[string]map[typex.Window]map[string][][]byte{} - } - if _, ok := b.MultiMapSideInputData[tid]; !ok { - b.MultiMapSideInputData[tid] = map[string]map[typex.Window]map[string][][]byte{} + b.MultiMapSideInputData = map[worker.SideInputKey]map[typex.Window]map[string][][]byte{} } - b.MultiMapSideInputData[tid][local] = collateByWindows(data, watermark, wDec, wEnc, - func(r io.Reader) map[string][][]byte { + + windowed := map[typex.Window]map[string][][]byte{} + for win, ds := range data { + if len(ds) == 0 { + continue + } + byKey := map[string][][]byte{} + for _, datum := range ds { + r := bytes.NewBuffer(datum) kb := kd(r) - return map[string][][]byte{ - string(kb): {vd(r)}, - } - }, func(a, b map[string][][]byte) map[string][][]byte { - if len(a) == 0 { - return b - } - for k, vs := range b { - a[k] = append(a[k], vs...) - } - return a - }) + byKey[string(kb)] = append(byKey[string(kb)], vd(r)) + } + windowed[win] = byKey + } + b.MultiMapSideInputData[worker.SideInputKey{ + TransformID: link.Transform, + Local: link.Local, + }] = windowed }, nil default: - return nil, fmt.Errorf("local input %v (global %v) uses accesspattern %v", local, global, si.GetAccessPattern().GetUrn()) + return nil, fmt.Errorf("local input %v (global %v) uses accesspattern %v", link.Local, link.Global, si.GetAccessPattern().GetUrn()) } } @@ -516,24 +504,3 @@ func sinkTransform(sinkID string, sinkPortBytes []byte, inPID string) *pipepb.PT } return source } - -// collateByWindows takes the data and collates them into window keyed maps. -// Uses generics to consolidate the repetitive window loops. -func collateByWindows[T any](data [][]byte, watermark mtime.Time, wDec exec.WindowDecoder, wEnc exec.WindowEncoder, ed func(io.Reader) T, join func(T, T) T) map[typex.Window]T { - windowed := map[typex.Window]T{} - for _, datum := range data { - inBuf := bytes.NewBuffer(datum) - for { - ws, _, _, err := exec.DecodeWindowedValueHeader(wDec, inBuf) - if err == io.EOF { - break - } - // Get the element out, and window them properly. - e := ed(inBuf) - for _, w := range ws { - windowed[w] = join(windowed[w], e) - } - } - } - return windowed -} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go index fab8cbc141f0..97250092940d 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go @@ -27,6 +27,11 @@ import ( "golang.org/x/exp/slog" ) +// SideInputKey is for data lookups for a given bundle. +type SideInputKey struct { + TransformID, Local string +} + // B represents an extant ProcessBundle instruction sent to an SDK worker. // Generally manipulated by another package to interact with a worker. type B struct { @@ -37,11 +42,10 @@ type B struct { InputTransformID string InputData [][]byte // Data specifically for this bundle. - // TODO change to a single map[tid] -> map[input] -> map[window] -> struct { Iter data, MultiMap data } instead of all maps. // IterableSideInputData is a map from transformID, to inputID, to window, to data. - IterableSideInputData map[string]map[string]map[typex.Window][][]byte + IterableSideInputData map[SideInputKey]map[typex.Window][][]byte // MultiMapSideInputData is a map from transformID, to inputID, to window, to data key, to data values. - MultiMapSideInputData map[string]map[string]map[typex.Window]map[string][][]byte + MultiMapSideInputData map[SideInputKey]map[typex.Window]map[string][][]byte // OutputCount is the number of data or timer outputs this bundle has. // We need to see this many closed data channels before the bundle is complete. diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index 4968c9eb433e..beee5e896ffc 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -36,7 +36,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" - "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/slog" "google.golang.org/grpc" @@ -442,7 +441,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { panic(fmt.Sprintf("error decoding iterable side input window key %v: %v", wKey, err)) } } - winMap := b.IterableSideInputData[ikey.GetTransformId()][ikey.GetSideInputId()] + winMap := b.IterableSideInputData[SideInputKey{TransformID: ikey.GetTransformId(), Local: ikey.GetSideInputId()}] var wins []typex.Window for w := range winMap { wins = append(wins, w) @@ -463,7 +462,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { } } dKey := mmkey.GetKey() - winMap := b.MultiMapSideInputData[mmkey.GetTransformId()][mmkey.GetSideInputId()] + winMap := b.MultiMapSideInputData[SideInputKey{TransformID: mmkey.GetTransformId(), Local: mmkey.GetSideInputId()}] var wins []typex.Window for w := range winMap { wins = append(wins, w) @@ -562,31 +561,3 @@ func (wk *W) MonitoringMetadata(ctx context.Context, unknownIDs []string) *fnpb. }, }).GetMonitoringInfos() } - -// DataService is slated to be deleted in favour of stage based state -// management for side inputs. -// TODO(https://github.com/apache/beam/issues/28543), remove this concept. -type DataService struct { - mu sync.Mutex - // TODO actually quick process the data to windows here as well. - raw map[string][][]byte -} - -// Commit tentative data to the datastore. -func (d *DataService) Commit(tent engine.TentativeData) { - d.mu.Lock() - defer d.mu.Unlock() - if d.raw == nil { - d.raw = map[string][][]byte{} - } - for colID, data := range tent.Raw { - d.raw[colID] = append(d.raw[colID], data...) - } -} - -// GetAllData is a hack for Side Inputs until watermarks are sorted out. -func (d *DataService) GetAllData(colID string) [][]byte { - d.mu.Lock() - defer d.mu.Unlock() - return d.raw[colID] -} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go index 6a90b463c45d..c45d33016832 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go @@ -233,12 +233,10 @@ func TestWorker_State_Iterable(t *testing.T) { instID := wk.NextInst() wk.activeInstructions[instID] = &B{ - IterableSideInputData: map[string]map[string]map[typex.Window][][]byte{ - "transformID": { - "i1": { - window.GlobalWindow{}: [][]byte{ - {42}, - }, + IterableSideInputData: map[SideInputKey]map[typex.Window][][]byte{ + {TransformID: "transformID", Local: "i1"}: { + window.GlobalWindow{}: [][]byte{ + {42}, }, }, },