Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 53 additions & 16 deletions cacheaside.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@ import (
"golang.org/x/sync/errgroup"
)

type lockEntry struct {
ctx context.Context
cancel context.CancelFunc
}

type CacheAside struct {
client rueidis.Client
locks syncx.Map[string, chan struct{}]
locks syncx.Map[string, *lockEntry]
lockTTL time.Duration
}

type CacheAsideOption struct {
// LockTTL is the maximum time a lock can be held, and also the timeout for waiting
// on locks when handling lost Redis invalidation messages. Defaults to 10 seconds.
LockTTL time.Duration
ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error)
}
Expand Down Expand Up @@ -57,9 +64,9 @@ func (rca *CacheAside) Client() rueidis.Client {
func (rca *CacheAside) onInvalidate(messages []rueidis.RedisMessage) {
for _, m := range messages {
key, _ := m.ToString()
ch, loaded := rca.locks.LoadAndDelete(key)
entry, loaded := rca.locks.LoadAndDelete(key)
if loaded {
close(ch)
entry.cancel() // Cancel context, which closes the channel
}
}
}
Expand All @@ -72,8 +79,37 @@ var (
)

func (rca *CacheAside) register(key string) <-chan struct{} {
ch, _ := rca.locks.LoadOrStore(key, make(chan struct{}))
return ch
// Try to load existing entry first
if entry, loaded := rca.locks.Load(key); loaded {
// Check if the context is still active (not cancelled/timed out)
select {
case <-entry.ctx.Done():
// Context is done - clean it up and create a new one
rca.locks.Delete(key)
default:
// Context is still active - use it
return entry.ctx.Done()
}
}

// Create new entry with context that auto-cancels after lockTTL
ctx, cancel := context.WithTimeout(context.Background(), rca.lockTTL)

entry := &lockEntry{
ctx: ctx,
cancel: cancel,
}

// Store or get existing entry atomically
actual, _ := rca.locks.LoadOrStore(key, entry)

// If another goroutine stored first, cancel our context and use theirs
if actual != entry {
cancel()
return actual.ctx.Done()
}

return ctx.Done()
}

func (rca *CacheAside) Get(
Expand All @@ -82,8 +118,6 @@ func (rca *CacheAside) Get(
key string,
fn func(ctx context.Context, key string) (val string, err error),
) (string, error) {
ctx, cancel := context.WithTimeout(ctx, ttl)
defer cancel()
retry:
wait := rca.register(key)
val, err := rca.tryGet(ctx, ttl, key)
Expand All @@ -105,10 +139,12 @@ retry:
}

if val == "" {
// Wait for lock release (channel auto-closes after lockTTL or on invalidation)
select {
case <-wait:
goto retry
case <-ctx.Done():
// Parent context cancelled
return "", ctx.Err()
}
}
Expand Down Expand Up @@ -159,7 +195,8 @@ func (rca *CacheAside) trySetKeyFunc(ctx context.Context, ttl time.Duration, key
if !setVal {
toCtx, cancel := context.WithTimeout(context.Background(), rca.lockTTL)
defer cancel()
rca.unlock(toCtx, key, lockVal)
// Best effort unlock - errors are non-fatal as lock will expire
_ = rca.unlock(toCtx, key, lockVal)
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error from unlock is being silently discarded. Consider logging the error for debugging purposes while maintaining the non-fatal behavior.

Copilot uses AI. Check for mistakes.
}
}()
if val, err = fn(ctx, key); err == nil {
Expand Down Expand Up @@ -199,8 +236,8 @@ func (rca *CacheAside) setWithLock(ctx context.Context, ttl time.Duration, key s
return valLock.val, nil
}

func (rca *CacheAside) unlock(ctx context.Context, key string, lock string) {
delKeyLua.Exec(ctx, rca.client, []string{key}, []string{lock})
func (rca *CacheAside) unlock(ctx context.Context, key string, lock string) error {
return delKeyLua.Exec(ctx, rca.client, []string{key}, []string{lock}).Error()
}

func (rca *CacheAside) GetMulti(
Expand All @@ -210,9 +247,6 @@ func (rca *CacheAside) GetMulti(
fn func(ctx context.Context, key []string) (val map[string]string, err error),
) (map[string]string, error) {

ctx, cancel := context.WithTimeout(ctx, ttl)
defer cancel()

res := make(map[string]string, len(keys))

waitLock := make(map[string]<-chan struct{}, len(keys))
Expand Down Expand Up @@ -245,9 +279,11 @@ retry:
}

if len(waitLock) > 0 {
// Wait for lock releases (channels auto-close after lockTTL or on invalidation)
err = syncx.WaitForAll(ctx, maps.Values(waitLock), len(waitLock))
if err != nil {
return nil, err
// Parent context cancelled or deadline exceeded
return nil, ctx.Err()
}
goto retry
}
Expand Down Expand Up @@ -413,7 +449,7 @@ func (rca *CacheAside) setMultiWithLock(ctx context.Context, ttl time.Duration,
}
continue
}
keyByStmt[ii] = append(out, kos.keyOrder[j])
keyByStmt[ii] = append(keyByStmt[ii], kos.keyOrder[j])
}
return nil
})
Expand Down Expand Up @@ -445,7 +481,8 @@ func (rca *CacheAside) unlockMulti(ctx context.Context, lockVals map[string]stri
wg.Add(1)
go func() {
defer wg.Done()
delKeyLua.ExecMulti(ctx, rca.client, stmts...)
// Best effort unlock - errors are non-fatal as locks will expire
_ = delKeyLua.ExecMulti(ctx, rca.client, stmts...)
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the single unlock, the error from unlockMulti is being silently discarded. Consider logging the error for debugging purposes while maintaining the non-fatal behavior.

Copilot uses AI. Check for mistakes.
}()
}
wg.Wait()
Expand Down
30 changes: 30 additions & 0 deletions cacheaside_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,33 @@ func TestCacheAside_DelMulti(t *testing.T) {
require.True(t, rueidis.IsRedisNil(err))
}
}

func TestCacheAside_GetParentContextCancellation(t *testing.T) {
client := makeClient(t, addr)
defer client.Client().Close()

ctx, cancel := context.WithCancel(context.Background())
key := "key:" + uuid.New().String()
val := "val:" + uuid.New().String()

// Set a lock on the key so Get will wait
innerClient := client.Client()
lockVal := "redcache:" + uuid.New().String()
err := innerClient.Do(context.Background(), innerClient.B().Set().Key(key).Value(lockVal).Nx().Get().Px(time.Second*30).Build()).Error()
require.True(t, rueidis.IsRedisNil(err))

// Cancel the parent context after a short delay
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()

cb := func(ctx context.Context, key string) (string, error) {
return val, nil
}

// Should get parent context cancelled error, not a timeout
_, err = client.Get(ctx, time.Second*10, key, cb)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
}
Loading