Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions sdks/go/pkg/beam/runners/prism/internal/handlepardo.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (h *pardo) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb
ckvERSID: coder(urns.CoderKV, ckvERID, cSID),
}

// PCollections only have two new ones.
// There are only two new PCollections.
// INPUT -> same as ordinary DoFn
// PWR, uses ckvER
// SPLITnSIZED, uses ckvERS
Expand All @@ -201,23 +201,27 @@ func (h *pardo) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb
nSPLITnSIZEDID: pcol(nSPLITnSIZEDID, ckvERSID),
}

// PTransforms have 3 new ones, with process sized elements and restrictions
// There are 3 new PTransforms, with process sized elements and restrictions
// taking the brunt of the complexity, consuming the inputs

ePWRID := "e" + tid + "_pwr"
eSPLITnSIZEDID := "e" + tid + "_splitnsize"
eProcessID := "e" + tid + "_processandsplit"

tform := func(name, urn, in, out string) *pipepb.PTransform {
// Apparently we also send side inputs to PairWithRestriction
// and SplitAndSize. We should consider wether we could simply
// drop the side inputs from the ParDo payload instead, which
// could lead to an additional fusion oppportunity.
newInputs := maps.Clone(t.GetInputs())
newInputs[inputLocalID] = in
return &pipepb.PTransform{
UniqueName: name,
Spec: &pipepb.FunctionSpec{
Urn: urn,
Payload: pardoPayload,
},
Inputs: map[string]string{
inputLocalID: in,
},
Inputs: newInputs,
Outputs: map[string]string{
"i0": out,
},
Expand Down
16 changes: 14 additions & 2 deletions sdks/go/pkg/beam/runners/prism/internal/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package internal
import (
"fmt"
"sort"
"strings"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
Expand Down Expand Up @@ -438,7 +439,8 @@ func finalizeStage(stg *stage, comps *pipepb.Components, pipelineFacts *fusionFa
t := comps.GetTransforms()[link.Transform]

var sis map[string]*pipepb.SideInput
if t.GetSpec().GetUrn() == urns.TransformParDo {
switch t.GetSpec().GetUrn() {
case urns.TransformParDo, urns.TransformProcessSizedElements, urns.TransformPairWithRestriction, urns.TransformSplitAndSize, urns.TransformTruncate:
pardo := &pipepb.ParDoPayload{}
if err := (proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != nil {
return fmt.Errorf("unable to decode ParDoPayload for %v", link.Transform)
Expand Down Expand Up @@ -485,7 +487,17 @@ func finalizeStage(stg *stage, comps *pipepb.Components, pipelineFacts *fusionFa
// Quick check that this is lead by a flatten node, and that it's handled runner side.
t := comps.GetTransforms()[stg.transforms[0]]
if !(t.GetSpec().GetUrn() == urns.TransformFlatten && t.GetEnvironmentId() == "") {
return fmt.Errorf("expected runner flatten node, but wasn't: %v -- %v", stg.transforms, mainInputs)
formatMap := func(in map[string]string) string {
var b strings.Builder
for k, v := range in {
b.WriteString(k)
b.WriteString(" : ")
b.WriteString(v)
b.WriteString("\n\t")
}
return b.String()
}
return fmt.Errorf("stage requires multiple parallel inputs but wasn't a flatten:\n\ttransforms\n\t%v\n\tmain inputs\n\t%v\n\tsidinputs\n\t%v", strings.Join(stg.transforms, "\n\t\t"), formatMap(mainInputs), sideInputs)
}
}
return nil
Expand Down
22 changes: 14 additions & 8 deletions sdks/go/pkg/beam/runners/prism/internal/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,11 @@ progress:
}

func getSideInputs(t *pipepb.PTransform) (map[string]*pipepb.SideInput, error) {
if t.GetSpec().GetUrn() != urns.TransformParDo {
switch t.GetSpec().GetUrn() {
case urns.TransformParDo, urns.TransformProcessSizedElements, urns.TransformPairWithRestriction, urns.TransformSplitAndSize, urns.TransformTruncate:
// Intentionally empty since these are permitted to have side inputs.
default:
// Nothing else is allowed to have side inputs.
return nil, nil
}
// TODO, memoize this, so we don't need to repeatedly unmarshal.
Expand Down Expand Up @@ -334,6 +338,7 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng
return col
}

// Update coders for Stateful transforms.
for _, tid := range stg.transforms {
t := comps.GetTransforms()[tid]

Expand Down Expand Up @@ -461,10 +466,11 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng
}
// Update side inputs to point to new PCollection with any replaced coders.
transforms[si.Transform].GetInputs()[si.Local] = newGlobal
// TODO: replace si.Global with newGlobal?
}
prepSide, err := handleSideInput(si, comps, coders, em)
prepSide, err := handleSideInput(si, comps, transforms, pcollections, coders, em)
if err != nil {
slog.Error("buildDescriptor: handleSideInputs", err, slog.String("transformID", si.Transform))
slog.Error("buildDescriptor: handleSideInputs", "error", err, slog.String("transformID", si.Transform))
return err
}
prepareSides = append(prepareSides, prepSide)
Expand Down Expand Up @@ -556,8 +562,8 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *eng
}

// handleSideInput returns a closure that will look up the data for a side input appropriate for the given watermark.
func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[string]*pipepb.Coder, em *engine.ElementManager) (func(b *worker.B, watermark mtime.Time), error) {
t := comps.GetTransforms()[link.Transform]
func handleSideInput(link engine.LinkID, comps *pipepb.Components, transforms map[string]*pipepb.PTransform, pcols map[string]*pipepb.PCollection, coders map[string]*pipepb.Coder, em *engine.ElementManager) (func(b *worker.B, watermark mtime.Time), error) {
t := transforms[link.Transform]
sis, err := getSideInputs(t)
if err != nil {
return nil, err
Expand All @@ -570,7 +576,7 @@ func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[st
slog.String("local", link.Local),
slog.String("global", link.Global))

col := comps.GetPcollections()[link.Global]
col := pcols[link.Global]
// The returned coders are unused here, but they add the side input coders
// to the stage components for use SDK side.

Expand All @@ -594,7 +600,7 @@ func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[st
slog.String("sourceTransform", t.GetUniqueName()),
slog.String("local", link.Local),
slog.String("global", link.Global))
col := comps.GetPcollections()[link.Global]
col := pcols[link.Global]

kvc := comps.GetCoders()[col.GetCoderId()]
if kvc.GetSpec().GetUrn() != urns.CoderKV {
Expand Down Expand Up @@ -633,7 +639,7 @@ func handleSideInput(link engine.LinkID, comps *pipepb.Components, coders map[st
}] = windowed
}, nil
default:
return nil, fmt.Errorf("local input %v (global %v) uses accesspattern %v", link.Local, link.Global, si.GetAccessPattern().GetUrn())
return nil, fmt.Errorf("local input %v (global %v) uses accesspattern %v", link.Local, link.Global, prototext.Format(si.GetAccessPattern()))
}
}

Expand Down
2 changes: 1 addition & 1 deletion sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error {
// TODO: move data handling to be pcollection based.

key := req.GetStateKey()
slog.Debug("StateRequest_Get", prototext.Format(req), "bundle", b)
slog.Debug("StateRequest_Get", "request", prototext.Format(req), "bundle", b)
var data [][]byte
switch key.GetType().(type) {
case *fnpb.StateKey_IterableSideInput_:
Expand Down