From c45c62e639426d45411a6562f8dbaec2d2229f18 Mon Sep 17 00:00:00 2001 From: zufardhiyaulhaq Date: Sun, 11 Apr 2021 07:17:25 +0200 Subject: [PATCH 1/8] refactor storage impelementation Signed-off-by: zufardhiyaulhaq --- go.mod | 1 + go.sum | 1 + src/memcached/cache_impl.go | 182 +----- src/memcached/client.go | 14 - src/memcached/fixed_cache_impl.go | 170 +++++ src/memcached/stats_collecting_client.go | 80 --- src/redis/cache_impl.go | 11 +- src/redis/driver.go | 49 -- src/redis/driver_impl.go | 164 ----- src/redis/fixed_cache_impl.go | 68 +- src/storage/factory/memcached_factory.go | 21 + src/storage/factory/redis_factory.go | 97 +++ src/storage/service/memcached_client.go | 27 + src/storage/service/redis_client.go | 18 + src/storage/strategy/memcached_strategy.go | 50 ++ src/storage/strategy/redis_strategy.go | 43 ++ src/storage/strategy/storage_strategy.go | 7 + src/storage/utils/utils.go | 13 + test/memcached/cache_impl_test.go | 594 ------------------ .../memcached/stats_collecting_client_test.go | 199 ------ test/mocks/redis/redis.go | 128 ---- test/redis/bench_test.go | 94 --- test/redis/driver_impl_test.go | 205 ------ test/redis/fixed_cache_impl_test.go | 462 -------------- 24 files changed, 501 insertions(+), 2197 deletions(-) delete mode 100644 src/memcached/client.go create mode 100644 src/memcached/fixed_cache_impl.go delete mode 100644 src/memcached/stats_collecting_client.go delete mode 100644 src/redis/driver.go delete mode 100644 src/redis/driver_impl.go create mode 100644 src/storage/factory/memcached_factory.go create mode 100644 src/storage/factory/redis_factory.go create mode 100644 src/storage/service/memcached_client.go create mode 100644 src/storage/service/redis_client.go create mode 100644 src/storage/strategy/memcached_strategy.go create mode 100644 src/storage/strategy/redis_strategy.go create mode 100644 src/storage/strategy/storage_strategy.go create mode 100644 src/storage/utils/utils.go delete mode 100644 test/memcached/cache_impl_test.go delete mode 100644 test/memcached/stats_collecting_client_test.go delete mode 100644 test/mocks/redis/redis.go delete mode 100644 test/redis/bench_test.go delete mode 100644 test/redis/driver_impl_test.go delete mode 100644 test/redis/fixed_cache_impl_test.go diff --git a/go.mod b/go.mod index 1c282fcd9..b36e9e9f9 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e // indirect golang.org/x/text v0.3.3-0.20191122225017-cbf43d21aaeb // indirect + google.golang.org/appengine v1.4.0 google.golang.org/grpc v1.27.0 google.golang.org/protobuf v1.25.0 // indirect gopkg.in/yaml.v2 v2.3.0 diff --git a/go.sum b/go.sum index 071a59b37..7c2ffd741 100644 --- a/go.sum +++ b/go.sum @@ -126,6 +126,7 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= diff --git a/src/memcached/cache_impl.go b/src/memcached/cache_impl.go index 1e7b0b69a..90f349a94 100644 --- a/src/memcached/cache_impl.go +++ b/src/memcached/cache_impl.go @@ -1,195 +1,21 @@ -// The memcached limiter uses GetMulti() to check keys in parallel and then does -// increments asynchronously in the backend, since the memcache interface doesn't -// support multi-increment and it seems worthwhile to minimize the number of -// concurrent or sequential RPCs in the critical path. -// -// Another difference from redis is that memcache doesn't create a key implicitly by -// incrementing a missing entry. Instead, when increment fails an explicit "add" needs -// to be called. The process of increment becomes a bit of a dance since we try to -// limit the number of RPCs. First we call increment, then add if the increment -// failed, then increment again if the add failed (which could happen if there was -// a race to call "add"). -// -// Note that max memcache key length is 250 characters. Attempting to get or increment -// a longer key will return memcache.ErrMalformedKey - package memcached import ( - "context" "math/rand" - "strconv" - "sync" "github.com/coocood/freecache" - stats "github.com/lyft/gostats" - - "github.com/bradfitz/gomemcache/memcache" - - logger "github.com/sirupsen/logrus" - - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - - "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/settings" "github.com/envoyproxy/ratelimit/src/utils" -) - -type rateLimitMemcacheImpl struct { - client Client - timeSource utils.TimeSource - jitterRand *rand.Rand - expirationJitterMaxSeconds int64 - localCache *freecache.Cache - waitGroup sync.WaitGroup - nearLimitRatio float32 - baseRateLimiter *limiter.BaseRateLimiter -} - -var AutoFlushForIntegrationTests bool = false - -var _ limiter.RateLimitCache = (*rateLimitMemcacheImpl)(nil) - -func (this *rateLimitMemcacheImpl) DoLimit( - ctx context.Context, - request *pb.RateLimitRequest, - limits []*config.RateLimit) []*pb.RateLimitResponse_DescriptorStatus { - - logger.Debugf("starting cache lookup") - - // request.HitsAddend could be 0 (default value) if not specified by the caller in the Ratelimit request. - hitsAddend := utils.Max(1, request.HitsAddend) - - // First build a list of all cache keys that we are actually going to hit. - cacheKeys := this.baseRateLimiter.GenerateCacheKeys(request, limits, hitsAddend) - - isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) - - keysToGet := make([]string, 0, len(request.Descriptors)) - - for i, cacheKey := range cacheKeys { - if cacheKey.Key == "" { - continue - } - - // Check if key is over the limit in local cache. - if this.baseRateLimiter.IsOverLimitWithLocalCache(cacheKey.Key) { - isOverLimitWithLocalCache[i] = true - logger.Debugf("cache key is over the limit: %s", cacheKey.Key) - continue - } - - logger.Debugf("looking up cache key: %s", cacheKey.Key) - keysToGet = append(keysToGet, cacheKey.Key) - } - - // Now fetch from memcache. - responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus, - len(request.Descriptors)) - - var memcacheValues map[string]*memcache.Item - var err error - - if len(keysToGet) > 0 { - memcacheValues, err = this.client.GetMulti(keysToGet) - if err != nil { - logger.Errorf("Error multi-getting memcache keys (%s): %s", keysToGet, err) - } - } - - for i, cacheKey := range cacheKeys { - - rawMemcacheValue, ok := memcacheValues[cacheKey.Key] - var limitBeforeIncrease uint32 - if ok { - decoded, err := strconv.ParseInt(string(rawMemcacheValue.Value), 10, 32) - if err != nil { - logger.Errorf("Unexpected non-numeric value in memcached: %v", rawMemcacheValue) - } else { - limitBeforeIncrease = uint32(decoded) - } - - } - - limitAfterIncrease := limitBeforeIncrease + hitsAddend - - limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0) - - responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key, - limitInfo, isOverLimitWithLocalCache[i], hitsAddend) - } - - this.waitGroup.Add(1) - go this.increaseAsync(cacheKeys, isOverLimitWithLocalCache, limits, uint64(hitsAddend)) - if AutoFlushForIntegrationTests { - this.Flush() - } - - return responseDescriptorStatuses -} - -func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, isOverLimitWithLocalCache []bool, - limits []*config.RateLimit, hitsAddend uint64) { - defer this.waitGroup.Done() - for i, cacheKey := range cacheKeys { - if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { - continue - } - - _, err := this.client.Increment(cacheKey.Key, hitsAddend) - if err == memcache.ErrCacheMiss { - expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) - if this.expirationJitterMaxSeconds > 0 { - expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) - } - - // Need to add instead of increment. - err = this.client.Add(&memcache.Item{ - Key: cacheKey.Key, - Value: []byte(strconv.FormatUint(hitsAddend, 10)), - Expiration: int32(expirationSeconds), - }) - if err == memcache.ErrNotStored { - // There was a race condition to do this add. We should be able to increment - // now instead. - _, err := this.client.Increment(cacheKey.Key, hitsAddend) - if err != nil { - logger.Errorf("Failed to increment key %s after failing to add: %s", cacheKey.Key, err) - continue - } - } else if err != nil { - logger.Errorf("Failed to add key %s: %s", cacheKey.Key, err) - continue - } - } else if err != nil { - logger.Errorf("Failed to increment key %s: %s", cacheKey.Key, err) - continue - } - } -} - -func (this *rateLimitMemcacheImpl) Flush() { - this.waitGroup.Wait() -} + stats "github.com/lyft/gostats" -func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRand *rand.Rand, - expirationJitterMaxSeconds int64, localCache *freecache.Cache, scope stats.Scope, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { - return &rateLimitMemcacheImpl{ - client: client, - timeSource: timeSource, - jitterRand: jitterRand, - expirationJitterMaxSeconds: expirationJitterMaxSeconds, - localCache: localCache, - nearLimitRatio: nearLimitRatio, - baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix), - } -} + storage_factory "github.com/envoyproxy/ratelimit/src/storage/factory" +) func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand, localCache *freecache.Cache, scope stats.Scope) limiter.RateLimitCache { return NewRateLimitCacheImpl( - CollectStats(memcache.New(s.MemcacheHostPort...), scope.Scope("memcache")), + storage_factory.NewMemcached(s.MemcacheHostPort), timeSource, jitterRand, s.ExpirationJitterMaxSeconds, diff --git a/src/memcached/client.go b/src/memcached/client.go deleted file mode 100644 index 55c0ec318..000000000 --- a/src/memcached/client.go +++ /dev/null @@ -1,14 +0,0 @@ -package memcached - -import ( - "github.com/bradfitz/gomemcache/memcache" -) - -var _ Client = (*memcache.Client)(nil) - -// Interface for memcached, used for mocking. -type Client interface { - GetMulti(keys []string) (map[string]*memcache.Item, error) - Increment(key string, delta uint64) (newValue uint64, err error) - Add(item *memcache.Item) error -} diff --git a/src/memcached/fixed_cache_impl.go b/src/memcached/fixed_cache_impl.go new file mode 100644 index 000000000..15f7a8240 --- /dev/null +++ b/src/memcached/fixed_cache_impl.go @@ -0,0 +1,170 @@ +// The memcached limiter uses GetMulti() to check keys in parallel and then does +// increments asynchronously in the backend, since the memcache interface doesn't +// support multi-increment and it seems worthwhile to minimize the number of +// concurrent or sequential RPCs in the critical path. +// +// Another difference from redis is that memcache doesn't create a key implicitly by +// incrementing a missing entry. Instead, when increment fails an explicit "add" needs +// to be called. The process of increment becomes a bit of a dance since we try to +// limit the number of RPCs. First we call increment, then add if the increment +// failed, then increment again if the add failed (which could happen if there was +// a race to call "add"). +// +// Note that max memcache key length is 250 characters. Attempting to get or increment +// a longer key will return memcache.ErrMalformedKey + +package memcached + +import ( + "context" + "math/rand" + "sync" + + "github.com/coocood/freecache" + stats "github.com/lyft/gostats" + + "github.com/bradfitz/gomemcache/memcache" + + logger "github.com/sirupsen/logrus" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/utils" + + storage_strategy "github.com/envoyproxy/ratelimit/src/storage/strategy" +) + +type rateLimitMemcacheImpl struct { + client storage_strategy.StorageStrategy + timeSource utils.TimeSource + jitterRand *rand.Rand + expirationJitterMaxSeconds int64 + localCache *freecache.Cache + waitGroup sync.WaitGroup + nearLimitRatio float32 + baseRateLimiter *limiter.BaseRateLimiter +} + +var AutoFlushForIntegrationTests bool = false + +var _ limiter.RateLimitCache = (*rateLimitMemcacheImpl)(nil) + +func (this *rateLimitMemcacheImpl) DoLimit( + ctx context.Context, + request *pb.RateLimitRequest, + limits []*config.RateLimit) []*pb.RateLimitResponse_DescriptorStatus { + + logger.Debugf("starting cache lookup") + + // request.HitsAddend could be 0 (default value) if not specified by the caller in the RateLimit request. + hitsAddend := utils.Max(1, request.HitsAddend) + + // First build a list of all cache keys that we are actually going to hit. + cacheKeys := this.baseRateLimiter.GenerateCacheKeys(request, limits, hitsAddend) + + isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) + results := make([]uint64, len(request.Descriptors)) + + // Now, actually setup the pipeline, skipping empty cache keys. + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" { + continue + } + + // Check if key is over the limit in local cache. + if this.baseRateLimiter.IsOverLimitWithLocalCache(cacheKey.Key) { + isOverLimitWithLocalCache[i] = true + logger.Debugf("cache key is over the limit: %s", cacheKey.Key) + continue + } + + logger.Debugf("looking up cache key: %s", cacheKey.Key) + + expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) + if this.baseRateLimiter.ExpirationJitterMaxSeconds > 0 { + expirationSeconds += this.baseRateLimiter.JitterRand.Int63n(this.baseRateLimiter.ExpirationJitterMaxSeconds) + } + + // Use the perSecondConn if it is not nil and the cacheKey represents a per second Limit. + value, err := this.client.GetValue(cacheKey.Key) + if err != nil { + logger.Error(err) + } + results[i] = value + + } + + // Now fetch the pipeline. + responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus, + len(request.Descriptors)) + for i, cacheKey := range cacheKeys { + + limitBeforeIncrease := uint32(results[i]) + limitAfterIncrease := limitBeforeIncrease + hitsAddend + + limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0) + + responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key, + limitInfo, isOverLimitWithLocalCache[i], hitsAddend) + } + + this.waitGroup.Add(1) + go this.increaseAsync(cacheKeys, isOverLimitWithLocalCache, limits, uint64(hitsAddend)) + + return responseDescriptorStatuses +} + +func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, isOverLimitWithLocalCache []bool, + limits []*config.RateLimit, hitsAddend uint64) { + defer this.waitGroup.Done() + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { + continue + } + + err := this.client.IncrementValue(cacheKey.Key, hitsAddend) + if err == memcache.ErrCacheMiss { + expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) + if this.expirationJitterMaxSeconds > 0 { + expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) + } + + // Need to add instead of increment. + err = this.client.SetValue(cacheKey.Key, hitsAddend, uint64(expirationSeconds)) + if err == memcache.ErrNotStored { + // There was a race condition to do this add. We should be able to increment + // now instead. + err := this.client.IncrementValue(cacheKey.Key, hitsAddend) + if err != nil { + logger.Errorf("Failed to increment key %s after failing to add: %s", cacheKey.Key, err) + continue + } + } else if err != nil { + logger.Errorf("Failed to add key %s: %s", cacheKey.Key, err) + continue + } + } else if err != nil { + logger.Errorf("Failed to increment key %s: %s", cacheKey.Key, err) + continue + } + } +} + +func (this *rateLimitMemcacheImpl) Flush() { + this.waitGroup.Wait() +} + +func NewRateLimitCacheImpl(client storage_strategy.StorageStrategy, timeSource utils.TimeSource, jitterRand *rand.Rand, + expirationJitterMaxSeconds int64, localCache *freecache.Cache, scope stats.Scope, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { + return &rateLimitMemcacheImpl{ + client: client, + timeSource: timeSource, + jitterRand: jitterRand, + expirationJitterMaxSeconds: expirationJitterMaxSeconds, + localCache: localCache, + nearLimitRatio: nearLimitRatio, + baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix), + } +} diff --git a/src/memcached/stats_collecting_client.go b/src/memcached/stats_collecting_client.go deleted file mode 100644 index 12b67bad5..000000000 --- a/src/memcached/stats_collecting_client.go +++ /dev/null @@ -1,80 +0,0 @@ -package memcached - -import ( - "github.com/bradfitz/gomemcache/memcache" - stats "github.com/lyft/gostats" -) - -type statsCollectingClient struct { - c Client - - multiGetSuccess stats.Counter - multiGetError stats.Counter - incrementSuccess stats.Counter - incrementMiss stats.Counter - incrementError stats.Counter - addSuccess stats.Counter - addError stats.Counter - addNotStored stats.Counter - keysRequested stats.Counter - keysFound stats.Counter -} - -func CollectStats(c Client, scope stats.Scope) Client { - return statsCollectingClient{ - c: c, - multiGetSuccess: scope.NewCounterWithTags("multiget", map[string]string{"code": "success"}), - multiGetError: scope.NewCounterWithTags("multiget", map[string]string{"code": "error"}), - incrementSuccess: scope.NewCounterWithTags("increment", map[string]string{"code": "success"}), - incrementMiss: scope.NewCounterWithTags("increment", map[string]string{"code": "miss"}), - incrementError: scope.NewCounterWithTags("increment", map[string]string{"code": "error"}), - addSuccess: scope.NewCounterWithTags("add", map[string]string{"code": "success"}), - addError: scope.NewCounterWithTags("add", map[string]string{"code": "error"}), - addNotStored: scope.NewCounterWithTags("add", map[string]string{"code": "not_stored"}), - keysRequested: scope.NewCounter("keys_requested"), - keysFound: scope.NewCounter("keys_found"), - } -} - -func (scc statsCollectingClient) GetMulti(keys []string) (map[string]*memcache.Item, error) { - scc.keysRequested.Add(uint64(len(keys))) - - results, err := scc.c.GetMulti(keys) - - if err != nil { - scc.multiGetError.Inc() - } else { - scc.keysFound.Add(uint64(len(results))) - scc.multiGetSuccess.Inc() - } - - return results, err -} - -func (scc statsCollectingClient) Increment(key string, delta uint64) (newValue uint64, err error) { - newValue, err = scc.c.Increment(key, delta) - switch err { - case memcache.ErrCacheMiss: - scc.incrementMiss.Inc() - case nil: - scc.incrementSuccess.Inc() - default: - scc.incrementError.Inc() - } - return -} - -func (scc statsCollectingClient) Add(item *memcache.Item) error { - err := scc.c.Add(item) - - switch err { - case memcache.ErrNotStored: - scc.addNotStored.Inc() - case nil: - scc.addSuccess.Inc() - default: - scc.addError.Inc() - } - - return err -} diff --git a/src/redis/cache_impl.go b/src/redis/cache_impl.go index cb4446234..004846677 100644 --- a/src/redis/cache_impl.go +++ b/src/redis/cache_impl.go @@ -8,16 +8,19 @@ import ( "github.com/envoyproxy/ratelimit/src/server" "github.com/envoyproxy/ratelimit/src/settings" "github.com/envoyproxy/ratelimit/src/utils" + + storage_factory "github.com/envoyproxy/ratelimit/src/storage/factory" + storage_strategy "github.com/envoyproxy/ratelimit/src/storage/strategy" ) func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64) limiter.RateLimitCache { - var perSecondPool Client + var perSecondPool storage_strategy.StorageStrategy if s.RedisPerSecond { - perSecondPool = NewClientImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, + perSecondPool = storage_factory.NewRedis(s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondType, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize, s.RedisPerSecondPipelineWindow, s.RedisPerSecondPipelineLimit) } - var otherPool Client - otherPool = NewClientImpl(srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisType, s.RedisUrl, s.RedisPoolSize, + var otherPool storage_strategy.StorageStrategy + otherPool = storage_factory.NewRedis(s.RedisTls, s.RedisAuth, s.RedisType, s.RedisUrl, s.RedisPoolSize, s.RedisPipelineWindow, s.RedisPipelineLimit) return NewFixedRateLimitCacheImpl( diff --git a/src/redis/driver.go b/src/redis/driver.go deleted file mode 100644 index 7ffc0c7b7..000000000 --- a/src/redis/driver.go +++ /dev/null @@ -1,49 +0,0 @@ -package redis - -import "github.com/mediocregopher/radix/v3" - -// Errors that may be raised during config parsing. -type RedisError string - -func (e RedisError) Error() string { - return string(e) -} - -// Interface for a redis client. -type Client interface { - // DoCmd is used to perform a redis command and retrieve a result. - // - // @param rcv supplies receiver for the result. - // @param cmd supplies the command to append. - // @param key supplies the key to append. - // @param args supplies the additional arguments. - DoCmd(rcv interface{}, cmd, key string, args ...interface{}) error - - // PipeAppend append a command onto the pipeline queue. - // - // @param pipeline supplies the queue for pending commands. - // @param rcv supplies receiver for the result. - // @param cmd supplies the command to append. - // @param key supplies the key to append. - // @param args supplies the additional arguments. - PipeAppend(pipeline Pipeline, rcv interface{}, cmd, key string, args ...interface{}) Pipeline - - // PipeDo writes multiple commands to a Conn in - // a single write, then reads their responses in a single read. This reduces - // network delay into a single round-trip. - // - // @param pipeline supplies the queue for pending commands. - PipeDo(pipeline Pipeline) error - - // Once Close() is called all future method calls on the Client will return - // an error - Close() error - - // NumActiveConns return number of active connections, used in testing. - NumActiveConns() int - - // ImplicitPipeliningEnabled return true if implicit pipelining is enabled. - ImplicitPipeliningEnabled() bool -} - -type Pipeline []radix.CmdAction diff --git a/src/redis/driver_impl.go b/src/redis/driver_impl.go deleted file mode 100644 index 18e213f1b..000000000 --- a/src/redis/driver_impl.go +++ /dev/null @@ -1,164 +0,0 @@ -package redis - -import ( - "crypto/tls" - "fmt" - "strings" - "time" - - "github.com/mediocregopher/radix/v3/trace" - - stats "github.com/lyft/gostats" - "github.com/mediocregopher/radix/v3" - logger "github.com/sirupsen/logrus" -) - -type poolStats struct { - connectionActive stats.Gauge - connectionTotal stats.Counter - connectionClose stats.Counter -} - -func newPoolStats(scope stats.Scope) poolStats { - ret := poolStats{} - ret.connectionActive = scope.NewGauge("cx_active") - ret.connectionTotal = scope.NewCounter("cx_total") - ret.connectionClose = scope.NewCounter("cx_local_close") - return ret -} - -func poolTrace(ps *poolStats) trace.PoolTrace { - return trace.PoolTrace{ - ConnCreated: func(_ trace.PoolConnCreated) { - ps.connectionTotal.Add(1) - ps.connectionActive.Add(1) - }, - ConnClosed: func(_ trace.PoolConnClosed) { - ps.connectionActive.Sub(1) - ps.connectionClose.Add(1) - }, - } -} - -type clientImpl struct { - client radix.Client - stats poolStats - implicitPipelining bool -} - -func checkError(err error) { - if err != nil { - panic(RedisError(err.Error())) - } -} - -func NewClientImpl(scope stats.Scope, useTls bool, auth string, redisType string, url string, poolSize int, - pipelineWindow time.Duration, pipelineLimit int) Client { - logger.Warnf("connecting to redis on %s with pool size %d", url, poolSize) - - df := func(network, addr string) (radix.Conn, error) { - var dialOpts []radix.DialOpt - - var err error - if useTls { - dialOpts = append(dialOpts, radix.DialUseTLS(&tls.Config{})) - } - - if err != nil { - return nil, err - } - if auth != "" { - logger.Warnf("enabling authentication to redis on %s", url) - - dialOpts = append(dialOpts, radix.DialAuthPass(auth)) - } - - return radix.Dial(network, addr, dialOpts...) - } - - stats := newPoolStats(scope) - - opts := []radix.PoolOpt{radix.PoolConnFunc(df), radix.PoolWithTrace(poolTrace(&stats))} - - implicitPipelining := true - if pipelineWindow == 0 && pipelineLimit == 0 { - implicitPipelining = false - } else { - opts = append(opts, radix.PoolPipelineWindow(pipelineWindow, pipelineLimit)) - } - logger.Debugf("Implicit pipelining enabled: %v", implicitPipelining) - - poolFunc := func(network, addr string) (radix.Client, error) { - return radix.NewPool(network, addr, poolSize, opts...) - } - - var client radix.Client - var err error - switch strings.ToLower(redisType) { - case "single": - client, err = poolFunc("tcp", url) - case "cluster": - urls := strings.Split(url, ",") - if implicitPipelining == false { - panic(RedisError("Implicit Pipelining must be enabled to work with Redis Cluster Mode. Set values for REDIS_PIPELINE_WINDOW or REDIS_PIPELINE_LIMIT to enable implicit pipelining")) - } - logger.Warnf("Creating cluster with urls %v", urls) - client, err = radix.NewCluster(urls, radix.ClusterPoolFunc(poolFunc)) - case "sentinel": - urls := strings.Split(url, ",") - if len(urls) < 2 { - panic(RedisError("Expected master name and a list of urls for the sentinels, in the format: ,,...,")) - } - client, err = radix.NewSentinel(urls[0], urls[1:], radix.SentinelPoolFunc(poolFunc)) - default: - panic(RedisError("Unrecognized redis type " + redisType)) - } - - checkError(err) - - // Check if connection is good - var pingResponse string - checkError(client.Do(radix.Cmd(&pingResponse, "PING"))) - if pingResponse != "PONG" { - checkError(fmt.Errorf("connecting redis error: %s", pingResponse)) - } - - return &clientImpl{ - client: client, - stats: stats, - implicitPipelining: implicitPipelining, - } -} - -func (c *clientImpl) DoCmd(rcv interface{}, cmd, key string, args ...interface{}) error { - return c.client.Do(radix.FlatCmd(rcv, cmd, key, args...)) -} - -func (c *clientImpl) Close() error { - return c.client.Close() -} - -func (c *clientImpl) NumActiveConns() int { - return int(c.stats.connectionActive.Value()) -} - -func (c *clientImpl) PipeAppend(pipeline Pipeline, rcv interface{}, cmd, key string, args ...interface{}) Pipeline { - return append(pipeline, radix.FlatCmd(rcv, cmd, key, args...)) -} - -func (c *clientImpl) PipeDo(pipeline Pipeline) error { - if c.implicitPipelining { - for _, action := range pipeline { - if err := c.client.Do(action); err != nil { - return err - } - } - return nil - } - - return c.client.Do(radix.Pipeline(pipeline...)) -} - -func (c *clientImpl) ImplicitPipeliningEnabled() bool { - return c.implicitPipelining -} diff --git a/src/redis/fixed_cache_impl.go b/src/redis/fixed_cache_impl.go index b2b3d3d24..16e5e85d5 100644 --- a/src/redis/fixed_cache_impl.go +++ b/src/redis/fixed_cache_impl.go @@ -2,29 +2,33 @@ package redis import ( "math/rand" + "sync" "github.com/coocood/freecache" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" + storage_strategy "github.com/envoyproxy/ratelimit/src/storage/strategy" "github.com/envoyproxy/ratelimit/src/utils" logger "github.com/sirupsen/logrus" "golang.org/x/net/context" ) +type RedisError string + +func (e RedisError) Error() string { + return string(e) +} + type fixedRateLimitCacheImpl struct { - client Client + client storage_strategy.StorageStrategy // Optional Client for a dedicated cache of per second limits. // If this client is nil, then the Cache will use the client for all // limits regardless of unit. If this client is not nil, then it // is used for limits that have a SECOND unit. - perSecondClient Client + perSecondClient storage_strategy.StorageStrategy baseRateLimiter *limiter.BaseRateLimiter -} - -func pipelineAppend(client Client, pipeline *Pipeline, key string, hitsAddend uint32, result *uint32, expirationSeconds int64) { - *pipeline = client.PipeAppend(*pipeline, result, "INCRBY", key, hitsAddend) - *pipeline = client.PipeAppend(*pipeline, nil, "EXPIRE", key, expirationSeconds) + waitGroup sync.WaitGroup } func (this *fixedRateLimitCacheImpl) DoLimit( @@ -41,8 +45,7 @@ func (this *fixedRateLimitCacheImpl) DoLimit( cacheKeys := this.baseRateLimiter.GenerateCacheKeys(request, limits, hitsAddend) isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) - results := make([]uint32, len(request.Descriptors)) - var pipeline, perSecondPipeline Pipeline + results := make([]uint64, len(request.Descriptors)) // Now, actually setup the pipeline, skipping empty cache keys. for i, cacheKey := range cacheKeys { @@ -66,47 +69,60 @@ func (this *fixedRateLimitCacheImpl) DoLimit( // Use the perSecondConn if it is not nil and the cacheKey represents a per second Limit. if this.perSecondClient != nil && cacheKey.PerSecond { - if perSecondPipeline == nil { - perSecondPipeline = Pipeline{} + value, err := this.perSecondClient.GetValue(cacheKey.Key) + if err != nil { + logger.Error(err) } - pipelineAppend(this.perSecondClient, &perSecondPipeline, cacheKey.Key, hitsAddend, &results[i], expirationSeconds) + results[i] = value } else { - if pipeline == nil { - pipeline = Pipeline{} + value, err := this.client.GetValue(cacheKey.Key) + if err != nil { + logger.Error(err) } - pipelineAppend(this.client, &pipeline, cacheKey.Key, hitsAddend, &results[i], expirationSeconds) + results[i] = value } } - if pipeline != nil { - checkError(this.client.PipeDo(pipeline)) - } - if perSecondPipeline != nil { - checkError(this.perSecondClient.PipeDo(perSecondPipeline)) - } - // Now fetch the pipeline. responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus, len(request.Descriptors)) for i, cacheKey := range cacheKeys { - limitAfterIncrease := results[i] - limitBeforeIncrease := limitAfterIncrease - hitsAddend + limitBeforeIncrease := uint32(results[i]) + limitAfterIncrease := limitBeforeIncrease + hitsAddend limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0) responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key, limitInfo, isOverLimitWithLocalCache[i], hitsAddend) - } + this.waitGroup.Add(1) + go this.increaseAsync(cacheKeys, isOverLimitWithLocalCache, limits, uint64(hitsAddend)) + return responseDescriptorStatuses } +func (this *fixedRateLimitCacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, isOverLimitWithLocalCache []bool, + limits []*config.RateLimit, hitsAddend uint64) { + defer this.waitGroup.Done() + for i, cacheKey := range cacheKeys { + if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { + continue + } + + if this.perSecondClient != nil && cacheKey.PerSecond { + this.perSecondClient.IncrementValue(cacheKey.Key, hitsAddend) + } else { + this.client.IncrementValue(cacheKey.Key, hitsAddend) + } + } +} + // Flush() is a no-op with redis since quota reads and updates happen synchronously. func (this *fixedRateLimitCacheImpl) Flush() {} -func NewFixedRateLimitCacheImpl(client Client, perSecondClient Client, timeSource utils.TimeSource, +func NewFixedRateLimitCacheImpl(client storage_strategy.StorageStrategy, perSecondClient storage_strategy.StorageStrategy, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { return &fixedRateLimitCacheImpl{ client: client, diff --git a/src/storage/factory/memcached_factory.go b/src/storage/factory/memcached_factory.go new file mode 100644 index 000000000..f57248acc --- /dev/null +++ b/src/storage/factory/memcached_factory.go @@ -0,0 +1,21 @@ +package factory + +import ( + "github.com/bradfitz/gomemcache/memcache" + "github.com/envoyproxy/ratelimit/src/storage/service" + "github.com/envoyproxy/ratelimit/src/storage/strategy" +) + +func NewMemcached(servers []string) strategy.StorageStrategy { + client := newMemcachedClient(servers) + return strategy.MemcachedStrategy{ + Client: client, + } +} + +func newMemcachedClient(servers []string) service.MemcachedClientInterface { + client := memcache.New(servers...) + return &service.MemcachedClient{ + Client: client, + } +} diff --git a/src/storage/factory/redis_factory.go b/src/storage/factory/redis_factory.go new file mode 100644 index 000000000..5bda8eac9 --- /dev/null +++ b/src/storage/factory/redis_factory.go @@ -0,0 +1,97 @@ +package factory + +import ( + "crypto/tls" + "fmt" + "strings" + "time" + + logger "github.com/sirupsen/logrus" + + "github.com/envoyproxy/ratelimit/src/storage/service" + "github.com/envoyproxy/ratelimit/src/storage/strategy" + "github.com/envoyproxy/ratelimit/src/storage/utils" + "github.com/mediocregopher/radix/v3" +) + +func NewRedis(useTls bool, auth string, redisType string, url string, poolSize int, + pipelineWindow time.Duration, pipelineLimit int) strategy.StorageStrategy { + client := newRedisClient(useTls, auth, redisType, url, poolSize, pipelineWindow, pipelineLimit) + return strategy.RedisStrategy{ + Client: client, + } +} + +func newRedisClient(useTls bool, auth string, redisType string, url string, poolSize int, pipelineWindow time.Duration, pipelineLimit int) service.RedisClientInterface { + logger.Warnf("connecting to redis on %s with pool size %d", url, poolSize) + + df := func(network, addr string) (radix.Conn, error) { + var dialOpts []radix.DialOpt + + var err error + if useTls { + dialOpts = append(dialOpts, radix.DialUseTLS(&tls.Config{})) + } + + if err != nil { + return nil, err + } + if auth != "" { + logger.Warnf("enabling authentication to redis on %s", url) + + dialOpts = append(dialOpts, radix.DialAuthPass(auth)) + } + + return radix.Dial(network, addr, dialOpts...) + } + + opts := []radix.PoolOpt{radix.PoolConnFunc(df)} + + implicitPipelining := true + if pipelineWindow == 0 && pipelineLimit == 0 { + implicitPipelining = false + } else { + opts = append(opts, radix.PoolPipelineWindow(pipelineWindow, pipelineLimit)) + } + logger.Debugf("Implicit pipelining enabled: %v", implicitPipelining) + + poolFunc := func(network, addr string) (radix.Client, error) { + return radix.NewPool(network, addr, poolSize, opts...) + } + + var client radix.Client + var err error + switch strings.ToLower(redisType) { + case "single": + client, err = poolFunc("tcp", url) + case "cluster": + urls := strings.Split(url, ",") + if implicitPipelining == false { + panic(utils.RedisError("Implicit Pipelining must be enabled to work with Redis Cluster Mode. Set values for REDIS_PIPELINE_WINDOW or REDIS_PIPELINE_LIMIT to enable implicit pipelining")) + } + logger.Warnf("Creating cluster with urls %v", urls) + client, err = radix.NewCluster(urls, radix.ClusterPoolFunc(poolFunc)) + case "sentinel": + urls := strings.Split(url, ",") + if len(urls) < 2 { + panic(utils.RedisError("Expected master name and a list of urls for the sentinels, in the format: ,,...,")) + } + client, err = radix.NewSentinel(urls[0], urls[1:], radix.SentinelPoolFunc(poolFunc)) + default: + panic(utils.RedisError("Unrecognized redis type " + redisType)) + } + + utils.CheckError(err) + + // Check if connection is good + var pingResponse string + utils.CheckError(client.Do(radix.Cmd(&pingResponse, "PING"))) + if pingResponse != "PONG" { + utils.CheckError(fmt.Errorf("connecting redis error: %s", pingResponse)) + } + + return &service.RedisClient{ + Client: client, + ImplicitPipelining: implicitPipelining, + } +} diff --git a/src/storage/service/memcached_client.go b/src/storage/service/memcached_client.go new file mode 100644 index 000000000..b2cc07962 --- /dev/null +++ b/src/storage/service/memcached_client.go @@ -0,0 +1,27 @@ +package service + +import ( + "github.com/bradfitz/gomemcache/memcache" +) + +type MemcachedClientInterface interface { + Get(key string) (*memcache.Item, error) + Set(item *memcache.Item) error + Increment(key string, delta uint64) (uint64, error) +} + +type MemcachedClient struct { + Client *memcache.Client +} + +func (m MemcachedClient) Get(key string) (*memcache.Item, error) { + return m.Client.Get(key) +} + +func (m MemcachedClient) Set(item *memcache.Item) error { + return m.Client.Set(item) +} + +func (m MemcachedClient) Increment(key string, delta uint64) (uint64, error) { + return m.Client.Increment(key, delta) +} diff --git a/src/storage/service/redis_client.go b/src/storage/service/redis_client.go new file mode 100644 index 000000000..56ffbde01 --- /dev/null +++ b/src/storage/service/redis_client.go @@ -0,0 +1,18 @@ +package service + +import ( + "github.com/mediocregopher/radix/v3" +) + +type RedisClientInterface interface { + Do(radix.Action) error +} + +type RedisClient struct { + Client radix.Client + ImplicitPipelining bool +} + +func (r RedisClient) Do(cmd radix.Action) error { + return r.Client.Do(cmd) +} diff --git a/src/storage/strategy/memcached_strategy.go b/src/storage/strategy/memcached_strategy.go new file mode 100644 index 000000000..e67cc4f6c --- /dev/null +++ b/src/storage/strategy/memcached_strategy.go @@ -0,0 +1,50 @@ +package strategy + +import ( + "strconv" + + "github.com/bradfitz/gomemcache/memcache" + "github.com/envoyproxy/ratelimit/src/storage/service" +) + +type MemcachedStrategy struct { + Client service.MemcachedClientInterface +} + +func (m MemcachedStrategy) GetValue(key string) (uint64, error) { + item, err := m.Client.Get(key) + if err != nil { + return 0, err + } + + value, err := strconv.ParseUint(string(item.Value), 10, 32) + if err != nil { + return 0, err + } + + return value, nil +} + +func (m MemcachedStrategy) SetValue(key string, value uint64, expirationSeconds uint64) error { + item := &memcache.Item{ + Key: key, + Value: []byte(strconv.FormatUint(value, 10)), + Expiration: int32(expirationSeconds), + } + + err := m.Client.Set(item) + if err != nil { + return err + } + + return nil +} + +func (m MemcachedStrategy) IncrementValue(key string, delta uint64) error { + _, err := m.Client.Increment(key, delta) + if err != nil { + return err + } + + return nil +} diff --git a/src/storage/strategy/redis_strategy.go b/src/storage/strategy/redis_strategy.go new file mode 100644 index 000000000..4200bff14 --- /dev/null +++ b/src/storage/strategy/redis_strategy.go @@ -0,0 +1,43 @@ +package strategy + +import ( + "github.com/envoyproxy/ratelimit/src/storage/service" + "github.com/mediocregopher/radix/v3" +) + +type RedisStrategy struct { + Client service.RedisClientInterface +} + +func (r RedisStrategy) GetValue(key string) (uint64, error) { + var value uint64 + err := r.Client.Do(radix.Cmd(&value, "GET", key)) + if err != nil { + return value, err + } + + return value, nil +} + +func (r RedisStrategy) SetValue(key string, value uint64, expirationSeconds uint64) error { + + err := r.Client.Do(radix.FlatCmd(nil, "SET", key, value)) + if err != nil { + return err + } + + err = r.Client.Do(radix.FlatCmd(nil, "EXPIRE", key, expirationSeconds)) + if err != nil { + return err + } + + return nil +} + +func (r RedisStrategy) IncrementValue(key string, delta uint64) error { + err := r.Client.Do(radix.FlatCmd(nil, "INCRBY", key, delta)) + if err != nil { + return err + } + return nil +} diff --git a/src/storage/strategy/storage_strategy.go b/src/storage/strategy/storage_strategy.go new file mode 100644 index 000000000..9e0306af2 --- /dev/null +++ b/src/storage/strategy/storage_strategy.go @@ -0,0 +1,7 @@ +package strategy + +type StorageStrategy interface { + GetValue(key string) (uint64, error) + SetValue(key string, value uint64, expirationSeconds uint64) error + IncrementValue(key string, delta uint64) error +} diff --git a/src/storage/utils/utils.go b/src/storage/utils/utils.go new file mode 100644 index 000000000..c6e2c7a8c --- /dev/null +++ b/src/storage/utils/utils.go @@ -0,0 +1,13 @@ +package utils + +type RedisError string + +func (e RedisError) Error() string { + return string(e) +} + +func CheckError(err error) { + if err != nil { + panic(RedisError(err.Error())) + } +} diff --git a/test/memcached/cache_impl_test.go b/test/memcached/cache_impl_test.go deleted file mode 100644 index 1e2ba8d77..000000000 --- a/test/memcached/cache_impl_test.go +++ /dev/null @@ -1,594 +0,0 @@ -// Adapted from test/redis/cache_impl_test.go, with most test cases being the same -// basic idea. TestMemcacheAdd() is unique to the memcache tests, since redis can create a new key -// simply by incrementing it but memcached cannot. In memcache new keys need to be explicitly -// added. -package memcached_test - -import ( - "math/rand" - "strconv" - "testing" - - "github.com/bradfitz/gomemcache/memcache" - "github.com/coocood/freecache" - - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - "github.com/envoyproxy/ratelimit/src/config" - "github.com/envoyproxy/ratelimit/src/limiter" - "github.com/envoyproxy/ratelimit/src/memcached" - "github.com/envoyproxy/ratelimit/src/utils" - stats "github.com/lyft/gostats" - - "github.com/envoyproxy/ratelimit/test/common" - mock_memcached "github.com/envoyproxy/ratelimit/test/mocks/memcached" - mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" -) - -func TestMemcached(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - timeSource := mock_utils.NewMockTimeSource(controller) - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, statsStore, 0.8, "") - - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return( - getMultiResult(map[string]int{"domain_key_value_1234": 4}), nil, - ) - client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return(uint64(5), nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key2_value2_subkey2_subvalue2_1200"}).Return( - getMultiResult(map[string]int{"domain_key2_value2_subkey2_subvalue2_1200": 10}), nil, - ) - client.EXPECT().Increment("domain_key2_value2_subkey2_subvalue2_1200", uint64(1)).Return(uint64(11), nil) - - request = common.NewRateLimitRequest( - "domain", - [][][2]string{ - {{"key2", "value2"}}, - {{"key2", "value2"}, {"subkey2", "subvalue2"}}, - }, 1) - limits = []*config.RateLimit{ - nil, - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2_subkey2_subvalue2", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) - assert.Equal(uint64(1), limits[1].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[1].Stats.NearLimit.Value()) - assert.Equal(uint64(0), limits[1].Stats.WithinLimit.Value()) - - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(5) - client.EXPECT().GetMulti([]string{ - "domain_key3_value3_997200", - "domain_key3_value3_subkey3_subvalue3_950400", - }).Return( - getMultiResult(map[string]int{ - "domain_key3_value3_997200": 10, - "domain_key3_value3_subkey3_subvalue3_950400": 12}), - nil, - ) - client.EXPECT().Increment("domain_key3_value3_997200", uint64(1)).Return(uint64(11), nil) - client.EXPECT().Increment("domain_key3_value3_subkey3_subvalue3_950400", uint64(1)).Return(uint64(13), nil) - - request = common.NewRateLimitRequest( - "domain", - [][][2]string{ - {{"key3", "value3"}}, - {{"key3", "value3"}, {"subkey3", "subvalue3"}}, - }, 1) - limits = []*config.RateLimit{ - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key3_value3", statsStore), - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, "key3_value3_subkey3_subvalue3", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) - - cache.Flush() -} - -func TestMemcachedGetError(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - timeSource := mock_utils.NewMockTimeSource(controller) - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, statsStore, 0.8, "") - - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return( - nil, memcache.ErrNoServers, - ) - client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return(uint64(5), nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - // No error, but the key is missing - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key_value1_1234"}).Return( - nil, nil, - ) - client.EXPECT().Increment("domain_key_value1_1234", uint64(1)).Return(uint64(5), nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value1"}}}, 1) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value1", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - cache.Flush() -} - -func testLocalCacheStats(localCacheStats stats.StatGenerator, statsStore stats.Store, sink *common.TestStatSink, - expectedHitCount int, expectedMissCount int, expectedLookUpCount int, expectedExpiredCount int, - expectedEntryCount int) func(*testing.T) { - return func(t *testing.T) { - localCacheStats.GenerateStats() - statsStore.Flush() - - // Check whether all local_cache related stats are available. - _, ok := sink.Record["averageAccessTime"] - assert.Equal(t, true, ok) - hitCount, ok := sink.Record["hitCount"] - assert.Equal(t, true, ok) - missCount, ok := sink.Record["missCount"] - assert.Equal(t, true, ok) - lookupCount, ok := sink.Record["lookupCount"] - assert.Equal(t, true, ok) - _, ok = sink.Record["overwriteCount"] - assert.Equal(t, true, ok) - _, ok = sink.Record["evacuateCount"] - assert.Equal(t, true, ok) - expiredCount, ok := sink.Record["expiredCount"] - assert.Equal(t, true, ok) - entryCount, ok := sink.Record["entryCount"] - assert.Equal(t, true, ok) - - // Check the correctness of hitCount, missCount, lookupCount, expiredCount and entryCount - assert.Equal(t, expectedHitCount, hitCount.(int)) - assert.Equal(t, expectedMissCount, missCount.(int)) - assert.Equal(t, expectedLookUpCount, lookupCount.(int)) - assert.Equal(t, expectedExpiredCount, expiredCount.(int)) - assert.Equal(t, expectedEntryCount, entryCount.(int)) - - sink.Clear() - } -} - -func TestOverLimitWithLocalCache(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - timeSource := mock_utils.NewMockTimeSource(controller) - client := mock_memcached.NewMockClient(controller) - localCache := freecache.NewCache(100) - sink := &common.TestStatSink{} - statsStore := stats.NewStore(sink, true) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, localCache, statsStore, 0.8, "") - localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) - - // Test Near Limit Stats. Under Near Limit Ratio - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( - getMultiResult(map[string]int{"domain_key4_value4_997200": 10}), nil, - ) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(5), nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) - - limits := []*config.RateLimit{ - config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, "key4_value4", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - // Check the local cache stats. - testLocalCacheStats(localCacheStats, statsStore, sink, 0, 1, 1, 0, 0) - - // Test Near Limit Stats. At Near Limit Ratio, still OK - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( - getMultiResult(map[string]int{"domain_key4_value4_997200": 12}), nil, - ) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(13), nil) - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) - - // Check the local cache stats. - testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 2, 0, 0) - - // Test Over limit stats - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( - getMultiResult(map[string]int{"domain_key4_value4_997200": 15}), nil, - ) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(16), nil) - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) - - // Check the local cache stats. - testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 3, 0, 1) - - // Test Over limit stats with local cache - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Times(0) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Times(0) - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(4), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.OverLimitWithLocalCache.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) - - // Check the local cache stats. - testLocalCacheStats(localCacheStats, statsStore, sink, 1, 3, 4, 0, 1) - - cache.Flush() -} - -func TestNearLimit(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - timeSource := mock_utils.NewMockTimeSource(controller) - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, statsStore, 0.8, "") - - // Test Near Limit Stats. Under Near Limit Ratio - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( - getMultiResult(map[string]int{"domain_key4_value4_997200": 10}), nil, - ) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(11), nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) - - limits := []*config.RateLimit{ - config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, "key4_value4", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - // Test Near Limit Stats. At Near Limit Ratio, still OK - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( - getMultiResult(map[string]int{"domain_key4_value4_997200": 12}), nil, - ) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(13), nil) - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) - - // Test Near Limit Stats. We went OVER_LIMIT, but the near_limit counter only increases - // when we are near limit, not after we have passed the limit. - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key4_value4_997200"}).Return( - getMultiResult(map[string]int{"domain_key4_value4_997200": 15}), nil, - ) - client.EXPECT().Increment("domain_key4_value4_997200", uint64(1)).Return(uint64(16), nil) - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) - - // Now test hitsAddend that is greater than 1 - // All of it under limit, under near limit - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key5_value5_1234"}).Return( - getMultiResult(map[string]int{"domain_key5_value5_1234": 2}), nil, - ) - client.EXPECT().Increment("domain_key5_value5_1234", uint64(3)).Return(uint64(5), nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key5", "value5"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key5_value5", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(3), limits[0].Stats.WithinLimit.Value()) - - // All of it under limit, some over near limit - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key6_value6_1234"}).Return( - getMultiResult(map[string]int{"domain_key6_value6_1234": 5}), nil, - ) - client.EXPECT().Increment("domain_key6_value6_1234", uint64(2)).Return(uint64(7), nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key6", "value6"}}}, 2) - limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, "key6_value6", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) - - // All of it under limit, all of it over near limit - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key7_value7_1234"}).Return( - getMultiResult(map[string]int{"domain_key7_value7_1234": 16}), nil, - ) - client.EXPECT().Increment("domain_key7_value7_1234", uint64(3)).Return(uint64(19), nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key7", "value7"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key7_value7", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(3), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(3), limits[0].Stats.WithinLimit.Value()) - - // Some of it over limit, all of it over near limit - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key8_value8_1234"}).Return( - getMultiResult(map[string]int{"domain_key8_value8_1234": 19}), nil, - ) - client.EXPECT().Increment("domain_key8_value8_1234", uint64(3)).Return(uint64(22), nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key8", "value8"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key8_value8", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) - - // Some of it in all three places - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key9_value9_1234"}).Return( - getMultiResult(map[string]int{"domain_key9_value9_1234": 15}), nil, - ) - client.EXPECT().Increment("domain_key9_value9_1234", uint64(7)).Return(uint64(22), nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key9", "value9"}}}, 7) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key9_value9", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(7), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(4), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) - - // all of it over limit - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key10_value10_1234"}).Return( - getMultiResult(map[string]int{"domain_key10_value10_1234": 27}), nil, - ) - client.EXPECT().Increment("domain_key10_value10_1234", uint64(3)).Return(uint64(30), nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key10", "value10"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key10_value10", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(3), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) - - cache.Flush() -} - -func TestMemcacheWithJitter(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - timeSource := mock_utils.NewMockTimeSource(controller) - client := mock_memcached.NewMockClient(controller) - jitterSource := mock_utils.NewMockJitterRandSource(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, rand.New(jitterSource), 3600, nil, statsStore, 0.8, "") - - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - jitterSource.EXPECT().Int63().Return(int64(100)) - - // Key is not found in memcache - client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return(nil, nil) - // First increment attempt will fail - client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return( - uint64(0), memcache.ErrCacheMiss) - // Add succeeds - client.EXPECT().Add( - &memcache.Item{ - Key: "domain_key_value_1234", - Value: []byte(strconv.FormatUint(1, 10)), - // 1 second + 100 seconds of jitter - Expiration: int32(101), - }, - ).Return(nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - cache.Flush() -} - -func TestMemcacheAdd(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - timeSource := mock_utils.NewMockTimeSource(controller) - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(stats.NewNullSink(), false) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, statsStore, 0.8, "") - - // Test a race condition with the initial add - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - - client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return(nil, nil) - client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return( - uint64(0), memcache.ErrCacheMiss) - // Add fails, must have been a race condition - client.EXPECT().Add( - &memcache.Item{ - Key: "domain_key_value_1234", - Value: []byte(strconv.FormatUint(1, 10)), - Expiration: int32(1), - }, - ).Return(memcache.ErrNotStored) - // Should work the second time, since some other client added the key. - client.EXPECT().Increment("domain_key_value_1234", uint64(1)).Return( - uint64(2), nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - // A rate limit with 1-minute window - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().GetMulti([]string{"domain_key2_value2_1200"}).Return(nil, nil) - client.EXPECT().Increment("domain_key2_value2_1200", uint64(1)).Return( - uint64(0), memcache.ErrCacheMiss) - client.EXPECT().Add( - &memcache.Item{ - Key: "domain_key2_value2_1200", - Value: []byte(strconv.FormatUint(1, 10)), - Expiration: int32(60), - }, - ).Return(nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key2", "value2"}}}, 1) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 9, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - cache.Flush() -} - -func getMultiResult(vals map[string]int) map[string]*memcache.Item { - result := make(map[string]*memcache.Item, len(vals)) - for k, v := range vals { - result[k] = &memcache.Item{ - Value: []byte(strconv.Itoa(v)), - } - } - return result -} diff --git a/test/memcached/stats_collecting_client_test.go b/test/memcached/stats_collecting_client_test.go deleted file mode 100644 index 548b93041..000000000 --- a/test/memcached/stats_collecting_client_test.go +++ /dev/null @@ -1,199 +0,0 @@ -package memcached_test - -import ( - "errors" - "testing" - - "github.com/bradfitz/gomemcache/memcache" - "github.com/envoyproxy/ratelimit/src/memcached" - mock_memcached "github.com/envoyproxy/ratelimit/test/mocks/memcached" - "github.com/golang/mock/gomock" - stats "github.com/lyft/gostats" - "github.com/stretchr/testify/assert" -) - -type fakeSink struct { - values map[string]uint64 -} - -func (fs *fakeSink) FlushCounter(name string, value uint64) { - if _, ok := fs.values[name]; ok { - panic(errors.New("fakeSink wasn't cleared before flushing again")) - } - - fs.values[name] = value -} - -func (fs *fakeSink) FlushGauge(name string, value uint64) {} - -func (fs *fakeSink) FlushTimer(name string, value float64) {} - -func (fs *fakeSink) Flush() {} - -func (fs *fakeSink) Reset() { - fs.values = make(map[string]uint64) -} - -func TestStats_MultiGet(t *testing.T) { - fakeSink := &fakeSink{} - fakeSink.Reset() - - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(fakeSink, false) - - sc := memcached.CollectStats(client, statsStore) - - returnValue := map[string]*memcache.Item{"foo": nil} - arg := []string{"foo"} - - client.EXPECT().GetMulti(arg).Return(returnValue, nil) - actualReturnValue, err := sc.GetMulti(arg) - statsStore.Flush() - - assert.Equal(returnValue, actualReturnValue) - assert.Nil(err) - assert.Equal(map[string]uint64{ - "keys_found": 1, - "keys_requested": 1, - "multiget.__code=success": 1, - }, fakeSink.values) - - fakeSink.Reset() - returnValue = map[string]*memcache.Item{"foo": nil, "bar": nil} - client.EXPECT().GetMulti(arg).Return(returnValue, nil) - actualReturnValue, err = sc.GetMulti(arg) - statsStore.Flush() - - assert.Equal(returnValue, actualReturnValue) - assert.Nil(err) - assert.Equal(map[string]uint64{ - "keys_found": 2, - "keys_requested": 1, - "multiget.__code=success": 1, - }, fakeSink.values) - - fakeSink.Reset() - returnValue = map[string]*memcache.Item{} - arg = []string{"foo", "bar"} - - client.EXPECT().GetMulti(arg).Return(returnValue, nil) - actualReturnValue, err = sc.GetMulti(arg) - - statsStore.Flush() - assert.Equal(returnValue, actualReturnValue) - assert.Nil(err) - - assert.Equal(map[string]uint64{ - "keys_requested": 2, - "multiget.__code=success": 1, - }, fakeSink.values) - - fakeSink.Reset() - returnValue = map[string]*memcache.Item{"ignored": nil} - arg = []string{"foo"} - returnedErr := errors.New("Random error") - - client.EXPECT().GetMulti(arg).Return(returnValue, returnedErr) - actualReturnValue, err = sc.GetMulti(arg) - - statsStore.Flush() - assert.Equal(returnValue, actualReturnValue) - assert.Equal(returnedErr, err) - - assert.Equal(map[string]uint64{ - "keys_requested": 1, - "multiget.__code=error": 1, - }, fakeSink.values) -} - -func TestStats_Increment(t *testing.T) { - fakeSink := &fakeSink{} - - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(fakeSink, false) - - sc := memcached.CollectStats(client, statsStore) - - fakeSink.Reset() - client.EXPECT().Increment("foo", uint64(5)).Return(uint64(6), nil) - newValue, err := sc.Increment("foo", 5) - statsStore.Flush() - - assert.Equal(uint64(6), newValue) - assert.Nil(err) - assert.Equal(map[string]uint64{ - "increment.__code=success": 1, - }, fakeSink.values) - - expectedErr := errors.New("expectedError") - fakeSink.Reset() - client.EXPECT().Increment("foo", uint64(5)).Return(uint64(0), expectedErr) - newValue, err = sc.Increment("foo", 5) - statsStore.Flush() - - assert.Equal(expectedErr, err) - assert.Equal(map[string]uint64{ - "increment.__code=error": 1, - }, fakeSink.values) - - fakeSink.Reset() - client.EXPECT().Increment("foo", uint64(5)).Return(uint64(0), memcache.ErrCacheMiss) - newValue, err = sc.Increment("foo", 5) - statsStore.Flush() - - assert.Equal(memcache.ErrCacheMiss, err) - assert.Equal(map[string]uint64{ - "increment.__code=miss": 1, - }, fakeSink.values) -} - -func TestStats_Add(t *testing.T) { - fakeSink := &fakeSink{} - - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - client := mock_memcached.NewMockClient(controller) - statsStore := stats.NewStore(fakeSink, false) - - sc := memcached.CollectStats(client, statsStore) - item := &memcache.Item{} - - fakeSink.Reset() - client.EXPECT().Add(item).Return(nil) - err := sc.Add(item) - statsStore.Flush() - - assert.Nil(err) - assert.Equal(map[string]uint64{ - "add.__code=success": 1, - }, fakeSink.values) - - expectedErr := errors.New("expected err") - - fakeSink.Reset() - client.EXPECT().Add(item).Return(expectedErr) - err = sc.Add(item) - statsStore.Flush() - - assert.Equal(expectedErr, err) - assert.Equal(map[string]uint64{ - "add.__code=error": 1, - }, fakeSink.values) - - fakeSink.Reset() - client.EXPECT().Add(item).Return(memcache.ErrNotStored) - err = sc.Add(item) - statsStore.Flush() - - assert.Equal(memcache.ErrNotStored, err) - assert.Equal(map[string]uint64{ - "add.__code=not_stored": 1, - }, fakeSink.values) -} diff --git a/test/mocks/redis/redis.go b/test/mocks/redis/redis.go deleted file mode 100644 index 032b500dc..000000000 --- a/test/mocks/redis/redis.go +++ /dev/null @@ -1,128 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/envoyproxy/ratelimit/src/redis (interfaces: Client) - -// Package mock_redis is a generated GoMock package. -package mock_redis - -import ( - redis "github.com/envoyproxy/ratelimit/src/redis" - gomock "github.com/golang/mock/gomock" - reflect "reflect" -) - -// MockClient is a mock of Client interface -type MockClient struct { - ctrl *gomock.Controller - recorder *MockClientMockRecorder -} - -// MockClientMockRecorder is the mock recorder for MockClient -type MockClientMockRecorder struct { - mock *MockClient -} - -// NewMockClient creates a new mock instance -func NewMockClient(ctrl *gomock.Controller) *MockClient { - mock := &MockClient{ctrl: ctrl} - mock.recorder = &MockClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockClient) EXPECT() *MockClientMockRecorder { - return m.recorder -} - -// Close mocks base method -func (m *MockClient) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close -func (mr *MockClientMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockClient)(nil).Close)) -} - -// DoCmd mocks base method -func (m *MockClient) DoCmd(arg0 interface{}, arg1, arg2 string, arg3 ...interface{}) error { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "DoCmd", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// DoCmd indicates an expected call of DoCmd -func (mr *MockClientMockRecorder) DoCmd(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoCmd", reflect.TypeOf((*MockClient)(nil).DoCmd), varargs...) -} - -// ImplicitPipeliningEnabled mocks base method -func (m *MockClient) ImplicitPipeliningEnabled() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ImplicitPipeliningEnabled") - ret0, _ := ret[0].(bool) - return ret0 -} - -// ImplicitPipeliningEnabled indicates an expected call of ImplicitPipeliningEnabled -func (mr *MockClientMockRecorder) ImplicitPipeliningEnabled() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ImplicitPipeliningEnabled", reflect.TypeOf((*MockClient)(nil).ImplicitPipeliningEnabled)) -} - -// NumActiveConns mocks base method -func (m *MockClient) NumActiveConns() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NumActiveConns") - ret0, _ := ret[0].(int) - return ret0 -} - -// NumActiveConns indicates an expected call of NumActiveConns -func (mr *MockClientMockRecorder) NumActiveConns() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NumActiveConns", reflect.TypeOf((*MockClient)(nil).NumActiveConns)) -} - -// PipeAppend mocks base method -func (m *MockClient) PipeAppend(arg0 redis.Pipeline, arg1 interface{}, arg2, arg3 string, arg4 ...interface{}) redis.Pipeline { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2, arg3} - for _, a := range arg4 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "PipeAppend", varargs...) - ret0, _ := ret[0].(redis.Pipeline) - return ret0 -} - -// PipeAppend indicates an expected call of PipeAppend -func (mr *MockClientMockRecorder) PipeAppend(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PipeAppend", reflect.TypeOf((*MockClient)(nil).PipeAppend), varargs...) -} - -// PipeDo mocks base method -func (m *MockClient) PipeDo(arg0 redis.Pipeline) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PipeDo", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// PipeDo indicates an expected call of PipeDo -func (mr *MockClientMockRecorder) PipeDo(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PipeDo", reflect.TypeOf((*MockClient)(nil).PipeDo), arg0) -} diff --git a/test/redis/bench_test.go b/test/redis/bench_test.go deleted file mode 100644 index 6c190ea7b..000000000 --- a/test/redis/bench_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package redis_test - -import ( - "context" - "runtime" - "testing" - "time" - - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - "github.com/envoyproxy/ratelimit/src/config" - "github.com/envoyproxy/ratelimit/src/redis" - "github.com/envoyproxy/ratelimit/src/utils" - stats "github.com/lyft/gostats" - - "math/rand" - - "github.com/envoyproxy/ratelimit/test/common" -) - -func BenchmarkParallelDoLimit(b *testing.B) { - b.Skip("Skip benchmark") - - b.ReportAllocs() - - // See https://github.com/mediocregopher/radix/blob/v3.5.1/bench/bench_test.go#L176 - parallel := runtime.GOMAXPROCS(0) - poolSize := parallel * runtime.GOMAXPROCS(0) - - do := func(b *testing.B, fn func() error) { - b.ResetTimer() - b.SetParallelism(parallel) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := fn(); err != nil { - b.Fatal(err) - } - } - }) - } - - mkDoLimitBench := func(pipelineWindow time.Duration, pipelineLimit int) func(*testing.B) { - return func(b *testing.B) { - statsStore := stats.NewStore(stats.NewNullSink(), false) - client := redis.NewClientImpl(statsStore, false, "", "single", "127.0.0.1:6379", poolSize, pipelineWindow, pipelineLimit) - defer client.Close() - - cache := redis.NewFixedRateLimitCacheImpl(client, nil, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), 10, nil, 0.8, "") - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(1000000000, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - - // wait for the pool to fill up - for { - time.Sleep(50 * time.Millisecond) - if client.NumActiveConns() >= poolSize { - break - } - } - - b.ResetTimer() - - do(b, func() error { - cache.DoLimit(context.Background(), request, limits) - return nil - }) - } - } - - b.Run("no pipeline", mkDoLimitBench(0, 0)) - - b.Run("pipeline 35us 1", mkDoLimitBench(35*time.Microsecond, 1)) - b.Run("pipeline 75us 1", mkDoLimitBench(75*time.Microsecond, 1)) - b.Run("pipeline 150us 1", mkDoLimitBench(150*time.Microsecond, 1)) - b.Run("pipeline 300us 1", mkDoLimitBench(300*time.Microsecond, 1)) - - b.Run("pipeline 35us 2", mkDoLimitBench(35*time.Microsecond, 2)) - b.Run("pipeline 75us 2", mkDoLimitBench(75*time.Microsecond, 2)) - b.Run("pipeline 150us 2", mkDoLimitBench(150*time.Microsecond, 2)) - b.Run("pipeline 300us 2", mkDoLimitBench(300*time.Microsecond, 2)) - - b.Run("pipeline 35us 4", mkDoLimitBench(35*time.Microsecond, 4)) - b.Run("pipeline 75us 4", mkDoLimitBench(75*time.Microsecond, 4)) - b.Run("pipeline 150us 4", mkDoLimitBench(150*time.Microsecond, 4)) - b.Run("pipeline 300us 4", mkDoLimitBench(300*time.Microsecond, 4)) - - b.Run("pipeline 35us 8", mkDoLimitBench(35*time.Microsecond, 8)) - b.Run("pipeline 75us 8", mkDoLimitBench(75*time.Microsecond, 8)) - b.Run("pipeline 150us 8", mkDoLimitBench(150*time.Microsecond, 8)) - b.Run("pipeline 300us 8", mkDoLimitBench(300*time.Microsecond, 8)) - - b.Run("pipeline 35us 16", mkDoLimitBench(35*time.Microsecond, 16)) - b.Run("pipeline 75us 16", mkDoLimitBench(75*time.Microsecond, 16)) - b.Run("pipeline 150us 16", mkDoLimitBench(150*time.Microsecond, 16)) - b.Run("pipeline 300us 16", mkDoLimitBench(300*time.Microsecond, 16)) -} diff --git a/test/redis/driver_impl_test.go b/test/redis/driver_impl_test.go deleted file mode 100644 index ab488e239..000000000 --- a/test/redis/driver_impl_test.go +++ /dev/null @@ -1,205 +0,0 @@ -package redis_test - -import ( - "testing" - "time" - - "github.com/alicebob/miniredis/v2" - "github.com/envoyproxy/ratelimit/src/redis" - "github.com/lyft/gostats" - "github.com/stretchr/testify/assert" -) - -func mustNewRedisServer() *miniredis.Miniredis { - srv, err := miniredis.Run() - if err != nil { - panic(err) - } - - return srv -} - -func expectPanicError(t *testing.T, f assert.PanicTestFunc) (result error) { - t.Helper() - defer func() { - panicResult := recover() - assert.NotNil(t, panicResult, "Expected a panic") - result = panicResult.(error) - }() - f() - return -} - -func testNewClientImpl(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) func(t *testing.T) { - return func(t *testing.T) { - redisAuth := "123" - statsStore := stats.NewStore(stats.NewNullSink(), false) - - mkRedisClient := func(auth, addr string) redis.Client { - return redis.NewClientImpl(statsStore, false, auth, "single", addr, 1, pipelineWindow, pipelineLimit) - } - - t.Run("connection refused", func(t *testing.T) { - // It's possible there is a redis server listening on 6379 in ci environment, so - // use a random port. - panicErr := expectPanicError(t, func() { mkRedisClient("", "localhost:12345") }) - assert.Contains(t, panicErr.Error(), "connection refused") - }) - - t.Run("ok", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - var client redis.Client - assert.NotPanics(t, func() { - client = mkRedisClient("", redisSrv.Addr()) - }) - assert.NotNil(t, client) - }) - - t.Run("auth fail", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - redisSrv.RequireAuth(redisAuth) - - assert.PanicsWithError(t, "NOAUTH Authentication required.", func() { - mkRedisClient("", redisSrv.Addr()) - }) - }) - - t.Run("auth pass", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - redisSrv.RequireAuth(redisAuth) - - assert.NotPanics(t, func() { - mkRedisClient(redisAuth, redisSrv.Addr()) - }) - }) - - t.Run("ImplicitPipeliningEnabled() return expected value", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - client := mkRedisClient("", redisSrv.Addr()) - - if pipelineWindow == 0 && pipelineLimit == 0 { - assert.False(t, client.ImplicitPipeliningEnabled()) - } else { - assert.True(t, client.ImplicitPipeliningEnabled()) - } - }) - } -} - -func TestNewClientImpl(t *testing.T) { - t.Run("ImplicitPipeliningEnabled", testNewClientImpl(t, 2*time.Millisecond, 2)) - t.Run("ImplicitPipeliningDisabled", testNewClientImpl(t, 0, 0)) -} - -func TestDoCmd(t *testing.T) { - statsStore := stats.NewStore(stats.NewNullSink(), false) - - mkRedisClient := func(addr string) redis.Client { - return redis.NewClientImpl(statsStore, false, "", "single", addr, 1, 0, 0) - } - - t.Run("SETGET ok", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - client := mkRedisClient(redisSrv.Addr()) - var res string - - assert.Nil(t, client.DoCmd(nil, "SET", "foo", "bar")) - assert.Nil(t, client.DoCmd(&res, "GET", "foo")) - assert.Equal(t, "bar", res) - }) - - t.Run("INCRBY ok", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - client := mkRedisClient(redisSrv.Addr()) - var res uint32 - hits := uint32(1) - - assert.Nil(t, client.DoCmd(&res, "INCRBY", "a", hits)) - assert.Equal(t, hits, res) - assert.Nil(t, client.DoCmd(&res, "INCRBY", "a", hits)) - assert.Equal(t, uint32(2), res) - }) - - t.Run("connection broken", func(t *testing.T) { - redisSrv := mustNewRedisServer() - client := mkRedisClient(redisSrv.Addr()) - - assert.Nil(t, client.DoCmd(nil, "SET", "foo", "bar")) - - redisSrv.Close() - assert.EqualError(t, client.DoCmd(nil, "GET", "foo"), "EOF") - }) -} - -func testPipeDo(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) func(t *testing.T) { - return func(t *testing.T) { - statsStore := stats.NewStore(stats.NewNullSink(), false) - - mkRedisClient := func(addr string) redis.Client { - return redis.NewClientImpl(statsStore, false, "", "single", addr, 1, pipelineWindow, pipelineLimit) - } - - t.Run("SETGET ok", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - client := mkRedisClient(redisSrv.Addr()) - var res string - - pipeline := redis.Pipeline{} - pipeline = client.PipeAppend(pipeline, nil, "SET", "foo", "bar") - pipeline = client.PipeAppend(pipeline, &res, "GET", "foo") - - assert.Nil(t, client.PipeDo(pipeline)) - assert.Equal(t, "bar", res) - }) - - t.Run("INCRBY ok", func(t *testing.T) { - redisSrv := mustNewRedisServer() - defer redisSrv.Close() - - client := mkRedisClient(redisSrv.Addr()) - var res uint32 - hits := uint32(1) - - assert.Nil(t, client.PipeDo(client.PipeAppend(redis.Pipeline{}, &res, "INCRBY", "a", hits))) - assert.Equal(t, hits, res) - - assert.Nil(t, client.PipeDo(client.PipeAppend(redis.Pipeline{}, &res, "INCRBY", "a", hits))) - assert.Equal(t, uint32(2), res) - }) - - t.Run("connection broken", func(t *testing.T) { - redisSrv := mustNewRedisServer() - client := mkRedisClient(redisSrv.Addr()) - - assert.Nil(t, nil, client.PipeDo(client.PipeAppend(redis.Pipeline{}, nil, "SET", "foo", "bar"))) - - redisSrv.Close() - - expectErrContainEOF := func(t *testing.T, err error) { - assert.NotNil(t, err) - assert.Contains(t, err.Error(), "EOF") - } - - expectErrContainEOF(t, client.PipeDo(client.PipeAppend(redis.Pipeline{}, nil, "GET", "foo"))) - }) - } -} - -func TestPipeDo(t *testing.T) { - t.Run("ImplicitPipeliningEnabled", testPipeDo(t, 10*time.Millisecond, 2)) - t.Run("ImplicitPipeliningDisabled", testPipeDo(t, 0, 0)) -} diff --git a/test/redis/fixed_cache_impl_test.go b/test/redis/fixed_cache_impl_test.go deleted file mode 100644 index 65883f4b9..000000000 --- a/test/redis/fixed_cache_impl_test.go +++ /dev/null @@ -1,462 +0,0 @@ -package redis_test - -import ( - "testing" - - "github.com/coocood/freecache" - "github.com/mediocregopher/radix/v3" - - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - "github.com/envoyproxy/ratelimit/src/config" - "github.com/envoyproxy/ratelimit/src/limiter" - "github.com/envoyproxy/ratelimit/src/redis" - "github.com/envoyproxy/ratelimit/src/utils" - stats "github.com/lyft/gostats" - - "math/rand" - - "github.com/envoyproxy/ratelimit/test/common" - mock_redis "github.com/envoyproxy/ratelimit/test/mocks/redis" - mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" -) - -func TestRedis(t *testing.T) { - t.Run("WithoutPerSecondRedis", testRedis(false)) - t.Run("WithPerSecondRedis", testRedis(true)) -} - -func pipeAppend(pipeline redis.Pipeline, rcv interface{}, cmd, key string, args ...interface{}) redis.Pipeline { - return append(pipeline, radix.FlatCmd(rcv, cmd, key, args...)) -} - -func testRedis(usePerSecondRedis bool) func(*testing.T) { - return func(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - client := mock_redis.NewMockClient(controller) - perSecondClient := mock_redis.NewMockClient(controller) - timeSource := mock_utils.NewMockTimeSource(controller) - var cache limiter.RateLimitCache - if usePerSecondRedis { - cache = redis.NewFixedRateLimitCacheImpl(client, perSecondClient, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "") - } else { - cache = redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "") - } - statsStore := stats.NewStore(stats.NewNullSink(), false) - - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - var clientUsed *mock_redis.MockClient - if usePerSecondRedis { - clientUsed = perSecondClient - } else { - clientUsed = client - } - - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key_value_1234", uint32(1)).SetArg(1, uint32(5)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_1234", int64(1)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - clientUsed = client - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key2_value2_subkey2_subvalue2_1200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key2_value2_subkey2_subvalue2_1200", int64(60)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request = common.NewRateLimitRequest( - "domain", - [][][2]string{ - {{"key2", "value2"}}, - {{"key2", "value2"}, {"subkey2", "subvalue2"}}, - }, 1) - limits = []*config.RateLimit{ - nil, - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2_subkey2_subvalue2", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) - assert.Equal(uint64(1), limits[1].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[1].Stats.NearLimit.Value()) - assert.Equal(uint64(0), limits[1].Stats.WithinLimit.Value()) - - clientUsed = client - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(5) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key3_value3_997200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key3_value3_997200", int64(3600)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key3_value3_subkey3_subvalue3_950400", uint32(1)).SetArg(1, uint32(13)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key3_value3_subkey3_subvalue3_950400", int64(86400)).DoAndReturn(pipeAppend) - clientUsed.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request = common.NewRateLimitRequest( - "domain", - [][][2]string{ - {{"key3", "value3"}}, - {{"key3", "value3"}, {"subkey3", "subvalue3"}}, - }, 1) - limits = []*config.RateLimit{ - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key3_value3", statsStore), - config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, "key3_value3_subkey3_subvalue3", statsStore)} - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}, - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) - } -} - -func testLocalCacheStats(localCacheStats stats.StatGenerator, statsStore stats.Store, sink *common.TestStatSink, - expectedHitCount int, expectedMissCount int, expectedLookUpCount int, expectedExpiredCount int, - expectedEntryCount int) func(*testing.T) { - return func(t *testing.T) { - localCacheStats.GenerateStats() - statsStore.Flush() - - // Check whether all local_cache related stats are available. - _, ok := sink.Record["averageAccessTime"] - assert.Equal(t, true, ok) - hitCount, ok := sink.Record["hitCount"] - assert.Equal(t, true, ok) - missCount, ok := sink.Record["missCount"] - assert.Equal(t, true, ok) - lookupCount, ok := sink.Record["lookupCount"] - assert.Equal(t, true, ok) - _, ok = sink.Record["overwriteCount"] - assert.Equal(t, true, ok) - _, ok = sink.Record["evacuateCount"] - assert.Equal(t, true, ok) - expiredCount, ok := sink.Record["expiredCount"] - assert.Equal(t, true, ok) - entryCount, ok := sink.Record["entryCount"] - assert.Equal(t, true, ok) - - // Check the correctness of hitCount, missCount, lookupCount, expiredCount and entryCount - assert.Equal(t, expectedHitCount, hitCount.(int)) - assert.Equal(t, expectedMissCount, missCount.(int)) - assert.Equal(t, expectedLookUpCount, lookupCount.(int)) - assert.Equal(t, expectedExpiredCount, expiredCount.(int)) - assert.Equal(t, expectedEntryCount, entryCount.(int)) - - sink.Clear() - } -} - -func TestOverLimitWithLocalCache(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - client := mock_redis.NewMockClient(controller) - timeSource := mock_utils.NewMockTimeSource(controller) - localCache := freecache.NewCache(100) - cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, localCache, 0.8, "") - sink := &common.TestStatSink{} - statsStore := stats.NewStore(sink, true) - localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) - - // Test Near Limit Stats. Under Near Limit Ratio - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) - - limits := []*config.RateLimit{ - config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, "key4_value4", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - // Check the local cache stats. - testLocalCacheStats(localCacheStats, statsStore, sink, 0, 1, 1, 0, 0) - - // Test Near Limit Stats. At Near Limit Ratio, still OK - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(13)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) - - // Check the local cache stats. - testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 2, 0, 0) - - // Test Over limit stats - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(16)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) - - // Check the local cache stats. - testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 3, 0, 1) - - // Test Over limit stats with local cache - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).Times(0) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).Times(0) - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(4), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.OverLimitWithLocalCache.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) - - // Check the local cache stats. - testLocalCacheStats(localCacheStats, statsStore, sink, 1, 3, 4, 0, 1) -} - -func TestNearLimit(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - client := mock_redis.NewMockClient(controller) - timeSource := mock_utils.NewMockTimeSource(controller) - cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "") - statsStore := stats.NewStore(stats.NewNullSink(), false) - - // Test Near Limit Stats. Under Near Limit Ratio - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(11)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) - - limits := []*config.RateLimit{ - config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, "key4_value4", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) - - // Test Near Limit Stats. At Near Limit Ratio, still OK - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(13)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) - - // Test Near Limit Stats. We went OVER_LIMIT, but the near_limit counter only increases - // when we are near limit, not after we have passed the limit. - timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key4_value4_997200", uint32(1)).SetArg(1, uint32(16)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), - "EXPIRE", "domain_key4_value4_997200", int64(3600)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{ - {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) - - // Now test hitsAddend that is greater than 1 - // All of it under limit, under near limit - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key5_value5_1234", uint32(3)).SetArg(1, uint32(5)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key5_value5_1234", int64(1)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key5", "value5"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key5_value5", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 15, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(3), limits[0].Stats.WithinLimit.Value()) - - // All of it under limit, some over near limit - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key6_value6_1234", uint32(2)).SetArg(1, uint32(7)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key6_value6_1234", int64(1)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key6", "value6"}}}, 2) - limits = []*config.RateLimit{config.NewRateLimit(8, pb.RateLimitResponse_RateLimit_SECOND, "key6_value6", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) - - // All of it under limit, all of it over near limit - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key7_value7_1234", uint32(3)).SetArg(1, uint32(19)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key7_value7_1234", int64(1)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key7", "value7"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key7_value7", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 1, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(3), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(3), limits[0].Stats.WithinLimit.Value()) - - // Some of it over limit, all of it over near limit - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key8_value8_1234", uint32(3)).SetArg(1, uint32(22)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key8_value8_1234", int64(1)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key8", "value8"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key8_value8", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) - - // Some of it in all three places - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key9_value9_1234", uint32(7)).SetArg(1, uint32(22)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key9_value9_1234", int64(1)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key9", "value9"}}}, 7) - limits = []*config.RateLimit{config.NewRateLimit(20, pb.RateLimitResponse_RateLimit_SECOND, "key9_value9", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(7), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(4), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) - - // all of it over limit - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key10_value10_1234", uint32(3)).SetArg(1, uint32(30)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key10_value10_1234", int64(1)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request = common.NewRateLimitRequest("domain", [][][2]string{{{"key10", "value10"}}}, 3) - limits = []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key10_value10", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(3), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) -} - -func TestRedisWithJitter(t *testing.T) { - assert := assert.New(t) - controller := gomock.NewController(t) - defer controller.Finish() - - client := mock_redis.NewMockClient(controller) - timeSource := mock_utils.NewMockTimeSource(controller) - jitterSource := mock_utils.NewMockJitterRandSource(controller) - cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(jitterSource), 3600, nil, 0.8, "") - statsStore := stats.NewStore(stats.NewNullSink(), false) - - timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - jitterSource.EXPECT().Int63().Return(int64(100)) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "INCRBY", "domain_key_value_1234", uint32(1)).SetArg(1, uint32(5)).DoAndReturn(pipeAppend) - client.EXPECT().PipeAppend(gomock.Any(), gomock.Any(), "EXPIRE", "domain_key_value_1234", int64(101)).DoAndReturn(pipeAppend) - client.EXPECT().PipeDo(gomock.Any()).Return(nil) - - request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) - limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} - - assert.Equal( - []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, - cache.DoLimit(nil, request, limits)) - assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) - assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) - assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) - assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) -} From db46531717f09cde11485b5a5954343538e924c0 Mon Sep 17 00:00:00 2001 From: zufardhiyaulhaq Date: Mon, 31 May 2021 15:37:21 +0200 Subject: [PATCH 2/8] add storage factory test Signed-off-by: zufardhiyaulhaq --- go.mod | 2 + go.sum | 4 + src/memcached/cache_impl.go | 2 +- src/memcached/fixed_cache_impl.go | 2 +- src/redis/cache_impl.go | 3 +- .../storage/service/memcached_client_mock.go | 78 +++++++++++++++++ .../storage/service/redis_client_mock.go | 48 +++++++++++ .../storage/factory/memcached_factory_test.go | 21 +++++ test/storage/factory/redis_factory_test.go | 86 +++++++++++++++++++ 9 files changed, 242 insertions(+), 4 deletions(-) create mode 100644 test/mocks/storage/service/memcached_client_mock.go create mode 100644 test/mocks/storage/service/redis_client_mock.go create mode 100644 test/storage/factory/memcached_factory_test.go create mode 100644 test/storage/factory/redis_factory_test.go diff --git a/go.mod b/go.mod index b36e9e9f9..371885d1f 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,14 @@ module github.com/envoyproxy/ratelimit go 1.14 require ( + github.com/alicebob/miniredis v2.5.0+incompatible github.com/alicebob/miniredis/v2 v2.11.4 github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b github.com/cespare/xxhash v1.1.0 // indirect github.com/coocood/freecache v1.1.0 github.com/envoyproxy/go-control-plane v0.9.7 github.com/fsnotify/fsnotify v1.4.7 // indirect + github.com/go-redis/redis v6.15.9+incompatible github.com/golang/mock v1.4.1 github.com/golang/protobuf v1.4.2 github.com/gorilla/mux v1.7.4-0.20191121170500-49c01487a141 diff --git a/go.sum b/go.sum index 7c2ffd741..4f3f4a0d7 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 h1:45bxf7AZMwWcqkLzDAQugVEwedisr5nRJ1r+7LYnv0U= github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis v2.5.0+incompatible h1:yBHoLpsyjupjz3NL3MhKMVkR41j82Yjf3KFv7ApYzUI= +github.com/alicebob/miniredis v2.5.0+incompatible/go.mod h1:8HZjEj4yU0dwhYHky+DxYx+6BMjkBbe5ONFIF1MXffk= github.com/alicebob/miniredis/v2 v2.11.4 h1:GsuyeunTx7EllZBU3/6Ji3dhMQZDpC9rLf1luJ+6M5M= github.com/alicebob/miniredis/v2 v2.11.4/go.mod h1:VL3UDEfAH59bSa7MuHMuFToxkqyHh69s/WUbYlOAuyg= github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b h1:L/QXpzIa3pOvUGt1D1lA5KjYhPBAN/3iWdP7xeFS9F0= @@ -30,6 +32,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0 h1:EQciDnbrYxy13PgWoY8AqoxGiPrp github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= +github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= diff --git a/src/memcached/cache_impl.go b/src/memcached/cache_impl.go index 90f349a94..2302bf7e0 100644 --- a/src/memcached/cache_impl.go +++ b/src/memcached/cache_impl.go @@ -14,7 +14,7 @@ import ( func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand, localCache *freecache.Cache, scope stats.Scope) limiter.RateLimitCache { - return NewRateLimitCacheImpl( + return NewFixedRateLimitCacheImpl( storage_factory.NewMemcached(s.MemcacheHostPort), timeSource, jitterRand, diff --git a/src/memcached/fixed_cache_impl.go b/src/memcached/fixed_cache_impl.go index 15f7a8240..29b5303cc 100644 --- a/src/memcached/fixed_cache_impl.go +++ b/src/memcached/fixed_cache_impl.go @@ -156,7 +156,7 @@ func (this *rateLimitMemcacheImpl) Flush() { this.waitGroup.Wait() } -func NewRateLimitCacheImpl(client storage_strategy.StorageStrategy, timeSource utils.TimeSource, jitterRand *rand.Rand, +func NewFixedRateLimitCacheImpl(client storage_strategy.StorageStrategy, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, scope stats.Scope, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { return &rateLimitMemcacheImpl{ client: client, diff --git a/src/redis/cache_impl.go b/src/redis/cache_impl.go index 004846677..eb2f430d0 100644 --- a/src/redis/cache_impl.go +++ b/src/redis/cache_impl.go @@ -19,8 +19,7 @@ func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freeca perSecondPool = storage_factory.NewRedis(s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondType, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize, s.RedisPerSecondPipelineWindow, s.RedisPerSecondPipelineLimit) } - var otherPool storage_strategy.StorageStrategy - otherPool = storage_factory.NewRedis(s.RedisTls, s.RedisAuth, s.RedisType, s.RedisUrl, s.RedisPoolSize, + otherPool := storage_factory.NewRedis(s.RedisTls, s.RedisAuth, s.RedisType, s.RedisUrl, s.RedisPoolSize, s.RedisPipelineWindow, s.RedisPipelineLimit) return NewFixedRateLimitCacheImpl( diff --git a/test/mocks/storage/service/memcached_client_mock.go b/test/mocks/storage/service/memcached_client_mock.go new file mode 100644 index 000000000..7ffd6253a --- /dev/null +++ b/test/mocks/storage/service/memcached_client_mock.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./src/storage/service/memcached_client.go + +// Package mock_service is a generated GoMock package. +package mock_service + +import ( + memcache "github.com/bradfitz/gomemcache/memcache" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockMemcachedClientInterface is a mock of MemcachedClientInterface interface +type MockMemcachedClientInterface struct { + ctrl *gomock.Controller + recorder *MockMemcachedClientInterfaceMockRecorder +} + +// MockMemcachedClientInterfaceMockRecorder is the mock recorder for MockMemcachedClientInterface +type MockMemcachedClientInterfaceMockRecorder struct { + mock *MockMemcachedClientInterface +} + +// NewMockMemcachedClientInterface creates a new mock instance +func NewMockMemcachedClientInterface(ctrl *gomock.Controller) *MockMemcachedClientInterface { + mock := &MockMemcachedClientInterface{ctrl: ctrl} + mock.recorder = &MockMemcachedClientInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockMemcachedClientInterface) EXPECT() *MockMemcachedClientInterfaceMockRecorder { + return m.recorder +} + +// Get mocks base method +func (m *MockMemcachedClientInterface) Get(key string) (*memcache.Item, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", key) + ret0, _ := ret[0].(*memcache.Item) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get +func (mr *MockMemcachedClientInterfaceMockRecorder) Get(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockMemcachedClientInterface)(nil).Get), key) +} + +// Set mocks base method +func (m *MockMemcachedClientInterface) Set(item *memcache.Item) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", item) + ret0, _ := ret[0].(error) + return ret0 +} + +// Set indicates an expected call of Set +func (mr *MockMemcachedClientInterfaceMockRecorder) Set(item interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockMemcachedClientInterface)(nil).Set), item) +} + +// Increment mocks base method +func (m *MockMemcachedClientInterface) Increment(key string, delta uint64) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Increment", key, delta) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Increment indicates an expected call of Increment +func (mr *MockMemcachedClientInterfaceMockRecorder) Increment(key, delta interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Increment", reflect.TypeOf((*MockMemcachedClientInterface)(nil).Increment), key, delta) +} diff --git a/test/mocks/storage/service/redis_client_mock.go b/test/mocks/storage/service/redis_client_mock.go new file mode 100644 index 000000000..e07cba904 --- /dev/null +++ b/test/mocks/storage/service/redis_client_mock.go @@ -0,0 +1,48 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./src/storage/service/redis_client.go + +// Package mock_service is a generated GoMock package. +package mock_service + +import ( + gomock "github.com/golang/mock/gomock" + radix "github.com/mediocregopher/radix/v3" + reflect "reflect" +) + +// MockRedisClientInterface is a mock of RedisClientInterface interface +type MockRedisClientInterface struct { + ctrl *gomock.Controller + recorder *MockRedisClientInterfaceMockRecorder +} + +// MockRedisClientInterfaceMockRecorder is the mock recorder for MockRedisClientInterface +type MockRedisClientInterfaceMockRecorder struct { + mock *MockRedisClientInterface +} + +// NewMockRedisClientInterface creates a new mock instance +func NewMockRedisClientInterface(ctrl *gomock.Controller) *MockRedisClientInterface { + mock := &MockRedisClientInterface{ctrl: ctrl} + mock.recorder = &MockRedisClientInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockRedisClientInterface) EXPECT() *MockRedisClientInterfaceMockRecorder { + return m.recorder +} + +// Do mocks base method +func (m *MockRedisClientInterface) Do(arg0 radix.Action) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Do", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Do indicates an expected call of Do +func (mr *MockRedisClientInterfaceMockRecorder) Do(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockRedisClientInterface)(nil).Do), arg0) +} diff --git a/test/storage/factory/memcached_factory_test.go b/test/storage/factory/memcached_factory_test.go new file mode 100644 index 000000000..1d4a7b1c7 --- /dev/null +++ b/test/storage/factory/memcached_factory_test.go @@ -0,0 +1,21 @@ +package factory_test + +import ( + "testing" + + "github.com/envoyproxy/ratelimit/src/storage/factory" + "github.com/envoyproxy/ratelimit/src/storage/strategy" + "github.com/stretchr/testify/assert" +) + +func TestNewMemcachedClient(t *testing.T) { + mkMemcachedClient := func(addr []string) strategy.StorageStrategy { + return factory.NewMemcached(addr) + } + + t.Run("empty server", func(t *testing.T) { + storage := mkMemcachedClient([]string{}) + _, err := storage.GetValue("test") + assert.Error(t, err) + }) +} diff --git a/test/storage/factory/redis_factory_test.go b/test/storage/factory/redis_factory_test.go new file mode 100644 index 000000000..a3e08931e --- /dev/null +++ b/test/storage/factory/redis_factory_test.go @@ -0,0 +1,86 @@ +package factory_test + +import ( + "testing" + "time" + + "github.com/alicebob/miniredis" + "github.com/envoyproxy/ratelimit/src/storage/factory" + "github.com/envoyproxy/ratelimit/src/storage/strategy" + + "github.com/stretchr/testify/assert" +) + +func mustNewRedisServer() *miniredis.Miniredis { + srv, err := miniredis.Run() + if err != nil { + panic(err) + } + + return srv +} + +func expectPanicError(t *testing.T, f assert.PanicTestFunc) (result error) { + t.Helper() + defer func() { + panicResult := recover() + assert.NotNil(t, panicResult, "Expected a panic") + result = panicResult.(error) + }() + f() + return +} + +func TestNewRedisClient(t *testing.T) { + t.Run("ImplicitPipeliningEnabled", testNewRedisClient(t, 2*time.Millisecond, 2)) + t.Run("ImplicitPipeliningDisabled", testNewRedisClient(t, 0, 0)) +} + +func testNewRedisClient(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) func(t *testing.T) { + return func(t *testing.T) { + redisAuth := "123" + mkRedisClient := func(auth, addr string) strategy.StorageStrategy { + return factory.NewRedis(false, auth, "single", addr, 1, pipelineWindow, pipelineLimit) + } + + t.Run("connection refused", func(t *testing.T) { + // It's possible there is a redis server listening on 6379 in ci environment, so + // use a random port. + panicErr := expectPanicError(t, func() { mkRedisClient("", "localhost:12345") }) + assert.Contains(t, panicErr.Error(), "connection refused") + }) + + t.Run("ok", func(t *testing.T) { + redisSrv := mustNewRedisServer() + defer redisSrv.Close() + + var client strategy.StorageStrategy + assert.NotPanics(t, func() { + client = mkRedisClient("", redisSrv.Addr()) + }) + assert.NotNil(t, client) + }) + + t.Run("auth fail", func(t *testing.T) { + redisSrv := mustNewRedisServer() + defer redisSrv.Close() + + redisSrv.RequireAuth(redisAuth) + + assert.PanicsWithError(t, "NOAUTH Authentication required.", func() { + mkRedisClient("", redisSrv.Addr()) + }) + }) + + t.Run("auth pass", func(t *testing.T) { + redisSrv := mustNewRedisServer() + defer redisSrv.Close() + + redisSrv.RequireAuth(redisAuth) + + assert.NotPanics(t, func() { + mkRedisClient(redisAuth, redisSrv.Addr()) + }) + }) + } +} From c50480acbdc90d7f445594ee741f9423bd69ff43 Mon Sep 17 00:00:00 2001 From: zufardhiyaulhaq Date: Mon, 31 May 2021 17:22:33 +0200 Subject: [PATCH 3/8] add storage strategy unit test Signed-off-by: zufardhiyaulhaq --- .../strategy/memcached_strategy_test.go | 65 +++++++++++++++++ test/storage/strategy/redis_strategy_test.go | 71 +++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 test/storage/strategy/memcached_strategy_test.go create mode 100644 test/storage/strategy/redis_strategy_test.go diff --git a/test/storage/strategy/memcached_strategy_test.go b/test/storage/strategy/memcached_strategy_test.go new file mode 100644 index 000000000..80e292ba9 --- /dev/null +++ b/test/storage/strategy/memcached_strategy_test.go @@ -0,0 +1,65 @@ +package strategy_test + +import ( + "strconv" + "testing" + + "github.com/bradfitz/gomemcache/memcache" + "github.com/envoyproxy/ratelimit/src/storage/strategy" + mock_service "github.com/envoyproxy/ratelimit/test/mocks/storage/service" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestMemcachedStrategyGetValue(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + mockMemcachedClient := mock_service.NewMockMemcachedClientInterface(controller) + memcachedStrategy := strategy.MemcachedStrategy{ + Client: mockMemcachedClient, + } + + mockMemcachedClient.EXPECT().Get("key").Return(&memcache.Item{Key: "key", Value: []byte("5")}, nil) + value, err := memcachedStrategy.GetValue("key") + + assert.Equal(value, uint64(5)) + assert.Nil(err) +} + +func TestMemcachedStrategySetValue(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + mockMemcachedClient := mock_service.NewMockMemcachedClientInterface(controller) + memcachedStrategy := strategy.MemcachedStrategy{ + Client: mockMemcachedClient, + } + + mockMemcachedClient.EXPECT().Set(&memcache.Item{ + Key: "key", + Value: []byte(strconv.FormatUint(uint64(5), 10)), + Expiration: int32(5), + }).Return(nil) + + err := memcachedStrategy.SetValue("key", uint64(5), uint64(5)) + assert.Nil(err) +} + +func TestMemcachedStrategyIncrementValue(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + mockMemcachedClient := mock_service.NewMockMemcachedClientInterface(controller) + memcachedStrategy := strategy.MemcachedStrategy{ + Client: mockMemcachedClient, + } + + mockMemcachedClient.EXPECT().Increment("key", uint64(1)).Return(uint64(1), nil) + + err := memcachedStrategy.IncrementValue("key", uint64(1)) + assert.Nil(err) +} diff --git a/test/storage/strategy/redis_strategy_test.go b/test/storage/strategy/redis_strategy_test.go new file mode 100644 index 000000000..989f7ffea --- /dev/null +++ b/test/storage/strategy/redis_strategy_test.go @@ -0,0 +1,71 @@ +package strategy_test + +import ( + "testing" + + "github.com/alicebob/miniredis" + "github.com/envoyproxy/ratelimit/src/storage/strategy" + mock_service "github.com/envoyproxy/ratelimit/test/mocks/storage/service" + "github.com/golang/mock/gomock" + "github.com/mediocregopher/radix/v3" + "github.com/stretchr/testify/assert" +) + +func mustNewRedisServer() *miniredis.Miniredis { + srv, err := miniredis.Run() + if err != nil { + panic(err) + } + + return srv +} +func TestRedisStrategyGetValue(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + mockRedisClient := mock_service.NewMockRedisClientInterface(controller) + redisStrategy := strategy.RedisStrategy{ + Client: mockRedisClient, + } + + var value uint64 + mockRedisClient.EXPECT().Do(radix.Cmd(&value, "GET", "key")).Return(nil) + value, err := redisStrategy.GetValue("key") + + assert.Equal(value, uint64(0)) + assert.Nil(err) +} + +func TestRedisStrategySetValue(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + mockRedisClient := mock_service.NewMockRedisClientInterface(controller) + redisStrategy := strategy.RedisStrategy{ + Client: mockRedisClient, + } + + mockRedisClient.EXPECT().Do(radix.FlatCmd(nil, "SET", "key", uint64(5))).Return(nil) + mockRedisClient.EXPECT().Do(radix.FlatCmd(nil, "EXPIRE", "key", uint64(5))).Return(nil) + + err := redisStrategy.SetValue("key", uint64(5), uint64(5)) + assert.Nil(err) +} + +func TestRedisStrategyIncrementValue(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + mockRedisClient := mock_service.NewMockRedisClientInterface(controller) + redisStrategy := strategy.RedisStrategy{ + Client: mockRedisClient, + } + + mockRedisClient.EXPECT().Do(radix.FlatCmd(nil, "INCRBY", "key", uint64(1))).Return(nil) + + err := redisStrategy.IncrementValue("key", uint64(1)) + assert.Nil(err) +} From 9ae450ccf565fea1097b497558bd9ff0c39efd5b Mon Sep 17 00:00:00 2001 From: zufardhiyaulhaq Date: Tue, 1 Jun 2021 10:17:43 +0200 Subject: [PATCH 4/8] implement storage stats Signed-off-by: zufardhiyaulhaq --- src/memcached/cache_impl.go | 10 +++--- src/memcached/fixed_cache_impl.go | 3 +- src/redis/cache_impl.go | 8 ++--- src/redis/fixed_cache_impl.go | 2 +- src/service_cmd/runner/runner.go | 14 ++++---- src/storage/factory/memcached_factory.go | 9 +++-- src/storage/factory/redis_factory.go | 11 ++++--- src/storage/service/memcached_client.go | 33 +++++++++++++++++-- src/storage/service/memcached_stats.go | 29 ++++++++++++++++ src/storage/service/redis_client.go | 1 + src/storage/service/redis_stats.go | 33 +++++++++++++++++++ .../storage/factory/memcached_factory_test.go | 4 ++- test/storage/factory/redis_factory_test.go | 5 ++- 13 files changed, 130 insertions(+), 32 deletions(-) create mode 100644 src/storage/service/memcached_stats.go create mode 100644 src/storage/service/redis_stats.go diff --git a/src/memcached/cache_impl.go b/src/memcached/cache_impl.go index 2302bf7e0..6aa6af99d 100644 --- a/src/memcached/cache_impl.go +++ b/src/memcached/cache_impl.go @@ -5,22 +5,20 @@ import ( "github.com/coocood/freecache" "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/server" "github.com/envoyproxy/ratelimit/src/settings" "github.com/envoyproxy/ratelimit/src/utils" - stats "github.com/lyft/gostats" storage_factory "github.com/envoyproxy/ratelimit/src/storage/factory" ) -func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand, - localCache *freecache.Cache, scope stats.Scope) limiter.RateLimitCache { +func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand) limiter.RateLimitCache { return NewFixedRateLimitCacheImpl( - storage_factory.NewMemcached(s.MemcacheHostPort), + storage_factory.NewMemcached(srv.Scope().Scope("memcache"), s.MemcacheHostPort), timeSource, jitterRand, - s.ExpirationJitterMaxSeconds, localCache, - scope, + s.ExpirationJitterMaxSeconds, s.NearLimitRatio, s.CacheKeyPrefix, ) diff --git a/src/memcached/fixed_cache_impl.go b/src/memcached/fixed_cache_impl.go index 29b5303cc..6818fa4ca 100644 --- a/src/memcached/fixed_cache_impl.go +++ b/src/memcached/fixed_cache_impl.go @@ -21,7 +21,6 @@ import ( "sync" "github.com/coocood/freecache" - stats "github.com/lyft/gostats" "github.com/bradfitz/gomemcache/memcache" @@ -157,7 +156,7 @@ func (this *rateLimitMemcacheImpl) Flush() { } func NewFixedRateLimitCacheImpl(client storage_strategy.StorageStrategy, timeSource utils.TimeSource, jitterRand *rand.Rand, - expirationJitterMaxSeconds int64, localCache *freecache.Cache, scope stats.Scope, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { + localCache *freecache.Cache, expirationJitterMaxSeconds int64, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { return &rateLimitMemcacheImpl{ client: client, timeSource: timeSource, diff --git a/src/redis/cache_impl.go b/src/redis/cache_impl.go index eb2f430d0..9a141b092 100644 --- a/src/redis/cache_impl.go +++ b/src/redis/cache_impl.go @@ -13,13 +13,13 @@ import ( storage_strategy "github.com/envoyproxy/ratelimit/src/storage/strategy" ) -func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64) limiter.RateLimitCache { +func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand) limiter.RateLimitCache { var perSecondPool storage_strategy.StorageStrategy if s.RedisPerSecond { - perSecondPool = storage_factory.NewRedis(s.RedisPerSecondTls, s.RedisPerSecondAuth, + perSecondPool = storage_factory.NewRedis(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondType, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize, s.RedisPerSecondPipelineWindow, s.RedisPerSecondPipelineLimit) } - otherPool := storage_factory.NewRedis(s.RedisTls, s.RedisAuth, s.RedisType, s.RedisUrl, s.RedisPoolSize, + otherPool := storage_factory.NewRedis(srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisType, s.RedisUrl, s.RedisPoolSize, s.RedisPipelineWindow, s.RedisPipelineLimit) return NewFixedRateLimitCacheImpl( @@ -27,8 +27,8 @@ func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freeca perSecondPool, timeSource, jitterRand, - expirationJitterMaxSeconds, localCache, + s.ExpirationJitterMaxSeconds, s.NearLimitRatio, s.CacheKeyPrefix, ) diff --git a/src/redis/fixed_cache_impl.go b/src/redis/fixed_cache_impl.go index 16e5e85d5..cd9cf0d6d 100644 --- a/src/redis/fixed_cache_impl.go +++ b/src/redis/fixed_cache_impl.go @@ -123,7 +123,7 @@ func (this *fixedRateLimitCacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, func (this *fixedRateLimitCacheImpl) Flush() {} func NewFixedRateLimitCacheImpl(client storage_strategy.StorageStrategy, perSecondClient storage_strategy.StorageStrategy, timeSource utils.TimeSource, - jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { + jitterRand *rand.Rand, localCache *freecache.Cache, expirationJitterMaxSeconds int64, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { return &fixedRateLimitCacheImpl{ client: client, perSecondClient: perSecondClient, diff --git a/src/service_cmd/runner/runner.go b/src/service_cmd/runner/runner.go index 589f1c5f0..9b2c658b7 100644 --- a/src/service_cmd/runner/runner.go +++ b/src/service_cmd/runner/runner.go @@ -1,7 +1,6 @@ package runner import ( - "github.com/envoyproxy/ratelimit/src/metrics" "io" "math/rand" "net/http" @@ -9,6 +8,8 @@ import ( "sync" "time" + "github.com/envoyproxy/ratelimit/src/metrics" + stats "github.com/lyft/gostats" "github.com/coocood/freecache" @@ -53,15 +54,14 @@ func createLimiter(srv server.Server, s settings.Settings, localCache *freecache localCache, srv, utils.NewTimeSourceImpl(), - rand.New(utils.NewLockedSource(time.Now().Unix())), - s.ExpirationJitterMaxSeconds) + rand.New(utils.NewLockedSource(time.Now().Unix()))) case "memcache": - return memcached.NewRateLimitCacheImplFromSettings( + return memcached.NewRateLimiterCacheImplFromSettings( s, - utils.NewTimeSourceImpl(), - rand.New(utils.NewLockedSource(time.Now().Unix())), localCache, - srv.Scope()) + srv, + utils.NewTimeSourceImpl(), + rand.New(utils.NewLockedSource(time.Now().Unix()))) default: logger.Fatalf("Invalid setting for BackendType: %s", s.BackendType) panic("This line should not be reachable") diff --git a/src/storage/factory/memcached_factory.go b/src/storage/factory/memcached_factory.go index f57248acc..9a91c29bb 100644 --- a/src/storage/factory/memcached_factory.go +++ b/src/storage/factory/memcached_factory.go @@ -4,18 +4,21 @@ import ( "github.com/bradfitz/gomemcache/memcache" "github.com/envoyproxy/ratelimit/src/storage/service" "github.com/envoyproxy/ratelimit/src/storage/strategy" + stats "github.com/lyft/gostats" ) -func NewMemcached(servers []string) strategy.StorageStrategy { - client := newMemcachedClient(servers) +func NewMemcached(scope stats.Scope, servers []string) strategy.StorageStrategy { + client := newMemcachedClient(scope, servers) return strategy.MemcachedStrategy{ Client: client, } } -func newMemcachedClient(servers []string) service.MemcachedClientInterface { +func newMemcachedClient(scope stats.Scope, servers []string) service.MemcachedClientInterface { client := memcache.New(servers...) + stats := service.NewMemcachedStats(scope) return &service.MemcachedClient{ Client: client, + Stats: stats, } } diff --git a/src/storage/factory/redis_factory.go b/src/storage/factory/redis_factory.go index 5bda8eac9..2379c27e9 100644 --- a/src/storage/factory/redis_factory.go +++ b/src/storage/factory/redis_factory.go @@ -6,6 +6,7 @@ import ( "strings" "time" + stats "github.com/lyft/gostats" logger "github.com/sirupsen/logrus" "github.com/envoyproxy/ratelimit/src/storage/service" @@ -14,15 +15,15 @@ import ( "github.com/mediocregopher/radix/v3" ) -func NewRedis(useTls bool, auth string, redisType string, url string, poolSize int, +func NewRedis(scope stats.Scope, useTls bool, auth string, redisType string, url string, poolSize int, pipelineWindow time.Duration, pipelineLimit int) strategy.StorageStrategy { - client := newRedisClient(useTls, auth, redisType, url, poolSize, pipelineWindow, pipelineLimit) + client := newRedisClient(scope, useTls, auth, redisType, url, poolSize, pipelineWindow, pipelineLimit) return strategy.RedisStrategy{ Client: client, } } -func newRedisClient(useTls bool, auth string, redisType string, url string, poolSize int, pipelineWindow time.Duration, pipelineLimit int) service.RedisClientInterface { +func newRedisClient(scope stats.Scope, useTls bool, auth string, redisType string, url string, poolSize int, pipelineWindow time.Duration, pipelineLimit int) service.RedisClientInterface { logger.Warnf("connecting to redis on %s with pool size %d", url, poolSize) df := func(network, addr string) (radix.Conn, error) { @@ -45,7 +46,8 @@ func newRedisClient(useTls bool, auth string, redisType string, url string, pool return radix.Dial(network, addr, dialOpts...) } - opts := []radix.PoolOpt{radix.PoolConnFunc(df)} + stats := service.NewRedisStats(scope) + opts := []radix.PoolOpt{radix.PoolConnFunc(df), radix.PoolWithTrace(service.PoolTrace(&stats))} implicitPipelining := true if pipelineWindow == 0 && pipelineLimit == 0 { @@ -92,6 +94,7 @@ func newRedisClient(useTls bool, auth string, redisType string, url string, pool return &service.RedisClient{ Client: client, + Stats: stats, ImplicitPipelining: implicitPipelining, } } diff --git a/src/storage/service/memcached_client.go b/src/storage/service/memcached_client.go index b2cc07962..d8740ca6b 100644 --- a/src/storage/service/memcached_client.go +++ b/src/storage/service/memcached_client.go @@ -12,16 +12,43 @@ type MemcachedClientInterface interface { type MemcachedClient struct { Client *memcache.Client + Stats MemcachedStats } func (m MemcachedClient) Get(key string) (*memcache.Item, error) { - return m.Client.Get(key) + m.Stats.keysRequested.Inc() + items, err := m.Client.Get(key) + if err != nil { + m.Stats.GetError.Inc() + } else { + m.Stats.keysFound.Inc() + m.Stats.GetSuccess.Inc() + } + + return items, err } func (m MemcachedClient) Set(item *memcache.Item) error { - return m.Client.Set(item) + err := m.Client.Set(item) + if err != nil { + m.Stats.SetError.Inc() + } else { + m.Stats.SetSuccess.Inc() + } + + return err } func (m MemcachedClient) Increment(key string, delta uint64) (uint64, error) { - return m.Client.Increment(key, delta) + newValue, err := m.Client.Increment(key, delta) + switch err { + case memcache.ErrCacheMiss: + m.Stats.IncrementMiss.Inc() + case nil: + m.Stats.IncrementSuccess.Inc() + default: + m.Stats.IncrementError.Inc() + } + + return newValue, err } diff --git a/src/storage/service/memcached_stats.go b/src/storage/service/memcached_stats.go new file mode 100644 index 000000000..ceb2917c0 --- /dev/null +++ b/src/storage/service/memcached_stats.go @@ -0,0 +1,29 @@ +package service + +import stats "github.com/lyft/gostats" + +type MemcachedStats struct { + GetSuccess stats.Counter + GetError stats.Counter + SetSuccess stats.Counter + SetError stats.Counter + IncrementSuccess stats.Counter + IncrementMiss stats.Counter + IncrementError stats.Counter + keysRequested stats.Counter + keysFound stats.Counter +} + +func NewMemcachedStats(scope stats.Scope) MemcachedStats { + return MemcachedStats{ + GetSuccess: scope.NewCounterWithTags("get", map[string]string{"code": "success"}), + GetError: scope.NewCounterWithTags("get", map[string]string{"code": "error"}), + SetSuccess: scope.NewCounterWithTags("set", map[string]string{"code": "success"}), + SetError: scope.NewCounterWithTags("set", map[string]string{"code": "error"}), + IncrementSuccess: scope.NewCounterWithTags("increment", map[string]string{"code": "success"}), + IncrementMiss: scope.NewCounterWithTags("increment", map[string]string{"code": "miss"}), + IncrementError: scope.NewCounterWithTags("increment", map[string]string{"code": "error"}), + keysRequested: scope.NewCounter("keys_requested"), + keysFound: scope.NewCounter("keys_found"), + } +} diff --git a/src/storage/service/redis_client.go b/src/storage/service/redis_client.go index 56ffbde01..6eb248f49 100644 --- a/src/storage/service/redis_client.go +++ b/src/storage/service/redis_client.go @@ -10,6 +10,7 @@ type RedisClientInterface interface { type RedisClient struct { Client radix.Client + Stats RedisStats ImplicitPipelining bool } diff --git a/src/storage/service/redis_stats.go b/src/storage/service/redis_stats.go new file mode 100644 index 000000000..8b5448cd5 --- /dev/null +++ b/src/storage/service/redis_stats.go @@ -0,0 +1,33 @@ +package service + +import ( + stats "github.com/lyft/gostats" + "github.com/mediocregopher/radix/v3/trace" +) + +type RedisStats struct { + connectionActive stats.Gauge + connectionTotal stats.Counter + connectionClose stats.Counter +} + +func PoolTrace(ps *RedisStats) trace.PoolTrace { + return trace.PoolTrace{ + ConnCreated: func(_ trace.PoolConnCreated) { + ps.connectionTotal.Add(1) + ps.connectionActive.Add(1) + }, + ConnClosed: func(_ trace.PoolConnClosed) { + ps.connectionActive.Sub(1) + ps.connectionClose.Add(1) + }, + } +} + +func NewRedisStats(scope stats.Scope) RedisStats { + ret := RedisStats{} + ret.connectionActive = scope.NewGauge("cx_active") + ret.connectionTotal = scope.NewCounter("cx_total") + ret.connectionClose = scope.NewCounter("cx_local_close") + return ret +} diff --git a/test/storage/factory/memcached_factory_test.go b/test/storage/factory/memcached_factory_test.go index 1d4a7b1c7..f2908cfc4 100644 --- a/test/storage/factory/memcached_factory_test.go +++ b/test/storage/factory/memcached_factory_test.go @@ -5,12 +5,14 @@ import ( "github.com/envoyproxy/ratelimit/src/storage/factory" "github.com/envoyproxy/ratelimit/src/storage/strategy" + stats "github.com/lyft/gostats" "github.com/stretchr/testify/assert" ) func TestNewMemcachedClient(t *testing.T) { + statsStore := stats.NewStore(stats.NewNullSink(), false) mkMemcachedClient := func(addr []string) strategy.StorageStrategy { - return factory.NewMemcached(addr) + return factory.NewMemcached(statsStore, addr) } t.Run("empty server", func(t *testing.T) { diff --git a/test/storage/factory/redis_factory_test.go b/test/storage/factory/redis_factory_test.go index a3e08931e..bfab7ba40 100644 --- a/test/storage/factory/redis_factory_test.go +++ b/test/storage/factory/redis_factory_test.go @@ -7,6 +7,7 @@ import ( "github.com/alicebob/miniredis" "github.com/envoyproxy/ratelimit/src/storage/factory" "github.com/envoyproxy/ratelimit/src/storage/strategy" + stats "github.com/lyft/gostats" "github.com/stretchr/testify/assert" ) @@ -39,8 +40,10 @@ func TestNewRedisClient(t *testing.T) { func testNewRedisClient(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) func(t *testing.T) { return func(t *testing.T) { redisAuth := "123" + statsStore := stats.NewStore(stats.NewNullSink(), false) + mkRedisClient := func(auth, addr string) strategy.StorageStrategy { - return factory.NewRedis(false, auth, "single", addr, 1, pipelineWindow, pipelineLimit) + return factory.NewRedis(statsStore, false, auth, "single", addr, 1, pipelineWindow, pipelineLimit) } t.Run("connection refused", func(t *testing.T) { From 04e81924878b699ea3131abff66b3e4fee05d681 Mon Sep 17 00:00:00 2001 From: zufardhiyaulhaq Date: Tue, 1 Jun 2021 15:36:34 +0200 Subject: [PATCH 5/8] add fixed redis cache implementation test Signed-off-by: zufardhiyaulhaq --- src/memcached/fixed_cache_impl.go | 31 +- src/redis/fixed_cache_impl.go | 68 +++-- src/storage/service/memcached_client.go | 1 + src/storage/service/redis_client.go | 1 + src/storage/strategy/memcached_strategy.go | 4 + src/storage/strategy/redis_strategy.go | 10 + src/storage/strategy/storage_strategy.go | 3 + .../storage/strategy/storage_strategy_mock.go | 90 ++++++ test/redis/fixed_cache_impl_test.go | 282 ++++++++++++++++++ .../storage/factory/memcached_factory_test.go | 3 +- test/storage/factory/redis_factory_test.go | 4 +- .../strategy/memcached_strategy_test.go | 3 +- test/storage/strategy/redis_strategy_test.go | 3 +- 13 files changed, 452 insertions(+), 51 deletions(-) create mode 100644 test/mocks/storage/strategy/storage_strategy_mock.go create mode 100644 test/redis/fixed_cache_impl_test.go diff --git a/src/memcached/fixed_cache_impl.go b/src/memcached/fixed_cache_impl.go index 6818fa4ca..7729861e1 100644 --- a/src/memcached/fixed_cache_impl.go +++ b/src/memcached/fixed_cache_impl.go @@ -20,19 +20,15 @@ import ( "math/rand" "sync" - "github.com/coocood/freecache" - "github.com/bradfitz/gomemcache/memcache" - - logger "github.com/sirupsen/logrus" - - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - + "github.com/coocood/freecache" "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/utils" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" storage_strategy "github.com/envoyproxy/ratelimit/src/storage/strategy" + logger "github.com/sirupsen/logrus" ) type rateLimitMemcacheImpl struct { @@ -81,11 +77,6 @@ func (this *rateLimitMemcacheImpl) DoLimit( logger.Debugf("looking up cache key: %s", cacheKey.Key) - expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) - if this.baseRateLimiter.ExpirationJitterMaxSeconds > 0 { - expirationSeconds += this.baseRateLimiter.JitterRand.Int63n(this.baseRateLimiter.ExpirationJitterMaxSeconds) - } - // Use the perSecondConn if it is not nil and the cacheKey represents a per second Limit. value, err := this.client.GetValue(cacheKey.Key) if err != nil { @@ -123,18 +114,20 @@ func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, i continue } + expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) + if this.expirationJitterMaxSeconds > 0 { + expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) + } + err := this.client.IncrementValue(cacheKey.Key, hitsAddend) + // if key is not found if err == memcache.ErrCacheMiss { - expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) - if this.expirationJitterMaxSeconds > 0 { - expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) - } - // Need to add instead of increment. + // create key err = this.client.SetValue(cacheKey.Key, hitsAddend, uint64(expirationSeconds)) if err == memcache.ErrNotStored { - // There was a race condition to do this add. We should be able to increment - // now instead. + + // increment the key err := this.client.IncrementValue(cacheKey.Key, hitsAddend) if err != nil { logger.Errorf("Failed to increment key %s after failing to add: %s", cacheKey.Key, err) diff --git a/src/redis/fixed_cache_impl.go b/src/redis/fixed_cache_impl.go index cd9cf0d6d..048ac890a 100644 --- a/src/redis/fixed_cache_impl.go +++ b/src/redis/fixed_cache_impl.go @@ -5,13 +5,14 @@ import ( "sync" "github.com/coocood/freecache" - pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" - storage_strategy "github.com/envoyproxy/ratelimit/src/storage/strategy" "github.com/envoyproxy/ratelimit/src/utils" - logger "github.com/sirupsen/logrus" "golang.org/x/net/context" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + storage_strategy "github.com/envoyproxy/ratelimit/src/storage/strategy" + logger "github.com/sirupsen/logrus" ) type RedisError string @@ -26,9 +27,11 @@ type fixedRateLimitCacheImpl struct { // If this client is nil, then the Cache will use the client for all // limits regardless of unit. If this client is not nil, then it // is used for limits that have a SECOND unit. - perSecondClient storage_strategy.StorageStrategy - baseRateLimiter *limiter.BaseRateLimiter - waitGroup sync.WaitGroup + perSecondClient storage_strategy.StorageStrategy + jitterRand *rand.Rand + expirationJitterMaxSeconds int64 + baseRateLimiter *limiter.BaseRateLimiter + waitGroup sync.WaitGroup } func (this *fixedRateLimitCacheImpl) DoLimit( @@ -62,23 +65,20 @@ func (this *fixedRateLimitCacheImpl) DoLimit( logger.Debugf("looking up cache key: %s", cacheKey.Key) - expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) - if this.baseRateLimiter.ExpirationJitterMaxSeconds > 0 { - expirationSeconds += this.baseRateLimiter.JitterRand.Int63n(this.baseRateLimiter.ExpirationJitterMaxSeconds) - } - // Use the perSecondConn if it is not nil and the cacheKey represents a per second Limit. if this.perSecondClient != nil && cacheKey.PerSecond { value, err := this.perSecondClient.GetValue(cacheKey.Key) if err != nil { logger.Error(err) } + results[i] = value } else { value, err := this.client.GetValue(cacheKey.Key) if err != nil { logger.Error(err) } + results[i] = value } } @@ -95,28 +95,40 @@ func (this *fixedRateLimitCacheImpl) DoLimit( responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key, limitInfo, isOverLimitWithLocalCache[i], hitsAddend) - } - this.waitGroup.Add(1) - go this.increaseAsync(cacheKeys, isOverLimitWithLocalCache, limits, uint64(hitsAddend)) - - return responseDescriptorStatuses -} - -func (this *fixedRateLimitCacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, isOverLimitWithLocalCache []bool, - limits []*config.RateLimit, hitsAddend uint64) { - defer this.waitGroup.Done() - for i, cacheKey := range cacheKeys { if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { continue } + expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) + if this.expirationJitterMaxSeconds > 0 { + expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) + } + if this.perSecondClient != nil && cacheKey.PerSecond { - this.perSecondClient.IncrementValue(cacheKey.Key, hitsAddend) + err := this.perSecondClient.IncrementValue(cacheKey.Key, uint64(hitsAddend)) + if err != nil { + logger.Error(err) + } + + err = this.perSecondClient.SetExpire(cacheKey.Key, uint64(expirationSeconds)) + if err != nil { + logger.Error(err) + } } else { - this.client.IncrementValue(cacheKey.Key, hitsAddend) + err := this.client.IncrementValue(cacheKey.Key, uint64(hitsAddend)) + if err != nil { + logger.Error(err) + } + + err = this.client.SetExpire(cacheKey.Key, uint64(expirationSeconds)) + if err != nil { + logger.Error(err) + } } } + + return responseDescriptorStatuses } // Flush() is a no-op with redis since quota reads and updates happen synchronously. @@ -125,8 +137,10 @@ func (this *fixedRateLimitCacheImpl) Flush() {} func NewFixedRateLimitCacheImpl(client storage_strategy.StorageStrategy, perSecondClient storage_strategy.StorageStrategy, timeSource utils.TimeSource, jitterRand *rand.Rand, localCache *freecache.Cache, expirationJitterMaxSeconds int64, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { return &fixedRateLimitCacheImpl{ - client: client, - perSecondClient: perSecondClient, - baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix), + client: client, + perSecondClient: perSecondClient, + jitterRand: jitterRand, + expirationJitterMaxSeconds: expirationJitterMaxSeconds, + baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix), } } diff --git a/src/storage/service/memcached_client.go b/src/storage/service/memcached_client.go index d8740ca6b..0346edeb3 100644 --- a/src/storage/service/memcached_client.go +++ b/src/storage/service/memcached_client.go @@ -4,6 +4,7 @@ import ( "github.com/bradfitz/gomemcache/memcache" ) +// Client interface for memcached type MemcachedClientInterface interface { Get(key string) (*memcache.Item, error) Set(item *memcache.Item) error diff --git a/src/storage/service/redis_client.go b/src/storage/service/redis_client.go index 6eb248f49..b54ab2d6b 100644 --- a/src/storage/service/redis_client.go +++ b/src/storage/service/redis_client.go @@ -4,6 +4,7 @@ import ( "github.com/mediocregopher/radix/v3" ) +// Client interface for Redis type RedisClientInterface interface { Do(radix.Action) error } diff --git a/src/storage/strategy/memcached_strategy.go b/src/storage/strategy/memcached_strategy.go index e67cc4f6c..deb4cfe7d 100644 --- a/src/storage/strategy/memcached_strategy.go +++ b/src/storage/strategy/memcached_strategy.go @@ -48,3 +48,7 @@ func (m MemcachedStrategy) IncrementValue(key string, delta uint64) error { return nil } + +func (m MemcachedStrategy) SetExpire(key string, expirationSeconds uint64) error { + return nil +} diff --git a/src/storage/strategy/redis_strategy.go b/src/storage/strategy/redis_strategy.go index 4200bff14..3c0f6f0f0 100644 --- a/src/storage/strategy/redis_strategy.go +++ b/src/storage/strategy/redis_strategy.go @@ -39,5 +39,15 @@ func (r RedisStrategy) IncrementValue(key string, delta uint64) error { if err != nil { return err } + + return nil +} + +func (r RedisStrategy) SetExpire(key string, expirationSeconds uint64) error { + err := r.Client.Do(radix.FlatCmd(nil, "EXPIRE", key, expirationSeconds)) + if err != nil { + return err + } + return nil } diff --git a/src/storage/strategy/storage_strategy.go b/src/storage/strategy/storage_strategy.go index 9e0306af2..b45669940 100644 --- a/src/storage/strategy/storage_strategy.go +++ b/src/storage/strategy/storage_strategy.go @@ -1,7 +1,10 @@ package strategy +// Interface to abstract underlying storage like memcached and redis +// Implement bussiness level where we don't care how underlying storage doing it.\ type StorageStrategy interface { GetValue(key string) (uint64, error) SetValue(key string, value uint64, expirationSeconds uint64) error IncrementValue(key string, delta uint64) error + SetExpire(key string, expirationSeconds uint64) error } diff --git a/test/mocks/storage/strategy/storage_strategy_mock.go b/test/mocks/storage/strategy/storage_strategy_mock.go new file mode 100644 index 000000000..d7418059c --- /dev/null +++ b/test/mocks/storage/strategy/storage_strategy_mock.go @@ -0,0 +1,90 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./src/storage/strategy/storage_strategy.go + +// Package mock_strategy is a generated GoMock package. +package mock_strategy + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockStorageStrategy is a mock of StorageStrategy interface +type MockStorageStrategy struct { + ctrl *gomock.Controller + recorder *MockStorageStrategyMockRecorder +} + +// MockStorageStrategyMockRecorder is the mock recorder for MockStorageStrategy +type MockStorageStrategyMockRecorder struct { + mock *MockStorageStrategy +} + +// NewMockStorageStrategy creates a new mock instance +func NewMockStorageStrategy(ctrl *gomock.Controller) *MockStorageStrategy { + mock := &MockStorageStrategy{ctrl: ctrl} + mock.recorder = &MockStorageStrategyMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStorageStrategy) EXPECT() *MockStorageStrategyMockRecorder { + return m.recorder +} + +// GetValue mocks base method +func (m *MockStorageStrategy) GetValue(key string) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValue", key) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetValue indicates an expected call of GetValue +func (mr *MockStorageStrategyMockRecorder) GetValue(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValue", reflect.TypeOf((*MockStorageStrategy)(nil).GetValue), key) +} + +// SetValue mocks base method +func (m *MockStorageStrategy) SetValue(key string, value, expirationSeconds uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetValue", key, value, expirationSeconds) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetValue indicates an expected call of SetValue +func (mr *MockStorageStrategyMockRecorder) SetValue(key, value, expirationSeconds interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetValue", reflect.TypeOf((*MockStorageStrategy)(nil).SetValue), key, value, expirationSeconds) +} + +// IncrementValue mocks base method +func (m *MockStorageStrategy) IncrementValue(key string, delta uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IncrementValue", key, delta) + ret0, _ := ret[0].(error) + return ret0 +} + +// IncrementValue indicates an expected call of IncrementValue +func (mr *MockStorageStrategyMockRecorder) IncrementValue(key, delta interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementValue", reflect.TypeOf((*MockStorageStrategy)(nil).IncrementValue), key, delta) +} + +// SetExpire mocks base method +func (m *MockStorageStrategy) SetExpire(key string, expirationSeconds uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetExpire", key, expirationSeconds) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetExpire indicates an expected call of SetExpire +func (mr *MockStorageStrategyMockRecorder) SetExpire(key, expirationSeconds interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetExpire", reflect.TypeOf((*MockStorageStrategy)(nil).SetExpire), key, expirationSeconds) +} diff --git a/test/redis/fixed_cache_impl_test.go b/test/redis/fixed_cache_impl_test.go new file mode 100644 index 000000000..070d66c93 --- /dev/null +++ b/test/redis/fixed_cache_impl_test.go @@ -0,0 +1,282 @@ +package redis_test + +import ( + "math/rand" + "testing" + + "github.com/coocood/freecache" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/redis" + "github.com/envoyproxy/ratelimit/src/utils" + "github.com/envoyproxy/ratelimit/test/common" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + mock_strategy "github.com/envoyproxy/ratelimit/test/mocks/storage/strategy" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + stats "github.com/lyft/gostats" +) + +func TestRedis(t *testing.T) { + t.Run("WithoutPerSecondRedis", testRedis(false)) + t.Run("WithPerSecondRedis", testRedis(true)) +} + +func testRedis(usePerSecondRedis bool) func(*testing.T) { + return func(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + client := mock_strategy.NewMockStorageStrategy(controller) + perSecondClient := mock_strategy.NewMockStorageStrategy(controller) + timeSource := mock_utils.NewMockTimeSource(controller) + + var cache limiter.RateLimitCache + if usePerSecondRedis { + cache = redis.NewFixedRateLimitCacheImpl(client, perSecondClient, timeSource, rand.New(rand.NewSource(1)), nil, 0, 0.8, "") + } else { + cache = redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), nil, 0, 0.8, "") + } + statsStore := stats.NewStore(stats.NewNullSink(), false) + + var clientUsed *mock_strategy.MockStorageStrategy + if usePerSecondRedis { + clientUsed = perSecondClient + } else { + clientUsed = client + } + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + clientUsed.EXPECT().GetValue("domain_key_value_1234").Return(uint64(4), nil).MaxTimes(1) + clientUsed.EXPECT().IncrementValue("domain_key_value_1234", uint64(1)).MaxTimes(1) + clientUsed.EXPECT().SetExpire("domain_key_value_1234", uint64(1)).MaxTimes(1) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) + + clientUsed = client + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + clientUsed.EXPECT().GetValue("domain_key2_value2_subkey2_subvalue2_1200").Return(uint64(10), nil).MaxTimes(1) + clientUsed.EXPECT().IncrementValue("domain_key2_value2_subkey2_subvalue2_1200", uint64(1)).MaxTimes(1) + clientUsed.EXPECT().SetExpire("domain_key2_value2_subkey2_subvalue2_1200", uint64(60)).MaxTimes(1) + + request = common.NewRateLimitRequest( + "domain", + [][][2]string{ + {{"key2", "value2"}}, + {{"key2", "value2"}, {"subkey2", "subvalue2"}}, + }, 1) + limits = []*config.RateLimit{ + nil, + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2_subkey2_subvalue2", statsStore)} + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[1].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[1].Stats.NearLimit.Value()) + assert.Equal(uint64(0), limits[1].Stats.WithinLimit.Value()) + + clientUsed = client + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(5) + + clientUsed.EXPECT().GetValue("domain_key3_value3_997200").Return(uint64(10), nil).MaxTimes(1) + clientUsed.EXPECT().GetValue("domain_key3_value3_subkey3_subvalue3_950400").Return(uint64(12), nil).MaxTimes(1) + clientUsed.EXPECT().IncrementValue("domain_key3_value3_997200", uint64(1)).MaxTimes(1) + clientUsed.EXPECT().IncrementValue("domain_key3_value3_subkey3_subvalue3_950400", uint64(1)).MaxTimes(1) + clientUsed.EXPECT().SetExpire("domain_key3_value3_997200", uint64(3600)).MaxTimes(1) + clientUsed.EXPECT().SetExpire("domain_key3_value3_subkey3_subvalue3_950400", uint64(86400)).MaxTimes(1) + + request = common.NewRateLimitRequest( + "domain", + [][][2]string{ + {{"key3", "value3"}}, + {{"key3", "value3"}, {"subkey3", "subvalue3"}}, + }, 1) + limits = []*config.RateLimit{ + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key3_value3", statsStore), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, "key3_value3_subkey3_subvalue3", statsStore)} + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) + } +} + +func testLocalCacheStats(localCacheStats stats.StatGenerator, statsStore stats.Store, sink *common.TestStatSink, + expectedHitCount int, expectedMissCount int, expectedLookUpCount int, expectedExpiredCount int, + expectedEntryCount int) func(*testing.T) { + return func(t *testing.T) { + localCacheStats.GenerateStats() + statsStore.Flush() + + // Check whether all local_cache related stats are available. + _, ok := sink.Record["averageAccessTime"] + assert.Equal(t, true, ok) + hitCount, ok := sink.Record["hitCount"] + assert.Equal(t, true, ok) + missCount, ok := sink.Record["missCount"] + assert.Equal(t, true, ok) + lookupCount, ok := sink.Record["lookupCount"] + assert.Equal(t, true, ok) + _, ok = sink.Record["overwriteCount"] + assert.Equal(t, true, ok) + _, ok = sink.Record["evacuateCount"] + assert.Equal(t, true, ok) + expiredCount, ok := sink.Record["expiredCount"] + assert.Equal(t, true, ok) + entryCount, ok := sink.Record["entryCount"] + assert.Equal(t, true, ok) + + // Check the correctness of hitCount, missCount, lookupCount, expiredCount and entryCount + assert.Equal(t, expectedHitCount, hitCount.(int)) + assert.Equal(t, expectedMissCount, missCount.(int)) + assert.Equal(t, expectedLookUpCount, lookupCount.(int)) + assert.Equal(t, expectedExpiredCount, expiredCount.(int)) + assert.Equal(t, expectedEntryCount, entryCount.(int)) + + sink.Clear() + } +} + +func TestOverLimitWithLocalCache(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + client := mock_strategy.NewMockStorageStrategy(controller) + timeSource := mock_utils.NewMockTimeSource(controller) + localCache := freecache.NewCache(100) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), localCache, 0, 0.8, "") + sink := &common.TestStatSink{} + statsStore := stats.NewStore(sink, true) + localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) + + // Test Near Limit Stats. Under Near Limit Ratio + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(10), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) + limits := []*config.RateLimit{ + config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, "key4_value4", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 1, 1, 0, 0) + + // Test Near Limit Stats. At Near Limit Ratio, still OK + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(12), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 2, 0, 0) + + // Test Over limit stats + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(15), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 3, 0, 1) + + // Test Over limit stats with local cache + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(4), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 1, 3, 4, 0, 1) +} + +func TestRedisWithJitter(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + client := mock_strategy.NewMockStorageStrategy(controller) + timeSource := mock_utils.NewMockTimeSource(controller) + jitterSource := mock_utils.NewMockJitterRandSource(controller) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(jitterSource), nil, 3600, 0.8, "") + statsStore := stats.NewStore(stats.NewNullSink(), false) + + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + jitterSource.EXPECT().Int63().Return(int64(100)) + client.EXPECT().GetValue("domain_key_value_1234").Return(uint64(4), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key_value_1234", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key_value_1234", uint64(101)).MaxTimes(1) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) +} diff --git a/test/storage/factory/memcached_factory_test.go b/test/storage/factory/memcached_factory_test.go index f2908cfc4..0609a5640 100644 --- a/test/storage/factory/memcached_factory_test.go +++ b/test/storage/factory/memcached_factory_test.go @@ -5,8 +5,9 @@ import ( "github.com/envoyproxy/ratelimit/src/storage/factory" "github.com/envoyproxy/ratelimit/src/storage/strategy" - stats "github.com/lyft/gostats" "github.com/stretchr/testify/assert" + + stats "github.com/lyft/gostats" ) func TestNewMemcachedClient(t *testing.T) { diff --git a/test/storage/factory/redis_factory_test.go b/test/storage/factory/redis_factory_test.go index bfab7ba40..7073ab65b 100644 --- a/test/storage/factory/redis_factory_test.go +++ b/test/storage/factory/redis_factory_test.go @@ -7,9 +7,9 @@ import ( "github.com/alicebob/miniredis" "github.com/envoyproxy/ratelimit/src/storage/factory" "github.com/envoyproxy/ratelimit/src/storage/strategy" - stats "github.com/lyft/gostats" - "github.com/stretchr/testify/assert" + + stats "github.com/lyft/gostats" ) func mustNewRedisServer() *miniredis.Miniredis { diff --git a/test/storage/strategy/memcached_strategy_test.go b/test/storage/strategy/memcached_strategy_test.go index 80e292ba9..563f96031 100644 --- a/test/storage/strategy/memcached_strategy_test.go +++ b/test/storage/strategy/memcached_strategy_test.go @@ -6,9 +6,10 @@ import ( "github.com/bradfitz/gomemcache/memcache" "github.com/envoyproxy/ratelimit/src/storage/strategy" - mock_service "github.com/envoyproxy/ratelimit/test/mocks/storage/service" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + + mock_service "github.com/envoyproxy/ratelimit/test/mocks/storage/service" ) func TestMemcachedStrategyGetValue(t *testing.T) { diff --git a/test/storage/strategy/redis_strategy_test.go b/test/storage/strategy/redis_strategy_test.go index 989f7ffea..d1ba3d124 100644 --- a/test/storage/strategy/redis_strategy_test.go +++ b/test/storage/strategy/redis_strategy_test.go @@ -5,10 +5,11 @@ import ( "github.com/alicebob/miniredis" "github.com/envoyproxy/ratelimit/src/storage/strategy" - mock_service "github.com/envoyproxy/ratelimit/test/mocks/storage/service" "github.com/golang/mock/gomock" "github.com/mediocregopher/radix/v3" "github.com/stretchr/testify/assert" + + mock_service "github.com/envoyproxy/ratelimit/test/mocks/storage/service" ) func mustNewRedisServer() *miniredis.Miniredis { From 9f01384697f10c44f06607ffcfb45f45a542a2d6 Mon Sep 17 00:00:00 2001 From: zufardhiyaulhaq Date: Sat, 5 Jun 2021 05:12:59 +0200 Subject: [PATCH 6/8] add fixed memcached cache implementation test Signed-off-by: zufardhiyaulhaq --- test/memcached/fixed_cache_impl_test.go | 259 ++++++++++++++++++++++++ 1 file changed, 259 insertions(+) create mode 100644 test/memcached/fixed_cache_impl_test.go diff --git a/test/memcached/fixed_cache_impl_test.go b/test/memcached/fixed_cache_impl_test.go new file mode 100644 index 000000000..db50a7042 --- /dev/null +++ b/test/memcached/fixed_cache_impl_test.go @@ -0,0 +1,259 @@ +package memcached_test + +import ( + "math/rand" + "testing" + + "github.com/coocood/freecache" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/src/memcached" + "github.com/envoyproxy/ratelimit/src/utils" + "github.com/envoyproxy/ratelimit/test/common" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + mock_strategy "github.com/envoyproxy/ratelimit/test/mocks/storage/strategy" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + stats "github.com/lyft/gostats" +) + +func TestMemcached(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + client := mock_strategy.NewMockStorageStrategy(controller) + timeSource := mock_utils.NewMockTimeSource(controller) + + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, rand.New(rand.NewSource(1)), nil, 0, 0.8, "") + statsStore := stats.NewStore(stats.NewNullSink(), false) + + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetValue("domain_key_value_1234").Return(uint64(4), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key_value_1234", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key_value_1234", uint64(1)).MaxTimes(1) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) + + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetValue("domain_key2_value2_subkey2_subvalue2_1200").Return(uint64(10), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key2_value2_subkey2_subvalue2_1200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key2_value2_subkey2_subvalue2_1200", uint64(60)).MaxTimes(1) + + request = common.NewRateLimitRequest( + "domain", + [][][2]string{ + {{"key2", "value2"}}, + {{"key2", "value2"}, {"subkey2", "subvalue2"}}, + }, 1) + limits = []*config.RateLimit{ + nil, + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_MINUTE, "key2_value2_subkey2_subvalue2", statsStore)} + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[1].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[1].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[1].Stats.NearLimit.Value()) + assert.Equal(uint64(0), limits[1].Stats.WithinLimit.Value()) + + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(5) + client.EXPECT().GetValue("domain_key3_value3_997200").Return(uint64(10), nil).MaxTimes(1) + client.EXPECT().GetValue("domain_key3_value3_subkey3_subvalue3_950400").Return(uint64(12), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key3_value3_997200", uint64(1)).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key3_value3_subkey3_subvalue3_950400", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key3_value3_997200", uint64(3600)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key3_value3_subkey3_subvalue3_950400", uint64(86400)).MaxTimes(1) + + request = common.NewRateLimitRequest( + "domain", + [][][2]string{ + {{"key3", "value3"}}, + {{"key3", "value3"}, {"subkey3", "subvalue3"}}, + }, 1) + limits = []*config.RateLimit{ + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_HOUR, "key3_value3", statsStore), + config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_DAY, "key3_value3_subkey3_subvalue3", statsStore)} + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}, + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[1].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.WithinLimit.Value()) +} + +func testLocalCacheStats(localCacheStats stats.StatGenerator, statsStore stats.Store, sink *common.TestStatSink, + expectedHitCount int, expectedMissCount int, expectedLookUpCount int, expectedExpiredCount int, + expectedEntryCount int) func(*testing.T) { + return func(t *testing.T) { + localCacheStats.GenerateStats() + statsStore.Flush() + + // Check whether all local_cache related stats are available. + _, ok := sink.Record["averageAccessTime"] + assert.Equal(t, true, ok) + hitCount, ok := sink.Record["hitCount"] + assert.Equal(t, true, ok) + missCount, ok := sink.Record["missCount"] + assert.Equal(t, true, ok) + lookupCount, ok := sink.Record["lookupCount"] + assert.Equal(t, true, ok) + _, ok = sink.Record["overwriteCount"] + assert.Equal(t, true, ok) + _, ok = sink.Record["evacuateCount"] + assert.Equal(t, true, ok) + expiredCount, ok := sink.Record["expiredCount"] + assert.Equal(t, true, ok) + entryCount, ok := sink.Record["entryCount"] + assert.Equal(t, true, ok) + + // Check the correctness of hitCount, missCount, lookupCount, expiredCount and entryCount + assert.Equal(t, expectedHitCount, hitCount.(int)) + assert.Equal(t, expectedMissCount, missCount.(int)) + assert.Equal(t, expectedLookUpCount, lookupCount.(int)) + assert.Equal(t, expectedExpiredCount, expiredCount.(int)) + assert.Equal(t, expectedEntryCount, entryCount.(int)) + + sink.Clear() + } +} + +func TestOverLimitWithLocalCache(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + client := mock_strategy.NewMockStorageStrategy(controller) + timeSource := mock_utils.NewMockTimeSource(controller) + localCache := freecache.NewCache(100) + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, rand.New(rand.NewSource(1)), localCache, 0, 0.8, "") + sink := &common.TestStatSink{} + statsStore := stats.NewStore(sink, true) + localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) + + // Test Near Limit Stats. Under Near Limit Ratio + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(10), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key4", "value4"}}}, 1) + limits := []*config.RateLimit{ + config.NewRateLimit(15, pb.RateLimitResponse_RateLimit_HOUR, "key4_value4", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 4, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 1, 1, 0, 0) + + // Test Near Limit Stats. At Near Limit Ratio, still OK + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(12), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 2, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(2), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 2, 0, 0) + + // Test Over limit stats + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().GetValue("domain_key4_value4_997200").Return(uint64(15), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(3), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 0, 2, 3, 0, 1) + + // Test Over limit stats with local cache + timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) + client.EXPECT().IncrementValue("domain_key4_value4_997200", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key4_value4_997200", uint64(3600)).MaxTimes(1) + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{ + {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(4), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.OverLimitWithLocalCache.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(2), limits[0].Stats.WithinLimit.Value()) + + // Check the local cache stats. + testLocalCacheStats(localCacheStats, statsStore, sink, 1, 3, 4, 0, 1) +} + +func TestMemcachedWithJitter(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + client := mock_strategy.NewMockStorageStrategy(controller) + timeSource := mock_utils.NewMockTimeSource(controller) + jitterSource := mock_utils.NewMockJitterRandSource(controller) + cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, rand.New(jitterSource), nil, 3600, 0.8, "") + statsStore := stats.NewStore(stats.NewNullSink(), false) + + timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + client.EXPECT().GetValue("domain_key_value_1234").Return(uint64(4), nil).MaxTimes(1) + client.EXPECT().IncrementValue("domain_key_value_1234", uint64(1)).MaxTimes(1) + client.EXPECT().SetExpire("domain_key_value_1234", uint64(101)).MaxTimes(1) + + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + + assert.Equal( + []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: limits[0].Limit, LimitRemaining: 5, DurationUntilReset: utils.CalculateReset(limits[0].Limit, timeSource)}}, + cache.DoLimit(nil, request, limits)) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) + assert.Equal(uint64(0), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.WithinLimit.Value()) +} From 2f7bd253b6c55aa896e311beaad96286d74cb7c2 Mon Sep 17 00:00:00 2001 From: zufardhiyaulhaq Date: Sat, 5 Jun 2021 12:12:17 +0200 Subject: [PATCH 7/8] fix format and memcached test Signed-off-by: zufardhiyaulhaq --- test/memcached/fixed_cache_impl_test.go | 1 + test/redis/fixed_cache_impl_test.go | 2 +- test/storage/factory/memcached_factory_test.go | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/memcached/fixed_cache_impl_test.go b/test/memcached/fixed_cache_impl_test.go index c855a11bf..4c910a235 100644 --- a/test/memcached/fixed_cache_impl_test.go +++ b/test/memcached/fixed_cache_impl_test.go @@ -411,6 +411,7 @@ func TestMemcachedWithJitter(t *testing.T) { cache := memcached.NewFixedRateLimitCacheImpl(client, timeSource, rand.New(jitterSource), nil, 3600, 0.8, "", sm) timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) + jitterSource.EXPECT().Int63().Return(int64(100)).MaxTimes(1) client.EXPECT().GetValue("domain_key_value_1234").Return(uint64(4), nil).MaxTimes(1) client.EXPECT().IncrementValue("domain_key_value_1234", uint64(1)).MaxTimes(1) client.EXPECT().SetExpire("domain_key_value_1234", uint64(101)).MaxTimes(1) diff --git a/test/redis/fixed_cache_impl_test.go b/test/redis/fixed_cache_impl_test.go index 563add1c3..62d71210d 100644 --- a/test/redis/fixed_cache_impl_test.go +++ b/test/redis/fixed_cache_impl_test.go @@ -434,7 +434,7 @@ func TestRedisWithJitter(t *testing.T) { cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(jitterSource), nil, 3600, 0.8, "", sm) timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) - jitterSource.EXPECT().Int63().Return(int64(100)) + jitterSource.EXPECT().Int63().Return(int64(100)).MaxTimes(1) client.EXPECT().GetValue("domain_key_value_1234").Return(uint64(4), nil).MaxTimes(1) client.EXPECT().IncrementValue("domain_key_value_1234", uint64(1)).MaxTimes(1) client.EXPECT().SetExpire("domain_key_value_1234", uint64(101)).MaxTimes(1) diff --git a/test/storage/factory/memcached_factory_test.go b/test/storage/factory/memcached_factory_test.go index 4a36b8aae..4fb9b3846 100644 --- a/test/storage/factory/memcached_factory_test.go +++ b/test/storage/factory/memcached_factory_test.go @@ -30,7 +30,6 @@ func TestNewRateLimitCacheImplFromSettingsWhenSrvCannotBeResolved(t *testing.T) }) } - func TestNewRateLimitCacheImplFromSettingsWhenHostAndPortAndSrvAreBothSet(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) assert.Panics(t, func() { From bfa70cfeaef7e2087405edee48c73d7f1e71df27 Mon Sep 17 00:00:00 2001 From: zufardhiyaulhaq Date: Sat, 5 Jun 2021 12:57:15 +0200 Subject: [PATCH 8/8] remove unused code in fixed memcached cache implementation Signed-off-by: zufardhiyaulhaq --- src/memcached/fixed_cache_impl.go | 8 +------- src/storage/factory/memcached_factory.go | 5 +++-- src/storage/service/memcached_client.go | 7 ------- src/storage/utils/utils.go | 6 ++++++ 4 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/memcached/fixed_cache_impl.go b/src/memcached/fixed_cache_impl.go index ebb4f378a..ffaa7fc3e 100644 --- a/src/memcached/fixed_cache_impl.go +++ b/src/memcached/fixed_cache_impl.go @@ -35,13 +35,10 @@ import ( type rateLimitMemcacheImpl struct { client storage_strategy.StorageStrategy - timeSource utils.TimeSource jitterRand *rand.Rand expirationJitterMaxSeconds int64 - localCache *freecache.Cache - waitGroup sync.WaitGroup - nearLimitRatio float32 baseRateLimiter *limiter.BaseRateLimiter + waitGroup sync.WaitGroup } var AutoFlushForIntegrationTests bool = false @@ -188,11 +185,8 @@ func NewFixedRateLimitCacheImpl(client storage_strategy.StorageStrategy, timeSou localCache *freecache.Cache, expirationJitterMaxSeconds int64, nearLimitRatio float32, cacheKeyPrefix string, statsManager stats.Manager) limiter.RateLimitCache { return &rateLimitMemcacheImpl{ client: client, - timeSource: timeSource, jitterRand: jitterRand, expirationJitterMaxSeconds: expirationJitterMaxSeconds, - localCache: localCache, - nearLimitRatio: nearLimitRatio, baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix, statsManager), } } diff --git a/src/storage/factory/memcached_factory.go b/src/storage/factory/memcached_factory.go index 7d838ab3b..41ac4fb35 100644 --- a/src/storage/factory/memcached_factory.go +++ b/src/storage/factory/memcached_factory.go @@ -10,6 +10,7 @@ import ( "github.com/bradfitz/gomemcache/memcache" "github.com/envoyproxy/ratelimit/src/storage/service" "github.com/envoyproxy/ratelimit/src/storage/strategy" + "github.com/envoyproxy/ratelimit/src/storage/utils" stats "github.com/lyft/gostats" logger "github.com/sirupsen/logrus" @@ -19,7 +20,7 @@ func NewMemcached(scope stats.Scope, hosts []string, srv string, srvRefresh time var client service.MemcachedClientInterface if srv != "" && len(hosts) > 0 { - panic(service.MemcacheError("Both MEMCADHE_HOST_PORT and MEMCACHE_SRV are set")) + panic(utils.MemcacheError("Both MEMCADHE_HOST_PORT and MEMCACHE_SRV are set")) } if srv != "" { @@ -49,7 +50,7 @@ func newMemcachedClientFromSrv(scope stats.Scope, srv string, srvRefresh time.Du if err != nil { errorText := "Unable to fetch servers from SRV" logger.Errorf(errorText) - panic(service.MemcacheError(errorText)) + panic(utils.MemcacheError(errorText)) } if srvRefresh > 0 { diff --git a/src/storage/service/memcached_client.go b/src/storage/service/memcached_client.go index 9c79e6a51..0346edeb3 100644 --- a/src/storage/service/memcached_client.go +++ b/src/storage/service/memcached_client.go @@ -11,13 +11,6 @@ type MemcachedClientInterface interface { Increment(key string, delta uint64) (uint64, error) } -// Errors that may be raised during config parsing. -type MemcacheError string - -func (e MemcacheError) Error() string { - return string(e) -} - type MemcachedClient struct { Client *memcache.Client Stats MemcachedStats diff --git a/src/storage/utils/utils.go b/src/storage/utils/utils.go index c6e2c7a8c..205c9c38d 100644 --- a/src/storage/utils/utils.go +++ b/src/storage/utils/utils.go @@ -11,3 +11,9 @@ func CheckError(err error) { panic(RedisError(err.Error())) } } + +type MemcacheError string + +func (e MemcacheError) Error() string { + return string(e) +}