Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions sdks/go/pkg/beam/core/funcx/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ const (
FnMultiMap FnParamKind = 0x200
// FnPane indicates a function input parameter that is a PaneInfo
FnPane FnParamKind = 0x400
// FnBundleFinalization indicates a function input parameter that implements typex.BundleFinalization.
FnBundleFinalization FnParamKind = 0x800
)

func (k FnParamKind) String() string {
Expand All @@ -104,6 +106,8 @@ func (k FnParamKind) String() string {
return "MultiMap"
case FnPane:
return "Pane"
case FnBundleFinalization:
return "BundleFinalization"
default:
return fmt.Sprintf("%v", int(k))
}
Expand Down Expand Up @@ -267,6 +271,17 @@ func (u *Fn) RTracker() (pos int, exists bool) {
return -1, false
}

// BundleFinalization returns (index, true) iff the function expects a
// parameter that implements typex.BundleFinalization.
func (u *Fn) BundleFinalization() (pos int, exists bool) {
for i, p := range u.Param {
if p.Kind == FnBundleFinalization {
return i, true
}
}
return -1, false
}

// Error returns (index, true) iff the function returns an error.
func (u *Fn) Error() (pos int, exists bool) {
for i, p := range u.Ret {
Expand Down Expand Up @@ -329,6 +344,8 @@ func New(fn reflectx.Func) (*Fn, error) {
kind = FnEventTime
case t.Implements(typex.WindowType):
kind = FnWindow
case t == typex.BundleFinalizationType:
kind = FnBundleFinalization
case t == reflectx.Type:
kind = FnType
case t.Implements(reflect.TypeOf((*sdf.RTracker)(nil)).Elem()):
Expand Down Expand Up @@ -415,7 +432,7 @@ func SubReturns(list []ReturnParam, indices ...int) []ReturnParam {
}

// The order of present parameters and return values must be as follows:
// func(FnContext?, FnPane?, FnWindow?, FnEventTime?, FnType?, FnRTracker?, (FnValue, SideInput*)?, FnEmit*) (RetEventTime?, RetOutput?, RetError?)
// func(FnContext?, FnPane?, FnWindow?, FnEventTime?, FnType?, FnBundleFinalization?, FnRTracker?, (FnValue, SideInput*)?, FnEmit*) (RetEventTime?, RetOutput?, RetError?)
// where ? indicates 0 or 1, and * indicates any number.
// and a SideInput is one of FnValue or FnIter or FnReIter
// Note: Fns with inputs must have at least one FnValue as the main input.
Expand All @@ -439,13 +456,14 @@ func validateOrder(u *Fn) error {
}

var (
errContextParam = errors.New("may only have a single context.Context parameter and it must be the first parameter")
errPaneParamPrecedence = errors.New("may only have a single PaneInfo parameter and it must precede the WindowParam, EventTime and main input parameter")
errWindowParamPrecedence = errors.New("may only have a single Window parameter and it must precede the EventTime and main input parameter")
errEventTimeParamPrecedence = errors.New("may only have a single beam.EventTime parameter and it must precede the main input parameter")
errReflectTypePrecedence = errors.New("may only have a single reflect.Type parameter and it must precede the main input parameter")
errRTrackerPrecedence = errors.New("may only have a single sdf.RTracker parameter and it must precede the main input parameter")
errInputPrecedence = errors.New("inputs parameters must precede emit function parameters")
errContextParam = errors.New("may only have a single context.Context parameter and it must be the first parameter")
errPaneParamPrecedence = errors.New("may only have a single PaneInfo parameter and it must precede the WindowParam, EventTime and main input parameter")
errWindowParamPrecedence = errors.New("may only have a single Window parameter and it must precede the EventTime and main input parameter")
errEventTimeParamPrecedence = errors.New("may only have a single beam.EventTime parameter and it must precede the main input parameter")
errReflectTypePrecedence = errors.New("may only have a single reflect.Type parameter and it must precede the main input parameter")
errRTrackerPrecedence = errors.New("may only have a single sdf.RTracker parameter and it must precede the main input parameter")
errBundleFinalizationPrecedence = errors.New("may only have a single BundleFinalization parameter and it must precede the main input parameter")
errInputPrecedence = errors.New("inputs parameters must precede emit function parameters")
)

type paramState int
Expand All @@ -460,6 +478,7 @@ const (
psInput
psOutput
psRTracker
psBundleFinalization
)

func nextParamState(cur paramState, transition FnParamKind) (paramState, error) {
Expand All @@ -476,6 +495,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psEventTime, nil
case FnType:
return psType, nil
case FnBundleFinalization:
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
}
Expand All @@ -489,6 +510,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psEventTime, nil
case FnType:
return psType, nil
case FnBundleFinalization:
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
}
Expand All @@ -500,6 +523,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psEventTime, nil
case FnType:
return psType, nil
case FnBundleFinalization:
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
}
Expand All @@ -509,17 +534,28 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psEventTime, nil
case FnType:
return psType, nil
case FnBundleFinalization:
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
}
case psEventTime:
switch transition {
case FnType:
return psType, nil
case FnBundleFinalization:
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
}
case psType:
switch transition {
case FnBundleFinalization:
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
}
case psBundleFinalization:
switch transition {
case FnRTracker:
return psRTracker, nil
Expand Down Expand Up @@ -549,6 +585,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return -1, errEventTimeParamPrecedence
case FnType:
return -1, errReflectTypePrecedence
case FnBundleFinalization:
return -1, errBundleFinalizationPrecedence
case FnRTracker:
return -1, errRTrackerPrecedence
case FnIter, FnReIter, FnValue, FnMultiMap:
Expand Down
59 changes: 59 additions & 0 deletions sdks/go/pkg/beam/core/funcx/fn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ func TestNew(t *testing.T) {
Fn: func(typex.PaneInfo, typex.Window, typex.EventTime, reflect.Type, []byte) {},
Param: []FnParamKind{FnPane, FnWindow, FnEventTime, FnType, FnValue},
},
{
Name: "good8",
Fn: func(typex.PaneInfo, typex.Window, typex.EventTime, reflect.Type, typex.BundleFinalization, []byte) {},
Param: []FnParamKind{FnPane, FnWindow, FnEventTime, FnType, FnBundleFinalization, FnValue},
},
{
Name: "good-method",
Fn: foo{1}.Do,
Expand Down Expand Up @@ -172,6 +177,11 @@ func TestNew(t *testing.T) {
},
Err: errReflectTypePrecedence,
},
{
Name: "errReflectTypePrecedence: after bundle finalizer",
Fn: func(typex.PaneInfo, typex.Window, typex.EventTime, typex.BundleFinalization, reflect.Type, []byte) {},
Err: errReflectTypePrecedence,
},
{
Name: "errInputPrecedence- Iter before after output",
Fn: func(int, func(int), func(*int) bool, func(*int, *string) bool) {},
Expand Down Expand Up @@ -201,6 +211,11 @@ func TestNew(t *testing.T) {
},
Err: errErrorPrecedence,
},
{
Name: "errBundleFinalizationPrecedence",
Fn: func(typex.PaneInfo, typex.Window, typex.EventTime, reflect.Type, []byte, typex.BundleFinalization) {},
Err: errBundleFinalizationPrecedence,
},
{
Name: "errEventTimeRetPrecedence",
Fn: func() (string, typex.EventTime) {
Expand Down Expand Up @@ -437,6 +452,50 @@ func TestWindow(t *testing.T) {
}
}

func TestBundleFinalization(t *testing.T) {
tests := []struct {
Name string
Params []FnParamKind
Pos int
Exists bool
}{
{
Name: "bundleFinalization input",
Params: []FnParamKind{FnContext, FnBundleFinalization},
Pos: 1,
Exists: true,
},
{
Name: "no bundleFinalization input",
Params: []FnParamKind{FnContext, FnEventTime},
Pos: -1,
Exists: false,
},
}

for _, test := range tests {
test := test
t.Run(test.Name, func(t *testing.T) {
// Create a Fn with a filled params list.
params := make([]FnParam, len(test.Params))
for i, kind := range test.Params {
params[i].Kind = kind
params[i].T = nil
}
fn := &Fn{Param: params}

// Validate we get expected results for pane function.
pos, exists := fn.BundleFinalization()
if exists != test.Exists {
t.Errorf("BundleFinalization(%v) - exists: got %v, want %v", params, exists, test.Exists)
}
if pos != test.Pos {
t.Errorf("BundleFinalization(%v) - pos: got %v, want %v", params, pos, test.Pos)
}
})
}
}

func TestInputs(t *testing.T) {
tests := []struct {
Name string
Expand Down
8 changes: 3 additions & 5 deletions sdks/go/pkg/beam/core/runtime/exec/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,9 @@ func newInvoker(fn *funcx.Fn) *invoker {
if n.outErrIdx, ok = fn.Error(); !ok {
n.outErrIdx = -1
}
// TODO(BEAM-10976) - add this back in once BundleFinalization is implemented
// if n.bfIdx, ok = fn.BundleFinalization(); !ok {
// n.bfIdx = -1
// }
n.bfIdx = -1
if n.bfIdx, ok = fn.BundleFinalization(); !ok {
n.bfIdx = -1
}

n.initCall()

Expand Down
24 changes: 24 additions & 0 deletions sdks/go/pkg/beam/core/runtime/exec/fn_arity.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion sdks/go/pkg/beam/core/runtime/exec/fn_arity.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
func (n *invoker) initCall() {
switch fn := n.fn.Fn.(type) {
{{range $out := upto 4}}
{{range $in := upto 8}}
{{range $in := upto 9}}
case reflectx.Func{{$in}}x{{$out}}:
n.call = func(pn typex.PaneInfo, ws []typex.Window, ts typex.EventTime) (*FullValue, error) {
{{if $out}}{{mktuplef $out "r%v"}} := {{end}}fn.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
Expand Down
4 changes: 4 additions & 0 deletions sdks/go/pkg/beam/core/runtime/graphx/serialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ func tryEncodeSpecial(t reflect.Type) (v1pb.Type_Special, bool) {
return v1pb.Type_EVENTTIME, true
case typex.WindowType:
return v1pb.Type_WINDOW, true
case typex.BundleFinalizationType:
return v1pb.Type_BUNDLEFINALIZATION, true
case typex.KVType:
return v1pb.Type_KV, true
case typex.CoGBKType:
Expand Down Expand Up @@ -677,6 +679,8 @@ func decodeSpecial(s v1pb.Type_Special) (reflect.Type, error) {
return typex.EventTimeType, nil
case v1pb.Type_WINDOW:
return typex.WindowType, nil
case v1pb.Type_BUNDLEFINALIZATION:
return typex.BundleFinalizationType, nil
case v1pb.Type_KV:
return typex.KVType, nil
case v1pb.Type_COGBK:
Expand Down
8 changes: 7 additions & 1 deletion sdks/go/pkg/beam/core/runtime/graphx/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package graphx
import (
"context"
"fmt"
"sort"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
Expand Down Expand Up @@ -65,7 +66,8 @@ const (
URNLegacyProgressReporting = "beam:protocol:progress_reporting:v0"
URNMultiCore = "beam:protocol:multi_core_bundle_processing:v1"

URNRequiresSplittableDoFn = "beam:requirement:pardo:splittable_dofn:v1"
URNRequiresSplittableDoFn = "beam:requirement:pardo:splittable_dofn:v1"
URNRequiresBundleFinalization = "beam:requirement:pardo:finalization:v1"

// Deprecated: Determine worker binary based on GoWorkerBinary Role instead.
URNArtifactGoWorker = "beam:artifact:type:go_worker_binary:v1"
Expand Down Expand Up @@ -221,6 +223,7 @@ func (m *marshaller) getRequirements() []string {
reqs = append(reqs, req)
}
}
sort.Strings(reqs)
return reqs
}

Expand Down Expand Up @@ -445,6 +448,9 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
payload.RestrictionCoderId = coderId
m.requirements[URNRequiresSplittableDoFn] = true
}
if _, ok := edge.Edge.DoFn.ProcessElementFn().BundleFinalization(); ok {
m.requirements[URNRequiresBundleFinalization] = true
}
spec = &pipepb.FunctionSpec{Urn: URNParDo, Payload: protox.MustEncode(payload)}
annotations = edge.Edge.DoFn.Annotations()

Expand Down
Loading