diff --git a/cacheaside.go b/cacheaside.go index d94ddc0..56ea442 100644 --- a/cacheaside.go +++ b/cacheaside.go @@ -18,13 +18,20 @@ import ( "golang.org/x/sync/errgroup" ) +type lockEntry struct { + ctx context.Context + cancel context.CancelFunc +} + type CacheAside struct { client rueidis.Client - locks syncx.Map[string, chan struct{}] + locks syncx.Map[string, *lockEntry] lockTTL time.Duration } type CacheAsideOption struct { + // LockTTL is the maximum time a lock can be held, and also the timeout for waiting + // on locks when handling lost Redis invalidation messages. Defaults to 10 seconds. LockTTL time.Duration ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error) } @@ -57,9 +64,9 @@ func (rca *CacheAside) Client() rueidis.Client { func (rca *CacheAside) onInvalidate(messages []rueidis.RedisMessage) { for _, m := range messages { key, _ := m.ToString() - ch, loaded := rca.locks.LoadAndDelete(key) + entry, loaded := rca.locks.LoadAndDelete(key) if loaded { - close(ch) + entry.cancel() // Cancel context, which closes the channel } } } @@ -72,8 +79,37 @@ var ( ) func (rca *CacheAside) register(key string) <-chan struct{} { - ch, _ := rca.locks.LoadOrStore(key, make(chan struct{})) - return ch + // Try to load existing entry first + if entry, loaded := rca.locks.Load(key); loaded { + // Check if the context is still active (not cancelled/timed out) + select { + case <-entry.ctx.Done(): + // Context is done - clean it up and create a new one + rca.locks.Delete(key) + default: + // Context is still active - use it + return entry.ctx.Done() + } + } + + // Create new entry with context that auto-cancels after lockTTL + ctx, cancel := context.WithTimeout(context.Background(), rca.lockTTL) + + entry := &lockEntry{ + ctx: ctx, + cancel: cancel, + } + + // Store or get existing entry atomically + actual, _ := rca.locks.LoadOrStore(key, entry) + + // If another goroutine stored first, cancel our context and use theirs + if actual != entry { + cancel() + return actual.ctx.Done() + } + + return ctx.Done() } func (rca *CacheAside) Get( @@ -82,8 +118,6 @@ func (rca *CacheAside) Get( key string, fn func(ctx context.Context, key string) (val string, err error), ) (string, error) { - ctx, cancel := context.WithTimeout(ctx, ttl) - defer cancel() retry: wait := rca.register(key) val, err := rca.tryGet(ctx, ttl, key) @@ -105,10 +139,12 @@ retry: } if val == "" { + // Wait for lock release (channel auto-closes after lockTTL or on invalidation) select { case <-wait: goto retry case <-ctx.Done(): + // Parent context cancelled return "", ctx.Err() } } @@ -159,7 +195,8 @@ func (rca *CacheAside) trySetKeyFunc(ctx context.Context, ttl time.Duration, key if !setVal { toCtx, cancel := context.WithTimeout(context.Background(), rca.lockTTL) defer cancel() - rca.unlock(toCtx, key, lockVal) + // Best effort unlock - errors are non-fatal as lock will expire + _ = rca.unlock(toCtx, key, lockVal) } }() if val, err = fn(ctx, key); err == nil { @@ -199,8 +236,8 @@ func (rca *CacheAside) setWithLock(ctx context.Context, ttl time.Duration, key s return valLock.val, nil } -func (rca *CacheAside) unlock(ctx context.Context, key string, lock string) { - delKeyLua.Exec(ctx, rca.client, []string{key}, []string{lock}) +func (rca *CacheAside) unlock(ctx context.Context, key string, lock string) error { + return delKeyLua.Exec(ctx, rca.client, []string{key}, []string{lock}).Error() } func (rca *CacheAside) GetMulti( @@ -210,9 +247,6 @@ func (rca *CacheAside) GetMulti( fn func(ctx context.Context, key []string) (val map[string]string, err error), ) (map[string]string, error) { - ctx, cancel := context.WithTimeout(ctx, ttl) - defer cancel() - res := make(map[string]string, len(keys)) waitLock := make(map[string]<-chan struct{}, len(keys)) @@ -245,9 +279,11 @@ retry: } if len(waitLock) > 0 { + // Wait for lock releases (channels auto-close after lockTTL or on invalidation) err = syncx.WaitForAll(ctx, maps.Values(waitLock), len(waitLock)) if err != nil { - return nil, err + // Parent context cancelled or deadline exceeded + return nil, ctx.Err() } goto retry } @@ -413,7 +449,7 @@ func (rca *CacheAside) setMultiWithLock(ctx context.Context, ttl time.Duration, } continue } - keyByStmt[ii] = append(out, kos.keyOrder[j]) + keyByStmt[ii] = append(keyByStmt[ii], kos.keyOrder[j]) } return nil }) @@ -445,7 +481,8 @@ func (rca *CacheAside) unlockMulti(ctx context.Context, lockVals map[string]stri wg.Add(1) go func() { defer wg.Done() - delKeyLua.ExecMulti(ctx, rca.client, stmts...) + // Best effort unlock - errors are non-fatal as locks will expire + _ = delKeyLua.ExecMulti(ctx, rca.client, stmts...) }() } wg.Wait() diff --git a/cacheaside_test.go b/cacheaside_test.go index f96ffac..cf04b56 100644 --- a/cacheaside_test.go +++ b/cacheaside_test.go @@ -621,3 +621,33 @@ func TestCacheAside_DelMulti(t *testing.T) { require.True(t, rueidis.IsRedisNil(err)) } } + +func TestCacheAside_GetParentContextCancellation(t *testing.T) { + client := makeClient(t, addr) + defer client.Client().Close() + + ctx, cancel := context.WithCancel(context.Background()) + key := "key:" + uuid.New().String() + val := "val:" + uuid.New().String() + + // Set a lock on the key so Get will wait + innerClient := client.Client() + lockVal := "redcache:" + uuid.New().String() + err := innerClient.Do(context.Background(), innerClient.B().Set().Key(key).Value(lockVal).Nx().Get().Px(time.Second*30).Build()).Error() + require.True(t, rueidis.IsRedisNil(err)) + + // Cancel the parent context after a short delay + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + + cb := func(ctx context.Context, key string) (string, error) { + return val, nil + } + + // Should get parent context cancelled error, not a timeout + _, err = client.Get(ctx, time.Second*10, key, cb) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) +}