diff --git a/CHANGES.md b/CHANGES.md index db88b8c79807..acefb2c9f503 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -97,6 +97,8 @@ * (Python) Fixed Java YAML provider fails on Windows ([#35617](https://github.com/apache/beam/issues/35617)). * Fixed BigQueryIO creating temporary datasets in wrong project when temp_dataset is specified with a different project than the pipeline project. For some jobs, temporary datasets will now be created in the correct project (Python) ([#35813](https://github.com/apache/beam/issues/35813)). +* (Go) Fix duplicates due to reads after blind writes to Bag State ([#35869](https://github.com/apache/beam/issues/35869)). + * Earlier Go SDK versions can avoid the issue by not reading in the same call after a blind write. ## Known Issues diff --git a/sdks/go/pkg/beam/core/runtime/exec/userstate.go b/sdks/go/pkg/beam/core/runtime/exec/userstate.go index f83aee4bf741..ea723b18e3a7 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/userstate.go +++ b/sdks/go/pkg/beam/core/runtime/exec/userstate.go @@ -35,17 +35,18 @@ type stateProvider struct { elementKey []byte window []byte - transactionsByKey map[string][]state.Transaction - initialValueByKey map[string]any - initialBagByKey map[string][]any - initialMapValuesByKey map[string]map[string]any - initialMapKeysByKey map[string][]any - readersByKey map[string]io.ReadCloser - appendersByKey map[string]io.Writer - clearersByKey map[string]io.Writer - codersByKey map[string]*coder.Coder - keyCodersByID map[string]*coder.Coder - combineFnsByKey map[string]*graph.CombineFn + transactionsByKey map[string][]state.Transaction + initialValueByKey map[string]any + initialBagByKey map[string][]any + blindBagWriteCountsByKey map[string]int // Tracks blind writes to bags before a read. + initialMapValuesByKey map[string]map[string]any + initialMapKeysByKey map[string][]any + readersByKey map[string]io.ReadCloser + appendersByKey map[string]io.Writer + clearersByKey map[string]io.Writer + codersByKey map[string]*coder.Coder + keyCodersByID map[string]*coder.Coder + combineFnsByKey map[string]*graph.CombineFn } // ReadValueState reads a value state from the State API @@ -148,6 +149,12 @@ func (s *stateProvider) ReadBagState(userStateID string) ([]any, []state.Transac if !ok { transactions = []state.Transaction{} } + // If there were blind writes before this read, trim the transactions. + // These don't need to be reset, unless a clear happens. + if s.blindBagWriteCountsByKey[userStateID] > 0 { + // Trim blind writes from the transaction queue, to avoid re-applying them. + transactions = transactions[s.blindBagWriteCountsByKey[userStateID]:] + } return initialValue, transactions, nil } @@ -165,12 +172,17 @@ func (s *stateProvider) ClearBagState(val state.Transaction) error { // Any transactions before a clear don't matter s.transactionsByKey[val.Key] = []state.Transaction{val} + s.blindBagWriteCountsByKey[val.Key] = 1 // To account for the clear. return nil } // WriteBagState writes a bag state to the State API func (s *stateProvider) WriteBagState(val state.Transaction) error { + _, ok := s.initialBagByKey[val.Key] + if !ok { + s.blindBagWriteCountsByKey[val.Key]++ + } ap, err := s.getBagAppender(val.Key) if err != nil { return err @@ -510,22 +522,23 @@ func (s *userStateAdapter) NewStateProvider(ctx context.Context, reader StateRea return stateProvider{}, err } sp := stateProvider{ - ctx: ctx, - sr: reader, - SID: s.sid, - elementKey: elementKey, - window: win, - transactionsByKey: make(map[string][]state.Transaction), - initialValueByKey: make(map[string]any), - initialBagByKey: make(map[string][]any), - initialMapValuesByKey: make(map[string]map[string]any), - initialMapKeysByKey: make(map[string][]any), - readersByKey: make(map[string]io.ReadCloser), - appendersByKey: make(map[string]io.Writer), - clearersByKey: make(map[string]io.Writer), - combineFnsByKey: s.stateIDToCombineFn, - codersByKey: s.stateIDToCoder, - keyCodersByID: s.stateIDToKeyCoder, + ctx: ctx, + sr: reader, + SID: s.sid, + elementKey: elementKey, + window: win, + transactionsByKey: make(map[string][]state.Transaction), + initialValueByKey: make(map[string]any), + initialBagByKey: make(map[string][]any), + blindBagWriteCountsByKey: make(map[string]int), + initialMapValuesByKey: make(map[string]map[string]any), + initialMapKeysByKey: make(map[string][]any), + readersByKey: make(map[string]io.ReadCloser), + appendersByKey: make(map[string]io.Writer), + clearersByKey: make(map[string]io.Writer), + combineFnsByKey: s.stateIDToCombineFn, + codersByKey: s.stateIDToCoder, + keyCodersByID: s.stateIDToKeyCoder, } return sp, nil diff --git a/sdks/go/test/integration/primitives/state.go b/sdks/go/test/integration/primitives/state.go index acf1bf8fa665..6b672acc27bd 100644 --- a/sdks/go/test/integration/primitives/state.go +++ b/sdks/go/test/integration/primitives/state.go @@ -34,6 +34,7 @@ func init() { register.DoFn3x1[state.Provider, string, int, string](&valueStateClearFn{}) register.DoFn3x1[state.Provider, string, int, string](&bagStateFn{}) register.DoFn3x1[state.Provider, string, int, string](&bagStateClearFn{}) + register.DoFn3x1[state.Provider, string, int, string](&bagStateBlindWriteFn{}) register.DoFn3x1[state.Provider, string, int, string](&combiningStateFn{}) register.DoFn3x1[state.Provider, string, int, string](&mapStateFn{}) register.DoFn3x1[state.Provider, string, int, string](&mapStateClearFn{}) @@ -211,6 +212,45 @@ func BagStateParDoClear(s beam.Scope) { passert.Equals(s, counts, "apple: 0", "pear: 0", "apple: 1", "apple: 2", "pear: 1", "apple: 3", "apple: 0", "pear: 2", "pear: 3", "pear: 0", "apple: 1", "pear: 1") } +type bagStateBlindWriteFn struct { + State1 state.Bag[int] +} + +func (f *bagStateBlindWriteFn) ProcessElement(s state.Provider, w string, c int) string { + err := f.State1.Add(s, 1) + if err != nil { + panic(err) + } + i, ok, err := f.State1.Read(s) + if err != nil { + panic(err) + } + if !ok { + i = []int{} + } + sum := 0 + for _, val := range i { + sum += val + } + + // Bonus "non-blind" write + err = f.State1.Add(s, 1) + if err != nil { + panic(err) + } + + return fmt.Sprintf("%s: %v", w, sum) +} + +// BagStateBlindWriteParDo tests a DoFn that uses bag state, but performs a +// blind write to the state before reading. +func BagStateBlindWriteParDo(s beam.Scope) { + in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear") + keyed := beam.ParDo(s, pairWithOne, in) + counts := beam.ParDo(s, &bagStateBlindWriteFn{}, keyed) + passert.Equals(s, counts, "apple: 1", "pear: 1", "peach: 1", "apple: 3", "apple: 5", "pear: 3") +} + type combiningStateFn struct { State0 state.Combining[int, int, int] State1 state.Combining[int, int, int] diff --git a/sdks/go/test/integration/primitives/state_test.go b/sdks/go/test/integration/primitives/state_test.go index 79cb8c1839fc..1d1d4860e8f9 100644 --- a/sdks/go/test/integration/primitives/state_test.go +++ b/sdks/go/test/integration/primitives/state_test.go @@ -47,6 +47,11 @@ func TestBagStateClear(t *testing.T) { ptest.BuildAndRun(t, BagStateParDoClear) } +func TestBagStateBlindWrite(t *testing.T) { + integration.CheckFilters(t) + ptest.BuildAndRun(t, BagStateBlindWriteParDo) +} + func TestCombiningState(t *testing.T) { integration.CheckFilters(t) ptest.BuildAndRun(t, CombiningStateParDo)