diff --git a/table/cached_generic.go b/table/cached_generic.go index cacb5af..a14e31c 100644 --- a/table/cached_generic.go +++ b/table/cached_generic.go @@ -7,6 +7,7 @@ import ( "context" "log" "reflect" + "sync" "github.com/go-core-stack/core/db" "github.com/go-core-stack/core/errors" @@ -25,8 +26,9 @@ import ( // E: Entry type (must NOT be a pointer type) type CachedTable[K comparable, E any] struct { reconciler.ManagerImpl - cache map[K]*E - col db.StoreCollection + cacheMu sync.RWMutex + cache map[K]*E + col db.StoreCollection } // Initialize sets up the Table with the provided db.StoreCollection. @@ -40,6 +42,10 @@ func (t *CachedTable[K, E]) Initialize(col db.StoreCollection) error { return errors.Wrapf(errors.AlreadyExists, "Table is already initialized") } + if t.cache == nil { + t.cache = map[K]*E{} + } + var e E if reflect.TypeOf(e).Kind() == reflect.Pointer { return errors.Wrapf(errors.InvalidArgument, "Table entry type must not be a pointer") @@ -68,7 +74,27 @@ func (t *CachedTable[K, E]) Initialize(col db.StoreCollection) error { } t.col = col - t.cache = map[K]*E{} + + list := []keyOnly[K]{} + err = t.col.FindMany(context.Background(), nil, &list) + if err != nil { + log.Panicf("got error while fetching all keys %s", err) + } + for _, k := range list { + entry, err := t.DBFind(context.Background(), &k.Key) + if err != nil { + // this should not happen in regular scenarios + // log and return from here + log.Printf("failed to find an entry, got error: %s", err) + } else { + func() { + t.cacheMu.Lock() + defer t.cacheMu.Unlock() + t.cache[k.Key] = entry + }() + } + } + return nil } @@ -88,7 +114,11 @@ func (t *CachedTable[K, E]) callback(op string, wKey any) { log.Printf("failed to find an entry, got error: %s", err) } } else { - t.cache[*key] = entry + func() { + t.cacheMu.Lock() + defer t.cacheMu.Unlock() + t.cache[*key] = entry + }() } } t.NotifyCallback(wKey) @@ -139,6 +169,8 @@ func (t *CachedTable[K, E]) Update(ctx context.Context, key *K, entry *E) error // Find retrieves an entry by key from the Cache // Returns the entry and error if not found or if the table is not initialized. func (t *CachedTable[K, E]) Find(ctx context.Context, key *K) (*E, error) { + t.cacheMu.RLock() + defer t.cacheMu.RUnlock() entry, ok := t.cache[*key] if !ok { return nil, errors.Wrapf(errors.NotFound, "failed to find entry with key %v", key) diff --git a/table/cached_generic_test.go b/table/cached_generic_test.go index 78b23ca..7a25ea0 100644 --- a/table/cached_generic_test.go +++ b/table/cached_generic_test.go @@ -32,6 +32,8 @@ type MyTable struct { var ( myTable *MyTable + + myTable2 *MyTable ) func clientInit() { @@ -68,6 +70,40 @@ func clientInit() { } } +func clientInitTable2() { + if myTable2 != nil { + return + } + myTable2 = &MyTable{} + + config := &db.MongoConfig{ + Host: "localhost", + Port: "27017", + Username: "root", + Password: "password", + } + + client, err := db.NewMongoClient(config) + + if err != nil { + log.Panicf("failed to connect to mongo DB Error: %s", err) + } + + err = client.HealthCheck(context.Background()) + if err != nil { + log.Panicf("failed to perform Health check with DB Error: %s", err) + } + + s := client.GetDataStore("test") + + col := s.GetCollection("my-cached-table") + + err = myTable2.Initialize(col) + if err != nil { + log.Panicf("failed to initialize cached table") + } +} + func Test_CachedClient(t *testing.T) { clientInit() t.Run("push_and_find_entries", func(t *testing.T) { @@ -132,6 +168,17 @@ func Test_CachedClient(t *testing.T) { } } + clientInitTable2() + time.Sleep(1 * time.Second) + entry, err = myTable2.Find(ctx, key2) + if err != nil { + t.Errorf("failed to find the inserted entry from the table, got error: %s", err) + } else { + if entry.Desc != "sample-description-2" { + t.Errorf("expected sample-description-2, but got %s", entry.Desc) + } + } + count, err := myTable.col.DeleteMany(ctx, bson.D{}) if err != nil { t.Errorf("failed to delete the entries from table, got error %s", err)