diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go index e58bb8f180ed..ba950de34697 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go +++ b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go @@ -21,6 +21,7 @@ import ( "io" "reflect" "sort" + "strings" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" @@ -108,12 +109,40 @@ func (h *runner) handleFlatten(tid string, t *pipepb.PTransform, comps *pipepb.C // they're written out to the runner in the same fashion. // This may stop being necessary once Flatten Unzipping happens in the optimizer. outPCol := comps.GetPcollections()[outColID] + outCoderID := outPCol.CoderId + outCoder := comps.GetCoders()[outCoderID] + coderSubs := map[string]*pipepb.Coder{} pcollSubs := map[string]*pipepb.PCollection{} + + if !strings.HasPrefix(outCoderID, "cf_") { + // Create a new coder id for the flatten output PCollection and use + // this coder id for all input PCollections + outCoderID = "cf_" + outColID + outCoder = proto.Clone(outCoder).(*pipepb.Coder) + coderSubs[outCoderID] = outCoder + + pcollSubs[outColID] = proto.Clone(outPCol).(*pipepb.PCollection) + pcollSubs[outColID].CoderId = outCoderID + + outPCol = pcollSubs[outColID] + } + for _, p := range t.GetInputs() { inPCol := comps.GetPcollections()[p] if inPCol.CoderId != outPCol.CoderId { - pcollSubs[p] = proto.Clone(inPCol).(*pipepb.PCollection) - pcollSubs[p].CoderId = outPCol.CoderId + if strings.HasPrefix(inPCol.CoderId, "cf_") { + // The input pcollection is the output of another flatten: + // e.g. [[a, b] | Flatten], c] | Flatten + // In this case, we just point the input coder id to the new flatten + // output coder, so any upstream input pcollections will use the new + // output coder. + coderSubs[inPCol.CoderId] = outCoder + } else { + // Create a substitute PCollection for this input with the flatten + // output coder id + pcollSubs[p] = proto.Clone(inPCol).(*pipepb.PCollection) + pcollSubs[p].CoderId = outPCol.CoderId + } } } @@ -125,6 +154,7 @@ func (h *runner) handleFlatten(tid string, t *pipepb.PTransform, comps *pipepb.C tid: t, }, Pcollections: pcollSubs, + Coders: coderSubs, }, RemovedLeaves: nil, ForcedRoots: forcedRoots,