diff --git a/go.mod b/go.mod index 1c282fcd9..371885d1f 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,14 @@ module github.com/envoyproxy/ratelimit go 1.14 require ( + github.com/alicebob/miniredis v2.5.0+incompatible github.com/alicebob/miniredis/v2 v2.11.4 github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b github.com/cespare/xxhash v1.1.0 // indirect github.com/coocood/freecache v1.1.0 github.com/envoyproxy/go-control-plane v0.9.7 github.com/fsnotify/fsnotify v1.4.7 // indirect + github.com/go-redis/redis v6.15.9+incompatible github.com/golang/mock v1.4.1 github.com/golang/protobuf v1.4.2 github.com/gorilla/mux v1.7.4-0.20191121170500-49c01487a141 @@ -23,6 +25,7 @@ require ( golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e // indirect golang.org/x/text v0.3.3-0.20191122225017-cbf43d21aaeb // indirect + google.golang.org/appengine v1.4.0 google.golang.org/grpc v1.27.0 google.golang.org/protobuf v1.25.0 // indirect gopkg.in/yaml.v2 v2.3.0 diff --git a/go.sum b/go.sum index 071a59b37..4f3f4a0d7 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 h1:45bxf7AZMwWcqkLzDAQugVEwedisr5nRJ1r+7LYnv0U= github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis v2.5.0+incompatible h1:yBHoLpsyjupjz3NL3MhKMVkR41j82Yjf3KFv7ApYzUI= +github.com/alicebob/miniredis v2.5.0+incompatible/go.mod h1:8HZjEj4yU0dwhYHky+DxYx+6BMjkBbe5ONFIF1MXffk= github.com/alicebob/miniredis/v2 v2.11.4 h1:GsuyeunTx7EllZBU3/6Ji3dhMQZDpC9rLf1luJ+6M5M= github.com/alicebob/miniredis/v2 v2.11.4/go.mod h1:VL3UDEfAH59bSa7MuHMuFToxkqyHh69s/WUbYlOAuyg= github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b h1:L/QXpzIa3pOvUGt1D1lA5KjYhPBAN/3iWdP7xeFS9F0= @@ -30,6 +32,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrp github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= +github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= @@ -126,6 +130,7 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= diff --git a/src/memcached/cache_impl.go b/src/memcached/cache_impl.go index 4b21af331..b5e966744 100644 --- a/src/memcached/cache_impl.go +++ b/src/memcached/cache_impl.go @@ -1,302 +1,27 @@ -// The memcached limiter uses GetMulti() to check keys in parallel and then does -// increments asynchronously in the backend, since the memcache interface doesn't -// support multi-increment and it seems worthwhile to minimize the number of -// concurrent or sequential RPCs in the critical path. -// -// Another difference from redis is that memcache doesn't create a key implicitly by -// incrementing a missing entry. Instead, when increment fails an explicit "add" needs -// to be called. The process of increment becomes a bit of a dance since we try to -// limit the number of RPCs. First we call increment, then add if the increment -// failed, then increment again if the add failed (which could happen if there was -// a race to call "add"). -// -// Note that max memcache key length is 250 characters. Attempting to get or increment -// a longer key will return memcache.ErrMalformedKey - package memcached import ( - "context" - "github.com/envoyproxy/ratelimit/src/stats" "math/rand" - "strconv" - "sync" - "time" "github.com/coocood/freecache" - gostats "github.com/lyft/gostats" - - "github.com/bradfitz/gomemcache/memcache" - - logger "github.com/sirupsen/logrus" - - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - - "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/server" "github.com/envoyproxy/ratelimit/src/settings" - "github.com/envoyproxy/ratelimit/src/srv" + "github.com/envoyproxy/ratelimit/src/stats" "github.com/envoyproxy/ratelimit/src/utils" -) - -type rateLimitMemcacheImpl struct { - client Client - timeSource utils.TimeSource - jitterRand *rand.Rand - expirationJitterMaxSeconds int64 - localCache *freecache.Cache - waitGroup sync.WaitGroup - nearLimitRatio float32 - baseRateLimiter *limiter.BaseRateLimiter -} - -var AutoFlushForIntegrationTests bool = false - -var _ limiter.RateLimitCache = (*rateLimitMemcacheImpl)(nil) - -func (this *rateLimitMemcacheImpl) DoLimit( - ctx context.Context, - request *pb.RateLimitRequest, - limits []*config.RateLimit) []*pb.RateLimitResponse_DescriptorStatus { - - logger.Debugf("starting cache lookup") - - // request.HitsAddend could be 0 (default value) if not specified by the caller in the Ratelimit request. - hitsAddend := utils.Max(1, request.HitsAddend) - - // First build a list of all cache keys that we are actually going to hit. - cacheKeys := this.baseRateLimiter.GenerateCacheKeys(request, limits, hitsAddend) - - isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) - - keysToGet := make([]string, 0, len(request.Descriptors)) - - for i, cacheKey := range cacheKeys { - if cacheKey.Key == "" { - continue - } - - // Check if key is over the limit in local cache. - if this.baseRateLimiter.IsOverLimitWithLocalCache(cacheKey.Key) { - isOverLimitWithLocalCache[i] = true - logger.Debugf("cache key is over the limit: %s", cacheKey.Key) - continue - } - - logger.Debugf("looking up cache key: %s", cacheKey.Key) - keysToGet = append(keysToGet, cacheKey.Key) - } - - // Now fetch from memcache. - responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus, - len(request.Descriptors)) - - var memcacheValues map[string]*memcache.Item - var err error - - if len(keysToGet) > 0 { - memcacheValues, err = this.client.GetMulti(keysToGet) - if err != nil { - logger.Errorf("Error multi-getting memcache keys (%s): %s", keysToGet, err) - } - } - - for i, cacheKey := range cacheKeys { - - rawMemcacheValue, ok := memcacheValues[cacheKey.Key] - var limitBeforeIncrease uint32 - if ok { - decoded, err := strconv.ParseInt(string(rawMemcacheValue.Value), 10, 32) - if err != nil { - logger.Errorf("Unexpected non-numeric value in memcached: %v", rawMemcacheValue) - } else { - limitBeforeIncrease = uint32(decoded) - } - - } - - limitAfterIncrease := limitBeforeIncrease + hitsAddend - - limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0) - - responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key, - limitInfo, isOverLimitWithLocalCache[i], hitsAddend) - } - - this.waitGroup.Add(1) - runAsync(func() { this.increaseAsync(cacheKeys, isOverLimitWithLocalCache, limits, uint64(hitsAddend)) }) - if AutoFlushForIntegrationTests { - this.Flush() - } - return responseDescriptorStatuses -} - -func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, isOverLimitWithLocalCache []bool, - limits []*config.RateLimit, hitsAddend uint64) { - defer this.waitGroup.Done() - for i, cacheKey := range cacheKeys { - if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { - continue - } - - _, err := this.client.Increment(cacheKey.Key, hitsAddend) - if err == memcache.ErrCacheMiss { - expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) - if this.expirationJitterMaxSeconds > 0 { - expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) - } - - // Need to add instead of increment. - err = this.client.Add(&memcache.Item{ - Key: cacheKey.Key, - Value: []byte(strconv.FormatUint(hitsAddend, 10)), - Expiration: int32(expirationSeconds), - }) - if err == memcache.ErrNotStored { - // There was a race condition to do this add. We should be able to increment - // now instead. - _, err := this.client.Increment(cacheKey.Key, hitsAddend) - if err != nil { - logger.Errorf("Failed to increment key %s after failing to add: %s", cacheKey.Key, err) - continue - } - } else if err != nil { - logger.Errorf("Failed to add key %s: %s", cacheKey.Key, err) - continue - } - } else if err != nil { - logger.Errorf("Failed to increment key %s: %s", cacheKey.Key, err) - continue - } - } -} - -func (this *rateLimitMemcacheImpl) Flush() { - this.waitGroup.Wait() -} - -func refreshServersPeriodically(serverList memcache.ServerList, srv string, d time.Duration, finish <-chan struct{}) { - t := time.NewTicker(d) - defer t.Stop() - for { - select { - case <-t.C: - err := refreshServers(serverList, srv) - if err != nil { - logger.Warn("failed to refresh memcahce hosts") - } else { - logger.Debug("refreshed memcache hosts") - } - case <-finish: - return - } - } -} - -func refreshServers(serverList memcache.ServerList, srv_ string) error { - servers, err := srv.ServerStringsFromSrv(srv_) - if err != nil { - return err - } - err = serverList.SetServers(servers...) - if err != nil { - return err - } - return nil -} - -func newMemcachedFromSrv(srv_ string, d time.Duration) Client { - serverList := new(memcache.ServerList) - err := refreshServers(*serverList, srv_) - if err != nil { - errorText := "Unable to fetch servers from SRV" - logger.Errorf(errorText) - panic(MemcacheError(errorText)) - } - - if d > 0 { - logger.Infof("refreshing memcache hosts every: %v milliseconds", d.Milliseconds()) - finish := make(chan struct{}) - go refreshServersPeriodically(*serverList, srv_, d, finish) - } else { - logger.Debugf("not periodically refreshing memcached hosts") - } - - return memcache.NewFromSelector(serverList) -} - -func newMemcacheFromSettings(s settings.Settings) Client { - if s.MemcacheSrv != "" && len(s.MemcacheHostPort) > 0 { - panic(MemcacheError("Both MEMCADHE_HOST_PORT and MEMCACHE_SRV are set")) - } - if s.MemcacheSrv != "" { - logger.Debugf("Using MEMCACHE_SRV: %v", s.MemcacheSrv) - return newMemcachedFromSrv(s.MemcacheSrv, s.MemcacheSrvRefresh) - } - logger.Debugf("Usng MEMCACHE_HOST_PORT:: %v", s.MemcacheHostPort) - client := memcache.New(s.MemcacheHostPort...) - client.MaxIdleConns = s.MemcacheMaxIdleConns - return client -} - -var taskQueue = make(chan func()) - -func runAsync(task func()) { - select { - case taskQueue <- task: - // submitted, everything is ok - - default: - go func() { - // do the given task - task() - - tasksProcessedWithinOnePeriod := 0 - const tickDuration = 10 * time.Second - tick := time.NewTicker(tickDuration) - defer tick.Stop() - - for { - select { - case t := <-taskQueue: - t() - tasksProcessedWithinOnePeriod++ - case <-tick.C: - if tasksProcessedWithinOnePeriod > 0 { - tasksProcessedWithinOnePeriod = 0 - continue - } - return - } - } - }() - } -} - -func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRand *rand.Rand, - expirationJitterMaxSeconds int64, localCache *freecache.Cache, statsManager stats.Manager, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { - return &rateLimitMemcacheImpl{ - client: client, - timeSource: timeSource, - jitterRand: jitterRand, - expirationJitterMaxSeconds: expirationJitterMaxSeconds, - localCache: localCache, - nearLimitRatio: nearLimitRatio, - baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix, statsManager), - } -} + storage_factory "github.com/envoyproxy/ratelimit/src/storage/factory" +) -func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand, - localCache *freecache.Cache, scope gostats.Scope, statsManager stats.Manager) limiter.RateLimitCache { - return NewRateLimitCacheImpl( - CollectStats(newMemcacheFromSettings(s), scope.Scope("memcache")), +func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, statsManager stats.Manager) limiter.RateLimitCache { + return NewFixedRateLimitCacheImpl( + storage_factory.NewMemcached(srv.Scope().Scope("memcache"), s.MemcacheHostPort, s.MemcacheSrv, s.MemcacheSrvRefresh, s.MemcacheMaxIdleConns), timeSource, jitterRand, - s.ExpirationJitterMaxSeconds, localCache, - statsManager, + s.ExpirationJitterMaxSeconds, s.NearLimitRatio, s.CacheKeyPrefix, + statsManager, ) } diff --git a/src/memcached/client.go b/src/memcached/client.go deleted file mode 100644 index e80902692..000000000 --- a/src/memcached/client.go +++ /dev/null @@ -1,21 +0,0 @@ -package memcached - -import ( - "github.com/bradfitz/gomemcache/memcache" -) - -// Errors that may be raised during config parsing. -type MemcacheError string - -func (e MemcacheError) Error() string { - return string(e) -} - -var _ Client = (*memcache.Client)(nil) - -// Interface for memcached, used for mocking. -type Client interface { - GetMulti(keys []string) (map[string]*memcache.Item, error) - Increment(key string, delta uint64) (newValue uint64, err error) - Add(item *memcache.Item) error -} diff --git a/src/memcached/fixed_cache_impl.go b/src/memcached/fixed_cache_impl.go new file mode 100644 index 000000000..ffaa7fc3e --- /dev/null +++ b/src/memcached/fixed_cache_impl.go @@ -0,0 +1,192 @@ +// The memcached limiter uses GetMulti() to check keys in parallel and then does +// increments asynchronously in the backend, since the memcache interface doesn't +// support multi-increment and it seems worthwhile to minimize the number of +// concurrent or sequential RPCs in the critical path. +// +// Another difference from redis is that memcache doesn't create a key implicitly by +// incrementing a missing entry. Instead, when increment fails an explicit "add" needs +// to be called. The process of increment becomes a bit of a dance since we try to +// limit the number of RPCs. First we call increment, then add if the increment +// failed, then increment again if the add failed (which could happen if there was +// a race to call "add"). +// +// Note that max memcache key length is 250 characters. Attempting to get or increment +// a longer key will return memcache.ErrMalformedKey + +package memcached + +import ( + "context" + "math/rand" + "sync" + "time" + + "github.com/bradfitz/gomemcache/memcache" + "github.com/coocood/freecache" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/stats" + "github.com/envoyproxy/ratelimit/src/utils" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + storage_strategy "github.com/envoyproxy/ratelimit/src/storage/strategy" + logger "github.com/sirupsen/logrus" +) + +type rateLimitMemcacheImpl struct { + client storage_strategy.StorageStrategy + jitterRand *rand.Rand + expirationJitterMaxSeconds int64 + baseRateLimiter *limiter.BaseRateLimiter + waitGroup sync.WaitGroup +} + +var AutoFlushForIntegrationTests bool = false + +var _ limiter.RateLimitCache = (*rateLimitMemcacheImpl)(nil) + +func (this *rateLimitMemcacheImpl) DoLimit( + ctx context.Context, + request *pb.RateLimitRequest, + limits []*config.RateLimit) []*pb.RateLimitResponse_DescriptorStatus { + + logger.Debugf("starting cache lookup") + + // request.HitsAddend could be 0 (default value) if not specified by the caller in the RateLimit request. + hitsAddend := utils.Max(1, request.HitsAddend) + + // First build a list of all cache keys that we are actually going to hit. + cacheKeys := this.baseRateLimiter.GenerateCacheKeys(request, limits, hitsAddend) + + isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) + results := make([]uint64, len(request.Descriptors)) + + // Now, actually setup the pipeline, skipping empty cache keys. + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" { + continue + } + + // Check if key is over the limit in local cache. + if this.baseRateLimiter.IsOverLimitWithLocalCache(cacheKey.Key) { + isOverLimitWithLocalCache[i] = true + logger.Debugf("cache key is over the limit: %s", cacheKey.Key) + continue + } + + logger.Debugf("looking up cache key: %s", cacheKey.Key) + + // Use the perSecondConn if it is not nil and the cacheKey represents a per second Limit. + value, err := this.client.GetValue(cacheKey.Key) + if err != nil { + logger.Error(err) + } + results[i] = value + + } + + // Now fetch the pipeline. + responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus, + len(request.Descriptors)) + for i, cacheKey := range cacheKeys { + + limitBeforeIncrease := uint32(results[i]) + limitAfterIncrease := limitBeforeIncrease + hitsAddend + + limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0) + + responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key, + limitInfo, isOverLimitWithLocalCache[i], hitsAddend) + } + + this.waitGroup.Add(1) + runAsync(func() { this.increaseAsync(cacheKeys, isOverLimitWithLocalCache, limits, uint64(hitsAddend)) }) + + return responseDescriptorStatuses +} + +func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, isOverLimitWithLocalCache []bool, + limits []*config.RateLimit, hitsAddend uint64) { + defer this.waitGroup.Done() + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { + continue + } + + expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) + if this.expirationJitterMaxSeconds > 0 { + expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) + } + + err := this.client.IncrementValue(cacheKey.Key, hitsAddend) + // if key is not found + if err == memcache.ErrCacheMiss { + + // create key + err = this.client.SetValue(cacheKey.Key, hitsAddend, uint64(expirationSeconds)) + if err == memcache.ErrNotStored { + + // increment the key + err := this.client.IncrementValue(cacheKey.Key, hitsAddend) + if err != nil { + logger.Errorf("Failed to increment key %s after failing to add: %s", cacheKey.Key, err) + continue + } + } else if err != nil { + logger.Errorf("Failed to add key %s: %s", cacheKey.Key, err) + continue + } + } else if err != nil { + logger.Errorf("Failed to increment key %s: %s", cacheKey.Key, err) + continue + } + } +} + +func (this *rateLimitMemcacheImpl) Flush() { + this.waitGroup.Wait() +} + +var taskQueue = make(chan func()) + +func runAsync(task func()) { + select { + case taskQueue <- task: + // submitted, everything is ok + + default: + go func() { + // do the given task + task() + + tasksProcessedWithinOnePeriod := 0 + const tickDuration = 10 * time.Second + tick := time.NewTicker(tickDuration) + defer tick.Stop() + + for { + select { + case t := <-taskQueue: + t() + tasksProcessedWithinOnePeriod++ + case <-tick.C: + if tasksProcessedWithinOnePeriod > 0 { + tasksProcessedWithinOnePeriod = 0 + continue + } + return + } + } + }() + } +} + +func NewFixedRateLimitCacheImpl(client storage_strategy.StorageStrategy, timeSource utils.TimeSource, jitterRand *rand.Rand, + localCache *freecache.Cache, expirationJitterMaxSeconds int64, nearLimitRatio float32, cacheKeyPrefix string, statsManager stats.Manager) limiter.RateLimitCache { + return &rateLimitMemcacheImpl{ + client: client, + jitterRand: jitterRand, + expirationJitterMaxSeconds: expirationJitterMaxSeconds, + baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix, statsManager), + } +} diff --git a/src/memcached/stats_collecting_client.go b/src/memcached/stats_collecting_client.go deleted file mode 100644 index 12b67bad5..000000000 --- a/src/memcached/stats_collecting_client.go +++ /dev/null @@ -1,80 +0,0 @@ -package memcached - -import ( - "github.com/bradfitz/gomemcache/memcache" - stats "github.com/lyft/gostats" -) - -type statsCollectingClient struct { - c Client - - multiGetSuccess stats.Counter - multiGetError stats.Counter - incrementSuccess stats.Counter - incrementMiss stats.Counter - incrementError stats.Counter - addSuccess stats.Counter - addError stats.Counter - addNotStored stats.Counter - keysRequested stats.Counter - keysFound stats.Counter -} - -func CollectStats(c Client, scope stats.Scope) Client { - return statsCollectingClient{ - c: c, - multiGetSuccess: scope.NewCounterWithTags("multiget", map[string]string{"code": "success"}), - multiGetError: scope.NewCounterWithTags("multiget", map[string]string{"code": "error"}), - incrementSuccess: scope.NewCounterWithTags("increment", map[string]string{"code": "success"}), - incrementMiss: scope.NewCounterWithTags("increment", map[string]string{"code": "miss"}), - incrementError: scope.NewCounterWithTags("increment", map[string]string{"code": "error"}), - addSuccess: scope.NewCounterWithTags("add", map[string]string{"code": "success"}), - addError: scope.NewCounterWithTags("add", map[string]string{"code": "error"}), - addNotStored: scope.NewCounterWithTags("add", map[string]string{"code": "not_stored"}), - keysRequested: scope.NewCounter("keys_requested"), - keysFound: scope.NewCounter("keys_found"), - } -} - -func (scc statsCollectingClient) GetMulti(keys []string) (map[string]*memcache.Item, error) { - scc.keysRequested.Add(uint64(len(keys))) - - results, err := scc.c.GetMulti(keys) - - if err != nil { - scc.multiGetError.Inc() - } else { - scc.keysFound.Add(uint64(len(results))) - scc.multiGetSuccess.Inc() - } - - return results, err -} - -func (scc statsCollectingClient) Increment(key string, delta uint64) (newValue uint64, err error) { - newValue, err = scc.c.Increment(key, delta) - switch err { - case memcache.ErrCacheMiss: - scc.incrementMiss.Inc() - case nil: - scc.incrementSuccess.Inc() - default: - scc.incrementError.Inc() - } - return -} - -func (scc statsCollectingClient) Add(item *memcache.Item) error { - err := scc.c.Add(item) - - switch err { - case memcache.ErrNotStored: - scc.addNotStored.Inc() - case nil: - scc.addSuccess.Inc() - default: - scc.addError.Inc() - } - - return err -} diff --git a/src/redis/cache_impl.go b/src/redis/cache_impl.go index 7bf9eafcf..49c0512e2 100644 --- a/src/redis/cache_impl.go +++ b/src/redis/cache_impl.go @@ -1,24 +1,26 @@ package redis import ( - "github.com/envoyproxy/ratelimit/src/stats" "math/rand" "github.com/coocood/freecache" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/server" "github.com/envoyproxy/ratelimit/src/settings" + "github.com/envoyproxy/ratelimit/src/stats" "github.com/envoyproxy/ratelimit/src/utils" + + storage_factory "github.com/envoyproxy/ratelimit/src/storage/factory" + storage_strategy "github.com/envoyproxy/ratelimit/src/storage/strategy" ) -func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, statsManager stats.Manager) limiter.RateLimitCache { - var perSecondPool Client +func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, statsManager stats.Manager) limiter.RateLimitCache { + var perSecondPool storage_strategy.StorageStrategy if s.RedisPerSecond { - perSecondPool = NewClientImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, + perSecondPool = storage_factory.NewRedis(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondType, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize, s.RedisPerSecondPipelineWindow, s.RedisPerSecondPipelineLimit) } - var otherPool Client - otherPool = NewClientImpl(srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisType, s.RedisUrl, s.RedisPoolSize, + otherPool := storage_factory.NewRedis(srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisType, s.RedisUrl, s.RedisPoolSize, s.RedisPipelineWindow, s.RedisPipelineLimit) return NewFixedRateLimitCacheImpl( @@ -26,8 +28,8 @@ func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freeca perSecondPool, timeSource, jitterRand, - expirationJitterMaxSeconds, localCache, + s.ExpirationJitterMaxSeconds, s.NearLimitRatio, s.CacheKeyPrefix, statsManager, diff --git a/src/redis/driver.go b/src/redis/driver.go deleted file mode 100644 index 7ffc0c7b7..000000000 --- a/src/redis/driver.go +++ /dev/null @@ -1,49 +0,0 @@ -package redis - -import "github.com/mediocregopher/radix/v3" - -// Errors that may be raised during config parsing. -type RedisError string - -func (e RedisError) Error() string { - return string(e) -} - -// Interface for a redis client. -type Client interface { - // DoCmd is used to perform a redis command and retrieve a result. - // - // @param rcv supplies receiver for the result. - // @param cmd supplies the command to append. - // @param key supplies the key to append. - // @param args supplies the additional arguments. - DoCmd(rcv interface{}, cmd, key string, args ...interface{}) error - - // PipeAppend append a command onto the pipeline queue. - // - // @param pipeline supplies the queue for pending commands. - // @param rcv supplies receiver for the result. - // @param cmd supplies the command to append. - // @param key supplies the key to append. - // @param args supplies the additional arguments. - PipeAppend(pipeline Pipeline, rcv interface{}, cmd, key string, args ...interface{}) Pipeline - - // PipeDo writes multiple commands to a Conn in - // a single write, then reads their responses in a single read. This reduces - // network delay into a single round-trip. - // - // @param pipeline supplies the queue for pending commands. - PipeDo(pipeline Pipeline) error - - // Once Close() is called all future method calls on the Client will return - // an error - Close() error - - // NumActiveConns return number of active connections, used in testing. - NumActiveConns() int - - // ImplicitPipeliningEnabled return true if implicit pipelining is enabled. - ImplicitPipeliningEnabled() bool -} - -type Pipeline []radix.CmdAction diff --git a/src/redis/driver_impl.go b/src/redis/driver_impl.go deleted file mode 100644 index f6449ea52..000000000 --- a/src/redis/driver_impl.go +++ /dev/null @@ -1,160 +0,0 @@ -package redis - -import ( - "crypto/tls" - "fmt" - "strings" - "time" - - "github.com/mediocregopher/radix/v3/trace" - - stats "github.com/lyft/gostats" - "github.com/mediocregopher/radix/v3" - logger "github.com/sirupsen/logrus" -) - -type poolStats struct { - connectionActive stats.Gauge - connectionTotal stats.Counter - connectionClose stats.Counter -} - -func newPoolStats(scope stats.Scope) poolStats { - ret := poolStats{} - ret.connectionActive = scope.NewGauge("cx_active") - ret.connectionTotal = scope.NewCounter("cx_total") - ret.connectionClose = scope.NewCounter("cx_local_close") - return ret -} - -func poolTrace(ps *poolStats) trace.PoolTrace { - return trace.PoolTrace{ - ConnCreated: func(_ trace.PoolConnCreated) { - ps.connectionTotal.Add(1) - ps.connectionActive.Add(1) - }, - ConnClosed: func(_ trace.PoolConnClosed) { - ps.connectionActive.Sub(1) - ps.connectionClose.Add(1) - }, - } -} - -type clientImpl struct { - client radix.Client - stats poolStats - implicitPipelining bool -} - -func checkError(err error) { - if err != nil { - panic(RedisError(err.Error())) - } -} - -func NewClientImpl(scope stats.Scope, useTls bool, auth string, redisType string, url string, poolSize int, - pipelineWindow time.Duration, pipelineLimit int) Client { - logger.Warnf("connecting to redis on %s with pool size %d", url, poolSize) - - df := func(network, addr string) (radix.Conn, error) { - var dialOpts []radix.DialOpt - - if useTls { - dialOpts = append(dialOpts, radix.DialUseTLS(&tls.Config{})) - } - - if auth != "" { - logger.Warnf("enabling authentication to redis on %s", url) - - dialOpts = append(dialOpts, radix.DialAuthPass(auth)) - } - - return radix.Dial(network, addr, dialOpts...) - } - - stats := newPoolStats(scope) - - opts := []radix.PoolOpt{radix.PoolConnFunc(df), radix.PoolWithTrace(poolTrace(&stats))} - - implicitPipelining := true - if pipelineWindow == 0 && pipelineLimit == 0 { - implicitPipelining = false - } else { - opts = append(opts, radix.PoolPipelineWindow(pipelineWindow, pipelineLimit)) - } - logger.Debugf("Implicit pipelining enabled: %v", implicitPipelining) - - poolFunc := func(network, addr string) (radix.Client, error) { - return radix.NewPool(network, addr, poolSize, opts...) - } - - var client radix.Client - var err error - switch strings.ToLower(redisType) { - case "single": - client, err = poolFunc("tcp", url) - case "cluster": - urls := strings.Split(url, ",") - if implicitPipelining == false { - panic(RedisError("Implicit Pipelining must be enabled to work with Redis Cluster Mode. Set values for REDIS_PIPELINE_WINDOW or REDIS_PIPELINE_LIMIT to enable implicit pipelining")) - } - logger.Warnf("Creating cluster with urls %v", urls) - client, err = radix.NewCluster(urls, radix.ClusterPoolFunc(poolFunc)) - case "sentinel": - urls := strings.Split(url, ",") - if len(urls) < 2 { - panic(RedisError("Expected master name and a list of urls for the sentinels, in the format: ,,...,")) - } - client, err = radix.NewSentinel(urls[0], urls[1:], radix.SentinelPoolFunc(poolFunc)) - default: - panic(RedisError("Unrecognized redis type " + redisType)) - } - - checkError(err) - - // Check if connection is good - var pingResponse string - checkError(client.Do(radix.Cmd(&pingResponse, "PING"))) - if pingResponse != "PONG" { - checkError(fmt.Errorf("connecting redis error: %s", pingResponse)) - } - - return &clientImpl{ - client: client, - stats: stats, - implicitPipelining: implicitPipelining, - } -} - -func (c *clientImpl) DoCmd(rcv interface{}, cmd, key string, args ...interface{}) error { - return c.client.Do(radix.FlatCmd(rcv, cmd, key, args...)) -} - -func (c *clientImpl) Close() error { - return c.client.Close() -} - -func (c *clientImpl) NumActiveConns() int { - return int(c.stats.connectionActive.Value()) -} - -func (c *clientImpl) PipeAppend(pipeline Pipeline, rcv interface{}, cmd, key string, args ...interface{}) Pipeline { - return append(pipeline, radix.FlatCmd(rcv, cmd, key, args...)) -} - -func (c *clientImpl) PipeDo(pipeline Pipeline) error { - if c.implicitPipelining { - for _, action := range pipeline { - if err := c.client.Do(action); err != nil { - return err - } - } - return nil - } - - return c.client.Do(radix.Pipeline(pipeline...)) -} - -func (c *clientImpl) ImplicitPipeliningEnabled() bool { - return c.implicitPipelining -} diff --git a/src/redis/fixed_cache_impl.go b/src/redis/fixed_cache_impl.go index d364f4ea3..f9a2c55e3 100644 --- a/src/redis/fixed_cache_impl.go +++ b/src/redis/fixed_cache_impl.go @@ -1,31 +1,39 @@ package redis import ( - "github.com/envoyproxy/ratelimit/src/stats" "math/rand" + "sync" + + "github.com/envoyproxy/ratelimit/src/stats" "github.com/coocood/freecache" - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/utils" - logger "github.com/sirupsen/logrus" "golang.org/x/net/context" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + storage_strategy "github.com/envoyproxy/ratelimit/src/storage/strategy" + logger "github.com/sirupsen/logrus" ) +type RedisError string + +func (e RedisError) Error() string { + return string(e) +} + type fixedRateLimitCacheImpl struct { - client Client + client storage_strategy.StorageStrategy // Optional Client for a dedicated cache of per second limits. // If this client is nil, then the Cache will use the client for all // limits regardless of unit. If this client is not nil, then it // is used for limits that have a SECOND unit. - perSecondClient Client - baseRateLimiter *limiter.BaseRateLimiter -} - -func pipelineAppend(client Client, pipeline *Pipeline, key string, hitsAddend uint32, result *uint32, expirationSeconds int64) { - *pipeline = client.PipeAppend(*pipeline, result, "INCRBY", key, hitsAddend) - *pipeline = client.PipeAppend(*pipeline, nil, "EXPIRE", key, expirationSeconds) + perSecondClient storage_strategy.StorageStrategy + jitterRand *rand.Rand + expirationJitterMaxSeconds int64 + baseRateLimiter *limiter.BaseRateLimiter + waitGroup sync.WaitGroup } func (this *fixedRateLimitCacheImpl) DoLimit( @@ -42,8 +50,7 @@ func (this *fixedRateLimitCacheImpl) DoLimit( cacheKeys := this.baseRateLimiter.GenerateCacheKeys(request, limits, hitsAddend) isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) - results := make([]uint32, len(request.Descriptors)) - var pipeline, perSecondPipeline Pipeline + results := make([]uint64, len(request.Descriptors)) // Now, actually setup the pipeline, skipping empty cache keys. for i, cacheKey := range cacheKeys { @@ -60,30 +67,22 @@ func (this *fixedRateLimitCacheImpl) DoLimit( logger.Debugf("looking up cache key: %s", cacheKey.Key) - expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) - if this.baseRateLimiter.ExpirationJitterMaxSeconds > 0 { - expirationSeconds += this.baseRateLimiter.JitterRand.Int63n(this.baseRateLimiter.ExpirationJitterMaxSeconds) - } - // Use the perSecondConn if it is not nil and the cacheKey represents a per second Limit. if this.perSecondClient != nil && cacheKey.PerSecond { - if perSecondPipeline == nil { - perSecondPipeline = Pipeline{} + value, err := this.perSecondClient.GetValue(cacheKey.Key) + if err != nil { + logger.Error(err) } - pipelineAppend(this.perSecondClient, &perSecondPipeline, cacheKey.Key, hitsAddend, &results[i], expirationSeconds) + + results[i] = value } else { - if pipeline == nil { - pipeline = Pipeline{} + value, err := this.client.GetValue(cacheKey.Key) + if err != nil { + logger.Error(err) } - pipelineAppend(this.client, &pipeline, cacheKey.Key, hitsAddend, &results[i], expirationSeconds) - } - } - if pipeline != nil { - checkError(this.client.PipeDo(pipeline)) - } - if perSecondPipeline != nil { - checkError(this.perSecondClient.PipeDo(perSecondPipeline)) + results[i] = value + } } // Now fetch the pipeline. @@ -91,14 +90,44 @@ func (this *fixedRateLimitCacheImpl) DoLimit( len(request.Descriptors)) for i, cacheKey := range cacheKeys { - limitAfterIncrease := results[i] - limitBeforeIncrease := limitAfterIncrease - hitsAddend + limitBeforeIncrease := uint32(results[i]) + limitAfterIncrease := limitBeforeIncrease + hitsAddend limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0) responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key, limitInfo, isOverLimitWithLocalCache[i], hitsAddend) + if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { + continue + } + + expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) + if this.expirationJitterMaxSeconds > 0 { + expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) + } + + if this.perSecondClient != nil && cacheKey.PerSecond { + err := this.perSecondClient.IncrementValue(cacheKey.Key, uint64(hitsAddend)) + if err != nil { + logger.Error(err) + } + + err = this.perSecondClient.SetExpire(cacheKey.Key, uint64(expirationSeconds)) + if err != nil { + logger.Error(err) + } + } else { + err := this.client.IncrementValue(cacheKey.Key, uint64(hitsAddend)) + if err != nil { + logger.Error(err) + } + + err = this.client.SetExpire(cacheKey.Key, uint64(expirationSeconds)) + if err != nil { + logger.Error(err) + } + } } return responseDescriptorStatuses @@ -107,11 +136,12 @@ func (this *fixedRateLimitCacheImpl) DoLimit( // Flush() is a no-op with redis since quota reads and updates happen synchronously. func (this *fixedRateLimitCacheImpl) Flush() {} -func NewFixedRateLimitCacheImpl(client Client, perSecondClient Client, timeSource utils.TimeSource, - jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string, statsManager stats.Manager) limiter.RateLimitCache { +func NewFixedRateLimitCacheImpl(client storage_strategy.StorageStrategy, perSecondClient storage_strategy.StorageStrategy, timeSource utils.TimeSource, jitterRand *rand.Rand, localCache *freecache.Cache, expirationJitterMaxSeconds int64, nearLimitRatio float32, cacheKeyPrefix string, statsManager stats.Manager) limiter.RateLimitCache { return &fixedRateLimitCacheImpl{ - client: client, - perSecondClient: perSecondClient, - baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix, statsManager), + client: client, + perSecondClient: perSecondClient, + jitterRand: jitterRand, + expirationJitterMaxSeconds: expirationJitterMaxSeconds, + baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix, statsManager), } } diff --git a/src/service_cmd/runner/runner.go b/src/service_cmd/runner/runner.go index c8fb45e3c..7f5b6c388 100644 --- a/src/service_cmd/runner/runner.go +++ b/src/service_cmd/runner/runner.go @@ -1,8 +1,6 @@ package runner import ( - "github.com/envoyproxy/ratelimit/src/metrics" - "github.com/envoyproxy/ratelimit/src/stats" "io" "math/rand" "net/http" @@ -10,21 +8,21 @@ import ( "sync" "time" - gostats "github.com/lyft/gostats" - "github.com/coocood/freecache" - - pb_legacy "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v2" - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/memcached" + "github.com/envoyproxy/ratelimit/src/metrics" "github.com/envoyproxy/ratelimit/src/redis" "github.com/envoyproxy/ratelimit/src/server" - ratelimit "github.com/envoyproxy/ratelimit/src/service" "github.com/envoyproxy/ratelimit/src/settings" + "github.com/envoyproxy/ratelimit/src/stats" "github.com/envoyproxy/ratelimit/src/utils" + + pb_legacy "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v2" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + ratelimit "github.com/envoyproxy/ratelimit/src/service" + gostats "github.com/lyft/gostats" logger "github.com/sirupsen/logrus" ) @@ -55,15 +53,14 @@ func createLimiter(srv server.Server, s settings.Settings, localCache *freecache srv, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), - s.ExpirationJitterMaxSeconds, statsManager) case "memcache": - return memcached.NewRateLimitCacheImplFromSettings( + return memcached.NewRateLimiterCacheImplFromSettings( s, + localCache, + srv, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), - localCache, - srv.Scope(), statsManager) default: logger.Fatalf("Invalid setting for BackendType: %s", s.BackendType) diff --git a/src/srv/srv.go b/src/srv/srv.go deleted file mode 100644 index 041ceb950..000000000 --- a/src/srv/srv.go +++ /dev/null @@ -1,49 +0,0 @@ -package srv - -import ( - "errors" - "fmt" - "net" - "regexp" - - logger "github.com/sirupsen/logrus" -) - -var srvRegex = regexp.MustCompile(`^_(.+?)\._(.+?)\.(.+)$`) - -func ParseSrv(srv string) (string, string, string, error) { - matches := srvRegex.FindStringSubmatch(srv) - if matches == nil { - errorText := fmt.Sprintf("could not parse %s to SRV parts", srv) - logger.Errorf(errorText) - return "", "", "", errors.New(errorText) - } - return matches[1], matches[2], matches[3], nil -} - -func ServerStringsFromSrv(srv string) ([]string, error) { - service, proto, name, err := ParseSrv(srv) - - if err != nil { - logger.Errorf("failed to parse SRV: %s", err) - return nil, err - } - - _, srvs, err := net.LookupSRV(service, proto, name) - - if err != nil { - logger.Errorf("failed to lookup SRV: %s", err) - return nil, err - } - - logger.Debugf("found %v servers(s) from SRV", len(srvs)) - - serversFromSrv := make([]string, len(srvs)) - for i, srv := range srvs { - server := fmt.Sprintf("%s:%v", srv.Target, srv.Port) - logger.Debugf("server from srv[%v]: %s", i, server) - serversFromSrv[i] = fmt.Sprintf("%s:%v", srv.Target, srv.Port) - } - - return serversFromSrv, nil -} diff --git a/src/storage/factory/memcached_factory.go b/src/storage/factory/memcached_factory.go new file mode 100644 index 000000000..41ac4fb35 --- /dev/null +++ b/src/storage/factory/memcached_factory.go @@ -0,0 +1,139 @@ +package factory + +import ( + "errors" + "fmt" + "net" + "regexp" + "time" + + "github.com/bradfitz/gomemcache/memcache" + "github.com/envoyproxy/ratelimit/src/storage/service" + "github.com/envoyproxy/ratelimit/src/storage/strategy" + "github.com/envoyproxy/ratelimit/src/storage/utils" + + stats "github.com/lyft/gostats" + logger "github.com/sirupsen/logrus" +) + +func NewMemcached(scope stats.Scope, hosts []string, srv string, srvRefresh time.Duration, maxIdleConnection int) strategy.StorageStrategy { + var client service.MemcachedClientInterface + + if srv != "" && len(hosts) > 0 { + panic(utils.MemcacheError("Both MEMCADHE_HOST_PORT and MEMCACHE_SRV are set")) + } + + if srv != "" { + client = newMemcachedClientFromSrv(scope, srv, srvRefresh, maxIdleConnection) + } else { + + } + client = newMemcachedClient(scope, hosts, maxIdleConnection) + return strategy.MemcachedStrategy{ + Client: client, + } +} + +func newMemcachedClient(scope stats.Scope, hosts []string, maxIdleConnection int) service.MemcachedClientInterface { + client := memcache.New(hosts...) + client.MaxIdleConns = maxIdleConnection + stats := service.NewMemcachedStats(scope) + return &service.MemcachedClient{ + Client: client, + Stats: stats, + } +} + +func newMemcachedClientFromSrv(scope stats.Scope, srv string, srvRefresh time.Duration, maxIdleConnection int) service.MemcachedClientInterface { + serverList := new(memcache.ServerList) + err := refreshServers(*serverList, srv) + if err != nil { + errorText := "Unable to fetch servers from SRV" + logger.Errorf(errorText) + panic(utils.MemcacheError(errorText)) + } + + if srvRefresh > 0 { + logger.Infof("refreshing memcache hosts every: %v milliseconds", srvRefresh.Milliseconds()) + finish := make(chan struct{}) + go refreshServersPeriodically(*serverList, srv, srvRefresh, finish) + } else { + logger.Debugf("not periodically refreshing memcached hosts") + } + + stats := service.NewMemcachedStats(scope) + + return &service.MemcachedClient{ + Client: memcache.NewFromSelector(serverList), + Stats: stats, + } +} + +func refreshServers(serverList memcache.ServerList, srv string) error { + servers, err := serverStringsFromSrv(srv) + if err != nil { + return err + } + err = serverList.SetServers(servers...) + if err != nil { + return err + } + return nil +} + +func refreshServersPeriodically(serverList memcache.ServerList, srv string, d time.Duration, finish <-chan struct{}) { + t := time.NewTicker(d) + defer t.Stop() + for { + select { + case <-t.C: + err := refreshServers(serverList, srv) + if err != nil { + logger.Warn("failed to refresh memcahce hosts") + } else { + logger.Debug("refreshed memcache hosts") + } + case <-finish: + return + } + } +} + +var srvRegex = regexp.MustCompile(`^_(.+?)\._(.+?)\.(.+)$`) + +func serverStringsFromSrv(srv string) ([]string, error) { + service, proto, name, err := parseSrv(srv) + + if err != nil { + logger.Errorf("failed to parse SRV: %s", err) + return nil, err + } + + _, srvs, err := net.LookupSRV(service, proto, name) + + if err != nil { + logger.Errorf("failed to lookup SRV: %s", err) + return nil, err + } + + logger.Debugf("found %v servers(s) from SRV", len(srvs)) + + serversFromSrv := make([]string, len(srvs)) + for i, srv := range srvs { + server := fmt.Sprintf("%s:%v", srv.Target, srv.Port) + logger.Debugf("server from srv[%v]: %s", i, server) + serversFromSrv[i] = fmt.Sprintf("%s:%v", srv.Target, srv.Port) + } + + return serversFromSrv, nil +} + +func parseSrv(srv string) (string, string, string, error) { + matches := srvRegex.FindStringSubmatch(srv) + if matches == nil { + errorText := fmt.Sprintf("could not parse %s to SRV parts", srv) + logger.Errorf(errorText) + return "", "", "", errors.New(errorText) + } + return matches[1], matches[2], matches[3], nil +} diff --git a/src/storage/factory/redis_factory.go b/src/storage/factory/redis_factory.go new file mode 100644 index 000000000..cf4785d38 --- /dev/null +++ b/src/storage/factory/redis_factory.go @@ -0,0 +1,96 @@ +package factory + +import ( + "crypto/tls" + "fmt" + "strings" + "time" + + stats "github.com/lyft/gostats" + logger "github.com/sirupsen/logrus" + + "github.com/envoyproxy/ratelimit/src/storage/service" + "github.com/envoyproxy/ratelimit/src/storage/strategy" + "github.com/envoyproxy/ratelimit/src/storage/utils" + "github.com/mediocregopher/radix/v3" +) + +func NewRedis(scope stats.Scope, useTls bool, auth string, redisType string, url string, poolSize int, + pipelineWindow time.Duration, pipelineLimit int) strategy.StorageStrategy { + client := newRedisClient(scope, useTls, auth, redisType, url, poolSize, pipelineWindow, pipelineLimit) + return strategy.RedisStrategy{ + Client: client, + } +} + +func newRedisClient(scope stats.Scope, useTls bool, auth string, redisType string, url string, poolSize int, pipelineWindow time.Duration, pipelineLimit int) service.RedisClientInterface { + logger.Warnf("connecting to redis on %s with pool size %d", url, poolSize) + + df := func(network, addr string) (radix.Conn, error) { + var dialOpts []radix.DialOpt + + if useTls { + dialOpts = append(dialOpts, radix.DialUseTLS(&tls.Config{})) + } + + if auth != "" { + logger.Warnf("enabling authentication to redis on %s", url) + + dialOpts = append(dialOpts, radix.DialAuthPass(auth)) + } + + return radix.Dial(network, addr, dialOpts...) + } + + stats := service.NewRedisStats(scope) + opts := []radix.PoolOpt{radix.PoolConnFunc(df), radix.PoolWithTrace(service.PoolTrace(&stats))} + + implicitPipelining := true + if pipelineWindow == 0 && pipelineLimit == 0 { + implicitPipelining = false + } else { + opts = append(opts, radix.PoolPipelineWindow(pipelineWindow, pipelineLimit)) + } + logger.Debugf("Implicit pipelining enabled: %v", implicitPipelining) + + poolFunc := func(network, addr string) (radix.Client, error) { + return radix.NewPool(network, addr, poolSize, opts...) + } + + var client radix.Client + var err error + switch strings.ToLower(redisType) { + case "single": + client, err = poolFunc("tcp", url) + case "cluster": + urls := strings.Split(url, ",") + if implicitPipelining == false { + panic(utils.RedisError("Implicit Pipelining must be enabled to work with Redis Cluster Mode. Set values for REDIS_PIPELINE_WINDOW or REDIS_PIPELINE_LIMIT to enable implicit pipelining")) + } + logger.Warnf("Creating cluster with urls %v", urls) + client, err = radix.NewCluster(urls, radix.ClusterPoolFunc(poolFunc)) + case "sentinel": + urls := strings.Split(url, ",") + if len(urls) < 2 { + panic(utils.RedisError("Expected master name and a list of urls for the sentinels, in the format: ,,...,")) + } + client, err = radix.NewSentinel(urls[0], urls[1:], radix.SentinelPoolFunc(poolFunc)) + default: + panic(utils.RedisError("Unrecognized redis type " + redisType)) + } + + utils.CheckError(err) + + // Check if connection is good + var pingResponse string + utils.CheckError(client.Do(radix.Cmd(&pingResponse, "PING"))) + if pingResponse != "PONG" { + utils.CheckError(fmt.Errorf("connecting redis error: %s", pingResponse)) + } + + return &service.RedisClient{ + Client: client, + Stats: stats, + ImplicitPipelining: implicitPipelining, + } +} diff --git a/src/storage/service/memcached_client.go b/src/storage/service/memcached_client.go new file mode 100644 index 000000000..0346edeb3 --- /dev/null +++ b/src/storage/service/memcached_client.go @@ -0,0 +1,55 @@ +package service + +import ( + "github.com/bradfitz/gomemcache/memcache" +) + +// Client interface for memcached +type MemcachedClientInterface interface { + Get(key string) (*memcache.Item, error) + Set(item *memcache.Item) error + Increment(key string, delta uint64) (uint64, error) +} + +type MemcachedClient struct { + Client *memcache.Client + Stats MemcachedStats +} + +func (m MemcachedClient) Get(key string) (*memcache.Item, error) { + m.Stats.keysRequested.Inc() + items, err := m.Client.Get(key) + if err != nil { + m.Stats.GetError.Inc() + } else { + m.Stats.keysFound.Inc() + m.Stats.GetSuccess.Inc() + } + + return items, err +} + +func (m MemcachedClient) Set(item *memcache.Item) error { + err := m.Client.Set(item) + if err != nil { + m.Stats.SetError.Inc() + } else { + m.Stats.SetSuccess.Inc() + } + + return err +} + +func (m MemcachedClient) Increment(key string, delta uint64) (uint64, error) { + newValue, err := m.Client.Increment(key, delta) + switch err { + case memcache.ErrCacheMiss: + m.Stats.IncrementMiss.Inc() + case nil: + m.Stats.IncrementSuccess.Inc() + default: + m.Stats.IncrementError.Inc() + } + + return newValue, err +} diff --git a/src/storage/service/memcached_stats.go b/src/storage/service/memcached_stats.go new file mode 100644 index 000000000..ceb2917c0 --- /dev/null +++ b/src/storage/service/memcached_stats.go @@ -0,0 +1,29 @@ +package service + +import stats "github.com/lyft/gostats" + +type MemcachedStats struct { + GetSuccess stats.Counter + GetError stats.Counter + SetSuccess stats.Counter + SetError stats.Counter + IncrementSuccess stats.Counter + IncrementMiss stats.Counter + IncrementError stats.Counter + keysRequested stats.Counter + keysFound stats.Counter +} + +func NewMemcachedStats(scope stats.Scope) MemcachedStats { + return MemcachedStats{ + GetSuccess: scope.NewCounterWithTags("get", map[string]string{"code": "success"}), + GetError: scope.NewCounterWithTags("get", map[string]string{"code": "error"}), + SetSuccess: scope.NewCounterWithTags("set", map[string]string{"code": "success"}), + SetError: scope.NewCounterWithTags("set", map[string]string{"code": "error"}), + IncrementSuccess: scope.NewCounterWithTags("increment", map[string]string{"code": "success"}), + IncrementMiss: scope.NewCounterWithTags("increment", map[string]string{"code": "miss"}), + IncrementError: scope.NewCounterWithTags("increment", map[string]string{"code": "error"}), + keysRequested: scope.NewCounter("keys_requested"), + keysFound: scope.NewCounter("keys_found"), + } +} diff --git a/src/storage/service/redis_client.go b/src/storage/service/redis_client.go new file mode 100644 index 000000000..b54ab2d6b --- /dev/null +++ b/src/storage/service/redis_client.go @@ -0,0 +1,20 @@ +package service + +import ( + "github.com/mediocregopher/radix/v3" +) + +// Client interface for Redis +type RedisClientInterface interface { + Do(radix.Action) error +} + +type RedisClient struct { + Client radix.Client + Stats RedisStats + ImplicitPipelining bool +} + +func (r RedisClient) Do(cmd radix.Action) error { + return r.Client.Do(cmd) +} diff --git a/src/storage/service/redis_stats.go b/src/storage/service/redis_stats.go new file mode 100644 index 000000000..8b5448cd5 --- /dev/null +++ b/src/storage/service/redis_stats.go @@ -0,0 +1,33 @@ +package service + +import ( + stats "github.com/lyft/gostats" + "github.com/mediocregopher/radix/v3/trace" +) + +type RedisStats struct { + connectionActive stats.Gauge + connectionTotal stats.Counter + connectionClose stats.Counter +} + +func PoolTrace(ps *RedisStats) trace.PoolTrace { + return trace.PoolTrace{ + ConnCreated: func(_ trace.PoolConnCreated) { + ps.connectionTotal.Add(1) + ps.connectionActive.Add(1) + }, + ConnClosed: func(_ trace.PoolConnClosed) { + ps.connectionActive.Sub(1) + ps.connectionClose.Add(1) + }, + } +} + +func NewRedisStats(scope stats.Scope) RedisStats { + ret := RedisStats{} + ret.connectionActive = scope.NewGauge("cx_active") + ret.connectionTotal = scope.NewCounter("cx_total") + ret.connectionClose = scope.NewCounter("cx_local_close") + return ret +} diff --git a/src/storage/strategy/memcached_strategy.go b/src/storage/strategy/memcached_strategy.go new file mode 100644 index 000000000..deb4cfe7d --- /dev/null +++ b/src/storage/strategy/memcached_strategy.go @@ -0,0 +1,54 @@ +package strategy + +import ( + "strconv" + + "github.com/bradfitz/gomemcache/memcache" + "github.com/envoyproxy/ratelimit/src/storage/service" +) + +type MemcachedStrategy struct { + Client service.MemcachedClientInterface +} + +func (m MemcachedStrategy) GetValue(key string) (uint64, error) { + item, err := m.Client.Get(key) + if err != nil { + return 0, err + } + + value, err := strconv.ParseUint(string(item.Value), 10, 32) + if err != nil { + return 0, err + } + + return value, nil +} + +func (m MemcachedStrategy) SetValue(key string, value uint64, expirationSeconds uint64) error { + item := &memcache.Item{ + Key: key, + Value: []byte(strconv.FormatUint(value, 10)), + Expiration: int32(expirationSeconds), + } + + err := m.Client.Set(item) + if err != nil { + return err + } + + return nil +} + +func (m MemcachedStrategy) IncrementValue(key string, delta uint64) error { + _, err := m.Client.Increment(key, delta) + if err != nil { + return err + } + + return nil +} + +func (m MemcachedStrategy) SetExpire(key string, expirationSeconds uint64) error { + return nil +} diff --git a/src/storage/strategy/redis_strategy.go b/src/storage/strategy/redis_strategy.go new file mode 100644 index 000000000..3c0f6f0f0 --- /dev/null +++ b/src/storage/strategy/redis_strategy.go @@ -0,0 +1,53 @@ +package strategy + +import ( + "github.com/envoyproxy/ratelimit/src/storage/service" + "github.com/mediocregopher/radix/v3" +) + +type RedisStrategy struct { + Client service.RedisClientInterface +} + +func (r RedisStrategy) GetValue(key string) (uint64, error) { + var value uint64 + err := r.Client.Do(radix.Cmd(&value, "GET", key)) + if err != nil { + return value, err + } + + return value, nil +} + +func (r RedisStrategy) SetValue(key string, value uint64, expirationSeconds uint64) error { + + err := r.Client.Do(radix.FlatCmd(nil, "SET", key, value)) + if err != nil { + return err + } + + err = r.Client.Do(radix.FlatCmd(nil, "EXPIRE", key, expirationSeconds)) + if err != nil { + return err + } + + return nil +} + +func (r RedisStrategy) IncrementValue(key string, delta uint64) error { + err := r.Client.Do(radix.FlatCmd(nil, "INCRBY", key, delta)) + if err != nil { + return err + } + + return nil +} + +func (r RedisStrategy) SetExpire(key string, expirationSeconds uint64) error { + err := r.Client.Do(radix.FlatCmd(nil, "EXPIRE", key, expirationSeconds)) + if err != nil { + return err + } + + return nil +} diff --git a/src/storage/strategy/storage_strategy.go b/src/storage/strategy/storage_strategy.go new file mode 100644 index 000000000..b45669940 --- /dev/null +++ b/src/storage/strategy/storage_strategy.go @@ -0,0 +1,10 @@ +package strategy + +// Interface to abstract underlying storage like memcached and redis +// Implement bussiness level where we don't care how underlying storage doing it.\ +type StorageStrategy interface { + GetValue(key string) (uint64, error) + SetValue(key string, value uint64, expirationSeconds uint64) error + IncrementValue(key string, delta uint64) error + SetExpire(key string, expirationSeconds uint64) error +} diff --git a/src/storage/utils/utils.go b/src/storage/utils/utils.go new file mode 100644 index 000000000..205c9c38d --- /dev/null +++ b/src/storage/utils/utils.go @@ -0,0 +1,19 @@ +package utils + +type RedisError string + +func (e RedisError) Error() string { + return string(e) +} + +func CheckError(err error) { + if err != nil { + panic(RedisError(err.Error())) + } +} + +type MemcacheError string + +func (e MemcacheError) Error() string { + return string(e) +} diff --git a/test/memcached/cache_impl_test.go b/test/memcached/fixed_cache_impl_test.go similarity index 56% rename from test/memcached/cache_impl_test.go rename to test/memcached/fixed_cache_impl_test.go index d663f6750..4c910a235 100644 --- a/test/memcached/cache_impl_test.go +++ b/test/memcached/fixed_cache_impl_test.go @@ -1,31 +1,23 @@ -// Adapted from test/redis/cache_impl_test.go, with most test cases being the same -// basic idea. TestMemcacheAdd() is unique to the memcache tests, since redis can create a new key -// simply by incrementing it but memcached cannot. In memcache new keys need to be explicitly -// added. package memcached_test import ( - mockstats "github.com/envoyproxy/ratelimit/test/mocks/stats" "math/rand" - "strconv" "testing" - "github.com/bradfitz/gomemcache/memcache" "github.com/coocood/freecache" - - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/memcached" - "github.com/envoyproxy/ratelimit/src/settings" "github.com/envoyproxy/ratelimit/src/utils" - stats "github.com/lyft/gostats" - "github.com/envoyproxy/ratelimit/test/common" - mock_memcached "github.com/envoyproxy/ratelimit/test/mocks/memcached" - mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + "github.com/envoyproxy/ratelimit/test/mocks/stats" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + mock_strategy "github.com/envoyproxy/ratelimit/test/mocks/storage/strategy" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + gostats "github.com/lyft/gostats" ) func TestMemcached(t *testing.T) { @@ -33,17 +25,16 @@ func TestMemcached(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() + client := mock_strategy.NewMockStorageStrategy(controller) timeSource := mock_utils.NewMockTimeSource(controller) - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "") + statsStore := gostats.NewStore(gostats.NewNullSink(), false) + sm := stats.NewMockStatManager(statsStore) + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, rand.New(rand.NewSource(1)), nil, 0, 0.8, "", sm) timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return( - getMultiResult(map[string]int{"domain_key_value_1234": 4}), nil, - ) - client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return(uint64(5), nil) + client.EXPECT().GetValue("domain_key_value_1234").Return(uint64(4), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key_value_1234", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key_value_1234", uint64(1)).MaxTimes(1) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"))} @@ -57,10 +48,9 @@ func TestMemcached(t *testing.T) { assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key2_value2_subkey2_subvalue2_1200"}).Return( - getMultiResult(map[string]int{"domain_key2_value2_subkey2_subvalue2_1200": 10}), nil, - ) - client.EXPECT().Increment("domain_key2_value2_subkey2_subvalue2_1200", uint64(1)).Return(uint64(11), nil) + client.EXPECT().GetValue("domain_key2_value2_subkey2_subvalue2_1200").Return(uint64(10), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key2_value2_subkey2_subvalue2_1200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key2_value2_subkey2_subvalue2_1200", uint64(60)).MaxTimes(1) request = common.NewRateLimitRequest( "domain", @@ -81,17 +71,12 @@ func TestMemcached(t *testing.T) { assert.Equal(uint64(0), limits[1].Stats.WithinLimit.Value()) timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(5) - client.EXPECT().GetMulti([]string{ - "domain_key3_value3_997200", - "domain_key3_value3_subkey3_subvalue3_950400", - }).Return( - getMultiResult(map[string]int{ - "domain_key3_value3_997200": 10, - "domain_key3_value3_subkey3_subvalue3_950400": 12}), - nil, - ) - client.EXPECT().Increment("domain_key3_value3_997200", uint64(1)).Return(uint64(11), nil) - client.EXPECT().Increment("domain_key3_value3_subkey3_subvalue3_950400", uint64(1)).Return(uint64(13), nil) + client.EXPECT().GetValue("domain_key3_value3_997200").Return(uint64(10), nil).MaxTimes(1) + client.EXPECT().GetValue("domain_key3_value3_subkey3_subvalue3_950400").Return(uint64(12), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key3_value3_997200", uint64(1)).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key3_value3_subkey3_subvalue3_950400", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key3_value3_997200", uint64(3600)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key3_value3_subkey3_subvalue3_950400", uint64(86400)).MaxTimes(1) request = common.NewRateLimitRequest( "domain", @@ -115,60 +100,9 @@ func TestMemcached(t *testing.T) { assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) - - cache.Flush() } -func TestMemcachedGetError(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - timeSource := mock_utils.NewMockTimeSource(controller) - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "") - - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return( - nil, memcache.ErrNoServers, - ) - client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return(uint64(5), nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"))} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - // No error, but the key is missing - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key_value1_1234"}).Return( - nil, nil, - ) - client.EXPECT().Increment("domain_key_value1_1234", uint64(1)).Return(uint64(5), nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value1"}}}, 1) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value1"))} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - cache.Flush() -} - -func testLocalCacheStats(localCacheStats stats.StatGenerator, statsStore stats.Store, sink *common.TestStatSink, +func testLocalCacheStats(localCacheStats gostats.StatGenerator, statsStore gostats.Store, sink *common.TestStatSink, expectedHitCount int, expectedMissCount int, expectedLookUpCount int, expectedExpiredCount int, expectedEntryCount int) func(*testing.T) { return func(t *testing.T) { @@ -209,24 +143,22 @@ func TestOverLimitWithLocalCache(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() + client := mock_strategy.NewMockStorageStrategy(controller) timeSource := mock_utils.NewMockTimeSource(controller) - client := mock_memcached.NewMockClient(controller) localCache := freecache.NewCache(100) + statsStore := gostats.NewStore(gostats.NewNullSink(), false) + sm := stats.NewMockStatManager(statsStore) + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, rand.New(rand.NewSource(1)), localCache, 0, 0.8, "", sm) sink := &common.TestStatSink{} - statsStore := stats.NewStore(sink, true) - sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, localCache, sm, 0.8, "") localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) // Test Near Limit Stats. Under Near Limit Ratio timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( - getMultiResult(map[string]int{"domain_key4_value4_997200": 10}), nil, - ) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(5), nil) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(10), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) - limits := []*config.RateLimit{ config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"))} @@ -245,10 +177,9 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Test Near Limit Stats. At Near Limit Ratio, still OK timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( - getMultiResult(map[string]int{"domain_key4_value4_997200": 12}), nil, - ) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(13), nil) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(12), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ @@ -265,10 +196,9 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Test Over limit stats timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( - getMultiResult(map[string]int{"domain_key4_value4_997200": 15}), nil, - ) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(16), nil) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(15), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ @@ -285,8 +215,9 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Test Over limit stats with local cache timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Times(0) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Times(0) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) + assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, @@ -299,8 +230,6 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Check the local cache stats. testLocalCacheStats(localCacheStats, statsStore, sink, 1, 3, 4, 0, 1) - - cache.Flush() } func TestNearLimit(t *testing.T) { @@ -308,18 +237,17 @@ func TestNearLimit(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() + client := mock_strategy.NewMockStorageStrategy(controller) timeSource := mock_utils.NewMockTimeSource(controller) - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "") + statsStore := gostats.NewStore(gostats.NewNullSink(), false) + sm := stats.NewMockStatManager(statsStore) + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, rand.New(rand.NewSource(1)), nil, 0, 0.8, "", sm) // Test Near Limit Stats. Under Near Limit Ratio timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( - getMultiResult(map[string]int{"domain_key4_value4_997200": 10}), nil, - ) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(11), nil) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(10), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) @@ -337,10 +265,9 @@ func TestNearLimit(t *testing.T) { // Test Near Limit Stats. At Near Limit Ratio, still OK timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( - getMultiResult(map[string]int{"domain_key4_value4_997200": 12}), nil, - ) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(13), nil) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(12), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ @@ -354,10 +281,9 @@ func TestNearLimit(t *testing.T) { // Test Near Limit Stats. We went OVER_LIMIT, but the near_limit counter only increases // when we are near limit, not after we have passed the limit. timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( - getMultiResult(map[string]int{"domain_key4_value4_997200": 15}), nil, - ) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(16), nil) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(15), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ @@ -371,10 +297,9 @@ func TestNearLimit(t *testing.T) { // Now test hitsAddend that is greater than 1 // All of it under limit, under near limit timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key5_value5_1234"}).Return( - getMultiResult(map[string]int{"domain_key5_value5_1234": 2}), nil, - ) - client.EXPECT().Increment("domain_key5_value5_1234", uint64(3)).Return(uint64(5), nil) + client.EXPECT().GetValue("domain_key5_value5_1234").Return(uint64(2), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key5_value5_1234", uint64(3)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key5_value5_1234", uint64(1)).MaxTimes(1) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key5", "value5"}}}, 3) limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key5_value5"))} @@ -389,10 +314,9 @@ func TestNearLimit(t *testing.T) { // All of it under limit, some over near limit timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key6_value6_1234"}).Return( - getMultiResult(map[string]int{"domain_key6_value6_1234": 5}), nil, - ) - client.EXPECT().Increment("domain_key6_value6_1234", uint64(2)).Return(uint64(7), nil) + client.EXPECT().GetValue("domain_key6_value6_1234").Return(uint64(5), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key6_value6_1234", uint64(2)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key6_value6_1234", uint64(1)).MaxTimes(1) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key6", "value6"}}}, 2) limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key6_value6"))} @@ -407,10 +331,9 @@ func TestNearLimit(t *testing.T) { // All of it under limit, all of it over near limit timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key7_value7_1234"}).Return( - getMultiResult(map[string]int{"domain_key7_value7_1234": 16}), nil, - ) - client.EXPECT().Increment("domain_key7_value7_1234", uint64(3)).Return(uint64(19), nil) + client.EXPECT().GetValue("domain_key7_value7_1234").Return(uint64(16), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key7_value7_1234", uint64(3)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key7_value7_1234", uint64(1)).MaxTimes(1) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key7", "value7"}}}, 3) limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key7_value7"))} @@ -425,10 +348,9 @@ func TestNearLimit(t *testing.T) { // Some of it over limit, all of it over near limit timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key8_value8_1234"}).Return( - getMultiResult(map[string]int{"domain_key8_value8_1234": 19}), nil, - ) - client.EXPECT().Increment("domain_key8_value8_1234", uint64(3)).Return(uint64(22), nil) + client.EXPECT().GetValue("domain_key8_value8_1234").Return(uint64(19), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key8_value8_1234", uint64(3)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key8_value8_1234", uint64(1)).MaxTimes(1) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key8", "value8"}}}, 3) limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key8_value8"))} @@ -443,10 +365,9 @@ func TestNearLimit(t *testing.T) { // Some of it in all three places timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key9_value9_1234"}).Return( - getMultiResult(map[string]int{"domain_key9_value9_1234": 15}), nil, - ) - client.EXPECT().Increment("domain_key9_value9_1234", uint64(7)).Return(uint64(22), nil) + client.EXPECT().GetValue("domain_key9_value9_1234").Return(uint64(15), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key9_value9_1234", uint64(7)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key9_value9_1234", uint64(1)).MaxTimes(1) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key9", "value9"}}}, 7) limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key9_value9"))} @@ -461,10 +382,9 @@ func TestNearLimit(t *testing.T) { // all of it over limit timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key10_value10_1234"}).Return( - getMultiResult(map[string]int{"domain_key10_value10_1234": 27}), nil, - ) - client.EXPECT().Increment("domain_key10_value10_1234", uint64(3)).Return(uint64(30), nil) + client.EXPECT().GetValue("domain_key10_value10_1234").Return(uint64(27), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key10_value10_1234", uint64(3)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key10_value10_1234", uint64(1)).MaxTimes(1) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key10", "value10"}}}, 3) limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key10_value10"))} @@ -476,166 +396,34 @@ func TestNearLimit(t *testing.T) { assert.Equal(uint64(3), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) - - cache.Flush() } -func TestMemcacheWithJitter(t *testing.T) { +func TestMemcachedWithJitter(t *testing.T) { assert := assert.New(t) controller := gomock.NewController(t) defer controller.Finish() + client := mock_strategy.NewMockStorageStrategy(controller) timeSource := mock_utils.NewMockTimeSource(controller) - client := mock_memcached.NewMockClient(controller) jitterSource := mock_utils.NewMockJitterRandSource(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, rand.New(jitterSource), 3600, nil, sm, 0.8, "") - - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - jitterSource.EXPECT().Int63().Return(int64(100)) - - // Key is not found in memcache - client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return(nil, nil) - // First increment attempt will fail - client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return( - uint64(0), memcache.ErrCacheMiss) - // Add succeeds - client.EXPECT().Add( - &memcache.Item{ - Key: "domain_key_value_1234", - Value: []byte(strconv.FormatUint(1, 10)), - // 1 second + 100 seconds of jitter - Expiration: int32(101), - }, - ).Return(nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"))} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - cache.Flush() -} - -func TestMemcacheAdd(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - timeSource := mock_utils.NewMockTimeSource(controller) - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "") + statsStore := gostats.NewStore(gostats.NewNullSink(), false) + sm := stats.NewMockStatManager(statsStore) + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, rand.New(jitterSource), nil, 3600, 0.8, "", sm) - // Test a race condition with the initial add timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - - client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return(nil, nil) - client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return( - uint64(0), memcache.ErrCacheMiss) - // Add fails, must have been a race condition - client.EXPECT().Add( - &memcache.Item{ - Key: "domain_key_value_1234", - Value: []byte(strconv.FormatUint(1, 10)), - Expiration: int32(1), - }, - ).Return(memcache.ErrNotStored) - // Should work the second time, since some other client added the key. - client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return( - uint64(2), nil) + jitterSource.EXPECT().Int63().Return(int64(100)).MaxTimes(1) + client.EXPECT().GetValue("domain_key_value_1234").Return(uint64(4), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key_value_1234", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key_value_1234", uint64(101)).MaxTimes(1) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"))} assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - // A rate limit with 1-minute window - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key2_value2_1200"}).Return(nil, nil) - client.EXPECT().Increment("domain_key2_value2_1200", uint64(1)).Return( - uint64(0), memcache.ErrCacheMiss) - client.EXPECT().Add( - &memcache.Item{ - Key: "domain_key2_value2_1200", - Value: []byte(strconv.FormatUint(1, 10)), - Expiration: int32(60), - }, - ).Return(nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key2", "value2"}}}, 1) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, sm.NewStats("key2_value2"))} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, cache.DoLimit(nil, request, limits)) assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - cache.Flush() -} - -func TestNewRateLimitCacheImplFromSettingsWhenSrvCannotBeResolved(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - timeSource := mock_utils.NewMockTimeSource(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - - var s settings.Settings - s.NearLimitRatio = 0.8 - s.CacheKeyPrefix = "" - s.ExpirationJitterMaxSeconds = 300 - s.MemcacheSrv = "_something._tcp.example.invalid" - - assert.Panics(func() { - memcached.NewRateLimitCacheImplFromSettings(s, timeSource, nil, nil, statsStore, mockstats.NewMockStatManager(statsStore)) - }) -} - -func TestNewRateLimitCacheImplFromSettingsWhenHostAndPortAndSrvAreBothSet(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - timeSource := mock_utils.NewMockTimeSource(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - - var s settings.Settings - s.NearLimitRatio = 0.8 - s.CacheKeyPrefix = "" - s.ExpirationJitterMaxSeconds = 300 - s.MemcacheSrv = "_something._tcp.example.invalid" - s.MemcacheHostPort = []string{"example.org:11211"} - - assert.Panics(func() { - memcached.NewRateLimitCacheImplFromSettings(s, timeSource, nil, nil, statsStore, mockstats.NewMockStatManager(statsStore)) - }) -} - -func getMultiResult(vals map[string]int) map[string]*memcache.Item { - result := make(map[string]*memcache.Item, len(vals)) - for k, v := range vals { - result[k] = &memcache.Item{ - Value: []byte(strconv.Itoa(v)), - } - } - return result } diff --git a/test/memcached/stats_collecting_client_test.go b/test/memcached/stats_collecting_client_test.go deleted file mode 100644 index 548b93041..000000000 --- a/test/memcached/stats_collecting_client_test.go +++ /dev/null @@ -1,199 +0,0 @@ -package memcached_test - -import ( - "errors" - "testing" - - "github.com/bradfitz/gomemcache/memcache" - "github.com/envoyproxy/ratelimit/src/memcached" - mock_memcached "github.com/envoyproxy/ratelimit/test/mocks/memcached" - "github.com/golang/mock/gomock" - stats "github.com/lyft/gostats" - "github.com/stretchr/testify/assert" -) - -type fakeSink struct { - values map[string]uint64 -} - -func (fs *fakeSink) FlushCounter(name string, value uint64) { - if _, ok := fs.values[name]; ok { - panic(errors.New("fakeSink wasn't cleared before flushing again")) - } - - fs.values[name] = value -} - -func (fs *fakeSink) FlushGauge(name string, value uint64) {} - -func (fs *fakeSink) FlushTimer(name string, value float64) {} - -func (fs *fakeSink) Flush() {} - -func (fs *fakeSink) Reset() { - fs.values = make(map[string]uint64) -} - -func TestStats_MultiGet(t *testing.T) { - fakeSink := &fakeSink{} - fakeSink.Reset() - - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(fakeSink, false) - - sc := memcached.CollectStats(client, statsStore) - - returnValue := map[string]*memcache.Item{"foo": nil} - arg := []string{"foo"} - - client.EXPECT().GetMulti(arg).Return(returnValue, nil) - actualReturnValue, err := sc.GetMulti(arg) - statsStore.Flush() - - assert.Equal(returnValue, actualReturnValue) - assert.Nil(err) - assert.Equal(map[string]uint64{ - "keys_found": 1, - "keys_requested": 1, - "multiget.__code=success": 1, - }, fakeSink.values) - - fakeSink.Reset() - returnValue = map[string]*memcache.Item{"foo": nil, "bar": nil} - client.EXPECT().GetMulti(arg).Return(returnValue, nil) - actualReturnValue, err = sc.GetMulti(arg) - statsStore.Flush() - - assert.Equal(returnValue, actualReturnValue) - assert.Nil(err) - assert.Equal(map[string]uint64{ - "keys_found": 2, - "keys_requested": 1, - "multiget.__code=success": 1, - }, fakeSink.values) - - fakeSink.Reset() - returnValue = map[string]*memcache.Item{} - arg = []string{"foo", "bar"} - - client.EXPECT().GetMulti(arg).Return(returnValue, nil) - actualReturnValue, err = sc.GetMulti(arg) - - statsStore.Flush() - assert.Equal(returnValue, actualReturnValue) - assert.Nil(err) - - assert.Equal(map[string]uint64{ - "keys_requested": 2, - "multiget.__code=success": 1, - }, fakeSink.values) - - fakeSink.Reset() - returnValue = map[string]*memcache.Item{"ignored": nil} - arg = []string{"foo"} - returnedErr := errors.New("Random error") - - client.EXPECT().GetMulti(arg).Return(returnValue, returnedErr) - actualReturnValue, err = sc.GetMulti(arg) - - statsStore.Flush() - assert.Equal(returnValue, actualReturnValue) - assert.Equal(returnedErr, err) - - assert.Equal(map[string]uint64{ - "keys_requested": 1, - "multiget.__code=error": 1, - }, fakeSink.values) -} - -func TestStats_Increment(t *testing.T) { - fakeSink := &fakeSink{} - - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(fakeSink, false) - - sc := memcached.CollectStats(client, statsStore) - - fakeSink.Reset() - client.EXPECT().Increment("foo", uint64(5)).Return(uint64(6), nil) - newValue, err := sc.Increment("foo", 5) - statsStore.Flush() - - assert.Equal(uint64(6), newValue) - assert.Nil(err) - assert.Equal(map[string]uint64{ - "increment.__code=success": 1, - }, fakeSink.values) - - expectedErr := errors.New("expectedError") - fakeSink.Reset() - client.EXPECT().Increment("foo", uint64(5)).Return(uint64(0), expectedErr) - newValue, err = sc.Increment("foo", 5) - statsStore.Flush() - - assert.Equal(expectedErr, err) - assert.Equal(map[string]uint64{ - "increment.__code=error": 1, - }, fakeSink.values) - - fakeSink.Reset() - client.EXPECT().Increment("foo", uint64(5)).Return(uint64(0), memcache.ErrCacheMiss) - newValue, err = sc.Increment("foo", 5) - statsStore.Flush() - - assert.Equal(memcache.ErrCacheMiss, err) - assert.Equal(map[string]uint64{ - "increment.__code=miss": 1, - }, fakeSink.values) -} - -func TestStats_Add(t *testing.T) { - fakeSink := &fakeSink{} - - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(fakeSink, false) - - sc := memcached.CollectStats(client, statsStore) - item := &memcache.Item{} - - fakeSink.Reset() - client.EXPECT().Add(item).Return(nil) - err := sc.Add(item) - statsStore.Flush() - - assert.Nil(err) - assert.Equal(map[string]uint64{ - "add.__code=success": 1, - }, fakeSink.values) - - expectedErr := errors.New("expected err") - - fakeSink.Reset() - client.EXPECT().Add(item).Return(expectedErr) - err = sc.Add(item) - statsStore.Flush() - - assert.Equal(expectedErr, err) - assert.Equal(map[string]uint64{ - "add.__code=error": 1, - }, fakeSink.values) - - fakeSink.Reset() - client.EXPECT().Add(item).Return(memcache.ErrNotStored) - err = sc.Add(item) - statsStore.Flush() - - assert.Equal(memcache.ErrNotStored, err) - assert.Equal(map[string]uint64{ - "add.__code=not_stored": 1, - }, fakeSink.values) -} diff --git a/test/mocks/redis/redis.go b/test/mocks/redis/redis.go deleted file mode 100644 index 032b500dc..000000000 --- a/test/mocks/redis/redis.go +++ /dev/null @@ -1,128 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/envoyproxy/ratelimit/src/redis (interfaces: Client) - -// Package mock_redis is a generated GoMock package. -package mock_redis - -import ( - redis "github.com/envoyproxy/ratelimit/src/redis" - gomock "github.com/golang/mock/gomock" - reflect "reflect" -) - -// MockClient is a mock of Client interface -type MockClient struct { - ctrl *gomock.Controller - recorder *MockClientMockRecorder -} - -// MockClientMockRecorder is the mock recorder for MockClient -type MockClientMockRecorder struct { - mock *MockClient -} - -// NewMockClient creates a new mock instance -func NewMockClient(ctrl *gomock.Controller) *MockClient { - mock := &MockClient{ctrl: ctrl} - mock.recorder = &MockClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockClient) EXPECT() *MockClientMockRecorder { - return m.recorder -} - -// Close mocks base method -func (m *MockClient) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close -func (mr *MockClientMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockClient)(nil).Close)) -} - -// DoCmd mocks base method -func (m *MockClient) DoCmd(arg0 interface{}, arg1, arg2 string, arg3 ...interface{}) error { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "DoCmd", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// DoCmd indicates an expected call of DoCmd -func (mr *MockClientMockRecorder) DoCmd(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoCmd", reflect.TypeOf((*MockClient)(nil).DoCmd), varargs...) -} - -// ImplicitPipeliningEnabled mocks base method -func (m *MockClient) ImplicitPipeliningEnabled() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ImplicitPipeliningEnabled") - ret0, _ := ret[0].(bool) - return ret0 -} - -// ImplicitPipeliningEnabled indicates an expected call of ImplicitPipeliningEnabled -func (mr *MockClientMockRecorder) ImplicitPipeliningEnabled() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ImplicitPipeliningEnabled", reflect.TypeOf((*MockClient)(nil).ImplicitPipeliningEnabled)) -} - -// NumActiveConns mocks base method -func (m *MockClient) NumActiveConns() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NumActiveConns") - ret0, _ := ret[0].(int) - return ret0 -} - -// NumActiveConns indicates an expected call of NumActiveConns -func (mr *MockClientMockRecorder) NumActiveConns() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NumActiveConns", reflect.TypeOf((*MockClient)(nil).NumActiveConns)) -} - -// PipeAppend mocks base method -func (m *MockClient) PipeAppend(arg0 redis.Pipeline, arg1 interface{}, arg2, arg3 string, arg4 ...interface{}) redis.Pipeline { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "PipeAppend", varargs...) - ret0, _ := ret[0].(redis.Pipeline) - return ret0 -} - -// PipeAppend indicates an expected call of PipeAppend -func (mr *MockClientMockRecorder) PipeAppend(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PipeAppend", reflect.TypeOf((*MockClient)(nil).PipeAppend), varargs...) -} - -// PipeDo mocks base method -func (m *MockClient) PipeDo(arg0 redis.Pipeline) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PipeDo", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// PipeDo indicates an expected call of PipeDo -func (mr *MockClientMockRecorder) PipeDo(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PipeDo", reflect.TypeOf((*MockClient)(nil).PipeDo), arg0) -} diff --git a/test/mocks/storage/service/memcached_client_mock.go b/test/mocks/storage/service/memcached_client_mock.go new file mode 100644 index 000000000..7ffd6253a --- /dev/null +++ b/test/mocks/storage/service/memcached_client_mock.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./src/storage/service/memcached_client.go + +// Package mock_service is a generated GoMock package. +package mock_service + +import ( + memcache "github.com/bradfitz/gomemcache/memcache" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockMemcachedClientInterface is a mock of MemcachedClientInterface interface +type MockMemcachedClientInterface struct { + ctrl *gomock.Controller + recorder *MockMemcachedClientInterfaceMockRecorder +} + +// MockMemcachedClientInterfaceMockRecorder is the mock recorder for MockMemcachedClientInterface +type MockMemcachedClientInterfaceMockRecorder struct { + mock *MockMemcachedClientInterface +} + +// NewMockMemcachedClientInterface creates a new mock instance +func NewMockMemcachedClientInterface(ctrl *gomock.Controller) *MockMemcachedClientInterface { + mock := &MockMemcachedClientInterface{ctrl: ctrl} + mock.recorder = &MockMemcachedClientInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockMemcachedClientInterface) EXPECT() *MockMemcachedClientInterfaceMockRecorder { + return m.recorder +} + +// Get mocks base method +func (m *MockMemcachedClientInterface) Get(key string) (*memcache.Item, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", key) + ret0, _ := ret[0].(*memcache.Item) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get +func (mr *MockMemcachedClientInterfaceMockRecorder) Get(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockMemcachedClientInterface)(nil).Get), key) +} + +// Set mocks base method +func (m *MockMemcachedClientInterface) Set(item *memcache.Item) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", item) + ret0, _ := ret[0].(error) + return ret0 +} + +// Set indicates an expected call of Set +func (mr *MockMemcachedClientInterfaceMockRecorder) Set(item interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockMemcachedClientInterface)(nil).Set), item) +} + +// Increment mocks base method +func (m *MockMemcachedClientInterface) Increment(key string, delta uint64) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Increment", key, delta) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Increment indicates an expected call of Increment +func (mr *MockMemcachedClientInterfaceMockRecorder) Increment(key, delta interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Increment", reflect.TypeOf((*MockMemcachedClientInterface)(nil).Increment), key, delta) +} diff --git a/test/mocks/storage/service/redis_client_mock.go b/test/mocks/storage/service/redis_client_mock.go new file mode 100644 index 000000000..e07cba904 --- /dev/null +++ b/test/mocks/storage/service/redis_client_mock.go @@ -0,0 +1,48 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./src/storage/service/redis_client.go + +// Package mock_service is a generated GoMock package. +package mock_service + +import ( + gomock "github.com/golang/mock/gomock" + radix "github.com/mediocregopher/radix/v3" + reflect "reflect" +) + +// MockRedisClientInterface is a mock of RedisClientInterface interface +type MockRedisClientInterface struct { + ctrl *gomock.Controller + recorder *MockRedisClientInterfaceMockRecorder +} + +// MockRedisClientInterfaceMockRecorder is the mock recorder for MockRedisClientInterface +type MockRedisClientInterfaceMockRecorder struct { + mock *MockRedisClientInterface +} + +// NewMockRedisClientInterface creates a new mock instance +func NewMockRedisClientInterface(ctrl *gomock.Controller) *MockRedisClientInterface { + mock := &MockRedisClientInterface{ctrl: ctrl} + mock.recorder = &MockRedisClientInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockRedisClientInterface) EXPECT() *MockRedisClientInterfaceMockRecorder { + return m.recorder +} + +// Do mocks base method +func (m *MockRedisClientInterface) Do(arg0 radix.Action) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Do", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Do indicates an expected call of Do +func (mr *MockRedisClientInterfaceMockRecorder) Do(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockRedisClientInterface)(nil).Do), arg0) +} diff --git a/test/mocks/storage/strategy/storage_strategy_mock.go b/test/mocks/storage/strategy/storage_strategy_mock.go new file mode 100644 index 000000000..d7418059c --- /dev/null +++ b/test/mocks/storage/strategy/storage_strategy_mock.go @@ -0,0 +1,90 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./src/storage/strategy/storage_strategy.go + +// Package mock_strategy is a generated GoMock package. +package mock_strategy + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockStorageStrategy is a mock of StorageStrategy interface +type MockStorageStrategy struct { + ctrl *gomock.Controller + recorder *MockStorageStrategyMockRecorder +} + +// MockStorageStrategyMockRecorder is the mock recorder for MockStorageStrategy +type MockStorageStrategyMockRecorder struct { + mock *MockStorageStrategy +} + +// NewMockStorageStrategy creates a new mock instance +func NewMockStorageStrategy(ctrl *gomock.Controller) *MockStorageStrategy { + mock := &MockStorageStrategy{ctrl: ctrl} + mock.recorder = &MockStorageStrategyMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStorageStrategy) EXPECT() *MockStorageStrategyMockRecorder { + return m.recorder +} + +// GetValue mocks base method +func (m *MockStorageStrategy) GetValue(key string) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValue", key) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetValue indicates an expected call of GetValue +func (mr *MockStorageStrategyMockRecorder) GetValue(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValue", reflect.TypeOf((*MockStorageStrategy)(nil).GetValue), key) +} + +// SetValue mocks base method +func (m *MockStorageStrategy) SetValue(key string, value, expirationSeconds uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetValue", key, value, expirationSeconds) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetValue indicates an expected call of SetValue +func (mr *MockStorageStrategyMockRecorder) SetValue(key, value, expirationSeconds interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetValue", reflect.TypeOf((*MockStorageStrategy)(nil).SetValue), key, value, expirationSeconds) +} + +// IncrementValue mocks base method +func (m *MockStorageStrategy) IncrementValue(key string, delta uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IncrementValue", key, delta) + ret0, _ := ret[0].(error) + return ret0 +} + +// IncrementValue indicates an expected call of IncrementValue +func (mr *MockStorageStrategyMockRecorder) IncrementValue(key, delta interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementValue", reflect.TypeOf((*MockStorageStrategy)(nil).IncrementValue), key, delta) +} + +// SetExpire mocks base method +func (m *MockStorageStrategy) SetExpire(key string, expirationSeconds uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetExpire", key, expirationSeconds) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetExpire indicates an expected call of SetExpire +func (mr *MockStorageStrategyMockRecorder) SetExpire(key, expirationSeconds interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetExpire", reflect.TypeOf((*MockStorageStrategy)(nil).SetExpire), key, expirationSeconds) +} diff --git a/test/redis/bench_test.go b/test/redis/bench_test.go deleted file mode 100644 index 37bca1848..000000000 --- a/test/redis/bench_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package redis_test - -import ( - "context" - "github.com/envoyproxy/ratelimit/test/mocks/stats" - "runtime" - "testing" - "time" - - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - "github.com/envoyproxy/ratelimit/src/config" - "github.com/envoyproxy/ratelimit/src/redis" - "github.com/envoyproxy/ratelimit/src/utils" - gostats "github.com/lyft/gostats" - - "math/rand" - - "github.com/envoyproxy/ratelimit/test/common" -) - -func BenchmarkParallelDoLimit(b *testing.B) { - b.Skip("Skip benchmark") - - b.ReportAllocs() - - // See https://github.com/mediocregopher/radix/blob/v3.5.1/bench/bench_test.go#L176 - parallel := runtime.GOMAXPROCS(0) - poolSize := parallel * runtime.GOMAXPROCS(0) - - do := func(b *testing.B, fn func() error) { - b.ResetTimer() - b.SetParallelism(parallel) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := fn(); err != nil { - b.Fatal(err) - } - } - }) - } - - mkDoLimitBench := func(pipelineWindow time.Duration, pipelineLimit int) func(*testing.B) { - return func(b *testing.B) { - statsStore := gostats.NewStore(gostats.NewNullSink(), false) - sm := stats.NewMockStatManager(statsStore) - client := redis.NewClientImpl(statsStore, false, "", "single", "127.0.0.1:6379", poolSize, pipelineWindow, pipelineLimit) - defer client.Close() - - cache := redis.NewFixedRateLimitCacheImpl(client, nil, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), 10, nil, 0.8, "", sm) - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(1000000000, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"))} - - // wait for the pool to fill up - for { - time.Sleep(50 * time.Millisecond) - if client.NumActiveConns() >= poolSize { - break - } - } - - b.ResetTimer() - - do(b, func() error { - cache.DoLimit(context.Background(), request, limits) - return nil - }) - } - } - - b.Run("no pipeline", mkDoLimitBench(0, 0)) - - b.Run("pipeline 35us 1", mkDoLimitBench(35*time.Microsecond, 1)) - b.Run("pipeline 75us 1", mkDoLimitBench(75*time.Microsecond, 1)) - b.Run("pipeline 150us 1", mkDoLimitBench(150*time.Microsecond, 1)) - b.Run("pipeline 300us 1", mkDoLimitBench(300*time.Microsecond, 1)) - - b.Run("pipeline 35us 2", mkDoLimitBench(35*time.Microsecond, 2)) - b.Run("pipeline 75us 2", mkDoLimitBench(75*time.Microsecond, 2)) - b.Run("pipeline 150us 2", mkDoLimitBench(150*time.Microsecond, 2)) - b.Run("pipeline 300us 2", mkDoLimitBench(300*time.Microsecond, 2)) - - b.Run("pipeline 35us 4", mkDoLimitBench(35*time.Microsecond, 4)) - b.Run("pipeline 75us 4", mkDoLimitBench(75*time.Microsecond, 4)) - b.Run("pipeline 150us 4", mkDoLimitBench(150*time.Microsecond, 4)) - b.Run("pipeline 300us 4", mkDoLimitBench(300*time.Microsecond, 4)) - - b.Run("pipeline 35us 8", mkDoLimitBench(35*time.Microsecond, 8)) - b.Run("pipeline 75us 8", mkDoLimitBench(75*time.Microsecond, 8)) - b.Run("pipeline 150us 8", mkDoLimitBench(150*time.Microsecond, 8)) - b.Run("pipeline 300us 8", mkDoLimitBench(300*time.Microsecond, 8)) - - b.Run("pipeline 35us 16", mkDoLimitBench(35*time.Microsecond, 16)) - b.Run("pipeline 75us 16", mkDoLimitBench(75*time.Microsecond, 16)) - b.Run("pipeline 150us 16", mkDoLimitBench(150*time.Microsecond, 16)) - b.Run("pipeline 300us 16", mkDoLimitBench(300*time.Microsecond, 16)) -} diff --git a/test/redis/driver_impl_test.go b/test/redis/driver_impl_test.go deleted file mode 100644 index ab488e239..000000000 --- a/test/redis/driver_impl_test.go +++ /dev/null @@ -1,205 +0,0 @@ -package redis_test - -import ( - "testing" - "time" - - "github.com/alicebob/miniredis/v2" - "github.com/envoyproxy/ratelimit/src/redis" - "github.com/lyft/gostats" - "github.com/stretchr/testify/assert" -) - -func mustNewRedisServer() *miniredis.Miniredis { - srv, err := miniredis.Run() - if err != nil { - panic(err) - } - - return srv -} - -func expectPanicError(t *testing.T, f assert.PanicTestFunc) (result error) { - t.Helper() - defer func() { - panicResult := recover() - assert.NotNil(t, panicResult, "Expected a panic") - result = panicResult.(error) - }() - f() - return -} - -func testNewClientImpl(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) func(t *testing.T) { - return func(t *testing.T) { - redisAuth := "123" - statsStore := stats.NewStore(stats.NewNullSink(), false) - - mkRedisClient := func(auth, addr string) redis.Client { - return redis.NewClientImpl(statsStore, false, auth, "single", addr, 1, pipelineWindow, pipelineLimit) - } - - t.Run("connection refused", func(t *testing.T) { - // It's possible there is a redis server listening on 6379 in ci environment, so - // use a random port. - panicErr := expectPanicError(t, func() { mkRedisClient("", "localhost:12345") }) - assert.Contains(t, panicErr.Error(), "connection refused") - }) - - t.Run("ok", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - var client redis.Client - assert.NotPanics(t, func() { - client = mkRedisClient("", redisSrv.Addr()) - }) - assert.NotNil(t, client) - }) - - t.Run("auth fail", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - redisSrv.RequireAuth(redisAuth) - - assert.PanicsWithError(t, "NOAUTH Authentication required.", func() { - mkRedisClient("", redisSrv.Addr()) - }) - }) - - t.Run("auth pass", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - redisSrv.RequireAuth(redisAuth) - - assert.NotPanics(t, func() { - mkRedisClient(redisAuth, redisSrv.Addr()) - }) - }) - - t.Run("ImplicitPipeliningEnabled() return expected value", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - client := mkRedisClient("", redisSrv.Addr()) - - if pipelineWindow == 0 && pipelineLimit == 0 { - assert.False(t, client.ImplicitPipeliningEnabled()) - } else { - assert.True(t, client.ImplicitPipeliningEnabled()) - } - }) - } -} - -func TestNewClientImpl(t *testing.T) { - t.Run("ImplicitPipeliningEnabled", testNewClientImpl(t, 2*time.Millisecond, 2)) - t.Run("ImplicitPipeliningDisabled", testNewClientImpl(t, 0, 0)) -} - -func TestDoCmd(t *testing.T) { - statsStore := stats.NewStore(stats.NewNullSink(), false) - - mkRedisClient := func(addr string) redis.Client { - return redis.NewClientImpl(statsStore, false, "", "single", addr, 1, 0, 0) - } - - t.Run("SETGET ok", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - client := mkRedisClient(redisSrv.Addr()) - var res string - - assert.Nil(t, client.DoCmd(nil, "SET", "foo", "bar")) - assert.Nil(t, client.DoCmd(&res, "GET", "foo")) - assert.Equal(t, "bar", res) - }) - - t.Run("INCRBY ok", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - client := mkRedisClient(redisSrv.Addr()) - var res uint32 - hits := uint32(1) - - assert.Nil(t, client.DoCmd(&res, "INCRBY", "a", hits)) - assert.Equal(t, hits, res) - assert.Nil(t, client.DoCmd(&res, "INCRBY", "a", hits)) - assert.Equal(t, uint32(2), res) - }) - - t.Run("connection broken", func(t *testing.T) { - redisSrv := mustNewRedisServer() - client := mkRedisClient(redisSrv.Addr()) - - assert.Nil(t, client.DoCmd(nil, "SET", "foo", "bar")) - - redisSrv.Close() - assert.EqualError(t, client.DoCmd(nil, "GET", "foo"), "EOF") - }) -} - -func testPipeDo(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) func(t *testing.T) { - return func(t *testing.T) { - statsStore := stats.NewStore(stats.NewNullSink(), false) - - mkRedisClient := func(addr string) redis.Client { - return redis.NewClientImpl(statsStore, false, "", "single", addr, 1, pipelineWindow, pipelineLimit) - } - - t.Run("SETGET ok", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - client := mkRedisClient(redisSrv.Addr()) - var res string - - pipeline := redis.Pipeline{} - pipeline = client.PipeAppend(pipeline, nil, "SET", "foo", "bar") - pipeline = client.PipeAppend(pipeline, &res, "GET", "foo") - - assert.Nil(t, client.PipeDo(pipeline)) - assert.Equal(t, "bar", res) - }) - - t.Run("INCRBY ok", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - client := mkRedisClient(redisSrv.Addr()) - var res uint32 - hits := uint32(1) - - assert.Nil(t, client.PipeDo(client.PipeAppend(redis.Pipeline{}, &res, "INCRBY", "a", hits))) - assert.Equal(t, hits, res) - - assert.Nil(t, client.PipeDo(client.PipeAppend(redis.Pipeline{}, &res, "INCRBY", "a", hits))) - assert.Equal(t, uint32(2), res) - }) - - t.Run("connection broken", func(t *testing.T) { - redisSrv := mustNewRedisServer() - client := mkRedisClient(redisSrv.Addr()) - - assert.Nil(t, nil, client.PipeDo(client.PipeAppend(redis.Pipeline{}, nil, "SET", "foo", "bar"))) - - redisSrv.Close() - - expectErrContainEOF := func(t *testing.T, err error) { - assert.NotNil(t, err) - assert.Contains(t, err.Error(), "EOF") - } - - expectErrContainEOF(t, client.PipeDo(client.PipeAppend(redis.Pipeline{}, nil, "GET", "foo"))) - }) - } -} - -func TestPipeDo(t *testing.T) { - t.Run("ImplicitPipeliningEnabled", testPipeDo(t, 10*time.Millisecond, 2)) - t.Run("ImplicitPipeliningDisabled", testPipeDo(t, 0, 0)) -} diff --git a/test/redis/fixed_cache_impl_test.go b/test/redis/fixed_cache_impl_test.go index e07233330..62d71210d 100644 --- a/test/redis/fixed_cache_impl_test.go +++ b/test/redis/fixed_cache_impl_test.go @@ -1,26 +1,23 @@ package redis_test import ( - "github.com/envoyproxy/ratelimit/test/mocks/stats" + "math/rand" "testing" "github.com/coocood/freecache" - "github.com/mediocregopher/radix/v3" - - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/redis" "github.com/envoyproxy/ratelimit/src/utils" - gostats "github.com/lyft/gostats" - - "math/rand" - "github.com/envoyproxy/ratelimit/test/common" - mock_redis "github.com/envoyproxy/ratelimit/test/mocks/redis" - mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + "github.com/envoyproxy/ratelimit/test/mocks/stats" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + mock_strategy "github.com/envoyproxy/ratelimit/test/mocks/storage/strategy" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + gostats "github.com/lyft/gostats" ) func TestRedis(t *testing.T) { @@ -28,10 +25,6 @@ func TestRedis(t *testing.T) { t.Run("WithPerSecondRedis", testRedis(true)) } -func pipeAppend(pipeline redis.Pipeline, rcv interface{}, cmd, key string, args ...interface{}) redis.Pipeline { - return append(pipeline, radix.FlatCmd(rcv, cmd, key, args...)) -} - func testRedis(usePerSecondRedis bool) func(*testing.T) { return func(t *testing.T) { assert := assert.New(t) @@ -40,27 +33,27 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { statsStore := gostats.NewStore(gostats.NewNullSink(), false) sm := stats.NewMockStatManager(statsStore) - client := mock_redis.NewMockClient(controller) - perSecondClient := mock_redis.NewMockClient(controller) + client := mock_strategy.NewMockStorageStrategy(controller) + perSecondClient := mock_strategy.NewMockStorageStrategy(controller) timeSource := mock_utils.NewMockTimeSource(controller) + var cache limiter.RateLimitCache if usePerSecondRedis { - cache = redis.NewFixedRateLimitCacheImpl(client, perSecondClient, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "", sm) + cache = redis.NewFixedRateLimitCacheImpl(client, perSecondClient, timeSource, rand.New(rand.NewSource(1)), nil, 0, 0.8, "", sm) } else { - cache = redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "", sm) + cache = redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), nil, 0, 0.8, "", sm) } - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - var clientUsed *mock_redis.MockClient + var clientUsed *mock_strategy.MockStorageStrategy if usePerSecondRedis { clientUsed = perSecondClient } else { clientUsed = client } - - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key_value_1234", uint32(1)).SetArg(1, uint32(5)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_1234", int64(1)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + clientUsed.EXPECT().GetValue("domain_key_value_1234").Return(uint64(4), nil).MaxTimes(1) + clientUsed.EXPECT().IncrementValue("domain_key_value_1234", uint64(1)).MaxTimes(1) + clientUsed.EXPECT().SetExpire("domain_key_value_1234", uint64(1)).MaxTimes(1) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"))} @@ -75,10 +68,9 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { clientUsed = client timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key2_value2_subkey2_subvalue2_1200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key2_value2_subkey2_subvalue2_1200", int64(60)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) + clientUsed.EXPECT().GetValue("domain_key2_value2_subkey2_subvalue2_1200").Return(uint64(10), nil).MaxTimes(1) + clientUsed.EXPECT().IncrementValue("domain_key2_value2_subkey2_subvalue2_1200", uint64(1)).MaxTimes(1) + clientUsed.EXPECT().SetExpire("domain_key2_value2_subkey2_subvalue2_1200", uint64(60)).MaxTimes(1) request = common.NewRateLimitRequest( "domain", @@ -100,13 +92,13 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { clientUsed = client timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(5) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key3_value3_997200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key3_value3_997200", int64(3600)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key3_value3_subkey3_subvalue3_950400", uint32(1)).SetArg(1, uint32(13)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key3_value3_subkey3_subvalue3_950400", int64(86400)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) + + clientUsed.EXPECT().GetValue("domain_key3_value3_997200").Return(uint64(10), nil).MaxTimes(1) + clientUsed.EXPECT().GetValue("domain_key3_value3_subkey3_subvalue3_950400").Return(uint64(12), nil).MaxTimes(1) + clientUsed.EXPECT().IncrementValue("domain_key3_value3_997200", uint64(1)).MaxTimes(1) + clientUsed.EXPECT().IncrementValue("domain_key3_value3_subkey3_subvalue3_950400", uint64(1)).MaxTimes(1) + clientUsed.EXPECT().SetExpire("domain_key3_value3_997200", uint64(3600)).MaxTimes(1) + clientUsed.EXPECT().SetExpire("domain_key3_value3_subkey3_subvalue3_950400", uint64(86400)).MaxTimes(1) request = common.NewRateLimitRequest( "domain", @@ -174,24 +166,22 @@ func TestOverLimitWithLocalCache(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - client := mock_redis.NewMockClient(controller) + client := mock_strategy.NewMockStorageStrategy(controller) timeSource := mock_utils.NewMockTimeSource(controller) localCache := freecache.NewCache(100) statsStore := gostats.NewStore(gostats.NewNullSink(), false) sm := stats.NewMockStatManager(statsStore) - cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, localCache, 0.8, "", sm) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), localCache, 0, 0.8, "", sm) sink := &common.TestStatSink{} localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) // Test Near Limit Stats. Under Near Limit Ratio timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(10), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) - limits := []*config.RateLimit{ config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, sm.NewStats("key4_value4"))} @@ -210,10 +200,9 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Test Near Limit Stats. At Near Limit Ratio, still OK timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(13)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(12), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ @@ -230,10 +219,9 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Test Over limit stats timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(16)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(15), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ @@ -250,9 +238,9 @@ func TestOverLimitWithLocalCache(t *testing.T) { // Test Over limit stats with local cache timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).Times(0) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).Times(0) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) + assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, @@ -272,18 +260,17 @@ func TestNearLimit(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - client := mock_redis.NewMockClient(controller) + client := mock_strategy.NewMockStorageStrategy(controller) timeSource := mock_utils.NewMockTimeSource(controller) statsStore := gostats.NewStore(gostats.NewNullSink(), false) sm := stats.NewMockStatManager(statsStore) - cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "", sm) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), nil, 0, 0.8, "", sm) // Test Near Limit Stats. Under Near Limit Ratio timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(10), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) @@ -301,10 +288,9 @@ func TestNearLimit(t *testing.T) { // Test Near Limit Stats. At Near Limit Ratio, still OK timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(13)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(12), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ @@ -318,10 +304,9 @@ func TestNearLimit(t *testing.T) { // Test Near Limit Stats. We went OVER_LIMIT, but the near_limit counter only increases // when we are near limit, not after we have passed the limit. timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(16)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(15), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) assert.Equal( []*pb.RateLimitResponse_DescriptorStatus{ @@ -335,9 +320,9 @@ func TestNearLimit(t *testing.T) { // Now test hitsAddend that is greater than 1 // All of it under limit, under near limit timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key5_value5_1234", uint32(3)).SetArg(1, uint32(5)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key5_value5_1234", int64(1)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + client.EXPECT().GetValue("domain_key5_value5_1234").Return(uint64(2), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key5_value5_1234", uint64(3)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key5_value5_1234", uint64(1)).MaxTimes(1) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key5", "value5"}}}, 3) limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key5_value5"))} @@ -352,9 +337,9 @@ func TestNearLimit(t *testing.T) { // All of it under limit, some over near limit timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key6_value6_1234", uint32(2)).SetArg(1, uint32(7)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key6_value6_1234", int64(1)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + client.EXPECT().GetValue("domain_key6_value6_1234").Return(uint64(5), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key6_value6_1234", uint64(2)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key6_value6_1234", uint64(1)).MaxTimes(1) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key6", "value6"}}}, 2) limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key6_value6"))} @@ -369,9 +354,9 @@ func TestNearLimit(t *testing.T) { // All of it under limit, all of it over near limit timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key7_value7_1234", uint32(3)).SetArg(1, uint32(19)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key7_value7_1234", int64(1)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + client.EXPECT().GetValue("domain_key7_value7_1234").Return(uint64(16), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key7_value7_1234", uint64(3)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key7_value7_1234", uint64(1)).MaxTimes(1) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key7", "value7"}}}, 3) limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key7_value7"))} @@ -386,9 +371,9 @@ func TestNearLimit(t *testing.T) { // Some of it over limit, all of it over near limit timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key8_value8_1234", uint32(3)).SetArg(1, uint32(22)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key8_value8_1234", int64(1)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + client.EXPECT().GetValue("domain_key8_value8_1234").Return(uint64(19), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key8_value8_1234", uint64(3)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key8_value8_1234", uint64(1)).MaxTimes(1) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key8", "value8"}}}, 3) limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key8_value8"))} @@ -403,9 +388,9 @@ func TestNearLimit(t *testing.T) { // Some of it in all three places timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key9_value9_1234", uint32(7)).SetArg(1, uint32(22)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key9_value9_1234", int64(1)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + client.EXPECT().GetValue("domain_key9_value9_1234").Return(uint64(15), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key9_value9_1234", uint64(7)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key9_value9_1234", uint64(1)).MaxTimes(1) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key9", "value9"}}}, 7) limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key9_value9"))} @@ -420,9 +405,9 @@ func TestNearLimit(t *testing.T) { // all of it over limit timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key10_value10_1234", uint32(3)).SetArg(1, uint32(30)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key10_value10_1234", int64(1)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + client.EXPECT().GetValue("domain_key10_value10_1234").Return(uint64(27), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key10_value10_1234", uint64(3)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key10_value10_1234", uint64(1)).MaxTimes(1) request = common.NewRateLimitRequest("domain", [][][2]string{{{"key10", "value10"}}}, 3) limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key10_value10"))} @@ -441,18 +426,18 @@ func TestRedisWithJitter(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - client := mock_redis.NewMockClient(controller) + client := mock_strategy.NewMockStorageStrategy(controller) timeSource := mock_utils.NewMockTimeSource(controller) jitterSource := mock_utils.NewMockJitterRandSource(controller) statsStore := gostats.NewStore(gostats.NewNullSink(), false) sm := stats.NewMockStatManager(statsStore) - cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(jitterSource), 3600, nil, 0.8, "", sm) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(jitterSource), nil, 3600, 0.8, "", sm) timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - jitterSource.EXPECT().Int63().Return(int64(100)) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key_value_1234", uint32(1)).SetArg(1, uint32(5)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_1234", int64(101)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) + jitterSource.EXPECT().Int63().Return(int64(100)).MaxTimes(1) + client.EXPECT().GetValue("domain_key_value_1234").Return(uint64(4), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key_value_1234", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key_value_1234", uint64(101)).MaxTimes(1) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"))} diff --git a/test/srv/srv_test.go b/test/srv/srv_test.go deleted file mode 100644 index 5e3e8f79f..000000000 --- a/test/srv/srv_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package srv - -import ( - "errors" - "net" - "testing" - - "github.com/envoyproxy/ratelimit/src/srv" - "github.com/stretchr/testify/assert" -) - -func TestParseSrv(t *testing.T) { - service, proto, name, err := srv.ParseSrv("_something._tcp.example.org.") - assert.Equal(t, service, "something") - assert.Equal(t, proto, "tcp") - assert.Equal(t, name, "example.org.") - assert.Nil(t, err) - - service, proto, name, err = srv.ParseSrv("_something-else._udp.example.org") - assert.Equal(t, service, "something-else") - assert.Equal(t, proto, "udp") - assert.Equal(t, name, "example.org") - assert.Nil(t, err) - - _, _, _, err = srv.ParseSrv("example.org") - assert.Equal(t, err, errors.New("could not parse example.org to SRV parts")) -} - -func TestServerStringsFromSrvWhenSrvIsNotWellFormed(t *testing.T) { - _, err := srv.ServerStringsFromSrv("example.org") - assert.Equal(t, err, errors.New("could not parse example.org to SRV parts")) -} - -func TestServerStringsFromSevWhenSrvIsWellFormedButNotLookupable(t *testing.T) { - _, err := srv.ServerStringsFromSrv("_something._tcp.example.invalid") - var e *net.DNSError - if errors.As(err, &e) { - assert.Equal(t, e.Err, "no such host") - assert.Equal(t, e.Name, "_something._tcp.example.invalid") - assert.False(t, e.IsTimeout) - assert.False(t, e.IsTemporary) - assert.True(t, e.IsNotFound) - } else { - t.Fail() - } -} - -func TestServerStrings(t *testing.T) { - // it seems reasonable to think _xmpp-server._tcp.gmail.com will be available for a long time! - servers, err := srv.ServerStringsFromSrv("_xmpp-server._tcp.gmail.com.") - assert.True(t, len(servers) > 0) - for _, s := range servers { - assert.Regexp(t, `^.*xmpp-server.*google.com.:\d+$`, s) - } - assert.Nil(t, err) -} diff --git a/test/storage/factory/memcached_factory_test.go b/test/storage/factory/memcached_factory_test.go new file mode 100644 index 000000000..4fb9b3846 --- /dev/null +++ b/test/storage/factory/memcached_factory_test.go @@ -0,0 +1,38 @@ +package factory_test + +import ( + "testing" + + "github.com/envoyproxy/ratelimit/src/storage/factory" + "github.com/envoyproxy/ratelimit/src/storage/strategy" + "github.com/stretchr/testify/assert" + + stats "github.com/lyft/gostats" +) + +func TestNewMemcachedClient(t *testing.T) { + statsStore := stats.NewStore(stats.NewNullSink(), false) + mkMemcachedClient := func(addr []string) strategy.StorageStrategy { + return factory.NewMemcached(statsStore, addr, "", 0, 2) + } + + t.Run("empty server", func(t *testing.T) { + storage := mkMemcachedClient([]string{}) + _, err := storage.GetValue("test") + assert.Error(t, err) + }) +} + +func TestNewRateLimitCacheImplFromSettingsWhenSrvCannotBeResolved(t *testing.T) { + statsStore := stats.NewStore(stats.NewNullSink(), false) + assert.Panics(t, func() { + factory.NewMemcached(statsStore, []string{}, "_something._tcp.example.invalid", 0, 2) + }) +} + +func TestNewRateLimitCacheImplFromSettingsWhenHostAndPortAndSrvAreBothSet(t *testing.T) { + statsStore := stats.NewStore(stats.NewNullSink(), false) + assert.Panics(t, func() { + factory.NewMemcached(statsStore, []string{"example.org:11211"}, "_something._tcp.example.invalid", 0, 2) + }) +} diff --git a/test/storage/factory/redis_factory_test.go b/test/storage/factory/redis_factory_test.go new file mode 100644 index 000000000..7073ab65b --- /dev/null +++ b/test/storage/factory/redis_factory_test.go @@ -0,0 +1,89 @@ +package factory_test + +import ( + "testing" + "time" + + "github.com/alicebob/miniredis" + "github.com/envoyproxy/ratelimit/src/storage/factory" + "github.com/envoyproxy/ratelimit/src/storage/strategy" + "github.com/stretchr/testify/assert" + + stats "github.com/lyft/gostats" +) + +func mustNewRedisServer() *miniredis.Miniredis { + srv, err := miniredis.Run() + if err != nil { + panic(err) + } + + return srv +} + +func expectPanicError(t *testing.T, f assert.PanicTestFunc) (result error) { + t.Helper() + defer func() { + panicResult := recover() + assert.NotNil(t, panicResult, "Expected a panic") + result = panicResult.(error) + }() + f() + return +} + +func TestNewRedisClient(t *testing.T) { + t.Run("ImplicitPipeliningEnabled", testNewRedisClient(t, 2*time.Millisecond, 2)) + t.Run("ImplicitPipeliningDisabled", testNewRedisClient(t, 0, 0)) +} + +func testNewRedisClient(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) func(t *testing.T) { + return func(t *testing.T) { + redisAuth := "123" + statsStore := stats.NewStore(stats.NewNullSink(), false) + + mkRedisClient := func(auth, addr string) strategy.StorageStrategy { + return factory.NewRedis(statsStore, false, auth, "single", addr, 1, pipelineWindow, pipelineLimit) + } + + t.Run("connection refused", func(t *testing.T) { + // It's possible there is a redis server listening on 6379 in ci environment, so + // use a random port. + panicErr := expectPanicError(t, func() { mkRedisClient("", "localhost:12345") }) + assert.Contains(t, panicErr.Error(), "connection refused") + }) + + t.Run("ok", func(t *testing.T) { + redisSrv := mustNewRedisServer() + defer redisSrv.Close() + + var client strategy.StorageStrategy + assert.NotPanics(t, func() { + client = mkRedisClient("", redisSrv.Addr()) + }) + assert.NotNil(t, client) + }) + + t.Run("auth fail", func(t *testing.T) { + redisSrv := mustNewRedisServer() + defer redisSrv.Close() + + redisSrv.RequireAuth(redisAuth) + + assert.PanicsWithError(t, "NOAUTH Authentication required.", func() { + mkRedisClient("", redisSrv.Addr()) + }) + }) + + t.Run("auth pass", func(t *testing.T) { + redisSrv := mustNewRedisServer() + defer redisSrv.Close() + + redisSrv.RequireAuth(redisAuth) + + assert.NotPanics(t, func() { + mkRedisClient(redisAuth, redisSrv.Addr()) + }) + }) + } +} diff --git a/test/storage/strategy/memcached_strategy_test.go b/test/storage/strategy/memcached_strategy_test.go new file mode 100644 index 000000000..563f96031 --- /dev/null +++ b/test/storage/strategy/memcached_strategy_test.go @@ -0,0 +1,66 @@ +package strategy_test + +import ( + "strconv" + "testing" + + "github.com/bradfitz/gomemcache/memcache" + "github.com/envoyproxy/ratelimit/src/storage/strategy" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + mock_service "github.com/envoyproxy/ratelimit/test/mocks/storage/service" +) + +func TestMemcachedStrategyGetValue(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + mockMemcachedClient := mock_service.NewMockMemcachedClientInterface(controller) + memcachedStrategy := strategy.MemcachedStrategy{ + Client: mockMemcachedClient, + } + + mockMemcachedClient.EXPECT().Get("key").Return(&memcache.Item{Key: "key", Value: []byte("5")}, nil) + value, err := memcachedStrategy.GetValue("key") + + assert.Equal(value, uint64(5)) + assert.Nil(err) +} + +func TestMemcachedStrategySetValue(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + mockMemcachedClient := mock_service.NewMockMemcachedClientInterface(controller) + memcachedStrategy := strategy.MemcachedStrategy{ + Client: mockMemcachedClient, + } + + mockMemcachedClient.EXPECT().Set(&memcache.Item{ + Key: "key", + Value: []byte(strconv.FormatUint(uint64(5), 10)), + Expiration: int32(5), + }).Return(nil) + + err := memcachedStrategy.SetValue("key", uint64(5), uint64(5)) + assert.Nil(err) +} + +func TestMemcachedStrategyIncrementValue(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + mockMemcachedClient := mock_service.NewMockMemcachedClientInterface(controller) + memcachedStrategy := strategy.MemcachedStrategy{ + Client: mockMemcachedClient, + } + + mockMemcachedClient.EXPECT().Increment("key", uint64(1)).Return(uint64(1), nil) + + err := memcachedStrategy.IncrementValue("key", uint64(1)) + assert.Nil(err) +} diff --git a/test/storage/strategy/redis_strategy_test.go b/test/storage/strategy/redis_strategy_test.go new file mode 100644 index 000000000..d1ba3d124 --- /dev/null +++ b/test/storage/strategy/redis_strategy_test.go @@ -0,0 +1,72 @@ +package strategy_test + +import ( + "testing" + + "github.com/alicebob/miniredis" + "github.com/envoyproxy/ratelimit/src/storage/strategy" + "github.com/golang/mock/gomock" + "github.com/mediocregopher/radix/v3" + "github.com/stretchr/testify/assert" + + mock_service "github.com/envoyproxy/ratelimit/test/mocks/storage/service" +) + +func mustNewRedisServer() *miniredis.Miniredis { + srv, err := miniredis.Run() + if err != nil { + panic(err) + } + + return srv +} +func TestRedisStrategyGetValue(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + mockRedisClient := mock_service.NewMockRedisClientInterface(controller) + redisStrategy := strategy.RedisStrategy{ + Client: mockRedisClient, + } + + var value uint64 + mockRedisClient.EXPECT().Do(radix.Cmd(&value, "GET", "key")).Return(nil) + value, err := redisStrategy.GetValue("key") + + assert.Equal(value, uint64(0)) + assert.Nil(err) +} + +func TestRedisStrategySetValue(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + mockRedisClient := mock_service.NewMockRedisClientInterface(controller) + redisStrategy := strategy.RedisStrategy{ + Client: mockRedisClient, + } + + mockRedisClient.EXPECT().Do(radix.FlatCmd(nil, "SET", "key", uint64(5))).Return(nil) + mockRedisClient.EXPECT().Do(radix.FlatCmd(nil, "EXPIRE", "key", uint64(5))).Return(nil) + + err := redisStrategy.SetValue("key", uint64(5), uint64(5)) + assert.Nil(err) +} + +func TestRedisStrategyIncrementValue(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + mockRedisClient := mock_service.NewMockRedisClientInterface(controller) + redisStrategy := strategy.RedisStrategy{ + Client: mockRedisClient, + } + + mockRedisClient.EXPECT().Do(radix.FlatCmd(nil, "INCRBY", "key", uint64(1))).Return(nil) + + err := redisStrategy.IncrementValue("key", uint64(1)) + assert.Nil(err) +}