diff --git a/db/mongo.go b/db/mongo.go index feeb198..39b28ca 100644 --- a/db/mongo.go +++ b/db/mongo.go @@ -181,9 +181,23 @@ func (c *mongoCollection) DeleteMany(ctx context.Context, filter interface{}) (i // watch allows getting notified whenever a change happens to a document // in the collection -func (c *mongoCollection) Watch(ctx context.Context, cb WatchCallbackfn) error { +// allow provisiong for a filter to be passed on, where the callback +// function to receive only conditional notifications of the events +// listener is interested about +func (c *mongoCollection) Watch(ctx context.Context, filter interface{}, cb WatchCallbackfn) error { + if filter == nil { + // if passed filter is nil, initialize it to empty pipeline object + filter = mongo.Pipeline{} + } + switch v := filter.(type) { + case mongo.Pipeline: + // we are ok to proceed further + break + default: + return errors.Wrapf(errors.InvalidArgument, "Invalid watch filter pipeline type specified, %v", v) + } // start watching on the collection with passed context - stream, err := c.col.Watch(ctx, mongo.Pipeline{}) + stream, err := c.col.Watch(ctx, filter) if err != nil { return err } diff --git a/db/mongo_test.go b/db/mongo_test.go index ab1d5ce..a5ebc04 100644 --- a/db/mongo_test.go +++ b/db/mongo_test.go @@ -8,6 +8,9 @@ import ( "reflect" "testing" "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" ) type MyKey struct { @@ -140,8 +143,9 @@ func Test_ClientConnection(t *testing.T) { } var ( - mongoTestAddUpOps int - mongoTestDeleteOps int + mongoTestAddUpOps int + mongoTestDeleteOps int + myMongoTestDeleteOps int ) func myKeyWatcher(op string, wKey interface{}) { @@ -154,6 +158,11 @@ func myKeyWatcher(op string, wKey interface{}) { } } +func myDeleteWatcher(op string, wKey interface{}) { + _ = wKey.(*MyKey) + myMongoTestDeleteOps += 1 +} + func Test_CollectionWatch(t *testing.T) { t.Run("WatchTest", func(t *testing.T) { config := &MongoConfig{ @@ -200,11 +209,17 @@ func Test_CollectionWatch(t *testing.T) { if mongoTestDeleteOps != 2 { t.Errorf("Delete Notify: Got %d, expected 2", mongoTestDeleteOps) } + if myMongoTestDeleteOps != 2 { + t.Errorf("expected delete count %d, but got %d", 2, myMongoTestDeleteOps) + } }() // reset counters mongoTestAddUpOps = 0 mongoTestDeleteOps = 0 - col.Watch(watchCtx, myKeyWatcher) + myMongoTestDeleteOps = 0 + col.Watch(watchCtx, nil, myKeyWatcher) + matchDeleteStage := mongo.Pipeline{bson.D{{Key: "$match", Value: bson.D{{Key: "operationType", Value: "delete"}}}}} + col.Watch(watchCtx, matchDeleteStage, myDeleteWatcher) key := &MyKey{ Name: "test-key", diff --git a/db/store.go b/db/store.go index ea16fb3..f39e42f 100644 --- a/db/store.go +++ b/db/store.go @@ -50,7 +50,10 @@ type StoreCollection interface { // watch allows getting notified whenever a change happens to a document // in the collection - Watch(ctx context.Context, cb WatchCallbackfn) error + // allow provisiong for a filter to be passed on, where the callback + // function to receive only conditional notifications of the events + // listener is interested about + Watch(ctx context.Context, filter interface{}, cb WatchCallbackfn) error } // interface definition for a store, responsible for holding group diff --git a/errors/errors.go b/errors/errors.go index 13d6c3a..98681b4 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -5,6 +5,7 @@ package errors import ( base "errors" + "fmt" ) func Is(err error, target error) bool { @@ -47,6 +48,15 @@ func Wrap(code ErrCode, msg string) error { } } +// Wraps the error msg with recognized error codes +// using specified message format +func Wrapf(code ErrCode, format string, v ...any) error { + return &Error{ + code: code, + msg: fmt.Sprintf(format, v...), + } +} + // IsNotFound returns true if err // item isn't found in the space func IsNotFound(err error) bool {