diff --git a/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go new file mode 100644 index 000000000000..5496d8b81252 --- /dev/null +++ b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache.go @@ -0,0 +1,215 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package statecache implements the state caching feature described by the +// Beam Fn API +// +// The Beam State API and the intended caching behavior are described here: +// https://docs.google.com/document/d/1BOozW0bzBuz4oHJEuZNDOHdzaV5Y56ix58Ozrqm2jFg/edit#heading=h.7ghoih5aig5m +package statecache + +import ( + "sync" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" + "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" +) + +type token string + +// SideInputCache stores a cache of reusable inputs for the purposes of +// eliminating redundant calls to the runner during execution of ParDos +// using side inputs. +// +// A SideInputCache should be initialized when the SDK harness is initialized, +// creating storage for side input caching. On each ProcessBundleRequest, +// the cache will process the list of tokens for cacheable side inputs and +// be queried when side inputs are requested in bundle execution. Once a +// new bundle request comes in the valid tokens will be updated and the cache +// will be re-used. In the event that the cache reaches capacity, a random, +// currently invalid cached object will be evicted. +type SideInputCache struct { + capacity int + mu sync.Mutex + cache map[token]exec.ReusableInput + idsToTokens map[string]token + validTokens map[token]int8 // Maps tokens to active bundle counts + metrics CacheMetrics +} + +type CacheMetrics struct { + Hits int64 + Misses int64 + Evictions int64 + InUseEvictions int64 +} + +// Init makes the cache map and the map of IDs to cache tokens for the +// SideInputCache. Should only be called once. Returns an error for +// non-positive capacities. +func (c *SideInputCache) Init(cap int) error { + if cap <= 0 { + return errors.Errorf("capacity must be a positive integer, got %v", cap) + } + c.mu.Lock() + defer c.mu.Unlock() + c.cache = make(map[token]exec.ReusableInput, cap) + c.idsToTokens = make(map[string]token) + c.validTokens = make(map[token]int8) + c.capacity = cap + return nil +} + +// SetValidTokens clears the list of valid tokens then sets new ones, also updating the mapping of +// transform and side input IDs to cache tokens in the process. Should be called at the start of every +// new ProcessBundleRequest. If the runner does not support caching, the passed cache token values +// should be empty and all get/set requests will silently be no-ops. +func (c *SideInputCache) SetValidTokens(cacheTokens ...fnpb.ProcessBundleRequest_CacheToken) { + c.mu.Lock() + defer c.mu.Unlock() + for _, tok := range cacheTokens { + // User State caching is currently not supported, so these tokens are ignored + if tok.GetUserState() != nil { + continue + } + s := tok.GetSideInput() + transformID := s.GetTransformId() + sideInputID := s.GetSideInputId() + t := token(tok.GetToken()) + c.setValidToken(transformID, sideInputID, t) + } +} + +// setValidToken adds a new valid token for a request into the SideInputCache struct +// by mapping the transform ID and side input ID pairing to the cache token. +func (c *SideInputCache) setValidToken(transformID, sideInputID string, tok token) { + idKey := transformID + sideInputID + c.idsToTokens[idKey] = tok + count, ok := c.validTokens[tok] + if !ok { + c.validTokens[tok] = 1 + } else { + c.validTokens[tok] = count + 1 + } +} + +// CompleteBundle takes the cache tokens passed to set the valid tokens and decrements their +// usage count for the purposes of maintaining a valid count of whether or not a value is +// still in use. Should be called once ProcessBundle has completed. +func (c *SideInputCache) CompleteBundle(cacheTokens ...fnpb.ProcessBundleRequest_CacheToken) { + c.mu.Lock() + defer c.mu.Unlock() + for _, tok := range cacheTokens { + // User State caching is currently not supported, so these tokens are ignored + if tok.GetUserState() != nil { + continue + } + t := token(tok.GetToken()) + c.decrementTokenCount(t) + } +} + +// decrementTokenCount decrements the validTokens entry for +// a given token by 1. Should only be called when completing +// a bundle. +func (c *SideInputCache) decrementTokenCount(tok token) { + count := c.validTokens[tok] + if count == 1 { + delete(c.validTokens, tok) + } else { + c.validTokens[tok] = count - 1 + } +} + +func (c *SideInputCache) makeAndValidateToken(transformID, sideInputID string) (token, bool) { + idKey := transformID + sideInputID + // Check if it's a known token + tok, ok := c.idsToTokens[idKey] + if !ok { + return "", false + } + return tok, c.isValid(tok) +} + +// QueryCache takes a transform ID and side input ID and checking if a corresponding side +// input has been cached. A query having a bad token (e.g. one that doesn't make a known +// token or one that makes a known but currently invalid token) is treated the same as a +// cache miss. +func (c *SideInputCache) QueryCache(transformID, sideInputID string) exec.ReusableInput { + c.mu.Lock() + defer c.mu.Unlock() + tok, ok := c.makeAndValidateToken(transformID, sideInputID) + if !ok { + return nil + } + // Check to see if cached + input, ok := c.cache[tok] + if !ok { + c.metrics.Misses++ + return nil + } + + c.metrics.Hits++ + return input +} + +// SetCache allows a user to place a ReusableInput materialized from the reader into the SideInputCache +// with its corresponding transform ID and side input ID. If the IDs do not pair with a known, valid token +// then we silently do not cache the input, as this is an indication that the runner is treating that input +// as uncacheable. +func (c *SideInputCache) SetCache(transformID, sideInputID string, input exec.ReusableInput) { + c.mu.Lock() + defer c.mu.Unlock() + tok, ok := c.makeAndValidateToken(transformID, sideInputID) + if !ok { + return + } + if len(c.cache) >= c.capacity { + c.evictElement() + } + c.cache[tok] = input +} + +func (c *SideInputCache) isValid(tok token) bool { + count, ok := c.validTokens[tok] + // If the token is not known or not in use, return false + return ok && count > 0 +} + +// evictElement randomly evicts a ReusableInput that is not currently valid from the cache. +// It should only be called by a goroutine that obtained the lock in SetCache. +func (c *SideInputCache) evictElement() { + deleted := false + // Select a key from the cache at random + for k := range c.cache { + // Do not evict an element if it's currently valid + if !c.isValid(k) { + delete(c.cache, k) + c.metrics.Evictions++ + deleted = true + break + } + } + // Nothing is deleted if every side input is still valid. Clear + // out a random entry and record the in-use eviction + if !deleted { + for k := range c.cache { + delete(c.cache, k) + c.metrics.InUseEvictions++ + break + } + } +} diff --git a/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go new file mode 100644 index 000000000000..b9970c398154 --- /dev/null +++ b/sdks/go/pkg/beam/core/runtime/harness/statecache/statecache_test.go @@ -0,0 +1,290 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package statecache + +import ( + "testing" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" +) + +// TestReusableInput implements the ReusableInput interface for the purposes +// of testing. +type TestReusableInput struct { + transformID string + sideInputID string + value interface{} +} + +func makeTestReusableInput(transformID, sideInputID string, value interface{}) exec.ReusableInput { + return &TestReusableInput{transformID: transformID, sideInputID: sideInputID, value: value} +} + +// Init is a ReusableInput interface method, this is a no-op. +func (r *TestReusableInput) Init() error { + return nil +} + +// Value returns the stored value in the TestReusableInput. +func (r *TestReusableInput) Value() interface{} { + return r.value +} + +// Reset clears the value in the TestReusableInput. +func (r *TestReusableInput) Reset() error { + r.value = nil + return nil +} + +func TestInit(t *testing.T) { + var s SideInputCache + err := s.Init(5) + if err != nil { + t.Errorf("SideInputCache failed but should have succeeded, got %v", err) + } +} + +func TestInit_Bad(t *testing.T) { + var s SideInputCache + err := s.Init(0) + if err == nil { + t.Error("SideInputCache init succeeded but should have failed") + } +} + +func TestQueryCache_EmptyCase(t *testing.T) { + var s SideInputCache + err := s.Init(1) + if err != nil { + t.Fatalf("cache init failed, got %v", err) + } + output := s.QueryCache("side1", "transform1") + if output != nil { + t.Errorf("Cache hit when it should have missed, got %v", output) + } +} + +func TestSetCache_UncacheableCase(t *testing.T) { + var s SideInputCache + err := s.Init(1) + if err != nil { + t.Fatalf("cache init failed, got %v", err) + } + input := makeTestReusableInput("t1", "s1", 10) + s.SetCache("t1", "s1", input) + output := s.QueryCache("t1", "s1") + if output != nil { + t.Errorf("Cache hit when should have missed, got %v", output) + } +} + +func TestSetCache_CacheableCase(t *testing.T) { + var s SideInputCache + err := s.Init(1) + if err != nil { + t.Fatalf("cache init failed, got %v", err) + } + transID := "t1" + sideID := "s1" + tok := token("tok1") + s.setValidToken(transID, sideID, tok) + input := makeTestReusableInput(transID, sideID, 10) + s.SetCache(transID, sideID, input) + output := s.QueryCache(transID, sideID) + if output == nil { + t.Fatalf("call to query cache missed when should have hit") + } + val, ok := output.Value().(int) + if !ok { + t.Errorf("failed to convert value to integer, got %v", output.Value()) + } + if val != 10 { + t.Errorf("element mismatch, expected 10, got %v", val) + } +} + +func makeRequest(transformID, sideInputID string, t token) fnpb.ProcessBundleRequest_CacheToken { + var tok fnpb.ProcessBundleRequest_CacheToken + var wrap fnpb.ProcessBundleRequest_CacheToken_SideInput_ + var side fnpb.ProcessBundleRequest_CacheToken_SideInput + side.TransformId = transformID + side.SideInputId = sideInputID + wrap.SideInput = &side + tok.Type = &wrap + tok.Token = []byte(t) + return tok +} + +func TestSetValidTokens(t *testing.T) { + inputs := []struct { + transformID string + sideInputID string + tok token + }{ + { + "t1", + "s1", + "tok1", + }, + { + "t2", + "s2", + "tok2", + }, + { + "t3", + "s3", + "tok3", + }, + } + + var s SideInputCache + err := s.Init(3) + if err != nil { + t.Fatalf("cache init failed, got %v", err) + } + + var tokens []fnpb.ProcessBundleRequest_CacheToken + for _, input := range inputs { + t := makeRequest(input.transformID, input.sideInputID, input.tok) + tokens = append(tokens, t) + } + + s.SetValidTokens(tokens...) + if len(s.idsToTokens) != len(inputs) { + t.Errorf("Missing tokens, expected %v, got %v", len(inputs), len(s.idsToTokens)) + } + + for i, input := range inputs { + // Check that the token is in the valid list + if !s.isValid(input.tok) { + t.Errorf("error in input %v, token %v is not valid", i, input.tok) + } + // Check that the mapping of IDs to tokens is correct + mapped := s.idsToTokens[input.transformID+input.sideInputID] + if mapped != input.tok { + t.Errorf("token mismatch for input %v, expected %v, got %v", i, input.tok, mapped) + } + } +} + +func TestSetValidTokens_ClearingBetween(t *testing.T) { + inputs := []struct { + transformID string + sideInputID string + tk token + }{ + { + "t1", + "s1", + "tok1", + }, + { + "t2", + "s2", + "tok2", + }, + { + "t3", + "s3", + "tok3", + }, + } + + var s SideInputCache + err := s.Init(1) + if err != nil { + t.Fatalf("cache init failed, got %v", err) + } + + for i, input := range inputs { + tok := makeRequest(input.transformID, input.sideInputID, input.tk) + + s.SetValidTokens(tok) + + // Check that the token is in the valid list + if !s.isValid(input.tk) { + t.Errorf("error in input %v, token %v is not valid", i, input.tk) + } + // Check that the mapping of IDs to tokens is correct + mapped := s.idsToTokens[input.transformID+input.sideInputID] + if mapped != input.tk { + t.Errorf("token mismatch for input %v, expected %v, got %v", i, input.tk, mapped) + } + + s.CompleteBundle(tok) + } + + for k, _ := range s.validTokens { + if s.validTokens[k] != 0 { + t.Errorf("token count mismatch for token %v, expected 0, got %v", k, s.validTokens[k]) + } + } +} + +func TestSetCache_Eviction(t *testing.T) { + var s SideInputCache + err := s.Init(1) + if err != nil { + t.Fatalf("cache init failed, got %v", err) + } + + tokOne := makeRequest("t1", "s1", "tok1") + inOne := makeTestReusableInput("t1", "s1", 10) + s.SetValidTokens(tokOne) + s.SetCache("t1", "s1", inOne) + // Mark bundle as complete, drop count for tokOne to 0 + s.CompleteBundle(tokOne) + + tokTwo := makeRequest("t2", "s2", "tok2") + inTwo := makeTestReusableInput("t2", "s2", 20) + s.SetValidTokens(tokTwo) + s.SetCache("t2", "s2", inTwo) + + if len(s.cache) != 1 { + t.Errorf("cache size incorrect, expected 1, got %v", len(s.cache)) + } + if s.metrics.Evictions != 1 { + t.Errorf("number evictions incorrect, expected 1, got %v", s.metrics.Evictions) + } +} + +func TestSetCache_EvictionFailure(t *testing.T) { + var s SideInputCache + err := s.Init(1) + if err != nil { + t.Fatalf("cache init failed, got %v", err) + } + + tokOne := makeRequest("t1", "s1", "tok1") + inOne := makeTestReusableInput("t1", "s1", 10) + + tokTwo := makeRequest("t2", "s2", "tok2") + inTwo := makeTestReusableInput("t2", "s2", 20) + + s.SetValidTokens(tokOne, tokTwo) + s.SetCache("t1", "s1", inOne) + // Should fail to evict because the first token is still valid + s.SetCache("t2", "s2", inTwo) + // Cache should not exceed size 1 + if len(s.cache) != 1 { + t.Errorf("cache size incorrect, expected 1, got %v", len(s.cache)) + } + if s.metrics.InUseEvictions != 1 { + t.Errorf("number of failed evicition calls incorrect, expected 1, got %v", s.metrics.InUseEvictions) + } +}